1
//! Simple implementation for the internal map state of a ChanMgr.
2

            
3
use std::time::Duration;
4

            
5
use super::{AbstractChannel, Pending};
6
use crate::{Error, Result};
7

            
8
use std::collections::{hash_map, HashMap};
9
use tor_error::internal;
10

            
11
/// A map from channel id to channel state.
12
///
13
/// We make this a separate type instead of just using
14
/// `Mutex<HashMap<...>>` to limit the amount of code that can see and
15
/// lock the Mutex here.  (We're using a blocking mutex close to async
16
/// code, so we need to be careful.)
17
pub(crate) struct ChannelMap<C: AbstractChannel> {
18
    /// A map from identity to channel, or to pending channel status.
19
    ///
20
    /// (Danger: this uses a blocking mutex close to async code.  This mutex
21
    /// must never be held while an await is happening.)
22
    channels: std::sync::Mutex<HashMap<C::Ident, ChannelState<C>>>,
23
}
24

            
25
/// Structure that can only be constructed from within this module.
26
/// Used to make sure that only we can construct ChannelState::Poisoned.
27
pub(crate) struct Priv {
28
    /// (This field is private)
29
    _unused: (),
30
}
31

            
32
/// The state of a channel (or channel build attempt) within a map.
33
pub(crate) enum ChannelState<C> {
34
    /// An open channel.
35
    ///
36
    /// This channel might not be usable: it might be closing or
37
    /// broken.  We need to check its is_usable() method before
38
    /// yielding it to the user.
39
    Open(OpenEntry<C>),
40
    /// A channel that's getting built.
41
    Building(Pending<C>),
42
    /// A temporary invalid state.
43
    ///
44
    /// We insert this into the map temporarily as a placeholder in
45
    /// `change_state()`.
46
    Poisoned(Priv),
47
}
48

            
49
/// An open channel entry.
50
15
#[derive(Clone)]
51
pub(crate) struct OpenEntry<C> {
52
    /// The underlying open channel.
53
    pub(crate) channel: C,
54
    /// The maximum unused duration allowed for this channel.
55
    pub(crate) max_unused_duration: Duration,
56
}
57

            
58
impl<C: Clone> ChannelState<C> {
59
    /// Create a new shallow copy of this ChannelState.
60
    #[cfg(test)]
61
12
    fn clone_ref(&self) -> Result<Self> {
62
12
        use ChannelState::*;
63
12
        match self {
64
11
            Open(ent) => Ok(Open(ent.clone())),
65
            Building(pending) => Ok(Building(pending.clone())),
66
1
            Poisoned(_) => Err(Error::Internal(internal!("Poisoned state in channel map"))),
67
        }
68
12
    }
69

            
70
    /// For testing: either give the Open channel inside this state,
71
    /// or panic if there is none.
72
    #[cfg(test)]
73
4
    fn unwrap_open(&self) -> C {
74
4
        match self {
75
4
            ChannelState::Open(ent) => ent.clone().channel,
76
            _ => panic!("Not an open channel"),
77
        }
78
4
    }
79
}
80

            
81
impl<C: AbstractChannel> ChannelState<C> {
82
    /// Return an error if `ident`is definitely not a matching
83
    /// matching identity for this state.
84
43
    fn check_ident(&self, ident: &C::Ident) -> Result<()> {
85
43
        match self {
86
28
            ChannelState::Open(ent) => {
87
28
                if ent.channel.ident() == ident {
88
26
                    Ok(())
89
                } else {
90
2
                    Err(Error::Internal(internal!("Identity mismatch")))
91
                }
92
            }
93
            ChannelState::Poisoned(_) => {
94
                Err(Error::Internal(internal!("Poisoned state in channel map")))
95
            }
96
15
            ChannelState::Building(_) => Ok(()),
97
        }
98
43
    }
99

            
100
    /// Return true if a channel is ready to expire.
101
    /// Update `expire_after` if a smaller duration than
102
    /// the given value is required to expire this channel.
103
    fn ready_to_expire(&self, expire_after: &mut Duration) -> bool {
104
5
        if let ChannelState::Open(ent) = self {
105
5
            let unused_duration = ent.channel.duration_unused();
106
5
            if let Some(unused_duration) = unused_duration {
107
4
                let max_unused_duration = ent.max_unused_duration;
108

            
109
4
                if let Some(remaining) = max_unused_duration.checked_sub(unused_duration) {
110
2
                    *expire_after = std::cmp::min(*expire_after, remaining);
111
2
                    false
112
                } else {
113
2
                    true
114
                }
115
            } else {
116
                // still in use
117
1
                false
118
            }
119
        } else {
120
            false
121
        }
122
5
    }
123
}
124

            
125
impl<C: AbstractChannel> ChannelMap<C> {
126
    /// Create a new empty ChannelMap.
127
27
    pub(crate) fn new() -> Self {
128
27
        ChannelMap {
129
27
            channels: std::sync::Mutex::new(HashMap::new()),
130
27
        }
131
27
    }
132

            
133
    /// Return the channel state for the given identity, if any.
134
    #[cfg(test)]
135
22
    pub(crate) fn get(&self, ident: &C::Ident) -> Result<Option<ChannelState<C>>> {
136
22
        let map = self.channels.lock()?;
137
22
        map.get(ident).map(ChannelState::clone_ref).transpose()
138
22
    }
139

            
140
    /// Replace the channel state for `ident` with `newval`, and return the
141
    /// previous value if any.
142
    pub(crate) fn replace(
143
        &self,
144
        ident: C::Ident,
145
        newval: ChannelState<C>,
146
    ) -> Result<Option<ChannelState<C>>> {
147
23
        newval.check_ident(&ident)?;
148
23
        let mut map = self.channels.lock()?;
149
23
        Ok(map.insert(ident, newval))
150
23
    }
151

            
152
    /// Remove and return the state for `ident`, if any.
153
6
    pub(crate) fn remove(&self, ident: &C::Ident) -> Result<Option<ChannelState<C>>> {
154
6
        let mut map = self.channels.lock()?;
155
6
        Ok(map.remove(ident))
156
6
    }
157

            
158
    /// Remove every unusable state from the map.
159
    #[cfg(test)]
160
2
    pub(crate) fn remove_unusable(&self) -> Result<()> {
161
2
        let mut map = self.channels.lock()?;
162
7
        map.retain(|_, state| match state {
163
            ChannelState::Poisoned(_) => false,
164
7
            ChannelState::Open(ent) => ent.channel.is_usable(),
165
            ChannelState::Building(_) => true,
166
7
        });
167
2
        Ok(())
168
2
    }
169

            
170
    /// Replace the state whose identity is `ident` with a new state.
171
    ///
172
    /// The provided function `func` is invoked on the old state (if
173
    /// any), and must return a tuple containing an optional new
174
    /// state, and an arbitrary return value for this function.
175
    ///
176
    /// Because `func` is run while holding the lock on this object,
177
    /// it should be fast and nonblocking.  In return, you can be sure
178
    /// that it's running atomically with respect to other accessors
179
    /// of this map.
180
    ///
181
    /// If `func` panics, or if it returns a channel with a different
182
    /// identity, this position in the map will be become unusable and
183
    /// future accesses to that position may fail.
184
22
    pub(crate) fn change_state<F, V>(&self, ident: &C::Ident, func: F) -> Result<V>
185
22
    where
186
22
        F: FnOnce(Option<ChannelState<C>>) -> (Option<ChannelState<C>>, V),
187
22
    {
188
        use hash_map::Entry::*;
189
22
        let mut map = self.channels.lock()?;
190
22
        let entry = map.entry(ident.clone());
191
22
        match entry {
192
9
            Occupied(mut occupied) => {
193
9
                // Temporarily replace the entry for this identity with
194
9
                // a poisoned entry.
195
9
                let mut oldent = ChannelState::Poisoned(Priv { _unused: () });
196
9
                std::mem::swap(occupied.get_mut(), &mut oldent);
197
9
                let (newval, output) = func(Some(oldent));
198
9
                match newval {
199
8
                    Some(mut newent) => {
200
8
                        newent.check_ident(ident)?;
201
7
                        std::mem::swap(occupied.get_mut(), &mut newent);
202
                    }
203
1
                    None => {
204
1
                        occupied.remove();
205
1
                    }
206
                };
207
8
                Ok(output)
208
            }
209
13
            Vacant(vacant) => {
210
13
                let (newval, output) = func(None);
211
13
                if let Some(newent) = newval {
212
12
                    newent.check_ident(ident)?;
213
11
                    vacant.insert(newent);
214
1
                }
215
12
                Ok(output)
216
            }
217
        }
218
22
    }
219

            
220
    /// Expire all channels that have been unused for too long.
221
    ///
222
    /// Return a Duration until the next time at which
223
    /// a channel _could_ expire.
224
4
    pub(crate) fn expire_channels(&self) -> Duration {
225
4
        let mut ret = Duration::from_secs(180);
226
4
        self.channels
227
4
            .lock()
228
4
            .expect("Poisoned lock")
229
7
            .retain(|_id, chan| !chan.ready_to_expire(&mut ret));
230
4
        ret
231
4
    }
232
}
233

            
234
#[cfg(test)]
235
mod test {
236
    #![allow(clippy::unwrap_used)]
237
    use super::*;
238
    #[derive(Eq, PartialEq, Clone, Debug)]
239
    struct FakeChannel {
240
        ident: &'static str,
241
        usable: bool,
242
        unused_duration: Option<u64>,
243
    }
244
    impl AbstractChannel for FakeChannel {
245
        type Ident = u8;
246
        fn ident(&self) -> &Self::Ident {
247
            &self.ident.as_bytes()[0]
248
        }
249
        fn is_usable(&self) -> bool {
250
            self.usable
251
        }
252
        fn duration_unused(&self) -> Option<Duration> {
253
            self.unused_duration.map(Duration::from_secs)
254
        }
255
    }
256
    fn ch(ident: &'static str) -> ChannelState<FakeChannel> {
257
        let channel = FakeChannel {
258
            ident,
259
            usable: true,
260
            unused_duration: None,
261
        };
262
        ChannelState::Open(OpenEntry {
263
            channel,
264
            max_unused_duration: Duration::from_secs(180),
265
        })
266
    }
267
    fn ch_with_details(
268
        ident: &'static str,
269
        max_unused_duration: Duration,
270
        unused_duration: Option<u64>,
271
    ) -> ChannelState<FakeChannel> {
272
        let channel = FakeChannel {
273
            ident,
274
            usable: true,
275
            unused_duration,
276
        };
277
        ChannelState::Open(OpenEntry {
278
            channel,
279
            max_unused_duration,
280
        })
281
    }
282
    fn closed(ident: &'static str) -> ChannelState<FakeChannel> {
283
        let channel = FakeChannel {
284
            ident,
285
            usable: false,
286
            unused_duration: None,
287
        };
288
        ChannelState::Open(OpenEntry {
289
            channel,
290
            max_unused_duration: Duration::from_secs(180),
291
        })
292
    }
293

            
294
    #[test]
295
    fn simple_ops() {
296
        let map = ChannelMap::new();
297
        use ChannelState::Open;
298

            
299
        assert!(map.replace(b'h', ch("hello")).unwrap().is_none());
300
        assert!(map.replace(b'w', ch("wello")).unwrap().is_none());
301

            
302
        match map.get(&b'h') {
303
            Ok(Some(Open(ent))) if ent.channel.ident == "hello" => {}
304
            _ => panic!(),
305
        }
306

            
307
        assert!(map.get(&b'W').unwrap().is_none());
308

            
309
        match map.replace(b'h', ch("hebbo")) {
310
            Ok(Some(Open(ent))) if ent.channel.ident == "hello" => {}
311
            _ => panic!(),
312
        }
313

            
314
        assert!(map.remove(&b'Z').unwrap().is_none());
315
        match map.remove(&b'h') {
316
            Ok(Some(Open(ent))) if ent.channel.ident == "hebbo" => {}
317
            _ => panic!(),
318
        }
319
    }
320

            
321
    #[test]
322
    fn rmv_unusable() {
323
        let map = ChannelMap::new();
324

            
325
        map.replace(b'm', closed("machen")).unwrap();
326
        map.replace(b'f', ch("feinen")).unwrap();
327
        map.replace(b'w', closed("wir")).unwrap();
328
        map.replace(b'F', ch("Fug")).unwrap();
329

            
330
        map.remove_unusable().unwrap();
331

            
332
        assert!(map.get(&b'm').unwrap().is_none());
333
        assert!(map.get(&b'w').unwrap().is_none());
334
        assert!(map.get(&b'f').unwrap().is_some());
335
        assert!(map.get(&b'F').unwrap().is_some());
336
    }
337

            
338
    #[test]
339
    fn change() {
340
        let map = ChannelMap::new();
341

            
342
        map.replace(b'w', ch("wir")).unwrap();
343
        map.replace(b'm', ch("machen")).unwrap();
344
        map.replace(b'f', ch("feinen")).unwrap();
345
        map.replace(b'F', ch("Fug")).unwrap();
346

            
347
        //  Replace Some with Some.
348
        let (old, v) = map
349
            .change_state(&b'F', |state| (Some(ch("FUG")), (state, 99_u8)))
350
            .unwrap();
351
        assert_eq!(old.unwrap().unwrap_open().ident, "Fug");
352
        assert_eq!(v, 99);
353
        assert_eq!(map.get(&b'F').unwrap().unwrap().unwrap_open().ident, "FUG");
354

            
355
        // Replace Some with None.
356
        let (old, v) = map
357
            .change_state(&b'f', |state| (None, (state, 123_u8)))
358
            .unwrap();
359
        assert_eq!(old.unwrap().unwrap_open().ident, "feinen");
360
        assert_eq!(v, 123);
361
        assert!(map.get(&b'f').unwrap().is_none());
362

            
363
        // Replace None with Some.
364
        let (old, v) = map
365
            .change_state(&b'G', |state| (Some(ch("Geheimnisse")), (state, "Hi")))
366
            .unwrap();
367
        assert!(old.is_none());
368
        assert_eq!(v, "Hi");
369
        assert_eq!(
370
            map.get(&b'G').unwrap().unwrap().unwrap_open().ident,
371
            "Geheimnisse"
372
        );
373

            
374
        // Replace None with None
375
        let (old, v) = map
376
            .change_state(&b'Q', |state| (None, (state, "---")))
377
            .unwrap();
378
        assert!(old.is_none());
379
        assert_eq!(v, "---");
380
        assert!(map.get(&b'Q').unwrap().is_none());
381

            
382
        // Try replacing None with invalid entry (with mismatched ID)
383
        let e = map.change_state(&b'P', |state| (Some(ch("Geheimnisse")), (state, "Hi")));
384
        assert!(matches!(e, Err(Error::Internal(_))));
385
        assert!(matches!(map.get(&b'P'), Ok(None)));
386

            
387
        // Try replacing Some with invalid entry (mismatched ID)
388
        let e = map.change_state(&b'G', |state| (Some(ch("Wobbledy")), (state, "Hi")));
389
        assert!(matches!(e, Err(Error::Internal(_))));
390
        assert!(matches!(map.get(&b'G'), Err(Error::Internal(_))));
391
    }
392

            
393
    #[test]
394
    fn expire_channels() {
395
        let map = ChannelMap::new();
396

            
397
        // Channel that has been unused beyond max duration allowed is expired
398
        map.replace(
399
            b'w',
400
            ch_with_details("wello", Duration::from_secs(180), Some(181)),
401
        )
402
        .unwrap();
403

            
404
        // Minimum value of max unused duration is 180 seconds
405
        assert_eq!(180, map.expire_channels().as_secs());
406
        assert!(map.get(&b'w').unwrap().is_none());
407

            
408
        let map = ChannelMap::new();
409

            
410
        // Channel that has been unused for shorter than max unused duration
411
        map.replace(
412
            b'w',
413
            ch_with_details("wello", Duration::from_secs(180), Some(120)),
414
        )
415
        .unwrap();
416

            
417
        map.replace(
418
            b'y',
419
            ch_with_details("yello", Duration::from_secs(180), Some(170)),
420
        )
421
        .unwrap();
422

            
423
        // Channel that has been unused beyond max duration allowed is expired
424
        map.replace(
425
            b'g',
426
            ch_with_details("gello", Duration::from_secs(180), Some(181)),
427
        )
428
        .unwrap();
429

            
430
        // Closed channel should be retained
431
        map.replace(b'h', closed("hello")).unwrap();
432

            
433
        // Return duration until next channel expires
434
        assert_eq!(10, map.expire_channels().as_secs());
435
        assert!(map.get(&b'w').unwrap().is_some());
436
        assert!(map.get(&b'y').unwrap().is_some());
437
        assert!(map.get(&b'h').unwrap().is_some());
438
        assert!(map.get(&b'g').unwrap().is_none());
439
    }
440
}