From 7d26301571288a683f402a4efe6ad4c163d0410c Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 4 Aug 2021 14:11:57 -0400 Subject: [PATCH] Add server variants to blocking/tokio streams --- rups/src/blocking/mod.rs | 2 +- rups/src/blocking/stream.rs | 44 +++++++++++++++++++++++------- rups/src/tokio/mod.rs | 2 +- rups/src/tokio/stream.rs | 54 +++++++++++++++++++++++++++++++------ 4 files changed, 82 insertions(+), 20 deletions(-) diff --git a/rups/src/blocking/mod.rs b/rups/src/blocking/mod.rs index a0f2d5d..0ee4e44 100644 --- a/rups/src/blocking/mod.rs +++ b/rups/src/blocking/mod.rs @@ -109,7 +109,7 @@ impl TcpConnection { }; // Wrap and override the TCP stream - self.stream = self.stream.upgrade_ssl(sess)?; + self.stream = self.stream.upgrade_ssl_client(sess)?; } Ok(self) } diff --git a/rups/src/blocking/stream.rs b/rups/src/blocking/stream.rs index 2131ce8..96f3783 100644 --- a/rups/src/blocking/stream.rs +++ b/rups/src/blocking/stream.rs @@ -6,18 +6,36 @@ pub enum ConnectionStream { /// A plain TCP stream. Plain(TcpStream), - /// A stream wrapped with SSL using `rustls`. + /// A client stream wrapped with SSL using `rustls`. #[cfg(feature = "ssl")] - Ssl(Box>), + SslClient(Box>), + + /// A server stream wrapped with SSL using `rustls`. + #[cfg(feature = "ssl")] + SslServer(Box>), } impl ConnectionStream { - /// Wraps the current stream with SSL using `rustls`. + /// Wraps the current stream with SSL using `rustls` (client-side). #[cfg(feature = "ssl")] - pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result { - Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new( - session, self, - )))) + pub fn upgrade_ssl_client( + self, + session: rustls::ClientSession, + ) -> crate::Result { + Ok(ConnectionStream::SslClient(Box::new( + rustls::StreamOwned::new(session, self), + ))) + } + + /// Wraps the current stream with SSL using `rustls` (client-side). + #[cfg(feature = "ssl")] + pub fn upgrade_ssl_server( + self, + session: rustls::ServerSession, + ) -> crate::Result { + Ok(ConnectionStream::SslServer(Box::new( + rustls::StreamOwned::new(session, self), + ))) } } @@ -26,7 +44,9 @@ impl Read for ConnectionStream { match self { Self::Plain(stream) => stream.read(buf), #[cfg(feature = "ssl")] - Self::Ssl(stream) => stream.read(buf), + Self::SslClient(stream) => stream.read(buf), + #[cfg(feature = "ssl")] + Self::SslServer(stream) => stream.read(buf), } } } @@ -36,7 +56,9 @@ impl Write for ConnectionStream { match self { Self::Plain(stream) => stream.write(buf), #[cfg(feature = "ssl")] - Self::Ssl(stream) => stream.write(buf), + Self::SslClient(stream) => stream.write(buf), + #[cfg(feature = "ssl")] + Self::SslServer(stream) => stream.write(buf), } } @@ -44,7 +66,9 @@ impl Write for ConnectionStream { match self { Self::Plain(stream) => stream.flush(), #[cfg(feature = "ssl")] - Self::Ssl(stream) => stream.flush(), + Self::SslClient(stream) => stream.flush(), + #[cfg(feature = "ssl")] + Self::SslServer(stream) => stream.flush(), } } } diff --git a/rups/src/tokio/mod.rs b/rups/src/tokio/mod.rs index dc6677a..d7ad791 100644 --- a/rups/src/tokio/mod.rs +++ b/rups/src/tokio/mod.rs @@ -115,7 +115,7 @@ impl TcpConnection { let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config)); // Wrap and override the TCP stream - self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?; + self.stream = self.stream.upgrade_ssl_client(config, dns_name.as_ref()).await?; } Ok(self) } diff --git a/rups/src/tokio/stream.rs b/rups/src/tokio/stream.rs index 03c0008..319b50f 100644 --- a/rups/src/tokio/stream.rs +++ b/rups/src/tokio/stream.rs @@ -9,26 +9,44 @@ pub enum ConnectionStream { /// A plain TCP stream. Plain(TcpStream), - /// A stream wrapped with SSL using `rustls`. + /// A client stream wrapped with SSL using `rustls`. #[cfg(feature = "async-ssl")] - Ssl(Box>), + SslClient(Box>), + + /// A server stream wrapped with SSL using `rustls`. + #[cfg(feature = "async-ssl")] + SslServer(Box>), } impl ConnectionStream { /// Wraps the current stream with SSL using `rustls`. #[cfg(feature = "async-ssl")] - pub async fn upgrade_ssl( + pub async fn upgrade_ssl_client( self, config: tokio_rustls::TlsConnector, dns_name: webpki::DNSNameRef<'_>, ) -> crate::Result { - Ok(ConnectionStream::Ssl(Box::new( + Ok(ConnectionStream::SslClient(Box::new( config .connect(dns_name, self) .await .map_err(crate::ClientError::Io)?, ))) } + + /// Wraps the current stream with SSL using `rustls`. + #[cfg(feature = "async-ssl")] + pub async fn upgrade_ssl_server( + self, + acceptor: tokio_rustls::TlsAcceptor, + ) -> crate::Result { + Ok(ConnectionStream::SslServer(Box::new( + acceptor + .accept(self) + .await + .map_err(crate::ClientError::Io)?, + ))) + } } impl AsyncRead for ConnectionStream { @@ -43,7 +61,12 @@ impl AsyncRead for ConnectionStream { pinned.poll_read(cx, buf) } #[cfg(feature = "async-ssl")] - Self::Ssl(stream) => { + Self::SslClient(stream) => { + let pinned = Pin::new(stream); + pinned.poll_read(cx, buf) + } + #[cfg(feature = "async-ssl")] + Self::SslServer(stream) => { let pinned = Pin::new(stream); pinned.poll_read(cx, buf) } @@ -63,7 +86,12 @@ impl AsyncWrite for ConnectionStream { pinned.poll_write(cx, buf) } #[cfg(feature = "async-ssl")] - Self::Ssl(stream) => { + Self::SslClient(stream) => { + let pinned = Pin::new(stream); + pinned.poll_write(cx, buf) + } + #[cfg(feature = "async-ssl")] + Self::SslServer(stream) => { let pinned = Pin::new(stream); pinned.poll_write(cx, buf) } @@ -77,7 +105,12 @@ impl AsyncWrite for ConnectionStream { pinned.poll_flush(cx) } #[cfg(feature = "async-ssl")] - Self::Ssl(stream) => { + Self::SslClient(stream) => { + let pinned = Pin::new(stream); + pinned.poll_flush(cx) + } + #[cfg(feature = "async-ssl")] + Self::SslServer(stream) => { let pinned = Pin::new(stream); pinned.poll_flush(cx) } @@ -91,7 +124,12 @@ impl AsyncWrite for ConnectionStream { pinned.poll_shutdown(cx) } #[cfg(feature = "async-ssl")] - Self::Ssl(stream) => { + Self::SslClient(stream) => { + let pinned = Pin::new(stream); + pinned.poll_shutdown(cx) + } + #[cfg(feature = "async-ssl")] + Self::SslServer(stream) => { let pinned = Pin::new(stream); pinned.poll_shutdown(cx) }