1
//! Abstract implementation of a channel manager
2

            
3
use crate::mgr::map::OpenEntry;
4
use crate::{Error, Result};
5

            
6
use async_trait::async_trait;
7
use futures::channel::oneshot;
8
use futures::future::{FutureExt, Shared};
9
use rand::Rng;
10
use std::hash::Hash;
11
use std::time::Duration;
12
use tor_error::internal;
13

            
14
mod map;
15

            
16
/// Trait to describe as much of a
17
/// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr`
18
/// needs to use.
19
pub(crate) trait AbstractChannel: Clone {
20
    /// Identity type for the other side of the channel.
21
    type Ident: Hash + Eq + Clone;
22
    /// Return this channel's identity.
23
    fn ident(&self) -> &Self::Ident;
24
    /// Return true if this channel is usable.
25
    ///
26
    /// A channel might be unusable because it is closed, because it has
27
    /// hit a bug, or for some other reason.  We don't return unusable
28
    /// channels back to the user.
29
    fn is_usable(&self) -> bool;
30
    /// Return the amount of time a channel has not been in use.
31
    /// Return None if the channel is currently in use.
32
    fn duration_unused(&self) -> Option<Duration>;
33
}
34

            
35
/// Trait to describe how channels are created.
36
#[async_trait]
37
pub(crate) trait ChannelFactory {
38
    /// The type of channel that this factory can build.
39
    type Channel: AbstractChannel;
40
    /// Type that explains how to build a channel.
41
    type BuildSpec;
42

            
43
    /// Construct a new channel to the destination described at `target`.
44
    ///
45
    /// This function must take care of all timeouts, error detection,
46
    /// and so on.
47
    ///
48
    /// It should not retry; that is handled at a higher level.
49
    async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Self::Channel>;
50
}
51

            
52
/// A type- and network-agnostic implementation for
53
/// [`ChanMgr`](crate::ChanMgr).
54
///
55
/// This type does the work of keeping track of open channels and
56
/// pending channel requests, launching requests as needed, waiting
57
/// for pending requests, and so forth.
58
///
59
/// The actual job of launching connections is deferred to a ChannelFactory
60
/// type.
61
pub(crate) struct AbstractChanMgr<CF: ChannelFactory> {
62
    /// A 'connector' object that we use to create channels.
63
    connector: CF,
64

            
65
    /// A map from ed25519 identity to channel, or to pending channel status.
66
    channels: map::ChannelMap<CF::Channel>,
67
}
68

            
69
/// Type alias for a future that we wait on to see when a pending
70
/// channel is done or failed.
71
type Pending<C> = Shared<oneshot::Receiver<Result<C>>>;
72

            
73
/// Type alias for the sender we notify when we complete a channel (or
74
/// fail to complete it).
75
type Sending<C> = oneshot::Sender<Result<C>>;
76

            
77
impl<CF: ChannelFactory> AbstractChanMgr<CF> {
78
    /// Make a new empty channel manager.
79
22
    pub(crate) fn new(connector: CF) -> Self {
80
22
        AbstractChanMgr {
81
22
            connector,
82
22
            channels: map::ChannelMap::new(),
83
22
        }
84
22
    }
85

            
86
    /// Remove every unusable entry from this channel manager.
87
    #[cfg(test)]
88
1
    pub(crate) fn remove_unusable_entries(&self) -> Result<()> {
89
1
        self.channels.remove_unusable()
90
1
    }
91

            
92
    /// Helper: return the objects used to inform pending tasks
93
    /// about a newly open or failed channel.
94
11
    fn setup_launch<C: Clone>(&self) -> (map::ChannelState<C>, Sending<C>) {
95
11
        let (snd, rcv) = oneshot::channel();
96
11
        let shared = rcv.shared();
97
11
        (map::ChannelState::Building(shared), snd)
98
11
    }
99

            
100
    /// Get a channel whose identity is `ident`.
101
    ///
102
    /// If a usable channel exists with that identity, return it.
103
    ///
104
    /// If no such channel exists already, and none is in progress,
105
    /// launch a new request using `target`, which must match `ident`.
106
    ///
107
    /// If no such channel exists already, but we have one that's in
108
    /// progress, wait for it to succeed or fail.
109
13
    pub(crate) async fn get_or_launch(
110
13
        &self,
111
13
        ident: <<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
112
13
        target: CF::BuildSpec,
113
13
    ) -> Result<CF::Channel> {
114
13
        use map::ChannelState::*;
115
13

            
116
13
        /// Possible actions that we'll decide to take based on the
117
13
        /// channel's initial state.
118
13
        enum Action<C> {
119
13
            /// We found no channel.  We're going to launch a new one,
120
13
            /// then tell everybody about it.
121
13
            Launch(Sending<C>),
122
13
            /// We found an in-progress attempt at making a channel.
123
13
            /// We're going to wait for it to finish.
124
13
            Wait(Pending<C>),
125
13
            /// We found a usable channel.  We're going to return it.
126
13
            Return(Result<C>),
127
13
        }
128
13
        /// How many times do we try?
129
13
        const N_ATTEMPTS: usize = 2;
130
13

            
131
13
        // TODO(nickm): It would be neat to use tor_retry instead.
132
13
        let mut last_err = None;
133

            
134
19
        for _ in 0..N_ATTEMPTS {
135
            // First, see what state we're in, and what we should do
136
            // about it.
137
16
            let action = self
138
16
                .channels
139
16
                .change_state(&ident, |oldstate| match oldstate {
140
2
                    Some(Open(ref ent)) => {
141
2
                        if ent.channel.is_usable() {
142
                            // Good channel. Return it.
143
1
                            let action = Action::Return(Ok(ent.channel.clone()));
144
1
                            (oldstate, action)
145
                        } else {
146
                            // Unusable channel.  Move to the Building
147
                            // state and launch a new channel.
148
1
                            let (newstate, send) = self.setup_launch();
149
1
                            let action = Action::Launch(send);
150
1
                            (Some(newstate), action)
151
                        }
152
                    }
153
4
                    Some(Building(ref pending)) => {
154
4
                        let action = Action::Wait(pending.clone());
155
4
                        (oldstate, action)
156
                    }
157
                    Some(Poisoned(_)) => {
158
                        // We should never be able to see this state; this
159
                        // is a bug.
160
                        (
161
                            None,
162
                            Action::Return(Err(Error::Internal(internal!(
163
                                "Found a poisoned entry"
164
                            )))),
165
                        )
166
                    }
167
                    None => {
168
                        // No channel.  Move to the Building
169
                        // state and launch a new channel.
170
10
                        let (newstate, send) = self.setup_launch();
171
10
                        let action = Action::Launch(send);
172
10
                        (Some(newstate), action)
173
                    }
174
16
                })?;
175

            
176
            // Now we act based on the channel.
177
16
            match action {
178
                // Easy case: we have an error or a channel to return.
179
1
                Action::Return(v) => {
180
1
                    return v;
181
                }
182
                // There's an in-progress channel.  Wait for it.
183
4
                Action::Wait(pend) => match pend.await {
184
2
                    Ok(Ok(chan)) => return Ok(chan),
185
2
                    Ok(Err(e)) => {
186
2
                        last_err = Some(e);
187
2
                    }
188
                    Err(_) => {
189
                        last_err =
190
                            Some(Error::Internal(internal!("channel build task disappeared")));
191
                    }
192
                },
193
                // We need to launch a channel.
194
11
                Action::Launch(send) => match self.connector.build_channel(&target).await {
195
7
                    Ok(chan) => {
196
7
                        // The channel got built: remember it, tell the
197
7
                        // others, and return it.
198
7
                        self.channels.replace(
199
7
                            ident.clone(),
200
7
                            Open(OpenEntry {
201
7
                                channel: chan.clone(),
202
7
                                max_unused_duration: Duration::from_secs(
203
7
                                    rand::thread_rng().gen_range(180..270),
204
7
                                ),
205
7
                            }),
206
7
                        )?;
207
                        // It's okay if all the receivers went away:
208
                        // that means that nobody was waiting for this channel.
209
6
                        let _ignore_err = send.send(Ok(chan.clone()));
210
6
                        return Ok(chan);
211
                    }
212
4
                    Err(e) => {
213
4
                        // The channel failed. Make it non-pending, tell the
214
4
                        // others, and set the error.
215
4
                        self.channels.remove(&ident)?;
216
                        // (As above)
217
4
                        let _ignore_err = send.send(Err(e.clone()));
218
4
                        last_err = Some(e);
219
                    }
220
                },
221
            }
222
        }
223

            
224
3
        Err(last_err.unwrap_or_else(|| Error::Internal(internal!("no error was set!?"))))
225
12
    }
226

            
227
    /// Expire any channels that have been unused longer than
228
    /// their maximum unused duration assigned during creation.
229
    ///
230
    /// Return a duration from now until next channel expires.
231
    ///
232
    /// If all channels are in use or there are no open channels,
233
    /// return 180 seconds which is the minimum value of
234
    /// max_unused_duration.
235
2
    pub(crate) fn expire_channels(&self) -> Duration {
236
2
        self.channels.expire_channels()
237
2
    }
238

            
239
    /// Test only: return the current open usable channel with a given
240
    /// `ident`, if any.
241
    #[cfg(test)]
242
    pub(crate) fn get_nowait(
243
        &self,
244
        ident: &<<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
245
    ) -> Option<CF::Channel> {
246
        use map::ChannelState::*;
247
5
        match self.channels.get(ident) {
248
3
            Ok(Some(Open(ref ent))) if ent.channel.is_usable() => Some(ent.channel.clone()),
249
2
            _ => None,
250
        }
251
5
    }
252
}
253

            
254
#[cfg(test)]
255
mod test {
256
    #![allow(clippy::unwrap_used)]
257
    use super::*;
258
    use crate::Error;
259

            
260
    use futures::join;
261
    use std::sync::atomic::{AtomicBool, Ordering};
262
    use std::sync::Arc;
263
    use std::time::Duration;
264
    use tor_error::bad_api_usage;
265

            
266
    use tor_rtcompat::{task::yield_now, test_with_one_runtime, Runtime};
267

            
268
    struct FakeChannelFactory<RT> {
269
        runtime: RT,
270
    }
271

            
272
    #[derive(Clone, Debug)]
273
    struct FakeChannel {
274
        ident: u32,
275
        mood: char,
276
        closing: Arc<AtomicBool>,
277
        detect_reuse: Arc<char>,
278
    }
279

            
280
    impl PartialEq for FakeChannel {
281
        fn eq(&self, other: &Self) -> bool {
282
            Arc::ptr_eq(&self.detect_reuse, &other.detect_reuse)
283
        }
284
    }
285

            
286
    impl AbstractChannel for FakeChannel {
287
        type Ident = u32;
288
        fn ident(&self) -> &u32 {
289
            &self.ident
290
        }
291
        fn is_usable(&self) -> bool {
292
            !self.closing.load(Ordering::SeqCst)
293
        }
294
        fn duration_unused(&self) -> Option<Duration> {
295
            None
296
        }
297
    }
298

            
299
    impl FakeChannel {
300
        fn start_closing(&self) {
301
            self.closing.store(true, Ordering::SeqCst);
302
        }
303
    }
304

            
305
    impl<RT: Runtime> FakeChannelFactory<RT> {
306
        fn new(runtime: RT) -> Self {
307
            FakeChannelFactory { runtime }
308
        }
309
    }
310

            
311
    #[async_trait]
312
    impl<RT: Runtime> ChannelFactory for FakeChannelFactory<RT> {
313
        type Channel = FakeChannel;
314
        type BuildSpec = (u32, char);
315

            
316
        async fn build_channel(&self, target: &Self::BuildSpec) -> Result<FakeChannel> {
317
            yield_now().await;
318
            let (ident, mood) = *target;
319
            match mood {
320
                // "X" means never connect.
321
                '❌' | '🔥' => return Err(Error::UnusableTarget(bad_api_usage!("emoji"))),
322
                // "zzz" means wait for 15 seconds then succeed.
323
                '💤' => {
324
                    self.runtime.sleep(Duration::new(15, 0)).await;
325
                }
326
                _ => {}
327
            }
328
            Ok(FakeChannel {
329
                ident,
330
                mood,
331
                closing: Arc::new(AtomicBool::new(false)),
332
                detect_reuse: Default::default(),
333
            })
334
        }
335
    }
336

            
337
    #[test]
338
    fn connect_one_ok() {
339
        test_with_one_runtime!(|runtime| async {
340
            let cf = FakeChannelFactory::new(runtime);
341
            let mgr = AbstractChanMgr::new(cf);
342
            let target = (413, '!');
343
            let chan1 = mgr.get_or_launch(413, target).await.unwrap();
344
            let chan2 = mgr.get_or_launch(413, target).await.unwrap();
345

            
346
            assert_eq!(chan1, chan2);
347

            
348
            let chan3 = mgr.get_nowait(&413).unwrap();
349
            assert_eq!(chan1, chan3);
350
        });
351
    }
352

            
353
    #[test]
354
    fn connect_one_fail() {
355
        test_with_one_runtime!(|runtime| async {
356
            let cf = FakeChannelFactory::new(runtime);
357
            let mgr = AbstractChanMgr::new(cf);
358

            
359
            // This is set up to always fail.
360
            let target = (999, '❌');
361
            let res1 = mgr.get_or_launch(999, target).await;
362
            assert!(matches!(res1, Err(Error::UnusableTarget(_))));
363

            
364
            let chan3 = mgr.get_nowait(&999);
365
            assert!(chan3.is_none());
366
        });
367
    }
368

            
369
    #[test]
370
    fn test_concurrent() {
371
        test_with_one_runtime!(|runtime| async {
372
            let cf = FakeChannelFactory::new(runtime);
373
            let mgr = AbstractChanMgr::new(cf);
374

            
375
            // TODO(nickm): figure out how to make these actually run
376
            // concurrently. Right now it seems that they don't actually
377
            // interact.
378
            let (ch3a, ch3b, ch44a, ch44b, ch86a, ch86b) = join!(
379
                mgr.get_or_launch(3, (3, 'a')),
380
                mgr.get_or_launch(3, (3, 'b')),
381
                mgr.get_or_launch(44, (44, 'a')),
382
                mgr.get_or_launch(44, (44, 'b')),
383
                mgr.get_or_launch(86, (86, '❌')),
384
                mgr.get_or_launch(86, (86, '🔥')),
385
            );
386
            let ch3a = ch3a.unwrap();
387
            let ch3b = ch3b.unwrap();
388
            let ch44a = ch44a.unwrap();
389
            let ch44b = ch44b.unwrap();
390
            let err_a = ch86a.unwrap_err();
391
            let err_b = ch86b.unwrap_err();
392

            
393
            assert_eq!(ch3a, ch3b);
394
            assert_eq!(ch44a, ch44b);
395
            assert_ne!(ch44a, ch3a);
396

            
397
            assert!(matches!(err_a, Error::UnusableTarget(_)));
398
            assert!(matches!(err_b, Error::UnusableTarget(_)));
399
        });
400
    }
401

            
402
    #[test]
403
    fn unusable_entries() {
404
        test_with_one_runtime!(|runtime| async {
405
            let cf = FakeChannelFactory::new(runtime);
406
            let mgr = AbstractChanMgr::new(cf);
407

            
408
            let (ch3, ch4, ch5) = join!(
409
                mgr.get_or_launch(3, (3, 'a')),
410
                mgr.get_or_launch(4, (4, 'a')),
411
                mgr.get_or_launch(5, (5, 'a')),
412
            );
413

            
414
            let ch3 = ch3.unwrap();
415
            let _ch4 = ch4.unwrap();
416
            let ch5 = ch5.unwrap();
417

            
418
            ch3.start_closing();
419
            ch5.start_closing();
420

            
421
            let ch3_new = mgr.get_or_launch(3, (3, 'b')).await.unwrap();
422
            assert_ne!(ch3, ch3_new);
423
            assert_eq!(ch3_new.mood, 'b');
424

            
425
            mgr.remove_unusable_entries().unwrap();
426

            
427
            assert!(mgr.get_nowait(&3).is_some());
428
            assert!(mgr.get_nowait(&4).is_some());
429
            assert!(mgr.get_nowait(&5).is_none());
430
        });
431
    }
432
}