mirror of
https://github.com/aramperes/nut-rs.git
synced 2025-09-08 05:08:31 -04:00
Improve SSL upgrade
This commit is contained in:
parent
137251d29a
commit
60cd6a831a
3 changed files with 59 additions and 61 deletions
|
@ -1,47 +0,0 @@
|
|||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum ConnectionPipeline {
|
||||
Tcp(TcpStream),
|
||||
|
||||
#[cfg(feature = "ssl")]
|
||||
Ssl(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
|
||||
}
|
||||
|
||||
impl ConnectionPipeline {
|
||||
pub fn tcp(&self) -> Option<TcpStream> {
|
||||
match self {
|
||||
Self::Tcp(stream) => Some(stream.try_clone().ok()).flatten(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for ConnectionPipeline {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.read(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for ConnectionPipeline {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.write(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.flush(),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.flush(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +1,11 @@
|
|||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::net::{SocketAddr, TcpStream};
|
||||
|
||||
use crate::blocking::filter::ConnectionPipeline;
|
||||
use crate::blocking::stream::ConnectionStream;
|
||||
use crate::cmd::{Command, Response};
|
||||
use crate::{ClientError, Config, Host, NutError, Variable};
|
||||
|
||||
mod filter;
|
||||
mod stream;
|
||||
|
||||
/// A blocking NUT client connection.
|
||||
pub enum Connection {
|
||||
|
@ -62,7 +62,7 @@ impl Connection {
|
|||
/// A blocking TCP NUT client connection.
|
||||
pub struct TcpConnection {
|
||||
config: Config,
|
||||
pipeline: ConnectionPipeline,
|
||||
pipeline: ConnectionStream,
|
||||
}
|
||||
|
||||
impl TcpConnection {
|
||||
|
@ -71,11 +71,11 @@ impl TcpConnection {
|
|||
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
|
||||
let mut connection = Self {
|
||||
config,
|
||||
pipeline: ConnectionPipeline::Tcp(tcp_stream),
|
||||
pipeline: ConnectionStream::Plain(tcp_stream),
|
||||
};
|
||||
|
||||
// Initialize SSL connection
|
||||
connection.enable_ssl()?;
|
||||
connection = connection.enable_ssl()?;
|
||||
|
||||
// Attempt login using `config.auth`
|
||||
connection.login()?;
|
||||
|
@ -84,7 +84,7 @@ impl TcpConnection {
|
|||
}
|
||||
|
||||
#[cfg(feature = "ssl")]
|
||||
fn enable_ssl(&mut self) -> crate::Result<()> {
|
||||
fn enable_ssl(mut self) -> crate::Result<Self> {
|
||||
if self.config.ssl {
|
||||
// Send TLS request and check for 'OK'
|
||||
self.write_cmd(Command::StartTLS)?;
|
||||
|
@ -110,17 +110,12 @@ impl TcpConnection {
|
|||
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
|
||||
|
||||
// Wrap and override the TCP stream
|
||||
let tcp = self
|
||||
.pipeline
|
||||
.tcp()
|
||||
.ok_or_else(|| ClientError::from(NutError::SslNotSupported))?;
|
||||
let tls = rustls::StreamOwned::new(sess, tcp);
|
||||
self.pipeline = ConnectionPipeline::Ssl(tls);
|
||||
self.pipeline = self.pipeline.upgrade_ssl(sess)?;
|
||||
|
||||
// Send a test command
|
||||
self.get_network_version()?;
|
||||
}
|
||||
Ok(())
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ssl"))]
|
||||
|
@ -190,7 +185,7 @@ impl TcpConnection {
|
|||
}
|
||||
|
||||
fn parse_line(
|
||||
reader: &mut BufReader<&mut ConnectionPipeline>,
|
||||
reader: &mut BufReader<&mut ConnectionStream>,
|
||||
debug: bool,
|
||||
) -> crate::Result<Vec<String>> {
|
||||
let mut raw = String::new();
|
||||
|
|
50
nut-client/src/blocking/stream.rs
Normal file
50
nut-client/src/blocking/stream.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
|
||||
/// A wrapper for various synchronous stream types.
|
||||
pub enum ConnectionStream {
|
||||
/// A plain TCP stream.
|
||||
Plain(TcpStream),
|
||||
|
||||
/// A stream wrapped with SSL using `rustls`.
|
||||
#[cfg(feature = "ssl")]
|
||||
Ssl(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
|
||||
}
|
||||
|
||||
impl ConnectionStream {
|
||||
/// Wraps the current stream with SSL using `rustls`.
|
||||
#[cfg(feature = "ssl")]
|
||||
pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result<ConnectionStream> {
|
||||
Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new(
|
||||
session, self,
|
||||
))))
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for ConnectionStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
Self::Plain(stream) => stream.read(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for ConnectionStream {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
Self::Plain(stream) => stream.write(buf),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
match self {
|
||||
Self::Plain(stream) => stream.flush(),
|
||||
#[cfg(feature = "ssl")]
|
||||
Self::Ssl(stream) => stream.flush(),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue