Add server variants to blocking/tokio streams

This commit is contained in:
Aram 🍐 2021-08-04 14:11:57 -04:00
parent 35d40d3111
commit 7d26301571
4 changed files with 82 additions and 20 deletions

View file

@ -109,7 +109,7 @@ impl TcpConnection {
};
// Wrap and override the TCP stream
self.stream = self.stream.upgrade_ssl(sess)?;
self.stream = self.stream.upgrade_ssl_client(sess)?;
}
Ok(self)
}

View file

@ -6,18 +6,36 @@ pub enum ConnectionStream {
/// A plain TCP stream.
Plain(TcpStream),
/// A stream wrapped with SSL using `rustls`.
/// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")]
Ssl(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
SslClient(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
/// A server stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")]
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
}
impl ConnectionStream {
/// Wraps the current stream with SSL using `rustls`.
/// Wraps the current stream with SSL using `rustls` (client-side).
#[cfg(feature = "ssl")]
pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new(
session, self,
))))
pub fn upgrade_ssl_client(
self,
session: rustls::ClientSession,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::SslClient(Box::new(
rustls::StreamOwned::new(session, self),
)))
}
/// Wraps the current stream with SSL using `rustls` (client-side).
#[cfg(feature = "ssl")]
pub fn upgrade_ssl_server(
self,
session: rustls::ServerSession,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::SslServer(Box::new(
rustls::StreamOwned::new(session, self),
)))
}
}
@ -26,7 +44,9 @@ impl Read for ConnectionStream {
match self {
Self::Plain(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.read(buf),
Self::SslClient(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Self::SslServer(stream) => stream.read(buf),
}
}
}
@ -36,7 +56,9 @@ impl Write for ConnectionStream {
match self {
Self::Plain(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.write(buf),
Self::SslClient(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Self::SslServer(stream) => stream.write(buf),
}
}
@ -44,7 +66,9 @@ impl Write for ConnectionStream {
match self {
Self::Plain(stream) => stream.flush(),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.flush(),
Self::SslClient(stream) => stream.flush(),
#[cfg(feature = "ssl")]
Self::SslServer(stream) => stream.flush(),
}
}
}

View file

@ -115,7 +115,7 @@ impl TcpConnection {
let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config));
// Wrap and override the TCP stream
self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?;
self.stream = self.stream.upgrade_ssl_client(config, dns_name.as_ref()).await?;
}
Ok(self)
}

View file

@ -9,26 +9,44 @@ pub enum ConnectionStream {
/// A plain TCP stream.
Plain(TcpStream),
/// A stream wrapped with SSL using `rustls`.
/// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
Ssl(Box<tokio_rustls::client::TlsStream<ConnectionStream>>),
SslClient(Box<tokio_rustls::client::TlsStream<ConnectionStream>>),
/// A server stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
SslServer(Box<tokio_rustls::server::TlsStream<ConnectionStream>>),
}
impl ConnectionStream {
/// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
pub async fn upgrade_ssl(
pub async fn upgrade_ssl_client(
self,
config: tokio_rustls::TlsConnector,
dns_name: webpki::DNSNameRef<'_>,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::Ssl(Box::new(
Ok(ConnectionStream::SslClient(Box::new(
config
.connect(dns_name, self)
.await
.map_err(crate::ClientError::Io)?,
)))
}
/// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
pub async fn upgrade_ssl_server(
self,
acceptor: tokio_rustls::TlsAcceptor,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::SslServer(Box::new(
acceptor
.accept(self)
.await
.map_err(crate::ClientError::Io)?,
)))
}
}
impl AsyncRead for ConnectionStream {
@ -43,7 +61,12 @@ impl AsyncRead for ConnectionStream {
pinned.poll_read(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_read(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_read(cx, buf)
}
@ -63,7 +86,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_write(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_write(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_write(cx, buf)
}
@ -77,7 +105,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_flush(cx)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_flush(cx)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_flush(cx)
}
@ -91,7 +124,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_shutdown(cx)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_shutdown(cx)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_shutdown(cx)
}