Add server variants to blocking/tokio streams

This commit is contained in:
Aram 🍐 2021-08-04 14:11:57 -04:00
parent 35d40d3111
commit 7d26301571
4 changed files with 82 additions and 20 deletions

View file

@ -9,26 +9,44 @@ pub enum ConnectionStream {
/// A plain TCP stream.
Plain(TcpStream),
/// A stream wrapped with SSL using `rustls`.
/// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
Ssl(Box<tokio_rustls::client::TlsStream<ConnectionStream>>),
SslClient(Box<tokio_rustls::client::TlsStream<ConnectionStream>>),
/// A server stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
SslServer(Box<tokio_rustls::server::TlsStream<ConnectionStream>>),
}
impl ConnectionStream {
/// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
pub async fn upgrade_ssl(
pub async fn upgrade_ssl_client(
self,
config: tokio_rustls::TlsConnector,
dns_name: webpki::DNSNameRef<'_>,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::Ssl(Box::new(
Ok(ConnectionStream::SslClient(Box::new(
config
.connect(dns_name, self)
.await
.map_err(crate::ClientError::Io)?,
)))
}
/// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
pub async fn upgrade_ssl_server(
self,
acceptor: tokio_rustls::TlsAcceptor,
) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::SslServer(Box::new(
acceptor
.accept(self)
.await
.map_err(crate::ClientError::Io)?,
)))
}
}
impl AsyncRead for ConnectionStream {
@ -43,7 +61,12 @@ impl AsyncRead for ConnectionStream {
pinned.poll_read(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_read(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_read(cx, buf)
}
@ -63,7 +86,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_write(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_write(cx, buf)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_write(cx, buf)
}
@ -77,7 +105,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_flush(cx)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_flush(cx)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_flush(cx)
}
@ -91,7 +124,12 @@ impl AsyncWrite for ConnectionStream {
pinned.poll_shutdown(cx)
}
#[cfg(feature = "async-ssl")]
Self::Ssl(stream) => {
Self::SslClient(stream) => {
let pinned = Pin::new(stream);
pinned.poll_shutdown(cx)
}
#[cfg(feature = "async-ssl")]
Self::SslServer(stream) => {
let pinned = Pin::new(stream);
pinned.poll_shutdown(cx)
}