1
//! Implementation for using `native_tls`
2

            
3
use crate::traits::{CertifiedConn, TlsConnector, TlsProvider};
4

            
5
use async_trait::async_trait;
6
use futures::{AsyncRead, AsyncWrite};
7
use native_tls_crate as native_tls;
8
use std::{
9
    convert::TryInto,
10
    io::{Error as IoError, Result as IoResult},
11
};
12

            
13
/// A [`TlsProvider`] that uses `native_tls`.
14
///
15
/// It supports wrapping any reasonable stream type that implements `AsyncRead` + `AsyncWrite`.
16
2424
#[derive(Default)]
17
#[non_exhaustive]
18
pub struct NativeTlsProvider {}
19

            
20
impl<S> CertifiedConn for async_native_tls::TlsStream<S>
21
where
22
    S: AsyncRead + AsyncWrite + Unpin,
23
{
24
3
    fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
25
3
        let cert = self.peer_certificate();
26
3
        match cert {
27
3
            Ok(Some(c)) => {
28
3
                let der = c
29
3
                    .to_der()
30
3
                    .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
31
3
                Ok(Some(der))
32
            }
33
            Ok(None) => Ok(None),
34
            Err(e) => Err(IoError::new(std::io::ErrorKind::Other, e)),
35
        }
36
3
    }
37
}
38

            
39
/// An implementation of [`TlsConnector`] built with `native_tls`.
40
pub struct NativeTlsConnector<S> {
41
    /// The inner connector object.
42
    connector: async_native_tls::TlsConnector,
43
    /// Phantom data to ensure proper variance.
44
    _phantom: std::marker::PhantomData<fn(S) -> S>,
45
}
46

            
47
#[async_trait]
48
impl<S> TlsConnector<S> for NativeTlsConnector<S>
49
where
50
    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
51
{
52
    type Conn = async_native_tls::TlsStream<S>;
53

            
54
3
    async fn negotiate_unvalidated(&self, stream: S, sni_hostname: &str) -> IoResult<Self::Conn> {
55
3
        let conn = self
56
3
            .connector
57
6
            .connect(sni_hostname, stream)
58
6
            .await
59
3
            .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
60
3
        Ok(conn)
61
6
    }
62
}
63

            
64
impl<S> TlsProvider<S> for NativeTlsProvider
65
where
66
    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
67
{
68
    type Connector = NativeTlsConnector<S>;
69

            
70
    type TlsStream = async_native_tls::TlsStream<S>;
71

            
72
13
    fn tls_connector(&self) -> Self::Connector {
73
13
        let mut builder = native_tls::TlsConnector::builder();
74
13
        // These function names are scary, but they just mean that we
75
13
        // aren't checking whether the signer of this cert
76
13
        // participates in the web PKI, and we aren't checking the
77
13
        // hostname in the cert.
78
13
        builder
79
13
            .danger_accept_invalid_certs(true)
80
13
            .danger_accept_invalid_hostnames(true);
81
13

            
82
13
        let connector = builder.try_into().expect("Couldn't build a TLS connector!");
83
13

            
84
13
        NativeTlsConnector {
85
13
            connector,
86
13
            _phantom: std::marker::PhantomData,
87
13
        }
88
13
    }
89
}