Add server variants to blocking/tokio streams

This commit is contained in:
Aram 🍐 2021-08-04 14:11:57 -04:00
parent 35d40d3111
commit 7d26301571
4 changed files with 82 additions and 20 deletions

View file

@ -109,7 +109,7 @@ impl TcpConnection {
}; };
// Wrap and override the TCP stream // Wrap and override the TCP stream
self.stream = self.stream.upgrade_ssl(sess)?; self.stream = self.stream.upgrade_ssl_client(sess)?;
} }
Ok(self) Ok(self)
} }

View file

@ -6,18 +6,36 @@ pub enum ConnectionStream {
/// A plain TCP stream. /// A plain TCP stream.
Plain(TcpStream), Plain(TcpStream),
/// A stream wrapped with SSL using `rustls`. /// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
Ssl(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>), SslClient(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
/// A server stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")]
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
} }
impl ConnectionStream { impl ConnectionStream {
/// Wraps the current stream with SSL using `rustls`. /// Wraps the current stream with SSL using `rustls` (client-side).
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result<ConnectionStream> { pub fn upgrade_ssl_client(
Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new( self,
session, self, session: rustls::ClientSession,
)))) ) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::SslClient(Box::new(
rustls::StreamOwned::new(session, self),
)))
}
/// Wraps the current stream with SSL using `rustls` (client-side).
#[cfg(feature = "ssl")]
pub fn upgrade_ssl_server(
self,
session: rustls::ServerSession,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::SslServer(Box::new(
rustls::StreamOwned::new(session, self),
)))
} }
} }
@ -26,7 +44,9 @@ impl Read for ConnectionStream {
match self { match self {
Self::Plain(stream) => stream.read(buf), Self::Plain(stream) => stream.read(buf),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.read(buf), Self::SslClient(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Self::SslServer(stream) => stream.read(buf),
} }
} }
} }
@ -36,7 +56,9 @@ impl Write for ConnectionStream {
match self { match self {
Self::Plain(stream) => stream.write(buf), Self::Plain(stream) => stream.write(buf),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.write(buf), Self::SslClient(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Self::SslServer(stream) => stream.write(buf),
} }
} }
@ -44,7 +66,9 @@ impl Write for ConnectionStream {
match self { match self {
Self::Plain(stream) => stream.flush(), Self::Plain(stream) => stream.flush(),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.flush(), Self::SslClient(stream) => stream.flush(),
#[cfg(feature = "ssl")]
Self::SslServer(stream) => stream.flush(),
} }
} }
} }

View file

@ -115,7 +115,7 @@ impl TcpConnection {
let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config)); let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config));
// Wrap and override the TCP stream // Wrap and override the TCP stream
self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?; self.stream = self.stream.upgrade_ssl_client(config, dns_name.as_ref()).await?;
} }
Ok(self) Ok(self)
} }

View file

@ -9,26 +9,44 @@ pub enum ConnectionStream {
/// A plain TCP stream. /// A plain TCP stream.
Plain(TcpStream), Plain(TcpStream),
/// A stream wrapped with SSL using `rustls`. /// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Ssl(Box<tokio_rustls::client::TlsStream<ConnectionStream>>), SslClient(Box<tokio_rustls::client::TlsStream<ConnectionStream>>),
/// A server stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
SslServer(Box<tokio_rustls::server::TlsStream<ConnectionStream>>),
} }
impl ConnectionStream { impl ConnectionStream {
/// Wraps the current stream with SSL using `rustls`. /// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
pub async fn upgrade_ssl( pub async fn upgrade_ssl_client(
self, self,
config: tokio_rustls::TlsConnector, config: tokio_rustls::TlsConnector,
dns_name: webpki::DNSNameRef<'_>, dns_name: webpki::DNSNameRef<'_>,
) -> crate::Result<ConnectionStream> { ) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::Ssl(Box::new( Ok(ConnectionStream::SslClient(Box::new(
config config
.connect(dns_name, self) .connect(dns_name, self)
.await .await
.map_err(crate::ClientError::Io)?, .map_err(crate::ClientError::Io)?,
))) )))
} }
/// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
pub async fn upgrade_ssl_server(
self,
acceptor: tokio_rustls::TlsAcceptor,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::SslServer(Box::new(
acceptor
.accept(self)
.await
.map_err(crate::ClientError::Io)?,
)))
}
} }
impl AsyncRead for ConnectionStream { impl AsyncRead for ConnectionStream {
@ -43,7 +61,12 @@ impl AsyncRead for ConnectionStream {
pinned.poll_read(cx, buf) pinned.poll_read(cx, buf)
} }
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::Ssl(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_read(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_read(cx, buf) pinned.poll_read(cx, buf)
} }
@ -63,7 +86,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_write(cx, buf) pinned.poll_write(cx, buf)
} }
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::Ssl(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_write(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_write(cx, buf) pinned.poll_write(cx, buf)
} }
@ -77,7 +105,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_flush(cx) pinned.poll_flush(cx)
} }
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::Ssl(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_flush(cx)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_flush(cx) pinned.poll_flush(cx)
} }
@ -91,7 +124,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_shutdown(cx) pinned.poll_shutdown(cx)
} }
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::Ssl(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_shutdown(cx)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_shutdown(cx) pinned.poll_shutdown(cx)
} }