mirror of
https://github.com/aramperes/nut-rs.git
synced 2025-09-09 13:38:30 -04:00
Add server variants to blocking/tokio streams
This commit is contained in:
parent
35d40d3111
commit
7d26301571
4 changed files with 82 additions and 20 deletions
|
@ -109,7 +109,7 @@ impl TcpConnection {
|
|||
};
|
||||
|
||||
// Wrap and override the TCP stream
|
||||
self.stream = self.stream.upgrade_ssl(sess)?;
|
||||
self.stream = self.stream.upgrade_ssl_client(sess)?;
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
|
|
@ -6,18 +6,36 @@ pub enum ConnectionStream {
|
|||
/// A plain TCP stream.
|
||||
Plain(TcpStream),
|
||||
|
||||
/// A stream wrapped with SSL using `rustls`.
|
||||
/// A client stream wrapped with SSL using `rustls`.
|
||||
#[cfg(feature = "ssl")]
|
||||
Ssl(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
|
||||
SslClient(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
|
||||
|
||||
/// A server stream wrapped with SSL using `rustls`.
|
||||
#[cfg(feature = "ssl")]
|
||||
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
|
||||
}
|
||||
|
||||
impl ConnectionStream {
|
||||
/// Wraps the current stream with SSL using `rustls`.
|
||||
/// Wraps the current stream with SSL using `rustls` (client-side).
|
||||
#[cfg(feature = "ssl")]
|
||||
pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result<ConnectionStream> {
|
||||
Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new(
|
||||
session, self,
|
||||
))))
|
||||
pub fn upgrade_ssl_client(
|
||||
self,
|
||||
session: rustls::ClientSession,
|
||||
) -> crate::Result<ConnectionStream> {
|
||||
Ok(ConnectionStream::SslClient(Box::new(
|
||||
rustls::StreamOwned::new(session, self),
|
||||
)))
|
||||
}
|
||||
|
||||
/// Wraps the current stream with SSL using `rustls` (client-side).
|
||||
#[cfg(feature = "ssl")]
|
||||
pub fn upgrade_ssl_server(
|
||||
self,
|
||||
session: rustls::ServerSession,
|
||||
) -> crate::Result<ConnectionStream> {
|
||||
Ok(ConnectionStream::SslServer(Box::new(
|
||||
rustls::StreamOwned::new(session, self),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,7 +44,9 @@ impl Read for ConnectionStream {
|
|||
match self {
|
||||
Self::Plain(stream) => stream.read(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.read(buf),
|
||||
Self::SslClient(stream) => stream.read(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::SslServer(stream) => stream.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +56,9 @@ impl Write for ConnectionStream {
|
|||
match self {
|
||||
Self::Plain(stream) => stream.write(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.write(buf),
|
||||
Self::SslClient(stream) => stream.write(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::SslServer(stream) => stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,7 +66,9 @@ impl Write for ConnectionStream {
|
|||
match self {
|
||||
Self::Plain(stream) => stream.flush(),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.flush(),
|
||||
Self::SslClient(stream) => stream.flush(),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::SslServer(stream) => stream.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,7 +115,7 @@ impl TcpConnection {
|
|||
let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config));
|
||||
|
||||
// Wrap and override the TCP stream
|
||||
self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?;
|
||||
self.stream = self.stream.upgrade_ssl_client(config, dns_name.as_ref()).await?;
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
|
|
@ -9,26 +9,44 @@ pub enum ConnectionStream {
|
|||
/// A plain TCP stream.
|
||||
Plain(TcpStream),
|
||||
|
||||
/// A stream wrapped with SSL using `rustls`.
|
||||
/// A client stream wrapped with SSL using `rustls`.
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Ssl(Box<tokio_rustls::client::TlsStream<ConnectionStream>>),
|
||||
SslClient(Box<tokio_rustls::client::TlsStream<ConnectionStream>>),
|
||||
|
||||
/// A server stream wrapped with SSL using `rustls`.
|
||||
#[cfg(feature = "async-ssl")]
|
||||
SslServer(Box<tokio_rustls::server::TlsStream<ConnectionStream>>),
|
||||
}
|
||||
|
||||
impl ConnectionStream {
|
||||
/// Wraps the current stream with SSL using `rustls`.
|
||||
#[cfg(feature = "async-ssl")]
|
||||
pub async fn upgrade_ssl(
|
||||
pub async fn upgrade_ssl_client(
|
||||
self,
|
||||
config: tokio_rustls::TlsConnector,
|
||||
dns_name: webpki::DNSNameRef<'_>,
|
||||
) -> crate::Result<ConnectionStream> {
|
||||
Ok(ConnectionStream::Ssl(Box::new(
|
||||
Ok(ConnectionStream::SslClient(Box::new(
|
||||
config
|
||||
.connect(dns_name, self)
|
||||
.await
|
||||
.map_err(crate::ClientError::Io)?,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Wraps the current stream with SSL using `rustls`.
|
||||
#[cfg(feature = "async-ssl")]
|
||||
pub async fn upgrade_ssl_server(
|
||||
self,
|
||||
acceptor: tokio_rustls::TlsAcceptor,
|
||||
) -> crate::Result<ConnectionStream> {
|
||||
Ok(ConnectionStream::SslServer(Box::new(
|
||||
acceptor
|
||||
.accept(self)
|
||||
.await
|
||||
.map_err(crate::ClientError::Io)?,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for ConnectionStream {
|
||||
|
@ -43,7 +61,12 @@ impl AsyncRead for ConnectionStream {
|
|||
pinned.poll_read(cx, buf)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::Ssl(stream) => {
|
||||
Self::SslClient(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_read(cx, buf)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::SslServer(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_read(cx, buf)
|
||||
}
|
||||
|
@ -63,7 +86,12 @@ impl AsyncWrite for ConnectionStream {
|
|||
pinned.poll_write(cx, buf)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::Ssl(stream) => {
|
||||
Self::SslClient(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_write(cx, buf)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::SslServer(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_write(cx, buf)
|
||||
}
|
||||
|
@ -77,7 +105,12 @@ impl AsyncWrite for ConnectionStream {
|
|||
pinned.poll_flush(cx)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::Ssl(stream) => {
|
||||
Self::SslClient(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_flush(cx)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::SslServer(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_flush(cx)
|
||||
}
|
||||
|
@ -91,7 +124,12 @@ impl AsyncWrite for ConnectionStream {
|
|||
pinned.poll_shutdown(cx)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::Ssl(stream) => {
|
||||
Self::SslClient(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_shutdown(cx)
|
||||
}
|
||||
#[cfg(feature = "async-ssl")]
|
||||
Self::SslServer(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_shutdown(cx)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue