1
//! Definitions for [`SleepProviderExt`] and related types.
2

            
3
use crate::traits::SleepProvider;
4
use futures::{Future, FutureExt};
5
use pin_project::pin_project;
6
use std::{
7
    pin::Pin,
8
    task::{Context, Poll},
9
    time::{Duration, SystemTime},
10
};
11

            
12
/// An error value given when a function times out.
13
///
14
/// This value is generated when the timeout from
15
/// [`SleepProviderExt::timeout`] expires before the provided future
16
/// is ready.
17
3
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
18
#[allow(clippy::exhaustive_structs)]
19
pub struct TimeoutError;
20
impl std::error::Error for TimeoutError {}
21
impl std::fmt::Display for TimeoutError {
22
3
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23
3
        write!(f, "Timeout expired")
24
3
    }
25
}
26

            
27
impl From<TimeoutError> for std::io::Error {
28
    fn from(err: TimeoutError) -> std::io::Error {
29
        std::io::Error::new(std::io::ErrorKind::TimedOut, err)
30
    }
31
}
32

            
33
/// An extension trait on [`SleepProvider`] for timeouts and clock delays.
34
pub trait SleepProviderExt: SleepProvider {
35
    /// Wrap a [`Future`] with a timeout.
36
    ///
37
    /// The output of the new future will be the returned value of
38
    /// `future` if it completes within `duration`.  Otherwise, it
39
    /// will be `Err(TimeoutError)`.
40
    ///
41
    /// # Limitations
42
    ///
43
    /// This uses [`SleepProvider::sleep`] for its timer, and is
44
    /// subject to the same limitations.
45
    #[must_use = "timeout() returns a future, which does nothing unless used"]
46
162
    fn timeout<F: Future>(&self, duration: Duration, future: F) -> Timeout<F, Self::SleepFuture> {
47
162
        let sleep_future = self.sleep(duration);
48
162

            
49
162
        Timeout {
50
162
            future,
51
162
            sleep_future,
52
162
        }
53
162
    }
54

            
55
    /// Pause until the wall-clock is at `when` or later, trying to
56
    /// recover from clock jumps.
57
    ///
58
    /// Unlike [`SleepProvider::sleep()`], the future returned by this function will
59
    /// wake up periodically to check the current time, and see if
60
    /// it is at or past the target.
61
    ///
62
    /// # Limitations
63
    ///
64
    /// The ability of this function to detect clock jumps is limited
65
    /// to its granularity; it may finish a while after the declared
66
    /// wallclock time if the system clock jumps forward.
67
    ///
68
    /// This function does not detect backward clock jumps; arguably,
69
    /// we should have another function to do that.
70
    ///
71
    /// This uses [`SleepProvider::sleep`] for its timer, and is
72
    /// subject to the same limitations.
73
    #[must_use = "sleep_until_wallclock() returns a future, which does nothing unless used"]
74
20
    fn sleep_until_wallclock(&self, when: SystemTime) -> SleepUntilWallclock<'_, Self> {
75
20
        SleepUntilWallclock {
76
20
            provider: self,
77
20
            target: when,
78
20
            sleep_future: None,
79
20
        }
80
20
    }
81
}
82

            
83
impl<T: SleepProvider> SleepProviderExt for T {}
84

            
85
/// A timeout returned by [`SleepProviderExt::timeout`].
86
518
#[pin_project]
87
pub struct Timeout<T, S> {
88
    /// The future we want to execute.
89
    #[pin]
90
    future: T,
91
    /// The future implementing the timeout.
92
    #[pin]
93
    sleep_future: S,
94
}
95

            
96
impl<T, S> Future for Timeout<T, S>
97
where
98
    T: Future,
99
    S: Future<Output = ()>,
100
{
101
    type Output = Result<T::Output, TimeoutError>;
102

            
103
518
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104
518
        let this = self.project();
105
518
        if let Poll::Ready(x) = this.future.poll(cx) {
106
125
            return Poll::Ready(Ok(x));
107
393
        }
108
393

            
109
393
        match this.sleep_future.poll(cx) {
110
364
            Poll::Pending => Poll::Pending,
111
29
            Poll::Ready(()) => Poll::Ready(Err(TimeoutError)),
112
        }
113
518
    }
114
}
115

            
116
/// A future implementing [`SleepProviderExt::sleep_until_wallclock`].
117
pub struct SleepUntilWallclock<'a, SP: SleepProvider + ?Sized> {
118
    /// Reference to the provider that we use to make new SleepFutures.
119
    provider: &'a SP,
120
    /// The time that we are waiting for.
121
    target: SystemTime,
122
    /// The future representing our current delay.
123
    sleep_future: Option<Pin<Box<SP::SleepFuture>>>,
124
}
125

            
126
impl<'a, SP> Future for SleepUntilWallclock<'a, SP>
127
where
128
    SP: SleepProvider + ?Sized,
129
{
130
    type Output = ();
131

            
132
2541
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
133
2541
        // Strategy: we implement sleep_until_wallclock by
134
2541
        // waiting in increments of up to MAX_SLEEP, checking the
135
2541
        // wall clock before and after each increment.  This makes
136
2541
        // us wake up a bit more frequently, but enables us to detect it
137
2541
        // if the system clock jumps forward.
138
2541
        let target = self.target;
139
2541
        loop {
140
2541
            let now = self.provider.wallclock();
141
2541
            if now >= target {
142
11
                return Poll::Ready(());
143
2530
            }
144
2530

            
145
2530
            let (last_delay, delay) = calc_next_delay(now, target);
146
2530

            
147
2530
            // Note that we store this future to keep it from being
148
2530
            // cancelled, even though we don't ever poll it more than
149
2530
            // once.
150
2530
            //
151
2530
            // TODO: I'm not sure that it's actually necessary to keep
152
2530
            // this future around.
153
2530
            self.sleep_future.take();
154
2530

            
155
2530
            let mut sleep_future = Box::pin(self.provider.sleep(delay));
156
2530
            match sleep_future.poll_unpin(cx) {
157
                Poll::Pending => {
158
2525
                    self.sleep_future = Some(sleep_future);
159
2525
                    return Poll::Pending;
160
                }
161
                Poll::Ready(()) => {
162
5
                    if last_delay {
163
                        return Poll::Ready(());
164
                    }
165
                }
166
            }
167
        }
168
2536
    }
169
}
170

            
171
/// We never sleep more than this much, in case our system clock jumps.
172
///
173
/// Note that there's a tradeoff here: Making this duration
174
/// shorter helps our accuracy, but makes us wake up more
175
/// frequently and consume more CPU.
176
const MAX_SLEEP: Duration = Duration::from_secs(600);
177

            
178
/// Return the amount of time we should wait next, when running
179
/// sleep_until_wallclock().  Also return a boolean indicating whether we
180
/// expect this to be the final delay.
181
///
182
/// (This is a separate function for testing.)
183
45531
fn calc_next_delay(now: SystemTime, when: SystemTime) -> (bool, Duration) {
184
45531
    let remainder = when
185
45531
        .duration_since(now)
186
45531
        .unwrap_or_else(|_| Duration::from_secs(0));
187
45531
    if remainder > MAX_SLEEP {
188
45326
        (false, MAX_SLEEP)
189
    } else {
190
205
        (true, remainder)
191
    }
192
45531
}
193

            
194
#[cfg(test)]
195
mod test {
196
    #![allow(clippy::erasing_op)]
197
    use super::*;
198

            
199
    #[test]
200
    fn sleep_delay() {
201
        fn calc(now: SystemTime, when: SystemTime) -> Duration {
202
            calc_next_delay(now, when).1
203
        }
204
        let minute = Duration::from_secs(60);
205
        let second = Duration::from_secs(1);
206
        let start = SystemTime::now();
207

            
208
        let target = start + 30 * minute;
209

            
210
        assert_eq!(calc(start, target), minute * 10);
211
        assert_eq!(calc(target + minute, target), minute * 0);
212
        assert_eq!(calc(target, target), minute * 0);
213
        assert_eq!(calc(target - second, target), second);
214
        assert_eq!(calc(target - minute * 9, target), minute * 9);
215
        assert_eq!(calc(target - minute * 11, target), minute * 10);
216
    }
217
}