smoldot/libp2p/
with_buffers.rs

1// Smoldot
2// Copyright (C) 2019-2022  Parity Technologies (UK) Ltd.
3// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
4
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13// GNU General Public License for more details.
14
15// You should have received a copy of the GNU General Public License
16// along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18#![cfg(feature = "std")]
19
20//! Augments an implementation of `AsyncRead` and `AsyncWrite` with a read buffer and a write
21//! buffer.
22//!
23//! While this module is generic, the targeted use-case is TCP connections.
24
25// TODO: usage and example
26
27use crate::libp2p::read_write;
28
29use core::{
30    fmt, future, mem, ops,
31    pin::{self, Pin},
32    task::Poll,
33};
34use futures_util::{AsyncRead, AsyncWrite};
35use std::io;
36
37/// Holds an implementation of `AsyncRead` and `AsyncWrite`, alongside with a read buffer and a
38/// write buffer.
39#[pin_project::pin_project]
40pub struct WithBuffers<TSocketFut, TSocket, TNow> {
41    /// Actual socket to read from/write to.
42    #[pin]
43    socket: Socket<TSocketFut, TSocket>,
44    /// Error that has happened on the socket, if any.
45    error: Option<io::Error>,
46    /// Storage for data read from the socket. The first [`WithBuffers::read_buffer_valid`] bytes
47    /// contain actual socket data, while the rest contains garbage data.
48    /// The capacity of this buffer is at least equal to the amount of bytes requested by the
49    /// inner data consumer.
50    read_buffer: Vec<u8>,
51    /// Number of bytes of data in [`WithBuffers::read_buffer`] that contain actual data.
52    read_buffer_valid: usize,
53    read_buffer_reasonable_capacity: usize,
54    /// True if reading from the socket has returned `Ok(0)` earlier, in other words "end of
55    /// file".
56    read_closed: bool,
57    /// Storage for data to write to the socket.
58    write_buffers: Vec<Vec<u8>>,
59    /// True if the consumer has closed the writing side earlier.
60    write_closed: bool,
61    /// True if the consumer has closed the writing side earlier, and the socket still has to
62    /// be closed.
63    close_pending: bool,
64    /// True if data has been written on the socket and the socket needs to be flushed.
65    flush_pending: bool,
66
67    /// Value of [`read_write::ReadWrite::now`] that was fed by the latest call to
68    /// [`WithBuffers::read_write_access`].
69    read_write_now: Option<TNow>,
70    /// Value of [`read_write::ReadWrite::wake_up_after`] produced by the latest call
71    /// to [`WithBuffers::read_write_access`].
72    read_write_wake_up_after: Option<TNow>,
73}
74
75#[pin_project::pin_project(project = SocketProj)]
76enum Socket<TSocketFut, TSocket> {
77    Pending(#[pin] TSocketFut),
78    Resolved(#[pin] TSocket),
79}
80
81impl<TSocketFut, TSocket, TNow> WithBuffers<TSocketFut, TSocket, TNow>
82where
83    TNow: Clone + Ord,
84{
85    /// Initializes a new [`WithBuffers`] with the given socket-yielding future.
86    pub fn new(socket: TSocketFut) -> Self {
87        let read_buffer_reasonable_capacity = 65536; // TODO: make configurable?
88
89        WithBuffers {
90            socket: Socket::Pending(socket),
91            error: None,
92            read_buffer: Vec::with_capacity(read_buffer_reasonable_capacity),
93            read_buffer_valid: 0,
94            read_buffer_reasonable_capacity,
95            read_closed: false,
96            write_buffers: Vec::with_capacity(64),
97            write_closed: false,
98            close_pending: false,
99            flush_pending: false,
100            read_write_now: None,
101            read_write_wake_up_after: None,
102        }
103    }
104
105    /// Returns an object that implements `Deref<Target = ReadWrite>`. This object can be used
106    /// to push or pull data to/from the socket.
107    ///
108    /// > **Note**: The parameter requires `Self` to be pinned for consistency with
109    /// >           [`WithBuffers::wait_read_write_again`].
110    pub fn read_write_access(
111        self: Pin<&mut Self>,
112        now: TNow,
113    ) -> Result<ReadWriteAccess<TNow>, &io::Error> {
114        let this = self.project();
115
116        debug_assert!(
117            this.read_write_now
118                .as_ref()
119                .map_or(true, |old_now| *old_now <= now)
120        );
121        *this.read_write_wake_up_after = None;
122        *this.read_write_now = Some(now.clone());
123
124        if let Some(error) = this.error.as_ref() {
125            return Err(error);
126        }
127
128        this.read_buffer.truncate(*this.read_buffer_valid);
129
130        let is_resolved = matches!(*this.socket, Socket::Resolved(_));
131
132        let write_bytes_queued = this.write_buffers.iter().map(Vec::len).sum();
133
134        Ok(ReadWriteAccess {
135            read_buffer_len_before: this.read_buffer.len(),
136            write_buffers_len_before: this.write_buffers.len(),
137            read_write: read_write::ReadWrite {
138                now,
139                incoming_buffer: mem::take(this.read_buffer),
140                expected_incoming_bytes: if !*this.read_closed { Some(0) } else { None },
141                read_bytes: 0,
142                write_bytes_queued,
143                write_buffers: mem::take(this.write_buffers),
144                write_bytes_queueable: if !is_resolved {
145                    Some(0)
146                } else if !*this.write_closed {
147                    // Limit outgoing buffer size to 128kiB.
148                    // TODO: make configurable?
149                    Some((128 * 1024usize).saturating_sub(write_bytes_queued))
150                } else {
151                    None
152                },
153                wake_up_after: this.read_write_wake_up_after.take(),
154            },
155            read_buffer: this.read_buffer,
156            read_buffer_valid: this.read_buffer_valid,
157            read_buffer_reasonable_capacity: *this.read_buffer_reasonable_capacity,
158            write_buffers: this.write_buffers,
159            write_closed: this.write_closed,
160            close_pending: this.close_pending,
161            read_write_wake_up_after: this.read_write_wake_up_after,
162        })
163    }
164}
165
166impl<TSocketFut, TSocket, TNow> WithBuffers<TSocketFut, TSocket, TNow>
167where
168    TSocket: AsyncRead + AsyncWrite,
169    TSocketFut: Future<Output = Result<TSocket, io::Error>>,
170    TNow: Clone + Ord,
171{
172    /// Waits until [`WithBuffers::read_write_access`] should be called again.
173    ///
174    /// Returns immediately if [`WithBuffers::read_write_access`] has never been called.
175    ///
176    /// Returns if an error happens on the socket. If an error happened in the past on the socket,
177    /// the future never yields.
178    pub async fn wait_read_write_again<F>(
179        self: Pin<&mut Self>,
180        timer_builder: impl FnOnce(TNow) -> F,
181    ) where
182        F: Future<Output = ()>,
183    {
184        let mut this = self.project();
185
186        // Return immediately if `read_write_access` was never called or if `wake_up_after <= now`.
187        match (&*this.read_write_wake_up_after, &*this.read_write_now) {
188            (_, None) => return,
189            (Some(when_wake_up), Some(now)) if *when_wake_up <= *now => {
190                return;
191            }
192            _ => {}
193        }
194
195        let mut timer = pin::pin!({
196            let fut = this
197                .read_write_wake_up_after
198                .as_ref()
199                .map(|when| timer_builder(when.clone()));
200            async {
201                if let Some(fut) = fut {
202                    fut.await;
203                } else {
204                    future::pending::<()>().await;
205                }
206            }
207        });
208
209        // Grow the read buffer in order to make space for potentially more data.
210        this.read_buffer.resize(this.read_buffer.capacity(), 0);
211
212        future::poll_fn(move |cx| {
213            if this.error.is_some() {
214                // Never return.
215                return Poll::Pending;
216            }
217
218            // If still `true` at the end of the function, `Poll::Pending` is returned.
219            let mut pending = true;
220
221            match Future::poll(Pin::new(&mut timer), cx) {
222                Poll::Pending => {}
223                Poll::Ready(()) => {
224                    pending = false;
225                }
226            }
227
228            match this.socket.as_mut().project() {
229                SocketProj::Pending(future) => match Future::poll(future, cx) {
230                    Poll::Pending => {}
231                    Poll::Ready(Ok(socket)) => {
232                        this.socket.set(Socket::Resolved(socket));
233                        pending = false;
234                    }
235                    Poll::Ready(Err(err)) => {
236                        *this.error = Some(err);
237                        return Poll::Ready(());
238                    }
239                },
240                SocketProj::Resolved(mut socket) => {
241                    if !*this.read_closed && *this.read_buffer_valid < this.read_buffer.len() {
242                        let read_result = AsyncRead::poll_read(
243                            socket.as_mut(),
244                            cx,
245                            &mut this.read_buffer[*this.read_buffer_valid..],
246                        );
247
248                        match read_result {
249                            Poll::Pending => {}
250                            Poll::Ready(Ok(0)) => {
251                                *this.read_closed = true;
252                                pending = false;
253                            }
254                            Poll::Ready(Ok(n)) => {
255                                *this.read_buffer_valid += n;
256                                // TODO: consider waking up only if the expected bytes of the consumer are exceeded
257                                pending = false;
258                            }
259                            Poll::Ready(Err(err)) => {
260                                *this.error = Some(err);
261                                return Poll::Ready(());
262                            }
263                        };
264                    }
265
266                    loop {
267                        if this.write_buffers.iter().any(|b| !b.is_empty()) {
268                            let write_result = {
269                                let buffers = this
270                                    .write_buffers
271                                    .iter()
272                                    .map(|buf| io::IoSlice::new(buf))
273                                    .collect::<Vec<_>>();
274                                AsyncWrite::poll_write_vectored(socket.as_mut(), cx, &buffers)
275                            };
276
277                            match write_result {
278                                Poll::Ready(Ok(0)) => {
279                                    // It is not legal for `poll_write` to return 0 bytes written.
280                                    unreachable!();
281                                }
282                                Poll::Ready(Ok(mut n)) => {
283                                    *this.flush_pending = true;
284                                    while n > 0 {
285                                        let first_buf = this.write_buffers.first_mut().unwrap();
286                                        if first_buf.len() <= n {
287                                            n -= first_buf.len();
288                                            this.write_buffers.remove(0);
289                                        } else {
290                                            // TODO: consider keeping the buffer as is but starting the next write at a later offset
291                                            first_buf.copy_within(n.., 0);
292                                            first_buf.truncate(first_buf.len() - n);
293                                            break;
294                                        }
295                                    }
296                                    // Wake up if the write buffers switch from non-empty to empty.
297                                    if this.write_buffers.is_empty() {
298                                        pending = false;
299                                    }
300                                }
301                                Poll::Ready(Err(err)) => {
302                                    *this.error = Some(err);
303                                    return Poll::Ready(());
304                                }
305                                Poll::Pending => break,
306                            };
307                        } else if *this.flush_pending {
308                            match AsyncWrite::poll_flush(socket.as_mut(), cx) {
309                                Poll::Ready(Ok(())) => {
310                                    *this.flush_pending = false;
311                                }
312                                Poll::Ready(Err(err)) => {
313                                    *this.error = Some(err);
314                                    return Poll::Ready(());
315                                }
316                                Poll::Pending => break,
317                            }
318                        } else if *this.close_pending {
319                            match AsyncWrite::poll_close(socket.as_mut(), cx) {
320                                Poll::Ready(Ok(())) => {
321                                    *this.close_pending = false;
322                                    pending = false;
323                                    break;
324                                }
325                                Poll::Ready(Err(err)) => {
326                                    *this.error = Some(err);
327                                    return Poll::Ready(());
328                                }
329                                Poll::Pending => break,
330                            }
331                        } else {
332                            break;
333                        }
334                    }
335                }
336            };
337
338            if !pending {
339                Poll::Ready(())
340            } else {
341                Poll::Pending
342            }
343        })
344        .await;
345    }
346}
347
348impl<TSocketFut, TSocket: fmt::Debug, TNow> fmt::Debug for WithBuffers<TSocketFut, TSocket, TNow> {
349    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
350        let mut t = f.debug_tuple("WithBuffers");
351        if let Socket::Resolved(socket) = &self.socket {
352            t.field(socket);
353        } else {
354            t.field(&"<pending>");
355        }
356        t.finish()
357    }
358}
359
360/// See [`WithBuffers::read_write_access`].
361pub struct ReadWriteAccess<'a, TNow: Clone> {
362    read_write: read_write::ReadWrite<TNow>,
363
364    read_buffer_len_before: usize,
365    write_buffers_len_before: usize,
366
367    // Fields below as references from the content of the `WithBuffers`.
368    read_buffer: &'a mut Vec<u8>,
369    read_buffer_valid: &'a mut usize,
370    read_buffer_reasonable_capacity: usize,
371    write_buffers: &'a mut Vec<Vec<u8>>,
372    write_closed: &'a mut bool,
373    close_pending: &'a mut bool,
374    read_write_wake_up_after: &'a mut Option<TNow>,
375}
376
377impl<'a, TNow: Clone> ops::Deref for ReadWriteAccess<'a, TNow> {
378    type Target = read_write::ReadWrite<TNow>;
379
380    fn deref(&self) -> &Self::Target {
381        &self.read_write
382    }
383}
384
385impl<'a, TNow: Clone> ops::DerefMut for ReadWriteAccess<'a, TNow> {
386    fn deref_mut(&mut self) -> &mut Self::Target {
387        &mut self.read_write
388    }
389}
390
391impl<'a, TNow: Clone> Drop for ReadWriteAccess<'a, TNow> {
392    fn drop(&mut self) {
393        *self.read_buffer = mem::take(&mut self.read_write.incoming_buffer);
394        *self.read_buffer_valid = self.read_buffer.len();
395
396        // Adjust `read_buffer` to the number of bytes requested by the consumer.
397        if let Some(expected_incoming_bytes) = self.read_write.expected_incoming_bytes {
398            if expected_incoming_bytes < self.read_buffer_reasonable_capacity
399                && self.read_buffer.is_empty()
400            {
401                // We use `shrink_to(0)` then `reserve(cap)` rather than just `shrink_to(cap)`
402                // so that the `Vec` doesn't try to preserve the data in the read buffer.
403                self.read_buffer.shrink_to(0);
404                self.read_buffer
405                    .reserve(self.read_buffer_reasonable_capacity);
406            } else if expected_incoming_bytes > self.read_buffer.len() {
407                self.read_buffer
408                    .reserve(expected_incoming_bytes - self.read_buffer.len());
409            }
410            debug_assert!(self.read_buffer.capacity() >= expected_incoming_bytes);
411        }
412
413        *self.write_buffers = mem::take(&mut self.read_write.write_buffers);
414
415        if self.read_write.write_bytes_queueable.is_none() && !*self.write_closed {
416            *self.write_closed = true;
417            *self.close_pending = true;
418        }
419
420        *self.read_write_wake_up_after = self.read_write.wake_up_after.take();
421
422        // If the consumer has advanced its reading or writing sides, we make the next call to
423        // `read_write_access` return immediately by setting `wake_up_after`.
424        if (self.read_buffer_len_before != self.read_buffer.len()
425            && self
426                .read_write
427                .expected_incoming_bytes
428                .map_or(false, |b| b <= self.read_buffer.len()))
429            || (self.write_buffers_len_before != self.write_buffers.len() && !*self.write_closed)
430        {
431            *self.read_write_wake_up_after = Some(self.read_write.now.clone());
432        }
433    }
434}
435
436// TODO: tests