From 60cd6a831a8503de4c3a5d7c3e4bbe2e52ec0c7c Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Sat, 31 Jul 2021 09:55:43 -0400 Subject: [PATCH] Improve SSL upgrade --- nut-client/src/blocking/filter.rs | 47 ----------------------------- nut-client/src/blocking/mod.rs | 23 ++++++-------- nut-client/src/blocking/stream.rs | 50 +++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 61 deletions(-) delete mode 100644 nut-client/src/blocking/filter.rs create mode 100644 nut-client/src/blocking/stream.rs diff --git a/nut-client/src/blocking/filter.rs b/nut-client/src/blocking/filter.rs deleted file mode 100644 index e11035b..0000000 --- a/nut-client/src/blocking/filter.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::io::{Read, Write}; -use std::net::TcpStream; - -#[allow(clippy::large_enum_variant)] -pub enum ConnectionPipeline { - Tcp(TcpStream), - - #[cfg(feature = "ssl")] - Ssl(rustls::StreamOwned), -} - -impl ConnectionPipeline { - pub fn tcp(&self) -> Option { - match self { - Self::Tcp(stream) => Some(stream.try_clone().ok()).flatten(), - _ => None, - } - } -} - -impl Read for ConnectionPipeline { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - match self { - Self::Tcp(stream) => stream.read(buf), - #[cfg(feature = "ssl")] - Self::Ssl(stream) => stream.read(buf), - } - } -} - -impl Write for ConnectionPipeline { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - match self { - Self::Tcp(stream) => stream.write(buf), - #[cfg(feature = "ssl")] - Self::Ssl(stream) => stream.write(buf), - } - } - - fn flush(&mut self) -> std::io::Result<()> { - match self { - Self::Tcp(stream) => stream.flush(), - #[cfg(feature = "ssl")] - Self::Ssl(stream) => stream.flush(), - } - } -} diff --git a/nut-client/src/blocking/mod.rs b/nut-client/src/blocking/mod.rs index 4d377d9..ac2e75f 100644 --- a/nut-client/src/blocking/mod.rs +++ b/nut-client/src/blocking/mod.rs @@ -1,11 +1,11 @@ use std::io::{BufRead, BufReader, Write}; use std::net::{SocketAddr, TcpStream}; -use crate::blocking::filter::ConnectionPipeline; +use crate::blocking::stream::ConnectionStream; use crate::cmd::{Command, Response}; use crate::{ClientError, Config, Host, NutError, Variable}; -mod filter; +mod stream; /// A blocking NUT client connection. pub enum Connection { @@ -62,7 +62,7 @@ impl Connection { /// A blocking TCP NUT client connection. pub struct TcpConnection { config: Config, - pipeline: ConnectionPipeline, + pipeline: ConnectionStream, } impl TcpConnection { @@ -71,11 +71,11 @@ impl TcpConnection { let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?; let mut connection = Self { config, - pipeline: ConnectionPipeline::Tcp(tcp_stream), + pipeline: ConnectionStream::Plain(tcp_stream), }; // Initialize SSL connection - connection.enable_ssl()?; + connection = connection.enable_ssl()?; // Attempt login using `config.auth` connection.login()?; @@ -84,7 +84,7 @@ impl TcpConnection { } #[cfg(feature = "ssl")] - fn enable_ssl(&mut self) -> crate::Result<()> { + fn enable_ssl(mut self) -> crate::Result { if self.config.ssl { // Send TLS request and check for 'OK' self.write_cmd(Command::StartTLS)?; @@ -110,17 +110,12 @@ impl TcpConnection { let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name); // Wrap and override the TCP stream - let tcp = self - .pipeline - .tcp() - .ok_or_else(|| ClientError::from(NutError::SslNotSupported))?; - let tls = rustls::StreamOwned::new(sess, tcp); - self.pipeline = ConnectionPipeline::Ssl(tls); + self.pipeline = self.pipeline.upgrade_ssl(sess)?; // Send a test command self.get_network_version()?; } - Ok(()) + Ok(self) } #[cfg(not(feature = "ssl"))] @@ -190,7 +185,7 @@ impl TcpConnection { } fn parse_line( - reader: &mut BufReader<&mut ConnectionPipeline>, + reader: &mut BufReader<&mut ConnectionStream>, debug: bool, ) -> crate::Result> { let mut raw = String::new(); diff --git a/nut-client/src/blocking/stream.rs b/nut-client/src/blocking/stream.rs new file mode 100644 index 0000000..2131ce8 --- /dev/null +++ b/nut-client/src/blocking/stream.rs @@ -0,0 +1,50 @@ +use std::io::{Read, Write}; +use std::net::TcpStream; + +/// A wrapper for various synchronous stream types. +pub enum ConnectionStream { + /// A plain TCP stream. + Plain(TcpStream), + + /// A stream wrapped with SSL using `rustls`. + #[cfg(feature = "ssl")] + Ssl(Box>), +} + +impl ConnectionStream { + /// Wraps the current stream with SSL using `rustls`. + #[cfg(feature = "ssl")] + pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result { + Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new( + session, self, + )))) + } +} + +impl Read for ConnectionStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + Self::Plain(stream) => stream.read(buf), + #[cfg(feature = "ssl")] + Self::Ssl(stream) => stream.read(buf), + } + } +} + +impl Write for ConnectionStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + Self::Plain(stream) => stream.write(buf), + #[cfg(feature = "ssl")] + Self::Ssl(stream) => stream.write(buf), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + match self { + Self::Plain(stream) => stream.flush(), + #[cfg(feature = "ssl")] + Self::Ssl(stream) => stream.flush(), + } + } +}