1
//! Utility module to safely refer to a mutable Arc.
2

            
3
use std::sync::{Arc, RwLock};
4

            
5
use educe::Educe;
6

            
7
use crate::{Error, Result};
8

            
9
/// A shareable mutable-ish optional reference to a an [`Arc`].
10
///
11
/// Because you can't actually change a shared [`Arc`], this type implements
12
/// mutability by replacing the Arc itself with a new value.  It tries
13
/// to avoid needless clones by taking advantage of [`Arc::make_mut`].
14
///
15
// We give this construction its own type to simplify its users, and make
16
// sure we don't hold the lock against any async suspend points.
17
17
#[derive(Debug, Educe)]
18
#[educe(Default)]
19
#[cfg_attr(not(feature = "experimental-api"), allow(unreachable_pub))]
20
pub struct SharedMutArc<T> {
21
    /// Locked reference to the current value.
22
    ///
23
    /// (It's okay to use RwLock here, because we never suspend
24
    /// while holding the lock.)
25
    dir: RwLock<Option<Arc<T>>>,
26
}
27

            
28
#[cfg_attr(not(feature = "experimental-api"), allow(unreachable_pub))]
29
impl<T> SharedMutArc<T> {
30
    /// Construct a new empty SharedMutArc.
31
10
    pub fn new() -> Self {
32
10
        SharedMutArc::default()
33
10
    }
34

            
35
    /// Replace the current value with `new_val`.
36
2
    pub fn replace(&self, new_val: T) {
37
2
        let mut w = self
38
2
            .dir
39
2
            .write()
40
2
            .expect("Poisoned lock for directory reference");
41
2
        *w = Some(Arc::new(new_val));
42
2
    }
43

            
44
    /// Remove the current value of this SharedMutArc.
45
    #[allow(unused)]
46
1
    pub(crate) fn clear(&self) {
47
1
        let mut w = self
48
1
            .dir
49
1
            .write()
50
1
            .expect("Poisoned lock for directory reference");
51
1
        *w = None;
52
1
    }
53

            
54
    /// Return a new reference to the current value, if there is one.
55
20
    pub fn get(&self) -> Option<Arc<T>> {
56
20
        let r = self
57
20
            .dir
58
20
            .read()
59
20
            .expect("Poisoned lock for directory reference");
60
20
        r.as_ref().map(Arc::clone)
61
20
    }
62

            
63
    /// Replace the contents of this SharedMutArc with the results of applying
64
    /// `func` to the inner value.
65
    ///
66
    /// Gives an error if there is no inner value.
67
    ///
68
    /// Other threads will not abe able to access the inner value
69
    /// while the function is running.
70
    ///
71
    /// # Limitation: No panic-safety
72
    ///
73
    /// If `func` panics while it's running, this object will become invalid
74
    /// and future attempts to use it will panic. (TODO: Fix this.)
75
    // Note: If we decide to make this type public, we'll probably
76
    // want to fiddle with how we handle the return type.
77
2
    pub fn mutate<F, U>(&self, func: F) -> Result<U>
78
2
    where
79
2
        F: FnOnce(&mut T) -> Result<U>,
80
2
        T: Clone,
81
2
    {
82
2
        match self
83
2
            .dir
84
2
            .write()
85
2
            .expect("Poisoned lock for directory reference")
86
2
            .as_mut()
87
        {
88
1
            None => Err(Error::DirectoryNotPresent), // Kinda bogus.
89
1
            Some(arc) => func(Arc::make_mut(arc)),
90
        }
91
2
    }
92
}
93

            
94
#[cfg(test)]
95
mod test {
96
    #![allow(clippy::unwrap_used)]
97
    use super::*;
98
    #[test]
99
    fn shared_mut_arc() {
100
        let val: SharedMutArc<Vec<u32>> = SharedMutArc::new();
101
        assert_eq!(val.get(), None);
102

            
103
        val.replace(Vec::new());
104
        assert_eq!(val.get().unwrap().as_ref()[..], Vec::<u32>::new());
105

            
106
        val.mutate(|v| {
107
            v.push(99);
108
            Ok(())
109
        })
110
        .unwrap();
111
        assert_eq!(val.get().unwrap().as_ref()[..], [99]);
112

            
113
        val.clear();
114
        assert_eq!(val.get(), None);
115

            
116
        assert!(val
117
            .mutate(|v| {
118
                v.push(99);
119
                Ok(())
120
            })
121
            .is_err());
122
    }
123
}