1
//! Declare MockNetRuntime.
2

            
3
// TODO(nickm): This is mostly copy-paste from MockSleepRuntime.  If possible,
4
// we should make it so that more code is more shared.
5

            
6
use crate::net::MockNetProvider;
7
use tor_rtcompat::{BlockOn, Runtime, SleepProvider, TcpProvider, TlsProvider};
8

            
9
use crate::io::LocalStream;
10
use async_trait::async_trait;
11
use futures::task::{FutureObj, Spawn, SpawnError};
12
use futures::Future;
13
use std::io::Result as IoResult;
14
use std::net::SocketAddr;
15
use std::time::{Duration, Instant, SystemTime};
16

            
17
/// A wrapper Runtime that overrides the SleepProvider trait for the
18
/// underlying runtime.
19
#[derive(Clone)]
20
pub struct MockNetRuntime<R: Runtime> {
21
    /// The underlying runtime. Most calls get delegated here.
22
    runtime: R,
23
    /// A MockNetProvider.  Time-related calls get delegated here.
24
    net: MockNetProvider,
25
}
26

            
27
impl<R: Runtime> MockNetRuntime<R> {
28
    /// Create a new runtime that wraps `runtime`, but overrides
29
    /// its view of the network with a [`MockNetProvider`], `net`.
30
2
    pub fn new(runtime: R, net: MockNetProvider) -> Self {
31
2
        MockNetRuntime { runtime, net }
32
2
    }
33

            
34
    /// Return a reference to the underlying runtime.
35
    pub fn inner(&self) -> &R {
36
        &self.runtime
37
    }
38

            
39
    /// Return a reference to the [`MockNetProvider`]
40
1
    pub fn mock_net(&self) -> &MockNetProvider {
41
1
        &self.net
42
1
    }
43
}
44

            
45
impl<R: Runtime> Spawn for MockNetRuntime<R> {
46
1
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
47
1
        self.runtime.spawn_obj(future)
48
1
    }
49
}
50

            
51
impl<R: Runtime> BlockOn for MockNetRuntime<R> {
52
    fn block_on<F: Future>(&self, future: F) -> F::Output {
53
        self.runtime.block_on(future)
54
    }
55
}
56

            
57
#[async_trait]
58
impl<R: Runtime> TcpProvider for MockNetRuntime<R> {
59
    type TcpStream = <MockNetProvider as TcpProvider>::TcpStream;
60
    type TcpListener = <MockNetProvider as TcpProvider>::TcpListener;
61

            
62
1
    async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::TcpStream> {
63
1
        self.net.connect(addr).await
64
2
    }
65
    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::TcpListener> {
66
        self.net.listen(addr).await
67
    }
68
}
69

            
70
impl<R: Runtime> TlsProvider<LocalStream> for MockNetRuntime<R> {
71
    type Connector = <MockNetProvider as TlsProvider<LocalStream>>::Connector;
72
    type TlsStream = <MockNetProvider as TlsProvider<LocalStream>>::TlsStream;
73
1
    fn tls_connector(&self) -> Self::Connector {
74
1
        self.net.tls_connector()
75
1
    }
76
}
77

            
78
impl<R: Runtime> SleepProvider for MockNetRuntime<R> {
79
    type SleepFuture = R::SleepFuture;
80
    fn sleep(&self, dur: Duration) -> Self::SleepFuture {
81
        self.runtime.sleep(dur)
82
    }
83
    fn now(&self) -> Instant {
84
        self.runtime.now()
85
    }
86
    fn wallclock(&self) -> SystemTime {
87
        self.runtime.wallclock()
88
    }
89
}