1
//! Define a [`CompoundRuntime`] part that can be built from several component
2
//! pieces.
3

            
4
use std::{net::SocketAddr, sync::Arc, time::Duration};
5

            
6
use crate::traits::*;
7
use async_trait::async_trait;
8
use educe::Educe;
9
use futures::{future::FutureObj, task::Spawn};
10
use std::io::Result as IoResult;
11

            
12
/// A runtime made of several parts, each of which implements one trait-group.
13
///
14
/// The `SpawnR` component should implements [`Spawn`] and [`BlockOn`];
15
/// the `SleepR` component should implement [`SleepProvider`]; the `TcpR`
16
/// component should implement [`TcpProvider`]; and the `TlsR` component should
17
/// implement [`TlsProvider`].
18
///
19
/// You can use this structure to create new runtimes in two ways: either by
20
/// overriding a single part of an existing runtime, or by building an entirely
21
/// new runtime from pieces.
22
641
#[derive(Educe)]
23
#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
24
pub struct CompoundRuntime<SpawnR, SleepR, TcpR, TlsR> {
25
    /// The actual collection of Runtime objects.
26
    ///
27
    /// We wrap this in an Arc rather than requiring that each item implement
28
    /// Clone, though we could change our minds later on.
29
    inner: Arc<Inner<SpawnR, SleepR, TcpR, TlsR>>,
30
}
31

            
32
/// A collection of objects implementing that traits that make up a [`Runtime`]
33
struct Inner<SpawnR, SleepR, TcpR, TlsR> {
34
    /// A `Spawn` and `BlockOn` implementation.
35
    spawn: SpawnR,
36
    /// A `SleepProvider` implementation.
37
    sleep: SleepR,
38
    /// A `TcpProvider` implementation
39
    tcp: TcpR,
40
    /// A `TcpProvider<TcpR::TcpStream>` implementation.
41
    tls: TlsR,
42
}
43

            
44
impl<SpawnR, SleepR, TcpR, TlsR> CompoundRuntime<SpawnR, SleepR, TcpR, TlsR> {
45
    /// Construct a new CompoundRuntime from its components.
46
4087
    pub fn new(spawn: SpawnR, sleep: SleepR, tcp: TcpR, tls: TlsR) -> Self {
47
4087
        CompoundRuntime {
48
4087
            inner: Arc::new(Inner {
49
4087
                spawn,
50
4087
                sleep,
51
4087
                tcp,
52
4087
                tls,
53
4087
            }),
54
4087
        }
55
4087
    }
56
}
57

            
58
impl<SpawnR, SleepR, TcpR, TlsR> Spawn for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
59
where
60
    SpawnR: Spawn,
61
{
62
    #[inline]
63
269
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
64
269
        self.inner.spawn.spawn_obj(future)
65
269
    }
66
}
67

            
68
impl<SpawnR, SleepR, TcpR, TlsR> BlockOn for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
69
where
70
    SpawnR: BlockOn,
71
{
72
    #[inline]
73
262
    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
74
262
        self.inner.spawn.block_on(future)
75
262
    }
76
}
77

            
78
impl<SpawnR, SleepR, TcpR, TlsR> SleepProvider for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
79
where
80
    SleepR: SleepProvider,
81
{
82
    type SleepFuture = SleepR::SleepFuture;
83

            
84
    #[inline]
85
79
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
86
79
        self.inner.sleep.sleep(duration)
87
79
    }
88
}
89

            
90
#[async_trait]
91
impl<SpawnR, SleepR, TcpR, TlsR> TcpProvider for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
92
where
93
    TcpR: TcpProvider,
94
    SpawnR: Send + Sync + 'static,
95
    SleepR: Send + Sync + 'static,
96
    TcpR: Send + Sync + 'static,
97
    TlsR: Send + Sync + 'static,
98
{
99
    type TcpStream = TcpR::TcpStream;
100

            
101
    type TcpListener = TcpR::TcpListener;
102

            
103
    #[inline]
104
26
    async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::TcpStream> {
105
30
        self.inner.tcp.connect(addr).await
106
52
    }
107

            
108
    #[inline]
109
6
    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::TcpListener> {
110
6
        self.inner.tcp.listen(addr).await
111
12
    }
112
}
113

            
114
impl<SpawnR, SleepR, TcpR, TlsR, S> TlsProvider<S> for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
115
where
116
    TcpR: TcpProvider,
117
    TlsR: TlsProvider<S>,
118
{
119
    type Connector = TlsR::Connector;
120
    type TlsStream = TlsR::TlsStream;
121

            
122
    #[inline]
123
23
    fn tls_connector(&self) -> Self::Connector {
124
23
        self.inner.tls.tls_connector()
125
23
    }
126
}
127

            
128
impl<SpawnR, SleepR, TcpR, TlsR> std::fmt::Debug for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR> {
129
1
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130
1
        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
131
1
    }
132
}