1
//! Mocking helpers for testing with futures::io types.
2
//!
3
//! Note that some of this code might be of general use, but for now
4
//! we're only trying it for testing.
5

            
6
use futures::channel::mpsc;
7
use futures::io::{AsyncRead, AsyncWrite};
8
use futures::sink::{Sink, SinkExt};
9
use futures::stream::Stream;
10
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
11
use std::pin::Pin;
12
use std::task::{Context, Poll};
13

            
14
/// Channel capacity for our internal MPSC channels.
15
///
16
/// We keep this intentionally low to make sure that some blocking
17
/// will occur occur.
18
const CAPACITY: usize = 4;
19

            
20
/// Maximum size for a queued buffer on a local chunk.
21
///
22
/// This size is deliberately weird, to try to find errors.
23
const CHUNKSZ: usize = 213;
24

            
25
/// Construct a new pair of linked LocalStream objects.
26
///
27
/// Any bytes written to one will be readable on the other, and vice
28
/// versa.  These streams will behave more or less like a socketpair,
29
/// except without actually going through the operating system.
30
///
31
/// Note that this implementation is intended for testing only, and
32
/// isn't optimized.
33
57
pub fn stream_pair() -> (LocalStream, LocalStream) {
34
57
    let (w1, r2) = mpsc::channel(CAPACITY);
35
57
    let (w2, r1) = mpsc::channel(CAPACITY);
36
57
    let s1 = LocalStream {
37
57
        w: w1,
38
57
        r: r1,
39
57
        pending_bytes: Vec::new(),
40
57
        tls_cert: None,
41
57
    };
42
57
    let s2 = LocalStream {
43
57
        w: w2,
44
57
        r: r2,
45
57
        pending_bytes: Vec::new(),
46
57
        tls_cert: None,
47
57
    };
48
57
    (s1, s2)
49
57
}
50

            
51
/// One half of a pair of linked streams returned by [`stream_pair`].
52
//
53
// Implementation notes: linked streams are made out a pair of mpsc
54
// channels.  There's one channel for sending bytes in each direction.
55
// Bytes are sent as IoResult<Vec<u8>>: sending an error causes an error
56
// to occur on the other side.
57
pub struct LocalStream {
58
    /// The writing side of the channel that we use to implement this
59
    /// stream.
60
    ///
61
    /// The reading side is held by the other linked stream.
62
    w: mpsc::Sender<IoResult<Vec<u8>>>,
63
    /// The reading side of the channel that we use to implement this
64
    /// stream.
65
    ///
66
    /// The writing side is held by the other linked stream.
67
    r: mpsc::Receiver<IoResult<Vec<u8>>>,
68
    /// Bytes that we have read from `r` but not yet delivered.
69
    pending_bytes: Vec<u8>,
70
    /// Data about the other side of this stream's fake TLS certificate, if any.
71
    /// If this is present, I/O operations will fail with an error.
72
    ///
73
    /// How this is intended to work: things that return `LocalStream`s that could potentially
74
    /// be connected to a fake TLS listener should set this field. Then, a fake TLS wrapper
75
    /// type would clear this field (after checking its contents are as expected).
76
    ///
77
    /// FIXME(eta): this is a bit of a layering violation, but it's hard to do otherwise
78
    pub(crate) tls_cert: Option<Vec<u8>>,
79
}
80

            
81
/// Helper: pull bytes off the front of `pending_bytes` and put them
82
/// onto `buf.  Return the number of bytes moved.
83
3767
fn drain_helper(buf: &mut [u8], pending_bytes: &mut Vec<u8>) -> usize {
84
3767
    let n_to_drain = std::cmp::min(buf.len(), pending_bytes.len());
85
3767
    buf[..n_to_drain].copy_from_slice(&pending_bytes[..n_to_drain]);
86
3767
    pending_bytes.drain(..n_to_drain);
87
3767
    n_to_drain
88
3767
}
89

            
90
impl AsyncRead for LocalStream {
91
4226
    fn poll_read(
92
4226
        mut self: Pin<&mut Self>,
93
4226
        cx: &mut Context<'_>,
94
4226
        buf: &mut [u8],
95
4226
    ) -> Poll<IoResult<usize>> {
96
4226
        if buf.is_empty() {
97
            return Poll::Ready(Ok(0));
98
4226
        }
99
4226
        if self.tls_cert.is_some() {
100
            return Poll::Ready(Err(std::io::Error::new(
101
                std::io::ErrorKind::Other,
102
                "attempted to treat a TLS stream as non-TLS!",
103
            )));
104
4226
        }
105
4226
        if !self.pending_bytes.is_empty() {
106
2900
            return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes)));
107
1326
        }
108

            
109
1326
        match futures::ready!(Pin::new(&mut self.r).poll_next(cx)) {
110
1
            Some(Err(e)) => Poll::Ready(Err(e)),
111
875
            Some(Ok(bytes)) => {
112
875
                self.pending_bytes = bytes;
113
875
                let n = drain_helper(buf, &mut self.pending_bytes);
114
875
                Poll::Ready(Ok(n))
115
            }
116
85
            None => Poll::Ready(Ok(0)), // This is an EOF
117
        }
118
4234
    }
119
}
120

            
121
impl AsyncWrite for LocalStream {
122
1063
    fn poll_write(
123
1063
        mut self: Pin<&mut Self>,
124
1063
        cx: &mut Context<'_>,
125
1063
        buf: &[u8],
126
1063
    ) -> Poll<IoResult<usize>> {
127
1063
        if self.tls_cert.is_some() {
128
            return Poll::Ready(Err(std::io::Error::new(
129
                std::io::ErrorKind::Other,
130
                "attempted to treat a TLS stream as non-TLS!",
131
            )));
132
1063
        }
133

            
134
1063
        match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) {
135
896
            Ok(()) => (),
136
1
            Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
137
        }
138

            
139
896
        let buf = if buf.len() > CHUNKSZ {
140
809
            &buf[..CHUNKSZ]
141
        } else {
142
87
            buf
143
        };
144
896
        let len = buf.len();
145
896
        match Pin::new(&mut self.w).start_send(Ok(buf.to_vec())) {
146
896
            Ok(()) => Poll::Ready(Ok(len)),
147
            Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
148
        }
149
1059
    }
150
28
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
151
28
        Pin::new(&mut self.w)
152
28
            .poll_flush(cx)
153
28
            .map_err(|e| IoError::new(ErrorKind::BrokenPipe, e))
154
28
    }
155
77
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
156
77
        Pin::new(&mut self.w)
157
77
            .poll_close(cx)
158
77
            .map_err(|e| IoError::new(ErrorKind::Other, e))
159
77
    }
160
}
161

            
162
/// An error generated by [`LocalStream::send_err`].
163
#[derive(Debug, Clone, Eq, PartialEq)]
164
#[non_exhaustive]
165
pub struct SyntheticError;
166
impl std::error::Error for SyntheticError {}
167
impl std::fmt::Display for SyntheticError {
168
1
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169
1
        write!(f, "Synthetic error")
170
1
    }
171
}
172

            
173
impl LocalStream {
174
    /// Send an error to the other linked local stream.
175
    ///
176
    /// When the other stream reads this message, it will generate a
177
    /// [`std::io::Error`] with the provided `ErrorKind`.
178
1
    pub async fn send_err(&mut self, kind: ErrorKind) {
179
1
        let _ignore = self.w.send(Err(IoError::new(kind, SyntheticError))).await;
180
1
    }
181
}
182

            
183
#[cfg(test)]
184
mod test {
185
    #![allow(clippy::unwrap_used)]
186
    use super::*;
187

            
188
    use futures::io::{AsyncReadExt, AsyncWriteExt};
189
    use futures_await_test::async_test;
190
    use rand::thread_rng;
191
    use rand::Rng;
192

            
193
    #[async_test]
194
    async fn basic_rw() {
195
        let (mut s1, mut s2) = stream_pair();
196
        let mut text1 = vec![0_u8; 9999];
197
        thread_rng().fill(&mut text1[..]);
198

            
199
        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
200
            async {
201
                for _ in 0_u8..10 {
202
                    s1.write_all(&text1[..]).await?;
203
                }
204
                s1.close().await?;
205
                Ok(())
206
            },
207
            async {
208
                let mut text2: Vec<u8> = Vec::new();
209
                let mut buf = [0_u8; 33];
210
                loop {
211
                    let n = s2.read(&mut buf[..]).await?;
212
                    if n == 0 {
213
                        break;
214
                    }
215
                    text2.extend(&buf[..n]);
216
                }
217
                for ch in text2[..].chunks(text1.len()) {
218
                    assert_eq!(ch, &text1[..]);
219
                }
220
                Ok(())
221
            }
222
        );
223

            
224
        v1.unwrap();
225
        v2.unwrap();
226
    }
227

            
228
    #[async_test]
229
    async fn send_error() {
230
        let (mut s1, mut s2) = stream_pair();
231

            
232
        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
233
            async {
234
                s1.write_all(b"hello world").await?;
235
                s1.send_err(ErrorKind::PermissionDenied).await;
236
                Ok(())
237
            },
238
            async {
239
                let mut buf = [0_u8; 33];
240
                loop {
241
                    let n = s2.read(&mut buf[..]).await?;
242
                    if n == 0 {
243
                        break;
244
                    }
245
                }
246
                Ok(())
247
            }
248
        );
249

            
250
        v1.unwrap();
251
        let e = v2.err().unwrap();
252
        assert_eq!(e.kind(), ErrorKind::PermissionDenied);
253
        let synth = e.into_inner().unwrap();
254
        assert_eq!(synth.to_string(), "Synthetic error");
255
    }
256

            
257
    #[async_test]
258
    async fn drop_reader() {
259
        let (mut s1, s2) = stream_pair();
260

            
261
        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
262
            async {
263
                for _ in 0_u16..1000 {
264
                    s1.write_all(&[9_u8; 9999]).await?;
265
                }
266
                Ok(())
267
            },
268
            async {
269
                drop(s2);
270
                Ok(())
271
            }
272
        );
273

            
274
        v2.unwrap();
275
        let e = v1.err().unwrap();
276
        assert_eq!(e.kind(), ErrorKind::BrokenPipe);
277
    }
278
}