1
//! Types and code to map circuit IDs to circuits.
2

            
3
// NOTE: This is a work in progress and I bet I'll refactor it a lot;
4
// it needs to stay opaque!
5

            
6
use crate::{Error, Result};
7
use tor_cell::chancell::CircId;
8

            
9
use crate::circuit::celltypes::{ClientCircChanMsg, CreateResponse};
10
use crate::circuit::halfcirc::HalfCirc;
11

            
12
use futures::channel::{mpsc, oneshot};
13

            
14
use rand::distributions::Distribution;
15
use rand::Rng;
16
use std::collections::{hash_map::Entry, HashMap};
17
use std::ops::{Deref, DerefMut};
18

            
19
/// Which group of circuit IDs are we allowed to allocate in this map?
20
///
21
/// If we initiated the channel, we use High circuit ids.  If we're the
22
/// responder, we use low circuit ids.
23
#[derive(Copy, Clone)]
24
pub(super) enum CircIdRange {
25
    /// Only use circuit IDs with the MSB cleared.
26
    #[allow(dead_code)] // Relays will need this.
27
    Low,
28
    /// Only use circuit IDs with the MSB set.
29
    High,
30
    // Historical note: There used to be an "All" range of circuit IDs
31
    // available to clients only.  We stopped using "All" when we moved to link
32
    // protocol version 4.
33
}
34

            
35
impl rand::distributions::Distribution<CircId> for CircIdRange {
36
    /// Return a random circuit ID in the appropriate range.
37
260
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> CircId {
38
        // Make sure v is nonzero.
39
260
        let v = loop {
40
260
            match rng.gen() {
41
                0_u32 => (), // zero is not a valid circuit ID
42
260
                x => break x,
43
260
            }
44
260
        };
45
260
        // Force the high bit of v to the appropriate value.
46
260
        match self {
47
128
            CircIdRange::Low => v & 0x7fff_ffff,
48
132
            CircIdRange::High => v | 0x8000_0000,
49
        }
50
260
        .into()
51
260
    }
52
}
53

            
54
/// An entry in the circuit map.  Right now, we only have "here's the
55
/// way to send cells to a given circuit", but that's likely to
56
/// change.
57
#[derive(Debug)]
58
pub(super) enum CircEnt {
59
    /// A circuit that has not yet received a CREATED cell.
60
    ///
61
    /// For this circuit, the CREATED* cell or DESTROY cell gets sent
62
    /// to the oneshot sender to tell the corresponding
63
    /// PendingClientCirc that the handshake is done.
64
    ///
65
    /// Once that's done, the mpsc sender will be used to send subsequent
66
    /// cells to the circuit.
67
    Opening(
68
        oneshot::Sender<CreateResponse>,
69
        mpsc::Sender<ClientCircChanMsg>,
70
    ),
71

            
72
    /// A circuit that is open and can be given relay cells.
73
    Open(mpsc::Sender<ClientCircChanMsg>),
74

            
75
    /// A circuit where we have sent a DESTROY, but the other end might
76
    /// not have gotten a DESTROY yet.
77
    DestroySent(HalfCirc),
78
}
79

            
80
/// An "smart pointer" that wraps an exclusive reference
81
/// of a `CircEnt`.
82
///
83
/// When being dropped, this object updates the open or opening entries
84
/// counter of the `CircMap`.
85
pub(super) struct MutCircEnt<'a> {
86
    /// An exclusive reference to the `CircEnt`.
87
    value: &'a mut CircEnt,
88
    /// An exclusive reference to the open or opening
89
    ///  entries counter.
90
    open_count: &'a mut usize,
91
    /// True if the entry was open or opening when borrowed.
92
    was_open: bool,
93
}
94

            
95
impl<'a> Drop for MutCircEnt<'a> {
96
252
    fn drop(&mut self) {
97
252
        let is_open = !matches!(self.value, CircEnt::DestroySent(_));
98
252
        match (self.was_open, is_open) {
99
            (false, true) => *self.open_count = self.open_count.saturating_add(1),
100
            (true, false) => *self.open_count = self.open_count.saturating_sub(1),
101
252
            (_, _) => (),
102
        };
103
252
    }
104
}
105

            
106
impl<'a> Deref for MutCircEnt<'a> {
107
    type Target = CircEnt;
108
138
    fn deref(&self) -> &Self::Target {
109
138
        self.value
110
138
    }
111
}
112

            
113
impl<'a> DerefMut for MutCircEnt<'a> {
114
112
    fn deref_mut(&mut self) -> &mut Self::Target {
115
112
        self.value
116
112
    }
117
}
118

            
119
/// A map from circuit IDs to circuit entries. Each channel has one.
120
pub(super) struct CircMap {
121
    /// Map from circuit IDs to entries
122
    m: HashMap<CircId, CircEnt>,
123
    /// Rule for allocating new circuit IDs.
124
    range: CircIdRange,
125
    /// Number of open or opening entry in this map.
126
    open_count: usize,
127
}
128

            
129
impl CircMap {
130
    /// Make a new empty CircMap
131
88
    pub(super) fn new(idrange: CircIdRange) -> Self {
132
88
        CircMap {
133
88
            m: HashMap::new(),
134
88
            range: idrange,
135
88
            open_count: 0,
136
88
        }
137
88
    }
138

            
139
    /// Add a new pair of elements (corresponding to a PendingClientCirc)
140
    /// to this map.
141
    ///
142
    /// On success return the allocated circuit ID.
143
260
    pub(super) fn add_ent<R: Rng>(
144
260
        &mut self,
145
260
        rng: &mut R,
146
260
        createdsink: oneshot::Sender<CreateResponse>,
147
260
        sink: mpsc::Sender<ClientCircChanMsg>,
148
260
    ) -> Result<CircId> {
149
260
        /// How many times do we probe for a random circuit ID before
150
260
        /// we assume that the range is fully populated?
151
260
        ///
152
260
        /// TODO: C tor does 64, but that is probably overkill with 4-byte circuit IDs.
153
260
        const N_ATTEMPTS: usize = 16;
154
260
        let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
155
260
        let circ_ent = CircEnt::Opening(createdsink, sink);
156
260
        for id in iter {
157
260
            let ent = self.m.entry(id);
158
260
            if let Entry::Vacant(_) = &ent {
159
260
                ent.or_insert(circ_ent);
160
260
                self.open_count += 1;
161
260
                return Ok(id);
162
            }
163
        }
164
        Err(Error::IdRangeFull)
165
260
    }
166

            
167
    /// Testing only: install an entry in this circuit map without regard
168
    /// for consistency.
169
    #[cfg(test)]
170
24
    pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
171
24
        self.m.insert(id, ent);
172
24
    }
173

            
174
    /// Return the entry for `id` in this map, if any.
175
262
    pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
176
262
        let open_count = &mut self.open_count;
177
262
        self.m.get_mut(&id).map(move |ent| MutCircEnt {
178
252
            open_count,
179
252
            was_open: !matches!(ent, CircEnt::DestroySent(_)),
180
252
            value: ent,
181
262
        })
182
262
    }
183

            
184
    /// See whether 'id' is an opening circuit.  If so, mark it "open" and
185
    /// return a oneshot::Sender that is waiting for its create cell.
186
7
    pub(super) fn advance_from_opening(
187
7
        &mut self,
188
7
        id: CircId,
189
7
    ) -> Result<oneshot::Sender<CreateResponse>> {
190
        // TODO: there should be a better way to do
191
        // this. hash_map::Entry seems like it could be better, but
192
        // there seems to be no way to replace the object in-place as
193
        // a consuming function of itself.
194
7
        let ok = matches!(self.m.get(&id), Some(CircEnt::Opening(_, _)));
195
7
        if ok {
196
1
            if let Some(CircEnt::Opening(oneshot, sink)) = self.m.remove(&id) {
197
1
                self.m.insert(id, CircEnt::Open(sink));
198
1
                Ok(oneshot)
199
            } else {
200
                panic!("internal error: inconsistent circuit state");
201
            }
202
        } else {
203
6
            Err(Error::ChanProto(
204
6
                "Unexpected CREATED* cell not on opening circuit".into(),
205
6
            ))
206
        }
207
7
    }
208

            
209
    /// Called when we have sent a DESTROY on a circuit.  Configures
210
    /// a "HalfCirc" object to track how many cells we get on this
211
    /// circuit, and to prevent us from reusing it immediately.
212
    pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
213
5
        if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
214
4
            if !matches!(replaced, CircEnt::DestroySent(_)) {
215
4
                // replaced an Open/Opening entry with DestroySent
216
4
                self.open_count = self.open_count.saturating_sub(1);
217
4
            }
218
1
        }
219
5
    }
220

            
221
    /// Extract the value from this map with 'id' if any
222
17
    pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
223
17
        self.m.remove(&id).map(|removed| {
224
13
            if !matches!(removed, CircEnt::DestroySent(_)) {
225
9
                self.open_count = self.open_count.saturating_sub(1);
226
9
            }
227
13
            removed
228
17
        })
229
17
    }
230

            
231
    /// Return the total number of open and opening entries in the map
232
28
    pub(super) fn open_ent_count(&self) -> usize {
233
28
        self.open_count
234
28
    }
235

            
236
    // TODO: Eventually if we want relay support, we'll need to support
237
    // circuit IDs chosen by somebody else. But for now, we don't need those.
238
}
239

            
240
#[cfg(test)]
241
mod test {
242
    #![allow(clippy::unwrap_used)]
243
    use super::*;
244
    use futures::channel::{mpsc, oneshot};
245

            
246
    #[test]
247
    fn circmap_basics() {
248
        let mut map_low = CircMap::new(CircIdRange::Low);
249
        let mut map_high = CircMap::new(CircIdRange::High);
250
        let mut ids_low: Vec<CircId> = Vec::new();
251
        let mut ids_high: Vec<CircId> = Vec::new();
252
        let mut rng = rand::thread_rng();
253

            
254
        assert!(map_low.get_mut(CircId::from(77)).is_none());
255

            
256
        for _ in 0..128 {
257
            let (csnd, _) = oneshot::channel();
258
            let (snd, _) = mpsc::channel(8);
259
            let id_low = map_low.add_ent(&mut rng, csnd, snd).unwrap();
260
            assert!(u32::from(id_low) > 0);
261
            assert!(u32::from(id_low) < 0x80000000);
262
            assert!(!ids_low.iter().any(|x| *x == id_low));
263
            ids_low.push(id_low);
264

            
265
            assert!(matches!(
266
                *map_low.get_mut(id_low).unwrap(),
267
                CircEnt::Opening(_, _)
268
            ));
269

            
270
            let (csnd, _) = oneshot::channel();
271
            let (snd, _) = mpsc::channel(8);
272
            let id_high = map_high.add_ent(&mut rng, csnd, snd).unwrap();
273
            assert!(u32::from(id_high) >= 0x80000000);
274
            assert!(!ids_high.iter().any(|x| *x == id_high));
275
            ids_high.push(id_high);
276
        }
277

            
278
        // Test open / opening entry counting
279
        assert_eq!(128, map_low.open_ent_count());
280
        assert_eq!(128, map_high.open_ent_count());
281

            
282
        // Test remove
283
        assert!(map_low.get_mut(ids_low[0]).is_some());
284
        map_low.remove(ids_low[0]);
285
        assert!(map_low.get_mut(ids_low[0]).is_none());
286
        assert_eq!(127, map_low.open_ent_count());
287

            
288
        // Test DestroySent doesn't count
289
        map_low.destroy_sent(CircId::from(256), HalfCirc::new(1));
290
        assert_eq!(127, map_low.open_ent_count());
291

            
292
        // Test advance_from_opening.
293

            
294
        // Good case.
295
        assert!(map_high.get_mut(ids_high[0]).is_some());
296
        assert!(matches!(
297
            *map_high.get_mut(ids_high[0]).unwrap(),
298
            CircEnt::Opening(_, _)
299
        ));
300
        let adv = map_high.advance_from_opening(ids_high[0]);
301
        assert!(adv.is_ok());
302
        assert!(matches!(
303
            *map_high.get_mut(ids_high[0]).unwrap(),
304
            CircEnt::Open(_)
305
        ));
306

            
307
        // Can't double-advance.
308
        let adv = map_high.advance_from_opening(ids_high[0]);
309
        assert!(adv.is_err());
310

            
311
        // Can't advance an entry that is not there.  We know "77"
312
        // can't be in map_high, since we only added high circids to
313
        // it.
314
        let adv = map_high.advance_from_opening(77.into());
315
        assert!(adv.is_err());
316
    }
317
}