1
//! Declare MockSleepRuntime.
2

            
3
use crate::time::MockSleepProvider;
4
use tor_rtcompat::{BlockOn, Runtime, SleepProvider, TcpProvider, TlsProvider};
5

            
6
use async_trait::async_trait;
7
use futures::task::{FutureObj, Spawn, SpawnError};
8
use futures::Future;
9
use pin_project::pin_project;
10
use std::io::Result as IoResult;
11
use std::net::SocketAddr;
12
use std::time::{Duration, Instant, SystemTime};
13
use tracing::trace;
14

            
15
/// A wrapper Runtime that overrides the SleepProvider trait for the
16
/// underlying runtime.
17
266
#[derive(Clone)]
18
pub struct MockSleepRuntime<R: Runtime> {
19
    /// The underlying runtime. Most calls get delegated here.
20
    runtime: R,
21
    /// A MockSleepProvider.  Time-related calls get delegated here.
22
    sleep: MockSleepProvider,
23
}
24

            
25
impl<R: Runtime> MockSleepRuntime<R> {
26
    /// Create a new runtime that wraps `runtime`, but overrides
27
    /// its view of time with a [`MockSleepProvider`].
28
42
    pub fn new(runtime: R) -> Self {
29
42
        let sleep = MockSleepProvider::new(SystemTime::now());
30
42
        MockSleepRuntime { runtime, sleep }
31
42
    }
32

            
33
    /// Return a reference to the underlying runtime.
34
    pub fn inner(&self) -> &R {
35
        &self.runtime
36
    }
37

            
38
    /// Return a reference to the [`MockSleepProvider`]
39
    pub fn mock_sleep(&self) -> &MockSleepProvider {
40
        &self.sleep
41
    }
42

            
43
    /// See [`MockSleepProvider::advance()`]
44
2
    pub async fn advance(&self, dur: Duration) {
45
2
        self.sleep.advance(dur).await;
46
2
    }
47
    /// See [`MockSleepProvider::jump_to()`]
48
1
    pub fn jump_to(&self, new_wallclock: SystemTime) {
49
1
        self.sleep.jump_to(new_wallclock);
50
1
    }
51
    /// Run a future under mock time, advancing time forward where necessary until it completes.
52
    /// Users of this function should read the whole of this documentation before using!
53
    ///
54
    /// The returned future will run `fut`, expecting it to create `Sleeping` futures (as returned
55
    /// by `MockSleepProvider::sleep()` and similar functions). When all such created futures have
56
    /// been polled (indicating the future is waiting on them), time will be advanced in order that
57
    /// the first (or only) of said futures returns `Ready`. This process then repeats until `fut`
58
    /// returns `Ready` itself (as in, the returned wrapper future will wait for all created
59
    /// `Sleeping` futures to be polled, and advance time again).
60
    ///
61
    /// **Note:** The above described algorithm interacts poorly with futures that spawn
62
    /// asynchronous background tasks, or otherwise expect work to complete in the background
63
    /// before time is advanced. These futures will need to make use of the
64
    /// `SleepProvider::block_advance` (and similar) APIs in order to prevent time advancing while
65
    /// said tasks complete; see the documentation for those APIs for more detail.
66
    ///
67
    /// # Panics
68
    ///
69
    /// Panics if another `WaitFor` future is already running. (If two ran simultaneously, they
70
    /// would both try and advance the same mock time clock, which would be bad.)
71
60
    pub fn wait_for<F: futures::Future>(&self, fut: F) -> WaitFor<F> {
72
60
        assert!(
73
60
            !self.sleep.has_waitfor_waker(),
74
60
            "attempted to call MockSleepRuntime::wait_for while another WaitFor is active"
75
60
        );
76
60
        WaitFor {
77
60
            sleep: self.sleep.clone(),
78
60
            fut,
79
60
        }
80
60
    }
81
}
82

            
83
impl<R: Runtime> Spawn for MockSleepRuntime<R> {
84
97
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
85
97
        self.runtime.spawn_obj(future)
86
97
    }
87
}
88

            
89
impl<R: Runtime> BlockOn for MockSleepRuntime<R> {
90
    fn block_on<F: Future>(&self, future: F) -> F::Output {
91
        self.runtime.block_on(future)
92
    }
93
}
94

            
95
#[async_trait]
96
impl<R: Runtime> TcpProvider for MockSleepRuntime<R> {
97
    type TcpStream = R::TcpStream;
98
    type TcpListener = R::TcpListener;
99

            
100
1
    async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::TcpStream> {
101
1
        self.runtime.connect(addr).await
102
2
    }
103
    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::TcpListener> {
104
        self.runtime.listen(addr).await
105
    }
106
}
107

            
108
impl<R: Runtime> TlsProvider<R::TcpStream> for MockSleepRuntime<R> {
109
    type Connector = R::Connector;
110
    type TlsStream = R::TlsStream;
111
17
    fn tls_connector(&self) -> Self::Connector {
112
17
        self.runtime.tls_connector()
113
17
    }
114
}
115

            
116
impl<R: Runtime> SleepProvider for MockSleepRuntime<R> {
117
    type SleepFuture = crate::time::Sleeping;
118
321
    fn sleep(&self, dur: Duration) -> Self::SleepFuture {
119
321
        self.sleep.sleep(dur)
120
321
    }
121
286
    fn now(&self) -> Instant {
122
286
        self.sleep.now()
123
286
    }
124
1
    fn wallclock(&self) -> SystemTime {
125
1
        self.sleep.wallclock()
126
1
    }
127
81
    fn block_advance<T: Into<String>>(&self, reason: T) {
128
81
        self.sleep.block_advance(reason);
129
81
    }
130
45
    fn release_advance<T: Into<String>>(&self, reason: T) {
131
45
        self.sleep.release_advance(reason);
132
45
    }
133
126
    fn allow_one_advance(&self, dur: Duration) {
134
126
        self.sleep.allow_one_advance(dur);
135
126
    }
136
}
137

            
138
/// A future that advances time until another future is ready to complete.
139
404
#[pin_project]
140
pub struct WaitFor<F> {
141
    /// A reference to the sleep provider that's simulating time for us.
142
    #[pin]
143
    sleep: MockSleepProvider,
144
    /// The future that we're waiting for.
145
    #[pin]
146
    fut: F,
147
}
148

            
149
use std::pin::Pin;
150
use std::task::{Context, Poll};
151

            
152
impl<F: Future> Future for WaitFor<F> {
153
    type Output = F::Output;
154

            
155
405
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156
405
        trace!("waitfor poll");
157
405
        let mut this = self.project();
158
405
        this.sleep.register_waitfor_waker(cx.waker().clone());
159

            
160
405
        if let Poll::Ready(r) = this.fut.poll(cx) {
161
61
            trace!("waitfor done!");
162
60
            this.sleep.clear_waitfor_waker();
163
60
            return Poll::Ready(r);
164
344
        }
165
344
        trace!("waitfor poll complete");
166

            
167
343
        if this.sleep.should_advance() {
168
169
            if let Some(duration) = this.sleep.time_until_next_timeout() {
169
169
                trace!("Advancing by {:?}", duration);
170
169
                this.sleep.advance_noyield(duration);
171
            } else {
172
                // If we get here, something's probably wedged and the test isn't going to complete
173
                // anyway: we were expecting to advance in order to make progress, but we can't.
174
                // If we don't panic, the test will just run forever, which is really annoying, so
175
                // just panic and fail quickly.
176
                panic!("WaitFor told to advance, but didn't have any duration to advance by");
177
            }
178
        } else {
179
174
            trace!("waiting for sleepers to advance");
180
        }
181
343
        Poll::Pending
182
403
    }
183
}