1#![cfg(feature = "std")]
19
20use crate::libp2p::read_write;
28
29use core::{
30 fmt, future, mem, ops,
31 pin::{self, Pin},
32 task::Poll,
33};
34use futures_util::{AsyncRead, AsyncWrite};
35use std::io;
36
37#[pin_project::pin_project]
40pub struct WithBuffers<TSocketFut, TSocket, TNow> {
41 #[pin]
43 socket: Socket<TSocketFut, TSocket>,
44 error: Option<io::Error>,
46 read_buffer: Vec<u8>,
51 read_buffer_valid: usize,
53 read_buffer_reasonable_capacity: usize,
54 read_closed: bool,
57 write_buffers: Vec<Vec<u8>>,
59 write_closed: bool,
61 close_pending: bool,
64 flush_pending: bool,
66
67 read_write_now: Option<TNow>,
70 read_write_wake_up_after: Option<TNow>,
73}
74
75#[pin_project::pin_project(project = SocketProj)]
76enum Socket<TSocketFut, TSocket> {
77 Pending(#[pin] TSocketFut),
78 Resolved(#[pin] TSocket),
79}
80
81impl<TSocketFut, TSocket, TNow> WithBuffers<TSocketFut, TSocket, TNow>
82where
83 TNow: Clone + Ord,
84{
85 pub fn new(socket: TSocketFut) -> Self {
87 let read_buffer_reasonable_capacity = 65536; WithBuffers {
90 socket: Socket::Pending(socket),
91 error: None,
92 read_buffer: Vec::with_capacity(read_buffer_reasonable_capacity),
93 read_buffer_valid: 0,
94 read_buffer_reasonable_capacity,
95 read_closed: false,
96 write_buffers: Vec::with_capacity(64),
97 write_closed: false,
98 close_pending: false,
99 flush_pending: false,
100 read_write_now: None,
101 read_write_wake_up_after: None,
102 }
103 }
104
105 pub fn read_write_access(
111 self: Pin<&mut Self>,
112 now: TNow,
113 ) -> Result<ReadWriteAccess<TNow>, &io::Error> {
114 let this = self.project();
115
116 debug_assert!(
117 this.read_write_now
118 .as_ref()
119 .map_or(true, |old_now| *old_now <= now)
120 );
121 *this.read_write_wake_up_after = None;
122 *this.read_write_now = Some(now.clone());
123
124 if let Some(error) = this.error.as_ref() {
125 return Err(error);
126 }
127
128 this.read_buffer.truncate(*this.read_buffer_valid);
129
130 let is_resolved = matches!(*this.socket, Socket::Resolved(_));
131
132 let write_bytes_queued = this.write_buffers.iter().map(Vec::len).sum();
133
134 Ok(ReadWriteAccess {
135 read_buffer_len_before: this.read_buffer.len(),
136 write_buffers_len_before: this.write_buffers.len(),
137 read_write: read_write::ReadWrite {
138 now,
139 incoming_buffer: mem::take(this.read_buffer),
140 expected_incoming_bytes: if !*this.read_closed { Some(0) } else { None },
141 read_bytes: 0,
142 write_bytes_queued,
143 write_buffers: mem::take(this.write_buffers),
144 write_bytes_queueable: if !is_resolved {
145 Some(0)
146 } else if !*this.write_closed {
147 Some((128 * 1024usize).saturating_sub(write_bytes_queued))
150 } else {
151 None
152 },
153 wake_up_after: this.read_write_wake_up_after.take(),
154 },
155 read_buffer: this.read_buffer,
156 read_buffer_valid: this.read_buffer_valid,
157 read_buffer_reasonable_capacity: *this.read_buffer_reasonable_capacity,
158 write_buffers: this.write_buffers,
159 write_closed: this.write_closed,
160 close_pending: this.close_pending,
161 read_write_wake_up_after: this.read_write_wake_up_after,
162 })
163 }
164}
165
166impl<TSocketFut, TSocket, TNow> WithBuffers<TSocketFut, TSocket, TNow>
167where
168 TSocket: AsyncRead + AsyncWrite,
169 TSocketFut: Future<Output = Result<TSocket, io::Error>>,
170 TNow: Clone + Ord,
171{
172 pub async fn wait_read_write_again<F>(
179 self: Pin<&mut Self>,
180 timer_builder: impl FnOnce(TNow) -> F,
181 ) where
182 F: Future<Output = ()>,
183 {
184 let mut this = self.project();
185
186 match (&*this.read_write_wake_up_after, &*this.read_write_now) {
188 (_, None) => return,
189 (Some(when_wake_up), Some(now)) if *when_wake_up <= *now => {
190 return;
191 }
192 _ => {}
193 }
194
195 let mut timer = pin::pin!({
196 let fut = this
197 .read_write_wake_up_after
198 .as_ref()
199 .map(|when| timer_builder(when.clone()));
200 async {
201 if let Some(fut) = fut {
202 fut.await;
203 } else {
204 future::pending::<()>().await;
205 }
206 }
207 });
208
209 this.read_buffer.resize(this.read_buffer.capacity(), 0);
211
212 future::poll_fn(move |cx| {
213 if this.error.is_some() {
214 return Poll::Pending;
216 }
217
218 let mut pending = true;
220
221 match Future::poll(Pin::new(&mut timer), cx) {
222 Poll::Pending => {}
223 Poll::Ready(()) => {
224 pending = false;
225 }
226 }
227
228 match this.socket.as_mut().project() {
229 SocketProj::Pending(future) => match Future::poll(future, cx) {
230 Poll::Pending => {}
231 Poll::Ready(Ok(socket)) => {
232 this.socket.set(Socket::Resolved(socket));
233 pending = false;
234 }
235 Poll::Ready(Err(err)) => {
236 *this.error = Some(err);
237 return Poll::Ready(());
238 }
239 },
240 SocketProj::Resolved(mut socket) => {
241 if !*this.read_closed && *this.read_buffer_valid < this.read_buffer.len() {
242 let read_result = AsyncRead::poll_read(
243 socket.as_mut(),
244 cx,
245 &mut this.read_buffer[*this.read_buffer_valid..],
246 );
247
248 match read_result {
249 Poll::Pending => {}
250 Poll::Ready(Ok(0)) => {
251 *this.read_closed = true;
252 pending = false;
253 }
254 Poll::Ready(Ok(n)) => {
255 *this.read_buffer_valid += n;
256 pending = false;
258 }
259 Poll::Ready(Err(err)) => {
260 *this.error = Some(err);
261 return Poll::Ready(());
262 }
263 };
264 }
265
266 loop {
267 if this.write_buffers.iter().any(|b| !b.is_empty()) {
268 let write_result = {
269 let buffers = this
270 .write_buffers
271 .iter()
272 .map(|buf| io::IoSlice::new(buf))
273 .collect::<Vec<_>>();
274 AsyncWrite::poll_write_vectored(socket.as_mut(), cx, &buffers)
275 };
276
277 match write_result {
278 Poll::Ready(Ok(0)) => {
279 unreachable!();
281 }
282 Poll::Ready(Ok(mut n)) => {
283 *this.flush_pending = true;
284 while n > 0 {
285 let first_buf = this.write_buffers.first_mut().unwrap();
286 if first_buf.len() <= n {
287 n -= first_buf.len();
288 this.write_buffers.remove(0);
289 } else {
290 first_buf.copy_within(n.., 0);
292 first_buf.truncate(first_buf.len() - n);
293 break;
294 }
295 }
296 if this.write_buffers.is_empty() {
298 pending = false;
299 }
300 }
301 Poll::Ready(Err(err)) => {
302 *this.error = Some(err);
303 return Poll::Ready(());
304 }
305 Poll::Pending => break,
306 };
307 } else if *this.flush_pending {
308 match AsyncWrite::poll_flush(socket.as_mut(), cx) {
309 Poll::Ready(Ok(())) => {
310 *this.flush_pending = false;
311 }
312 Poll::Ready(Err(err)) => {
313 *this.error = Some(err);
314 return Poll::Ready(());
315 }
316 Poll::Pending => break,
317 }
318 } else if *this.close_pending {
319 match AsyncWrite::poll_close(socket.as_mut(), cx) {
320 Poll::Ready(Ok(())) => {
321 *this.close_pending = false;
322 pending = false;
323 break;
324 }
325 Poll::Ready(Err(err)) => {
326 *this.error = Some(err);
327 return Poll::Ready(());
328 }
329 Poll::Pending => break,
330 }
331 } else {
332 break;
333 }
334 }
335 }
336 };
337
338 if !pending {
339 Poll::Ready(())
340 } else {
341 Poll::Pending
342 }
343 })
344 .await;
345 }
346}
347
348impl<TSocketFut, TSocket: fmt::Debug, TNow> fmt::Debug for WithBuffers<TSocketFut, TSocket, TNow> {
349 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
350 let mut t = f.debug_tuple("WithBuffers");
351 if let Socket::Resolved(socket) = &self.socket {
352 t.field(socket);
353 } else {
354 t.field(&"<pending>");
355 }
356 t.finish()
357 }
358}
359
360pub struct ReadWriteAccess<'a, TNow: Clone> {
362 read_write: read_write::ReadWrite<TNow>,
363
364 read_buffer_len_before: usize,
365 write_buffers_len_before: usize,
366
367 read_buffer: &'a mut Vec<u8>,
369 read_buffer_valid: &'a mut usize,
370 read_buffer_reasonable_capacity: usize,
371 write_buffers: &'a mut Vec<Vec<u8>>,
372 write_closed: &'a mut bool,
373 close_pending: &'a mut bool,
374 read_write_wake_up_after: &'a mut Option<TNow>,
375}
376
377impl<'a, TNow: Clone> ops::Deref for ReadWriteAccess<'a, TNow> {
378 type Target = read_write::ReadWrite<TNow>;
379
380 fn deref(&self) -> &Self::Target {
381 &self.read_write
382 }
383}
384
385impl<'a, TNow: Clone> ops::DerefMut for ReadWriteAccess<'a, TNow> {
386 fn deref_mut(&mut self) -> &mut Self::Target {
387 &mut self.read_write
388 }
389}
390
391impl<'a, TNow: Clone> Drop for ReadWriteAccess<'a, TNow> {
392 fn drop(&mut self) {
393 *self.read_buffer = mem::take(&mut self.read_write.incoming_buffer);
394 *self.read_buffer_valid = self.read_buffer.len();
395
396 if let Some(expected_incoming_bytes) = self.read_write.expected_incoming_bytes {
398 if expected_incoming_bytes < self.read_buffer_reasonable_capacity
399 && self.read_buffer.is_empty()
400 {
401 self.read_buffer.shrink_to(0);
404 self.read_buffer
405 .reserve(self.read_buffer_reasonable_capacity);
406 } else if expected_incoming_bytes > self.read_buffer.len() {
407 self.read_buffer
408 .reserve(expected_incoming_bytes - self.read_buffer.len());
409 }
410 debug_assert!(self.read_buffer.capacity() >= expected_incoming_bytes);
411 }
412
413 *self.write_buffers = mem::take(&mut self.read_write.write_buffers);
414
415 if self.read_write.write_bytes_queueable.is_none() && !*self.write_closed {
416 *self.write_closed = true;
417 *self.close_pending = true;
418 }
419
420 *self.read_write_wake_up_after = self.read_write.wake_up_after.take();
421
422 if (self.read_buffer_len_before != self.read_buffer.len()
425 && self
426 .read_write
427 .expected_incoming_bytes
428 .map_or(false, |b| b <= self.read_buffer.len()))
429 || (self.write_buffers_len_before != self.write_buffers.len() && !*self.write_closed)
430 {
431 *self.read_write_wake_up_after = Some(self.read_write.now.clone());
432 }
433 }
434}
435
436