1
//! Functionality for simulating the passage of time in unit tests.
2
//!
3
//! We do this by providing [`MockSleepProvider`], a "SleepProvider"
4
//! instance that can simulate timeouts and retries without requiring
5
//! the actual system clock to advance.
6

            
7
#![allow(clippy::missing_docs_in_private_items)]
8

            
9
use std::{
10
    cmp::{Eq, Ordering, PartialEq, PartialOrd},
11
    collections::BinaryHeap,
12
    pin::Pin,
13
    sync::{Arc, Mutex, Weak},
14
    task::{Context, Poll, Waker},
15
    time::{Duration, Instant, SystemTime},
16
};
17

            
18
use futures::Future;
19
use tracing::trace;
20

            
21
use std::collections::HashSet;
22
use tor_rtcompat::SleepProvider;
23

            
24
/// A dummy [`SleepProvider`] instance for testing.
25
///
26
/// The MockSleepProvider ignores the current time, and instead keeps
27
/// its own view of the current `Instant` and `SystemTime`.  You
28
/// can advance them in-step by calling `advance()`, and you can simulate
29
/// jumps in the system clock by calling `jump()`.
30
///
31
/// This is *not* for production use.
32
326
#[derive(Clone)]
33
pub struct MockSleepProvider {
34
    /// The shared backend for this MockSleepProvider and its futures.
35
    state: Arc<Mutex<SleepSchedule>>,
36
}
37

            
38
/// Shared backend for sleep provider and Sleeping futures.
39
struct SleepSchedule {
40
    /// What time do we pretend it is (monotonic)?  This value only
41
    /// moves forward.
42
    instant: Instant,
43
    /// What time do we pretend it is (wall clock)? This value can move
44
    /// in any way, but usually moves in step with `instant`.
45
    wallclock: SystemTime,
46
    /// Priority queue of events, in the order that we should wake them.
47
    sleepers: BinaryHeap<SleepEntry>,
48
    /// If the mock time system is being driven by a `WaitFor`, holds a `Waker` to wake up that
49
    /// `WaitFor` in order for it to make more progress.
50
    waitfor_waker: Option<Waker>,
51
    /// Number of sleepers instantiated.
52
    sleepers_made: usize,
53
    /// Number of sleepers polled.
54
    sleepers_polled: usize,
55
    /// Whether an advance is needed.
56
    should_advance: bool,
57
    /// A set of reasons why advances shouldn't be allowed right now.
58
    blocked_advance: HashSet<String>,
59
    /// A time up to which advances are allowed, irrespective of them being blocked.
60
    allowed_advance: Duration,
61
}
62

            
63
/// An entry telling us when to wake which future up.
64
struct SleepEntry {
65
    /// The time at which this entry should wake
66
    when: Instant,
67
    /// The Waker to call when the instant has passed.
68
    waker: Waker,
69
}
70

            
71
/// A future returned by [`MockSleepProvider::sleep()`].
72
pub struct Sleeping {
73
    /// The instant when we should become ready.
74
    when: Instant,
75
    /// True if we have pushed this into the queue.
76
    inserted: bool,
77
    /// The schedule to queue ourselves in if we're polled before we're ready.
78
    provider: Weak<Mutex<SleepSchedule>>,
79
}
80

            
81
impl MockSleepProvider {
82
    /// Create a new MockSleepProvider, starting at a given wall-clock time.
83
353
    pub fn new(wallclock: SystemTime) -> Self {
84
353
        let instant = Instant::now();
85
353
        let sleepers = BinaryHeap::new();
86
353
        let state = SleepSchedule {
87
353
            instant,
88
353
            wallclock,
89
353
            sleepers,
90
353
            waitfor_waker: None,
91
353
            sleepers_made: 0,
92
353
            sleepers_polled: 0,
93
353
            should_advance: false,
94
353
            blocked_advance: HashSet::new(),
95
353
            allowed_advance: Duration::from_nanos(0),
96
353
        };
97
353
        MockSleepProvider {
98
353
            state: Arc::new(Mutex::new(state)),
99
353
        }
100
353
    }
101

            
102
    /// Advance the simulated timeline forward by `dur`.
103
    ///
104
    /// Calling this function will wake any pending futures as
105
    /// appropriate, and yield to the scheduler so they get a chance
106
    /// to run.
107
    ///
108
    /// # Limitations
109
    ///
110
    /// This function advances time in one big step.  We might instead
111
    /// want to advance in small steps and make sure that each step's
112
    /// futures can get run before the ones scheduled to run after it.
113
10148
    pub async fn advance(&self, dur: Duration) {
114
2548
        self.advance_noyield(dur);
115
2548
        tor_rtcompat::task::yield_now().await;
116
2540
    }
117

            
118
    /// Advance the simulated timeline forward by `dur`.
119
    ///
120
    /// Calling this function will wake any pending futures as
121
    /// appropriate, but not yield to the scheduler.  Mostly you
122
    /// should call [`advance`](Self::advance) instead.
123
10829
    pub(crate) fn advance_noyield(&self, dur: Duration) {
124
10829
        // It's not so great to unwrap here in general, but since this is
125
10829
        // only testing code we don't really care.
126
10829
        let mut state = self.state.lock().expect("Poisoned lock for state");
127
10829
        state.wallclock += dur;
128
10829
        state.instant += dur;
129
10829
        state.fire();
130
10829
    }
131

            
132
    /// Simulate a discontinuity in the system clock, by jumping to
133
    /// `new_wallclock`.
134
    ///
135
    /// # Panics
136
    ///
137
    /// Panics if we have already panicked while holding the lock on
138
    /// the internal timer state, and the lock is poisoned.
139
37
    pub fn jump_to(&self, new_wallclock: SystemTime) {
140
37
        let mut state = self.state.lock().expect("Poisoned lock for state");
141
37
        state.wallclock = new_wallclock;
142
37
    }
143

            
144
    /// Return the amount of virtual time until the next timeout
145
    /// should elapse.
146
    ///
147
    /// If there are no more timeouts, return None.  If the next
148
    /// timeout should elapse right now, return Some(0).
149
672
    pub(crate) fn time_until_next_timeout(&self) -> Option<Duration> {
150
672
        let state = self.state.lock().expect("Poisoned lock for state");
151
672
        let now = state.instant;
152
672
        state
153
672
            .sleepers
154
672
            .peek()
155
676
            .map(|sleepent| sleepent.when.saturating_duration_since(now))
156
672
    }
157

            
158
    /// Return true if a `WaitFor` driving this sleep provider should advance time in order for
159
    /// futures blocked on sleeping to make progress.
160
    ///
161
    /// NOTE: This function has side-effects; if it returns true, the caller is expected to do an
162
    /// advance before calling it again.
163
1376
    pub(crate) fn should_advance(&mut self) -> bool {
164
1376
        let mut state = self.state.lock().expect("Poisoned lock for state");
165
1376
        if !state.blocked_advance.is_empty() && state.allowed_advance == Duration::from_nanos(0) {
166
            // We've had advances blocked, and don't have any quota for doing allowances while
167
            // blocked left.
168
556
            trace!(
169
                "should_advance = false: blocked by {:?}",
170
                state.blocked_advance
171
            );
172
556
            return false;
173
824
        }
174
824
        if !state.should_advance {
175
            // The advance flag wasn't set.
176
116
            trace!("should_advance = false; bit not previously set");
177
116
            return false;
178
708
        }
179
708
        // Clear the advance flag; we'll either return true and cause an advance to happen,
180
708
        // or the reasons to return false below also imply that the advance flag will be set again
181
708
        // later on.
182
708
        state.should_advance = false;
183
708
        if state.sleepers_polled < state.sleepers_made {
184
            // Something did set the advance flag before, but it's not valid any more now because
185
            // more unpolled sleepers were created.
186
            trace!("should_advance = false; advancing no longer valid");
187
            return false;
188
708
        }
189
708
        if !state.blocked_advance.is_empty() && state.allowed_advance > Duration::from_nanos(0) {
190
            // If we're here, we would've returned earlier due to having advances blocked, but
191
            // we have quota to advance up to a certain time while advances are blocked.
192
            // Let's see when the next timeout is, and whether it falls within that quota.
193
568
            let next_timeout = {
194
568
                let now = state.instant;
195
568
                state
196
568
                    .sleepers
197
568
                    .peek()
198
568
                    .map(|sleepent| sleepent.when.saturating_duration_since(now))
199
            };
200
568
            let next_timeout = match next_timeout {
201
568
                Some(x) => x,
202
                None => {
203
                    // There's no timeout set, so we really shouldn't be here anyway.
204
                    trace!("should_advance = false; allow_one set but no timeout yet");
205
                    return false;
206
                }
207
            };
208
568
            if next_timeout <= state.allowed_advance {
209
                // We can advance up to the next timeout, since it's in our quota.
210
                // Subtract the amount we're going to advance by from said quota.
211
536
                state.allowed_advance -= next_timeout;
212
536
                trace!(
213
                    "WARNING: allowing advance due to allow_one; new allowed is {:?}",
214
                    state.allowed_advance
215
                );
216
            } else {
217
                // The next timeout is too far in the future.
218
32
                trace!(
219
                    "should_advance = false; allow_one set but only up to {:?}, next is {:?}",
220
                    state.allowed_advance,
221
                    next_timeout
222
                );
223
28
                return false;
224
            }
225
136
        }
226
672
        true
227
1372
    }
228

            
229
    /// Register a `Waker` to be woken up when an advance in time is required to make progress.
230
    ///
231
    /// This is used by `WaitFor`.
232
1620
    pub(crate) fn register_waitfor_waker(&mut self, waker: Waker) {
233
1620
        let mut state = self.state.lock().expect("Poisoned lock for state");
234
1620
        state.waitfor_waker = Some(waker);
235
1620
    }
236

            
237
    /// Remove a previously registered `Waker` registered with `register_waitfor_waker()`.
238
240
    pub(crate) fn clear_waitfor_waker(&mut self) {
239
240
        let mut state = self.state.lock().expect("Poisoned lock for state");
240
240
        state.waitfor_waker = None;
241
240
    }
242

            
243
    /// Returns true if a `Waker` has been registered with `register_waitfor_waker()`.
244
    ///
245
    /// This is used to ensure that you don't have two concurrent `WaitFor`s running.
246
240
    pub(crate) fn has_waitfor_waker(&self) -> bool {
247
240
        let state = self.state.lock().expect("Poisoned lock for state");
248
240
        state.waitfor_waker.is_some()
249
240
    }
250
}
251

            
252
impl SleepSchedule {
253
    /// Wake any pending events that are ready according to the
254
    /// current simulated time.
255
10825
    fn fire(&mut self) {
256
10825
        use std::collections::binary_heap::PeekMut;
257
10825

            
258
10825
        let now = self.instant;
259
21533
        while let Some(top) = self.sleepers.peek_mut() {
260
21300
            if now < top.when {
261
10568
                return;
262
10708
            }
263
10708

            
264
10708
            PeekMut::pop(top).waker.wake();
265
        }
266
10801
    }
267

            
268
    /// Add a new SleepEntry to this schedule.
269
11332
    fn push(&mut self, ent: SleepEntry) {
270
11332
        self.sleepers.push(ent);
271
11332
    }
272

            
273
    /// If all sleepers made have been polled, set the advance flag and wake up any `WaitFor` that
274
    /// might be waiting.
275
12816
    fn maybe_advance(&mut self) {
276
12816
        if self.sleepers_polled >= self.sleepers_made {
277
12156
            if let Some(ref waker) = self.waitfor_waker {
278
1812
                trace!("setting advance flag");
279
1812
                self.should_advance = true;
280
1812
                waker.wake_by_ref();
281
            } else {
282
10344
                trace!("would advance, but no waker");
283
            }
284
660
        }
285
12800
    }
286

            
287
    /// Register a sleeper as having been polled, and advance if necessary.
288
11500
    fn increment_poll_count(&mut self) {
289
11500
        self.sleepers_polled += 1;
290
11500
        trace!(
291
            "sleeper polled, {}/{}",
292
            self.sleepers_polled,
293
            self.sleepers_made
294
        );
295
11476
        self.maybe_advance();
296
11476
    }
297
}
298

            
299
impl SleepProvider for MockSleepProvider {
300
    type SleepFuture = Sleeping;
301
11520
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
302
11520
        let mut provider = self.state.lock().expect("Poisoned lock for state");
303
11520
        let when = provider.instant + duration;
304
11520
        // We're making a new sleeper, so register this in the state.
305
11520
        provider.sleepers_made += 1;
306
11520
        trace!(
307
            "sleeper made for {:?}, {}/{}",
308
            duration,
309
            provider.sleepers_polled,
310
            provider.sleepers_made
311
        );
312

            
313
11488
        Sleeping {
314
11488
            when,
315
11488
            inserted: false,
316
11488
            provider: Arc::downgrade(&self.state),
317
11488
        }
318
11488
    }
319

            
320
81
    fn block_advance<T: Into<String>>(&self, reason: T) {
321
81
        let mut provider = self.state.lock().expect("Poisoned lock for state");
322
81
        let reason = reason.into();
323
81
        trace!("advancing blocked: {}", reason);
324
81
        provider.blocked_advance.insert(reason);
325
81
    }
326

            
327
45
    fn release_advance<T: Into<String>>(&self, reason: T) {
328
45
        let mut provider = self.state.lock().expect("Poisoned lock for state");
329
45
        let reason = reason.into();
330
45
        trace!("advancing released: {}", reason);
331
45
        provider.blocked_advance.remove(&reason);
332
45
        if provider.blocked_advance.is_empty() {
333
37
            provider.maybe_advance();
334
37
        }
335
45
    }
336

            
337
504
    fn allow_one_advance(&self, dur: Duration) {
338
504
        let mut provider = self.state.lock().expect("Poisoned lock for state");
339
504
        provider.allowed_advance = Duration::max(provider.allowed_advance, dur);
340
504
        trace!(
341
            "** allow_one_advance fired; may advance up to {:?} **",
342
            provider.allowed_advance
343
        );
344
504
        provider.maybe_advance();
345
504
    }
346

            
347
1211
    fn now(&self) -> Instant {
348
1211
        self.state.lock().expect("Poisoned lock for state").instant
349
1211
    }
350

            
351
15147
    fn wallclock(&self) -> SystemTime {
352
15147
        self.state
353
15147
            .lock()
354
15147
            .expect("Poisoned lock for state")
355
15147
            .wallclock
356
15147
    }
357
}
358

            
359
impl PartialEq for SleepEntry {
360
    fn eq(&self, other: &Self) -> bool {
361
        self.when == other.when
362
    }
363
}
364
impl Eq for SleepEntry {}
365
impl PartialOrd for SleepEntry {
366
12288
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
367
12288
        Some(self.cmp(other))
368
12288
    }
369
}
370
impl Ord for SleepEntry {
371
12280
    fn cmp(&self, other: &Self) -> Ordering {
372
12280
        self.when.cmp(&other.when).reverse()
373
12280
    }
374
}
375

            
376
impl Drop for Sleeping {
377
    fn drop(&mut self) {
378
11404
        if let Some(provider) = Weak::upgrade(&self.provider) {
379
11380
            let mut provider = provider.lock().expect("Poisoned lock for provider");
380
11380
            if !self.inserted {
381
                // A sleeper being dropped will never be polled, so there's no point waiting;
382
                // act as if it's been polled in order to avoid waiting forever.
383
120
                trace!("sleeper dropped, incrementing count");
384
128
                provider.increment_poll_count();
385
128
                self.inserted = true;
386
11260
            }
387
24
        }
388
11412
    }
389
}
390

            
391
impl Future for Sleeping {
392
    type Output = ();
393
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
394
13520
        if let Some(provider) = Weak::upgrade(&self.provider) {
395
13520
            let mut provider = provider.lock().expect("Poisoned lock for provider");
396
13520
            let now = provider.instant;
397
13520

            
398
13520
            if now >= self.when {
399
                // The sleep time's elapsed.
400
744
                if !self.inserted {
401
44
                    // If we never registered this sleeper as being polled, do so now.
402
44
                    provider.increment_poll_count();
403
44
                    self.inserted = true;
404
700
                }
405
744
                if !provider.should_advance {
406
692
                    // The first advance during a `WaitFor` gets triggered by all sleepers that
407
692
                    // have been created being polled.
408
692
                    // However, this only happens once.
409
692
                    // What we do to get around this is have sleepers that return Ready kick off
410
692
                    // another advance, in order to wake the next waiting sleeper.
411
692
                    provider.maybe_advance();
412
692
                }
413
744
                return Poll::Ready(());
414
12776
            }
415
12776
            // dbg!("sleep check with", self.when-now);
416
12776

            
417
12776
            if !self.inserted {
418
11332
                let entry = SleepEntry {
419
11332
                    when: self.when,
420
11332
                    waker: cx.waker().clone(),
421
11332
                };
422
11332

            
423
11332
                provider.push(entry);
424
11332
                self.inserted = true;
425
11332
                // Register this sleeper as having been polled.
426
11332
                provider.increment_poll_count();
427
11332
            }
428
            // dbg!(provider.sleepers.len());
429
        }
430
12816
        Poll::Pending
431
13560
    }
432
}
433

            
434
#[cfg(test)]
435
mod test {
436
    #![allow(clippy::unwrap_used)]
437
    use super::*;
438
    use tor_rtcompat::test_with_all_runtimes;
439

            
440
    #[test]
441
    fn basics_of_time_travel() {
442
        let w1 = SystemTime::now();
443
        let sp = MockSleepProvider::new(w1);
444
        let i1 = sp.now();
445
        assert_eq!(sp.wallclock(), w1);
446

            
447
        let interval = Duration::new(4 * 3600 + 13 * 60, 0);
448
        sp.advance_noyield(interval);
449
        assert_eq!(sp.now(), i1 + interval);
450
        assert_eq!(sp.wallclock(), w1 + interval);
451

            
452
        sp.jump_to(w1 + interval * 3);
453
        assert_eq!(sp.now(), i1 + interval);
454
        assert_eq!(sp.wallclock(), w1 + interval * 3);
455
    }
456

            
457
    #[test]
458
    fn time_moves_on() {
459
        test_with_all_runtimes!(|_| async {
460
            use futures::channel::oneshot;
461
            use std::sync::atomic::AtomicBool;
462
            use std::sync::atomic::Ordering;
463

            
464
            let sp = MockSleepProvider::new(SystemTime::now());
465
            let one_hour = Duration::new(3600, 0);
466

            
467
            let (s1, r1) = oneshot::channel();
468
            let (s2, r2) = oneshot::channel();
469
            let (s3, r3) = oneshot::channel();
470

            
471
            let b1 = AtomicBool::new(false);
472
            let b2 = AtomicBool::new(false);
473
            let b3 = AtomicBool::new(false);
474

            
475
            let real_start = Instant::now();
476

            
477
            futures::join!(
478
                async {
479
                    sp.sleep(one_hour).await;
480
                    b1.store(true, Ordering::SeqCst);
481
                    s1.send(()).unwrap();
482
                },
483
                async {
484
                    sp.sleep(one_hour * 3).await;
485
                    b2.store(true, Ordering::SeqCst);
486
                    s2.send(()).unwrap();
487
                },
488
                async {
489
                    sp.sleep(one_hour * 5).await;
490
                    b3.store(true, Ordering::SeqCst);
491
                    s3.send(()).unwrap();
492
                },
493
                async {
494
                    sp.advance(one_hour * 2).await;
495
                    r1.await.unwrap();
496
                    assert!(b1.load(Ordering::SeqCst));
497
                    assert!(!b2.load(Ordering::SeqCst));
498
                    assert!(!b3.load(Ordering::SeqCst));
499

            
500
                    sp.advance(one_hour * 2).await;
501
                    r2.await.unwrap();
502
                    assert!(b1.load(Ordering::SeqCst));
503
                    assert!(b2.load(Ordering::SeqCst));
504
                    assert!(!b3.load(Ordering::SeqCst));
505

            
506
                    sp.advance(one_hour * 2).await;
507
                    r3.await.unwrap();
508
                    assert!(b1.load(Ordering::SeqCst));
509
                    assert!(b2.load(Ordering::SeqCst));
510
                    assert!(b3.load(Ordering::SeqCst));
511
                    let real_end = Instant::now();
512

            
513
                    assert!(real_end - real_start < one_hour);
514
                }
515
            );
516
            std::io::Result::Ok(())
517
        });
518
    }
519
}