1
//! Wrap tor_cell::...:::ChannelCodec for use with the futures_codec
2
//! crate.
3
use std::io::Error as IoError;
4

            
5
use tor_cell::chancell::{codec, ChanCell};
6

            
7
use asynchronous_codec as futures_codec;
8
use bytes::BytesMut;
9

            
10
/// An error from a ChannelCodec.
11
///
12
/// This is a separate error type for now because I suspect that we'll want to
13
/// handle these differently in the rest of our channel code.
14
#[derive(Debug, thiserror::Error)]
15
pub(crate) enum CodecError {
16
    /// An error from the underlying IO stream underneath a codec.
17
    ///
18
    /// (This isn't wrapped in an Arc, because we don't need this type to be
19
    /// clone; it's crate-internal.)
20
    #[error("Io error")]
21
    Io(#[from] IoError),
22
    /// An error from the cell encoding/decoding logic.
23
    #[error("encoding/decoding error")]
24
    Cell(#[from] tor_cell::Error),
25
}
26

            
27
/// Asynchronous wrapper around ChannelCodec in tor_cell, with implementation
28
/// for use with futures_codec.
29
///
30
/// This type lets us wrap a TLS channel (or some other secure
31
/// AsyncRead+AsyncWrite type) as a Sink and a Stream of ChanCell, so we
32
/// can forget about byte-oriented communication.
33
pub(crate) struct ChannelCodec(codec::ChannelCodec);
34

            
35
impl ChannelCodec {
36
    /// Create a new ChannelCodec with a given link protocol.
37
44
    pub(crate) fn new(link_proto: u16) -> Self {
38
44
        ChannelCodec(codec::ChannelCodec::new(link_proto))
39
44
    }
40
}
41

            
42
impl futures_codec::Encoder for ChannelCodec {
43
    type Item = ChanCell;
44
    type Error = CodecError;
45

            
46
    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
47
26
        self.0.write_cell(item, dst)?;
48
26
        Ok(())
49
26
    }
50
}
51

            
52
impl futures_codec::Decoder for ChannelCodec {
53
    type Item = ChanCell;
54
    type Error = CodecError;
55

            
56
272
    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
57
272
        Ok(self.0.decode_cell(src)?)
58
272
    }
59
}
60

            
61
#[cfg(test)]
62
pub(crate) mod test {
63
    #![allow(clippy::unwrap_used)]
64
    use futures::io::{AsyncRead, AsyncWrite, Cursor, Result};
65
    use futures::sink::SinkExt;
66
    use futures::stream::StreamExt;
67
    use futures::task::{Context, Poll};
68
    use hex_literal::hex;
69
    use std::pin::Pin;
70

            
71
    use super::{futures_codec, ChannelCodec};
72
    use tor_cell::chancell::{msg, ChanCell, ChanCmd, CircId};
73

            
74
    /// Helper type for reading and writing bytes to/from buffers.
75
    // TODO: We might want to move this
76
    pub(crate) struct MsgBuf {
77
        /// Data we have received as a reader.
78
        inbuf: futures::io::Cursor<Vec<u8>>,
79
        /// Data we write as a writer.
80
        outbuf: futures::io::Cursor<Vec<u8>>,
81
    }
82

            
83
    impl AsyncRead for MsgBuf {
84
28
        fn poll_read(
85
28
            mut self: Pin<&mut Self>,
86
28
            cx: &mut Context<'_>,
87
28
            buf: &mut [u8],
88
28
        ) -> Poll<Result<usize>> {
89
28
            Pin::new(&mut self.inbuf).poll_read(cx, buf)
90
28
        }
91
    }
92
    impl AsyncWrite for MsgBuf {
93
18
        fn poll_write(
94
18
            mut self: Pin<&mut Self>,
95
18
            cx: &mut Context<'_>,
96
18
            buf: &[u8],
97
18
        ) -> Poll<Result<usize>> {
98
18
            Pin::new(&mut self.outbuf).poll_write(cx, buf)
99
18
        }
100
22
        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
101
22
            Pin::new(&mut self.outbuf).poll_flush(cx)
102
22
        }
103
        fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
104
            Pin::new(&mut self.outbuf).poll_close(cx)
105
        }
106
    }
107

            
108
    impl MsgBuf {
109
30
        pub(crate) fn new<T: Into<Vec<u8>>>(output: T) -> Self {
110
30
            let inbuf = Cursor::new(output.into());
111
30
            let outbuf = Cursor::new(Vec::new());
112
30
            MsgBuf { inbuf, outbuf }
113
30
        }
114

            
115
4
        pub(crate) fn consumed(&self) -> usize {
116
4
            self.inbuf.position() as usize
117
4
        }
118

            
119
4
        pub(crate) fn all_consumed(&self) -> bool {
120
4
            self.inbuf.get_ref().len() == self.consumed()
121
4
        }
122

            
123
4
        pub(crate) fn into_response(self) -> Vec<u8> {
124
4
            self.outbuf.into_inner()
125
4
        }
126
    }
127

            
128
8
    fn frame_buf(mbuf: MsgBuf) -> futures_codec::Framed<MsgBuf, ChannelCodec> {
129
8
        futures_codec::Framed::new(mbuf, ChannelCodec::new(4))
130
8
    }
131

            
132
1
    #[test]
133
1
    fn check_encoding() {
134
4
        tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
135
4
            let mb = MsgBuf::new(&b""[..]);
136
4
            let mut framed = frame_buf(mb);
137
4

            
138
4
            let destroycell = msg::Destroy::new(2.into());
139
4
            framed
140
4
                .send(ChanCell::new(7.into(), destroycell.into()))
141
                .await
142
4
                .unwrap();
143
4

            
144
4
            let nocerts = msg::Certs::new_empty();
145
4
            framed
146
4
                .send(ChanCell::new(0.into(), nocerts.into()))
147
                .await
148
4
                .unwrap();
149
4

            
150
4
            framed.flush().await.unwrap();
151
4

            
152
4
            let data = framed.into_inner().into_response();
153
4

            
154
4
            assert_eq!(&data[0..10], &hex!("00000007 04 0200000000")[..]);
155

            
156
4
            assert_eq!(&data[514..], &hex!("00000000 81 0001 00")[..]);
157
4
        });
158
1
    }
159

            
160
1
    #[test]
161
1
    fn check_decoding() {
162
4
        tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
163
4
            let mut dat = Vec::new();
164
4
            dat.extend_from_slice(&hex!("00000007 04 0200000000")[..]);
165
4
            dat.resize(514, 0);
166
4
            dat.extend_from_slice(&hex!("00000000 81 0001 00")[..]);
167
4
            let mb = MsgBuf::new(&dat[..]);
168
4
            let mut framed = frame_buf(mb);
169

            
170
4
            let destroy = framed.next().await.unwrap().unwrap();
171
4
            let nocerts = framed.next().await.unwrap().unwrap();
172
4

            
173
4
            assert_eq!(destroy.circid(), CircId::from(7));
174
4
            assert_eq!(destroy.msg().cmd(), ChanCmd::DESTROY);
175
4
            assert_eq!(nocerts.circid(), CircId::from(0));
176
4
            assert_eq!(nocerts.msg().cmd(), ChanCmd::CERTS);
177

            
178
4
            assert!(framed.into_inner().all_consumed());
179
4
        });
180
1
    }
181
}