Improve SSL upgrade

This commit is contained in:
Aram 🍐 2021-07-31 09:55:43 -04:00
parent 137251d29a
commit 60cd6a831a
3 changed files with 59 additions and 61 deletions

View file

@ -1,47 +0,0 @@
use std::io::{Read, Write};
use std::net::TcpStream;
#[allow(clippy::large_enum_variant)]
pub enum ConnectionPipeline {
Tcp(TcpStream),
#[cfg(feature = "ssl")]
Ssl(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
}
impl ConnectionPipeline {
pub fn tcp(&self) -> Option<TcpStream> {
match self {
Self::Tcp(stream) => Some(stream.try_clone().ok()).flatten(),
_ => None,
}
}
}
impl Read for ConnectionPipeline {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.read(buf),
}
}
}
impl Write for ConnectionPipeline {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.flush(),
}
}
}

View file

@ -1,11 +1,11 @@
use std::io::{BufRead, BufReader, Write}; use std::io::{BufRead, BufReader, Write};
use std::net::{SocketAddr, TcpStream}; use std::net::{SocketAddr, TcpStream};
use crate::blocking::filter::ConnectionPipeline; use crate::blocking::stream::ConnectionStream;
use crate::cmd::{Command, Response}; use crate::cmd::{Command, Response};
use crate::{ClientError, Config, Host, NutError, Variable}; use crate::{ClientError, Config, Host, NutError, Variable};
mod filter; mod stream;
/// A blocking NUT client connection. /// A blocking NUT client connection.
pub enum Connection { pub enum Connection {
@ -62,7 +62,7 @@ impl Connection {
/// A blocking TCP NUT client connection. /// A blocking TCP NUT client connection.
pub struct TcpConnection { pub struct TcpConnection {
config: Config, config: Config,
pipeline: ConnectionPipeline, pipeline: ConnectionStream,
} }
impl TcpConnection { impl TcpConnection {
@ -71,11 +71,11 @@ impl TcpConnection {
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?; let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
let mut connection = Self { let mut connection = Self {
config, config,
pipeline: ConnectionPipeline::Tcp(tcp_stream), pipeline: ConnectionStream::Plain(tcp_stream),
}; };
// Initialize SSL connection // Initialize SSL connection
connection.enable_ssl()?; connection = connection.enable_ssl()?;
// Attempt login using `config.auth` // Attempt login using `config.auth`
connection.login()?; connection.login()?;
@ -84,7 +84,7 @@ impl TcpConnection {
} }
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
fn enable_ssl(&mut self) -> crate::Result<()> { fn enable_ssl(mut self) -> crate::Result<Self> {
if self.config.ssl { if self.config.ssl {
// Send TLS request and check for 'OK' // Send TLS request and check for 'OK'
self.write_cmd(Command::StartTLS)?; self.write_cmd(Command::StartTLS)?;
@ -110,17 +110,12 @@ impl TcpConnection {
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name); let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
// Wrap and override the TCP stream // Wrap and override the TCP stream
let tcp = self self.pipeline = self.pipeline.upgrade_ssl(sess)?;
.pipeline
.tcp()
.ok_or_else(|| ClientError::from(NutError::SslNotSupported))?;
let tls = rustls::StreamOwned::new(sess, tcp);
self.pipeline = ConnectionPipeline::Ssl(tls);
// Send a test command // Send a test command
self.get_network_version()?; self.get_network_version()?;
} }
Ok(()) Ok(self)
} }
#[cfg(not(feature = "ssl"))] #[cfg(not(feature = "ssl"))]
@ -190,7 +185,7 @@ impl TcpConnection {
} }
fn parse_line( fn parse_line(
reader: &mut BufReader<&mut ConnectionPipeline>, reader: &mut BufReader<&mut ConnectionStream>,
debug: bool, debug: bool,
) -> crate::Result<Vec<String>> { ) -> crate::Result<Vec<String>> {
let mut raw = String::new(); let mut raw = String::new();

View file

@ -0,0 +1,50 @@
use std::io::{Read, Write};
use std::net::TcpStream;
/// A wrapper for various synchronous stream types.
pub enum ConnectionStream {
/// A plain TCP stream.
Plain(TcpStream),
/// A stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")]
Ssl(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
}
impl ConnectionStream {
/// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "ssl")]
pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new(
session, self,
))))
}
}
impl Read for ConnectionStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Plain(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.read(buf),
}
}
}
impl Write for ConnectionStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
Self::Plain(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(stream) => stream.flush(),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.flush(),
}
}
}