1
//! Implement Tor's sort-of-Pareto estimator for circuit build timeouts.
2
//!
3
//! Our build times don't truly follow a
4
//! [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution)
5
//! distribution; instead they seem to be closer to a
6
//! [Fréchet](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution)
7
//! distribution.  But those are hard to work with, and we only care
8
//! about the right tail, so we're using Pareto instead.
9
//!
10
//! This estimator also includes several heuristics and kludges to
11
//! try to behave better on unreliable networks.
12
//! For more information on the exact algorithms and their rationales,
13
//! see [`path-spec.txt`](https://gitlab.torproject.org/tpo/core/torspec/-/blob/master/path-spec.txt).
14

            
15
use bounded_vec_deque::BoundedVecDeque;
16
use serde::{Deserialize, Serialize};
17
use std::collections::{BTreeMap, HashMap};
18
use std::convert::TryInto;
19
use std::time::Duration;
20
use tor_netdir::params::NetParameters;
21

            
22
use super::Action;
23
use tor_persist::JsonValue;
24

            
25
/// How many circuit build time observations do we record?
26
const TIME_HISTORY_LEN: usize = 1000;
27

            
28
/// How many circuit success-versus-timeout observations do we record
29
/// by default?
30
const SUCCESS_HISTORY_DEFAULT_LEN: usize = 20;
31

            
32
/// How many milliseconds wide is each bucket in our histogram?
33
const BUCKET_WIDTH_MSEC: u32 = 10;
34

            
35
/// A circuit build time or timeout duration, measured in milliseconds.
36
///
37
/// Requires that we don't care about tracking timeouts above u32::MAX
38
/// milliseconds (about 49 days).
39
25005
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
40
#[serde(transparent)]
41
struct MsecDuration(u32);
42

            
43
impl MsecDuration {
44
    /// Convert a Duration into a MsecDuration, saturating
45
    /// extra-high values to u32::MAX milliseconds.
46
2223
    fn new_saturating(d: &Duration) -> Self {
47
2223
        let msec = std::cmp::min(d.as_millis(), u128::from(u32::MAX)) as u32;
48
2223
        MsecDuration(msec)
49
2223
    }
50
}
51

            
52
/// Module to hold calls to const_assert.
53
///
54
/// This is a separate module so we can change the clippy warnings on it.
55
#[allow(clippy::checked_conversions)]
56
mod assertion {
57
    use static_assertions::const_assert;
58
    // If this assertion is untrue, then we can't safely use u16 fields in
59
    // time_histogram.
60
    const_assert!(super::TIME_HISTORY_LEN <= u16::MAX as usize);
61
}
62

            
63
/// A history of circuit timeout observations, used to estimate our
64
/// likely circuit timeouts.
65
#[derive(Debug, Clone)]
66
struct History {
67
    /// Our most recent observed circuit construction times.
68
    ///
69
    /// For the purpose of this estimator, a circuit counts as
70
    /// "constructed" when a certain "significant" hop (typically the third)
71
    /// is completed.
72
    time_history: BoundedVecDeque<MsecDuration>,
73

            
74
    /// A histogram representation of the values in [`History::time_history`].
75
    ///
76
    /// This histogram is implemented as a sparse map from the center
77
    /// value of each histogram bucket to the number of entries in
78
    /// that bucket.  It is completely derivable from time_history; we
79
    /// keep it separate here for efficiency.
80
    time_histogram: BTreeMap<MsecDuration, u16>,
81

            
82
    /// Our most recent circuit timeout statuses.
83
    ///
84
    /// Each `true` value represents a successfully completed circuit
85
    /// (all hops).  Each `false` value represents a circuit that
86
    /// timed out after having completed at least one hop.
87
    success_history: BoundedVecDeque<bool>,
88
}
89

            
90
impl History {
91
    /// Initialize a new empty `History` with no observations.
92
40
    fn new_empty() -> Self {
93
40
        History {
94
40
            time_history: BoundedVecDeque::new(TIME_HISTORY_LEN),
95
40
            time_histogram: BTreeMap::new(),
96
40
            success_history: BoundedVecDeque::new(SUCCESS_HISTORY_DEFAULT_LEN),
97
40
        }
98
40
    }
99

            
100
    /// Remove all observations from this `History`.
101
3
    fn clear(&mut self) {
102
3
        self.time_history.clear();
103
3
        self.time_histogram.clear();
104
3
        self.success_history.clear();
105
3
    }
106

            
107
    /// Change the number of successes to record in our success
108
    /// history to `n`.
109
4
    fn set_success_history_len(&mut self, n: usize) {
110
4
        if n < self.success_history.len() {
111
1
            self.success_history
112
1
                .drain(0..(self.success_history.len() - n));
113
3
        }
114
4
        self.success_history.set_max_len(n);
115
4
    }
116

            
117
    /// Change the number of circuit time observations to record in
118
    /// our time history to `n`.
119
    ///
120
    /// This is a testing-only function.
121
    #[cfg(test)]
122
1
    fn set_time_history_len(&mut self, n: usize) {
123
1
        self.time_history.set_max_len(n);
124
1
    }
125

            
126
    /// Construct a new `History` from an iterator representing a sparse
127
    /// histogram of values.
128
    ///
129
    /// The input must be a sequence of `(D,N)` tuples, where each `D`
130
    /// represents a circuit build duration, and `N` represents the
131
    /// number of observations with that duration.
132
    ///
133
    /// These observations are shuffled into a random order, then
134
    /// added to a new History.
135
32
    fn from_sparse_histogram<I>(iter: I) -> Self
136
32
    where
137
32
        I: Iterator<Item = (MsecDuration, u16)>,
138
32
    {
139
32
        use rand::seq::{IteratorRandom, SliceRandom};
140
32
        use std::iter;
141
32
        let mut rng = rand::thread_rng();
142
32

            
143
32
        // We want to build a vector with the elements of the old histogram in
144
32
        // random order, but we want to defend ourselves against bogus inputs
145
32
        // that would take too much RAM.
146
32
        let mut observations = iter
147
32
            .take(TIME_HISTORY_LEN) // limit number of bins
148
326
            .flat_map(|(dur, n)| iter::repeat(dur).take(n as usize))
149
32
            .choose_multiple(&mut rng, TIME_HISTORY_LEN);
150
32
        // choose_multiple doesn't guarantee anything about the order of its output.
151
32
        observations.shuffle(&mut rng);
152
32

            
153
32
        let mut result = History::new_empty();
154
2040
        for obs in observations {
155
2008
            result.add_time(obs);
156
2008
        }
157

            
158
32
        result
159
32
    }
160

            
161
    /// Return an iterator yielding a sparse histogram of the circuit build
162
    /// time values in this `History`.
163
    ///
164
    /// Each histogram entry is a `(D,N)` tuple, where `D` is the
165
    /// center of a histogram bucket, and `N` is the number of
166
    /// observations in that bucket.
167
    ///
168
    /// Buckets with `N=0` are omitted.  Buckets are yielded in order.
169
21
    fn sparse_histogram(&self) -> impl Iterator<Item = (MsecDuration, u16)> + '_ {
170
935
        self.time_histogram.iter().map(|(d, n)| (*d, *n))
171
21
    }
172

            
173
    /// Return the center value for the bucket containing `time`.
174
4475
    fn bucket_center(time: MsecDuration) -> MsecDuration {
175
4475
        let idx = time.0 / BUCKET_WIDTH_MSEC;
176
4475
        let msec = (idx * BUCKET_WIDTH_MSEC) + (BUCKET_WIDTH_MSEC) / 2;
177
4475
        MsecDuration(msec)
178
4475
    }
179

            
180
    /// Increment the histogram bucket containing `time` by one.
181
4266
    fn inc_bucket(&mut self, time: MsecDuration) {
182
4266
        let center = History::bucket_center(time);
183
4266
        *self.time_histogram.entry(center).or_insert(0) += 1;
184
4266
    }
185

            
186
    /// Decrement the histogram bucket containing `time` by one, removing
187
    /// it if it becomes 0.
188
205
    fn dec_bucket(&mut self, time: MsecDuration) {
189
205
        use std::collections::btree_map::Entry;
190
205
        let center = History::bucket_center(time);
191
205
        match self.time_histogram.entry(center) {
192
1
            Entry::Vacant(_) => {
193
1
                // this is a bug.
194
1
            }
195
204
            Entry::Occupied(e) if e.get() <= &1 => {
196
2
                e.remove();
197
2
            }
198
202
            Entry::Occupied(mut e) => {
199
202
                *e.get_mut() -= 1;
200
202
            }
201
        }
202
205
    }
203

            
204
    /// Add `time` to our list of circuit build time observations, and
205
    /// adjust the histogram accordingly.
206
4258
    fn add_time(&mut self, time: MsecDuration) {
207
4258
        match self.time_history.push_back(time) {
208
4056
            None => {}
209
202
            Some(removed_time) => {
210
202
                // `removed_time` just fell off the end of the deque:
211
202
                // remove it from the histogram.
212
202
                self.dec_bucket(removed_time);
213
202
            }
214
        }
215
4258
        self.inc_bucket(time);
216
4258
    }
217

            
218
    /// Return the number of observations in our time history.
219
    ///
220
    /// This will always be `<= TIME_HISTORY_LEN`.
221
21
    fn n_times(&self) -> usize {
222
21
        self.time_history.len()
223
21
    }
224

            
225
    /// Record a success (true) or timeout (false) in our record of whether
226
    /// circuits timed out or not.
227
2466
    fn add_success(&mut self, succeeded: bool) {
228
2466
        self.success_history.push_back(succeeded);
229
2466
    }
230

            
231
    /// Return the number of timeouts recorded in our success history.
232
46
    fn n_recent_timeouts(&self) -> usize {
233
589
        self.success_history.iter().filter(|x| !**x).count()
234
46
    }
235

            
236
    /// Helper: return the `n` most frequent histogram bins.
237
15
    fn n_most_frequent_bins(&self, n: usize) -> Vec<(MsecDuration, u16)> {
238
15
        use itertools::Itertools;
239
15
        // we use cmp::Reverse here so that we can use k_smallest as
240
15
        // if it were "k_largest".
241
15
        use std::cmp::Reverse;
242
15

            
243
15
        // We want the buckets that have the _largest_ counts; we want
244
15
        // to break ties in favor of the _smallest_ values.  So we
245
15
        // apply Reverse only to the counts before passing the tuples
246
15
        // to k_smallest.
247
15

            
248
15
        self.sparse_histogram()
249
628
            .map(|(center, count)| (Reverse(count), center))
250
15
            // (k_smallest runs in O(n_bins * lg(n))
251
15
            .k_smallest(n)
252
15
            .into_iter()
253
42
            .map(|(Reverse(count), center)| (center, count))
254
15
            .collect()
255
15
    }
256

            
257
    /// Return an estimator for the `X_m` of our Pareto distribution,
258
    /// by looking at the `n_modes` most frequently filled histogram
259
    /// bins.
260
    ///
261
    /// It is not a true `X_m` value, since there are definitely
262
    /// values less than this, but it seems to work as a decent
263
    /// heuristic.
264
    ///
265
    /// Return `None` if we have no observations.
266
12
    fn estimate_xm(&self, n_modes: usize) -> Option<u32> {
267
12
        // From path-spec:
268
12
        //   Tor clients compute the Xm parameter using the weighted
269
12
        //   average of the the midpoints of the 'cbtnummodes' (10)
270
12
        //   most frequently occurring 10ms histogram bins.
271
12

            
272
12
        // The most frequently used bins.
273
12
        let bins = self.n_most_frequent_bins(n_modes);
274
12
        // Total number of observations in these bins.
275
35
        let n_observations: u16 = bins.iter().map(|(_, n)| n).sum();
276
12
        // Sum of all observations in these bins.
277
12
        let total_observations: u64 = bins
278
12
            .iter()
279
35
            .map(|(d, n)| u64::from(d.0 * u32::from(*n)))
280
12
            .sum();
281
12

            
282
12
        if n_observations == 0 {
283
3
            None
284
        } else {
285
9
            Some((total_observations / u64::from(n_observations)) as u32)
286
        }
287
12
    }
288

            
289
    /// Compute a maximum-likelihood pareto distribution based on this
290
    /// history, computing `X_m` based on the `n_modes` most frequent
291
    /// histograms.
292
    ///
293
    /// Return None if we have no observations.
294
10
    fn pareto_estimate(&self, n_modes: usize) -> Option<ParetoDist> {
295
10
        let xm = self.estimate_xm(n_modes)?;
296

            
297
        // From path-spec:
298
        //     alpha = n/(Sum_n{ln(MAX(Xm, x_i))} - n*ln(Xm))
299

            
300
8
        let n = self.time_history.len();
301
8
        let sum_of_log_observations: f64 = self
302
8
            .time_history
303
8
            .iter()
304
5030
            .map(|m| f64::from(std::cmp::max(m.0, xm)).ln())
305
8
            .sum();
306
8
        let sum_of_log_xm = (n as f64) * f64::from(xm).ln();
307
8

            
308
8
        // We're computing 1/alpha here, instead of alpha.  This avoids
309
8
        // division by zero, and has the advantage of being what our
310
8
        // quantile estimator actually needs.
311
8
        let inv_alpha = (sum_of_log_observations - sum_of_log_xm) / (n as f64);
312
8

            
313
8
        Some(ParetoDist {
314
8
            x_m: f64::from(xm),
315
8
            inv_alpha,
316
8
        })
317
10
    }
318
}
319

            
320
/// A Pareto distribution, for use in estimating timeouts.
321
///
322
/// Values are represented by a number of milliseconds.
323
#[derive(Debug)]
324
struct ParetoDist {
325
    /// The lower bound for the pareto distribution.
326
    x_m: f64,
327
    /// The inverse of the alpha parameter in the pareto distribution.
328
    ///
329
    /// (We use 1/alpha here to save a step in [`ParetoDist::quantile`].
330
    inv_alpha: f64,
331
}
332

            
333
impl ParetoDist {
334
    /// Compute an inverse CDF for this distribution.
335
    ///
336
    /// Given a `q` value between 0 and 1, compute a distribution `v`
337
    /// value such that `q` of the Pareto Distribution is expected to
338
    /// be less than `v`.
339
    ///
340
    /// If `q` is out of bounds, it is clamped to [0.0, 1.0].
341
16
    fn quantile(&self, q: f64) -> f64 {
342
16
        let q = q.clamp(0.0, 1.0);
343
16
        self.x_m / ((1.0 - q).powf(self.inv_alpha))
344
16
    }
345
}
346

            
347
/// A set of parameters determining the behavior of a ParetoTimeoutEstimator.
348
///
349
/// These are typically derived from a set of consensus parameters.
350
2
#[derive(Clone, Debug)]
351
pub(crate) struct Params {
352
    /// Should we use our estimates when deciding on circuit timeouts.
353
    ///
354
    /// When this is false, our timeouts are fixed to the default.
355
    use_estimates: bool,
356
    /// How many observations must we have made before we can use our
357
    /// Pareto estimators to guess a good set of timeouts?
358
    min_observations: u16,
359
    /// Which hop is the "significant hop" we should use when recording circuit
360
    /// build times?  (Watch out! This is zero-indexed.)
361
    significant_hop: u8,
362
    /// A quantile (in range [0.0,1.0]) describing a point in the
363
    /// Pareto distribution to use when determining when a circuit
364
    /// should be treated as having "timed out".
365
    ///
366
    /// (A "timed out" circuit continues building for measurement
367
    /// purposes, but can't be used for traffic.)
368
    timeout_quantile: f64,
369
    /// A quantile (in range [0.0,1.0]) describing a point in the Pareto
370
    /// distribution to use when determining when a circuit should be
371
    /// "abandoned".
372
    ///
373
    /// (An "abandoned" circuit is stopped entirely, and not included
374
    /// in measurements.
375
    abandon_quantile: f64,
376
    /// Default values to return from the `timeouts` function when we
377
    /// have no observations.
378
    default_thresholds: (Duration, Duration),
379
    /// Number of histogram buckets to use when determining the Xm estimate.
380
    ///
381
    /// (See [`History::estimate_xm`] for details.)
382
    n_modes_for_xm: usize,
383
    /// How many entries do we record in our success/timeout history?
384
    success_history_len: usize,
385
    /// How many timeouts should we allow in our success/timeout history
386
    /// before we assume that network has changed in a way that makes
387
    /// our estimates completely wrong?
388
    reset_after_timeouts: usize,
389
    /// Minimum base timeout to ever infer or return.
390
    min_timeout: Duration,
391
}
392

            
393
impl Default for Params {
394
35
    fn default() -> Self {
395
35
        Params {
396
35
            use_estimates: true,
397
35
            min_observations: 100,
398
35
            significant_hop: 2,
399
35
            timeout_quantile: 0.80,
400
35
            abandon_quantile: 0.99,
401
35
            default_thresholds: (Duration::from_secs(60), Duration::from_secs(60)),
402
35
            n_modes_for_xm: 10,
403
35
            success_history_len: SUCCESS_HISTORY_DEFAULT_LEN,
404
35
            reset_after_timeouts: 18,
405
35
            min_timeout: Duration::from_millis(10),
406
35
        }
407
35
    }
408
}
409

            
410
impl From<&NetParameters> for Params {
411
3
    fn from(p: &NetParameters) -> Params {
412
3
        // Because of the underlying bounds, the "unwrap_or_else"
413
3
        // conversions here should be impossible, and the "as"
414
3
        // conversions should always be in-range.
415
3

            
416
3
        let timeout = p
417
3
            .cbt_initial_timeout
418
3
            .try_into()
419
3
            .unwrap_or_else(|_| Duration::from_secs(60));
420
3
        let learning_disabled: bool = p.cbt_learning_disabled.into();
421
3
        Params {
422
3
            use_estimates: !learning_disabled,
423
3
            min_observations: p.cbt_min_circs_for_estimate.get() as u16,
424
3
            significant_hop: 2,
425
3
            timeout_quantile: p.cbt_timeout_quantile.as_fraction(),
426
3
            abandon_quantile: p.cbt_abandon_quantile.as_fraction(),
427
3
            default_thresholds: (timeout, timeout),
428
3
            n_modes_for_xm: p.cbt_num_xm_modes.get() as usize,
429
3
            success_history_len: p.cbt_success_count.get() as usize,
430
3
            reset_after_timeouts: p.cbt_max_timeouts.get() as usize,
431
3
            min_timeout: p
432
3
                .cbt_min_timeout
433
3
                .try_into()
434
3
                .unwrap_or_else(|_| Duration::from_millis(10)),
435
3
        }
436
3
    }
437
}
438

            
439
/// Tor's default circuit build timeout estimator.
440
///
441
/// This object records a set of observed circuit build times, and
442
/// uses it to determine good values for how long we should allow
443
/// circuits to build.
444
///
445
/// For full details of the algorithms used, see
446
/// [`path-spec.txt`](https://gitlab.torproject.org/tpo/core/torspec/-/blob/master/path-spec.txt).
447
pub(crate) struct ParetoTimeoutEstimator {
448
    /// Our observations for circuit build times and success/failure
449
    /// history.
450
    history: History,
451

            
452
    /// Our most recent timeout estimate, if we have one that is
453
    /// up-to-date.
454
    ///
455
    /// (We reset this to None whenever we get a new observation.)
456
    timeouts: Option<(Duration, Duration)>,
457

            
458
    /// The timeouts that we use when we do not have sufficient observations
459
    /// to conclude anything about our circuit build times.
460
    ///
461
    /// These start out as `p.default_thresholds`, but can be adjusted
462
    /// depending on how many timeouts we've been seeing.
463
    fallback_timeouts: (Duration, Duration),
464

            
465
    /// A set of parameters to use in computing circuit build timeout
466
    /// estimates.
467
    p: Params,
468
}
469

            
470
impl Default for ParetoTimeoutEstimator {
471
3
    fn default() -> Self {
472
3
        Self::from_history(History::new_empty())
473
3
    }
474
}
475

            
476
/// An object used to serialize our timeout history for persistent state.
477
77
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
478
#[serde(default)]
479
pub(crate) struct ParetoTimeoutState {
480
    /// A version field used to help encoding and decoding.
481
    #[allow(dead_code)]
482
    version: usize,
483
    /// A record of observed timeouts, as returned by `sparse_histogram()`.
484
    histogram: Vec<(MsecDuration, u16)>,
485
    /// The current timeout estimate: kept for reference.
486
    current_timeout: Option<MsecDuration>,
487

            
488
    /// Fields from the state file that was used to make this `ParetoTimeoutState` that
489
    /// this version of Arti doesn't understand.
490
    #[serde(flatten)]
491
    unknown_fields: HashMap<String, JsonValue>,
492
}
493

            
494
impl ParetoTimeoutState {
495
    /// Return the latest base timeout estimate, as recorded in this state.
496
30
    pub(crate) fn latest_estimate(&self) -> Option<Duration> {
497
30
        self.current_timeout
498
30
            .map(|m| Duration::from_millis(m.0.into()))
499
30
    }
500
}
501

            
502
impl ParetoTimeoutEstimator {
503
    /// Construct a new ParetoTimeoutEstimator from the provided history
504
    /// object.
505
34
    fn from_history(history: History) -> Self {
506
34
        let p = Params::default();
507
34
        ParetoTimeoutEstimator {
508
34
            history,
509
34
            timeouts: None,
510
34
            fallback_timeouts: p.default_thresholds,
511
34
            p,
512
34
        }
513
34
    }
514

            
515
    /// Create a new ParetoTimeoutEstimator based on a loaded
516
    /// ParetoTimeoutState.
517
31
    pub(crate) fn from_state(state: ParetoTimeoutState) -> Self {
518
31
        let history = History::from_sparse_histogram(state.histogram.into_iter());
519
31
        Self::from_history(history)
520
31
    }
521

            
522
    /// Compute an unscaled basic pair of timeouts for a circuit of
523
    /// the "normal" length.
524
    ///
525
    /// Return a cached value if we have no observations since the
526
    /// last time this function was called.
527
    fn base_timeouts(&mut self) -> (Duration, Duration) {
528
22
        if let Some(x) = self.timeouts {
529
            // Great; we have a cached value.
530
7
            return x;
531
15
        }
532
15

            
533
15
        if self.history.n_times() < self.p.min_observations as usize {
534
            // We don't have enough values to estimate.
535
7
            return self.fallback_timeouts;
536
8
        }
537

            
538
        // Here we're going to compute the timeouts, cache them, and
539
        // return them.
540
8
        let dist = match self.history.pareto_estimate(self.p.n_modes_for_xm) {
541
7
            Some(dist) => dist,
542
            None => {
543
1
                return self.fallback_timeouts;
544
            }
545
        };
546
7
        let timeout_threshold = dist.quantile(self.p.timeout_quantile);
547
7
        let abandon_threshold = dist
548
7
            .quantile(self.p.abandon_quantile)
549
7
            .max(timeout_threshold);
550
7

            
551
7
        let timeouts = (
552
7
            Duration::from_secs_f64(timeout_threshold / 1000.0).max(self.p.min_timeout),
553
7
            Duration::from_secs_f64(abandon_threshold / 1000.0).max(self.p.min_timeout),
554
7
        );
555
7
        self.timeouts = Some(timeouts);
556
7

            
557
7
        timeouts
558
22
    }
559
}
560

            
561
impl super::TimeoutEstimator for ParetoTimeoutEstimator {
562
2
    fn update_params(&mut self, p: &NetParameters) {
563
2
        let parameters = p.into();
564
2
        self.p = parameters;
565
2
        let new_success_len = self.p.success_history_len;
566
2
        self.history.set_success_history_len(new_success_len);
567
2
    }
568

            
569
2220
    fn note_hop_completed(&mut self, hop: u8, delay: Duration, is_last: bool) {
570
2220
        if hop == self.p.significant_hop {
571
2220
            let time = MsecDuration::new_saturating(&delay);
572
2220
            self.history.add_time(time);
573
2220
            self.timeouts.take();
574
2220
        }
575
2220
        if is_last {
576
2220
            self.history.add_success(true);
577
2220
        }
578
2220
    }
579

            
580
39
    fn note_circ_timeout(&mut self, hop: u8, delay: Duration) {
581
        // Only record this timeout if we have seen some network activity since
582
        // we launched the circuit.
583
39
        let have_seen_recent_activity =
584
39
            if let Some(last_traffic) = tor_proto::time_since_last_incoming_traffic() {
585
                last_traffic < delay
586
            } else {
587
                // TODO: Is this the correct behavior in this case?
588
39
                true
589
            };
590

            
591
39
        if hop > 0 && have_seen_recent_activity {
592
39
            self.history.add_success(false);
593
39
            if self.history.n_recent_timeouts() > self.p.reset_after_timeouts {
594
2
                let base_timeouts = self.base_timeouts();
595
2
                self.history.clear();
596
2
                self.timeouts.take();
597
2
                // If we already had a timeout that was at least the
598
2
                // length of our fallback timeouts, we should double
599
2
                // those fallback timeouts.
600
2
                if base_timeouts.0 >= self.fallback_timeouts.0 {
601
1
                    self.fallback_timeouts.0 *= 2;
602
1
                    self.fallback_timeouts.1 *= 2;
603
1
                }
604
37
            }
605
        }
606
39
    }
607

            
608
17
    fn timeouts(&mut self, action: &Action) -> (Duration, Duration) {
609
17
        let (base_t, base_a) = if self.p.use_estimates {
610
17
            self.base_timeouts()
611
        } else {
612
            // If we aren't using this estimator, then just return the
613
            // default thresholds from our parameters.
614
            return self.p.default_thresholds;
615
        };
616

            
617
17
        let reference_action = Action::BuildCircuit {
618
17
            length: self.p.significant_hop as usize + 1,
619
17
        };
620
17
        debug_assert!(reference_action.timeout_scale() > 0);
621

            
622
17
        let multiplier =
623
17
            (action.timeout_scale() as f64) / (reference_action.timeout_scale() as f64);
624
17

            
625
17
        // TODO-SPEC The spec doesn't define any of this
626
17
        // action-based-multiplier stuff.  Tor doesn't multiply the
627
17
        // abandon timeout.
628
17
        use super::mul_duration_f64_saturating as mul;
629
17
        (mul(base_t, multiplier), mul(base_a, multiplier))
630
17
    }
631

            
632
4
    fn learning_timeouts(&self) -> bool {
633
4
        self.p.use_estimates && self.history.n_times() < self.p.min_observations.into()
634
4
    }
635

            
636
3
    fn build_state(&mut self) -> Option<ParetoTimeoutState> {
637
3
        let cur_timeout = MsecDuration::new_saturating(&self.base_timeouts().0);
638
3
        Some(ParetoTimeoutState {
639
3
            version: 1,
640
3
            histogram: self.history.sparse_histogram().collect(),
641
3
            current_timeout: Some(cur_timeout),
642
3
            unknown_fields: Default::default(),
643
3
        })
644
3
    }
645
}
646

            
647
#[cfg(test)]
648
mod test {
649
    #![allow(clippy::unwrap_used)]
650
    use super::*;
651
    use crate::timeouts::TimeoutEstimator;
652

            
653
    /// Return an action to build a 3-hop circuit.
654
    fn b3() -> Action {
655
        Action::BuildCircuit { length: 3 }
656
    }
657

            
658
    impl From<u32> for MsecDuration {
659
        fn from(v: u32) -> Self {
660
            Self(v)
661
        }
662
    }
663

            
664
    #[test]
665
    fn ms_partial_cmp() {
666
        #![allow(clippy::eq_op)]
667
        let myriad: MsecDuration = 10_000.into();
668
        let lakh: MsecDuration = 100_000.into();
669
        let crore: MsecDuration = 10_000_000.into();
670

            
671
        assert!(myriad < lakh);
672
        assert!(myriad == myriad);
673
        assert!(crore > lakh);
674
        assert!(crore >= crore);
675
        assert!(crore <= crore);
676
    }
677

            
678
    #[test]
679
    fn history_lowlev() {
680
        assert_eq!(History::bucket_center(1.into()), 5.into());
681
        assert_eq!(History::bucket_center(903.into()), 905.into());
682
        assert_eq!(History::bucket_center(0.into()), 5.into());
683
        assert_eq!(History::bucket_center(u32::MAX.into()), 4294967295.into());
684

            
685
        let mut h = History::new_empty();
686
        h.inc_bucket(7.into());
687
        h.inc_bucket(8.into());
688
        h.inc_bucket(9.into());
689
        h.inc_bucket(10.into());
690
        h.inc_bucket(11.into());
691
        h.inc_bucket(12.into());
692
        h.inc_bucket(13.into());
693
        h.inc_bucket(299.into());
694
        assert_eq!(h.time_histogram.get(&5.into()), Some(&3));
695
        assert_eq!(h.time_histogram.get(&15.into()), Some(&4));
696
        assert_eq!(h.time_histogram.get(&25.into()), None);
697
        assert_eq!(h.time_histogram.get(&295.into()), Some(&1));
698

            
699
        h.dec_bucket(299.into());
700
        h.dec_bucket(24.into());
701
        h.dec_bucket(12.into());
702

            
703
        assert_eq!(h.time_histogram.get(&15.into()), Some(&3));
704
        assert_eq!(h.time_histogram.get(&25.into()), None);
705
        assert_eq!(h.time_histogram.get(&295.into()), None);
706

            
707
        h.add_success(true);
708
        h.add_success(false);
709
        assert_eq!(h.success_history.len(), 2);
710

            
711
        h.clear();
712
        assert_eq!(h.time_histogram.len(), 0);
713
        assert_eq!(h.time_history.len(), 0);
714
        assert_eq!(h.success_history.len(), 0);
715
    }
716

            
717
    #[test]
718
    fn time_observation_management() {
719
        let mut h = History::new_empty();
720
        h.set_time_history_len(8); // to make it easier to overflow.
721

            
722
        h.add_time(300.into());
723
        h.add_time(500.into());
724
        h.add_time(542.into());
725
        h.add_time(305.into());
726
        h.add_time(543.into());
727
        h.add_time(307.into());
728

            
729
        assert_eq!(h.n_times(), 6);
730
        let v = h.n_most_frequent_bins(10);
731
        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2), (505.into(), 1)]);
732
        let v = h.n_most_frequent_bins(2);
733
        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
734

            
735
        let v: Vec<_> = h.sparse_histogram().collect();
736
        assert_eq!(&v[..], [(305.into(), 3), (505.into(), 1), (545.into(), 2)]);
737

            
738
        h.add_time(212.into());
739
        h.add_time(203.into());
740
        // now we replace the first couple of older elements.
741
        h.add_time(617.into());
742
        h.add_time(413.into());
743

            
744
        assert_eq!(h.n_times(), 8);
745

            
746
        let v: Vec<_> = h.sparse_histogram().collect();
747
        assert_eq!(
748
            &v[..],
749
            [
750
                (205.into(), 1),
751
                (215.into(), 1),
752
                (305.into(), 2),
753
                (415.into(), 1),
754
                (545.into(), 2),
755
                (615.into(), 1)
756
            ]
757
        );
758

            
759
        let h2 = History::from_sparse_histogram(v.clone().into_iter());
760
        let v2: Vec<_> = h2.sparse_histogram().collect();
761
        assert_eq!(v, v2);
762
    }
763

            
764
    #[test]
765
    fn success_observation_mechanism() {
766
        let mut h = History::new_empty();
767
        h.set_success_history_len(20);
768

            
769
        assert_eq!(h.n_recent_timeouts(), 0);
770
        h.add_success(true);
771
        assert_eq!(h.n_recent_timeouts(), 0);
772
        h.add_success(false);
773
        assert_eq!(h.n_recent_timeouts(), 1);
774
        for _ in 0..200 {
775
            h.add_success(false);
776
        }
777
        assert_eq!(h.n_recent_timeouts(), 20);
778
        h.add_success(true);
779
        h.add_success(true);
780
        h.add_success(true);
781
        assert_eq!(h.n_recent_timeouts(), 20 - 3);
782

            
783
        h.set_success_history_len(10);
784
        assert_eq!(h.n_recent_timeouts(), 10 - 3);
785
    }
786

            
787
    #[test]
788
    fn xm_calculation() {
789
        let mut h = History::new_empty();
790
        assert_eq!(h.estimate_xm(2), None);
791

            
792
        for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
793
            h.add_time(MsecDuration(*n));
794
        }
795

            
796
        let v = h.n_most_frequent_bins(2);
797
        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
798
        let est = (305 * 3 + 545 * 2) / 5;
799
        assert_eq!(h.estimate_xm(2), Some(est));
800
        assert_eq!(est, 401);
801
    }
802

            
803
    #[test]
804
    fn pareto_estimate() {
805
        let mut h = History::new_empty();
806
        assert!(h.pareto_estimate(2).is_none());
807

            
808
        for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
809
            h.add_time(MsecDuration(*n));
810
        }
811
        let expected_log_sum: f64 = [401, 500, 542, 401, 543, 401, 401, 401, 617, 413]
812
            .iter()
813
            .map(|x| f64::from(*x).ln())
814
            .sum();
815
        let expected_log_xm: f64 = (401_f64).ln() * 10.0;
816
        let expected_alpha = 10.0 / (expected_log_sum - expected_log_xm);
817
        let expected_inv_alpha = 1.0 / expected_alpha;
818

            
819
        let p = h.pareto_estimate(2).unwrap();
820

            
821
        // We can't do "eq" with floats, so we'll do "very close".
822
        assert!((401.0 - p.x_m).abs() < 1.0e-9);
823
        assert!((expected_inv_alpha - p.inv_alpha).abs() < 1.0e-9);
824

            
825
        let q60 = p.quantile(0.60);
826
        let q99 = p.quantile(0.99);
827

            
828
        assert!((q60 - 451.127) < 0.001);
829
        assert!((q99 - 724.841) < 0.001);
830
    }
831

            
832
    #[test]
833
    fn pareto_estimate_timeout() {
834
        let mut est = ParetoTimeoutEstimator::default();
835

            
836
        assert_eq!(
837
            est.timeouts(&b3()),
838
            (Duration::from_secs(60), Duration::from_secs(60))
839
        );
840
        // Set the parameters up to mimic the situation in
841
        // `pareto_estimate` above.
842
        est.p.min_observations = 0;
843
        est.p.n_modes_for_xm = 2;
844
        assert_eq!(
845
            est.timeouts(&b3()),
846
            (Duration::from_secs(60), Duration::from_secs(60))
847
        );
848

            
849
        for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
850
            let d = Duration::from_millis(*msec);
851
            est.note_hop_completed(2, d, true);
852
        }
853

            
854
        let t = est.timeouts(&b3());
855
        assert_eq!(t.0.as_micros(), 493_169);
856
        assert_eq!(t.1.as_micros(), 724_841);
857

            
858
        let t2 = est.timeouts(&b3());
859
        assert_eq!(t2, t);
860

            
861
        let t2 = est.timeouts(&Action::BuildCircuit { length: 4 });
862
        assert_eq!(t2.0, t.0.mul_f64(10.0 / 6.0));
863
        assert_eq!(t2.1, t.1.mul_f64(10.0 / 6.0));
864
    }
865

            
866
    #[test]
867
    fn pareto_estimate_clear() {
868
        let mut est = ParetoTimeoutEstimator::default();
869

            
870
        // Set the parameters up to mimic the situation in
871
        // `pareto_estimate` above.
872
        let params = NetParameters::from_map(&"cbtmincircs=1 cbtnummodes=2".parse().unwrap());
873
        est.update_params(&params);
874

            
875
        assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
876
        assert!(est.learning_timeouts());
877

            
878
        for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
879
            let d = Duration::from_millis(*msec);
880
            est.note_hop_completed(2, d, true);
881
        }
882
        assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
883
        assert!(!est.learning_timeouts());
884
        assert_eq!(est.history.n_recent_timeouts(), 0);
885

            
886
        // 17 timeouts happen and we're still getting real numbers...
887
        for _ in 0..18 {
888
            est.note_circ_timeout(2, Duration::from_secs(2000));
889
        }
890
        assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
891

            
892
        // ... but 18 means "reset".
893
        est.note_circ_timeout(2, Duration::from_secs(2000));
894
        assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
895

            
896
        // And if we fail 18 bunch more times, it doubles.
897
        for _ in 0..20 {
898
            est.note_circ_timeout(2, Duration::from_secs(2000));
899
        }
900
        assert_eq!(est.timeouts(&b3()).0.as_micros(), 120_000_000);
901
    }
902

            
903
    #[test]
904
    fn default_params() {
905
        let p1 = Params::default();
906
        let p2 = Params::from(&tor_netdir::params::NetParameters::default());
907
        // discount version of derive(eq)
908
        assert_eq!(format!("{:?}", p1), format!("{:?}", p2));
909
    }
910

            
911
    #[test]
912
    fn state_conversion() {
913
        // We have tests elsewhere for converting to and from
914
        // histograms, so all we really need to ddo here is make sure
915
        // that the histogram conversion happens.
916

            
917
        use rand::Rng;
918
        let mut est = ParetoTimeoutEstimator::default();
919
        let mut rng = rand::thread_rng();
920
        for _ in 0..1000 {
921
            let d = Duration::from_millis(rng.gen_range(10..3_000));
922
            est.note_hop_completed(2, d, true);
923
        }
924

            
925
        let state = est.build_state().unwrap();
926
        assert_eq!(state.version, 1);
927
        assert!(state.current_timeout.is_some());
928

            
929
        let mut est2 = ParetoTimeoutEstimator::from_state(state);
930
        let act = Action::BuildCircuit { length: 3 };
931
        // This isn't going to be exact, since we're recording histogram bins
932
        // instead of exact timeouts.
933
        let ms1 = est.timeouts(&act).0.as_millis() as i32;
934
        let ms2 = est2.timeouts(&act).0.as_millis() as i32;
935
        assert!((ms1 - ms2).abs() < 50);
936
    }
937

            
938
    // TODO: add tests from Tor.
939
}