mirror of
https://github.com/aramperes/nut-rs.git
synced 2025-09-08 21:18:31 -04:00
Improve SSL upgrade
This commit is contained in:
parent
137251d29a
commit
60cd6a831a
3 changed files with 59 additions and 61 deletions
|
@ -1,47 +0,0 @@
|
||||||
use std::io::{Read, Write};
|
|
||||||
use std::net::TcpStream;
|
|
||||||
|
|
||||||
#[allow(clippy::large_enum_variant)]
|
|
||||||
pub enum ConnectionPipeline {
|
|
||||||
Tcp(TcpStream),
|
|
||||||
|
|
||||||
#[cfg(feature = "ssl")]
|
|
||||||
Ssl(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConnectionPipeline {
|
|
||||||
pub fn tcp(&self) -> Option<TcpStream> {
|
|
||||||
match self {
|
|
||||||
Self::Tcp(stream) => Some(stream.try_clone().ok()).flatten(),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Read for ConnectionPipeline {
|
|
||||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
|
||||||
match self {
|
|
||||||
Self::Tcp(stream) => stream.read(buf),
|
|
||||||
#[cfg(feature = "ssl")]
|
|
||||||
Self::Ssl(stream) => stream.read(buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Write for ConnectionPipeline {
|
|
||||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
|
||||||
match self {
|
|
||||||
Self::Tcp(stream) => stream.write(buf),
|
|
||||||
#[cfg(feature = "ssl")]
|
|
||||||
Self::Ssl(stream) => stream.write(buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> std::io::Result<()> {
|
|
||||||
match self {
|
|
||||||
Self::Tcp(stream) => stream.flush(),
|
|
||||||
#[cfg(feature = "ssl")]
|
|
||||||
Self::Ssl(stream) => stream.flush(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,11 +1,11 @@
|
||||||
use std::io::{BufRead, BufReader, Write};
|
use std::io::{BufRead, BufReader, Write};
|
||||||
use std::net::{SocketAddr, TcpStream};
|
use std::net::{SocketAddr, TcpStream};
|
||||||
|
|
||||||
use crate::blocking::filter::ConnectionPipeline;
|
use crate::blocking::stream::ConnectionStream;
|
||||||
use crate::cmd::{Command, Response};
|
use crate::cmd::{Command, Response};
|
||||||
use crate::{ClientError, Config, Host, NutError, Variable};
|
use crate::{ClientError, Config, Host, NutError, Variable};
|
||||||
|
|
||||||
mod filter;
|
mod stream;
|
||||||
|
|
||||||
/// A blocking NUT client connection.
|
/// A blocking NUT client connection.
|
||||||
pub enum Connection {
|
pub enum Connection {
|
||||||
|
@ -62,7 +62,7 @@ impl Connection {
|
||||||
/// A blocking TCP NUT client connection.
|
/// A blocking TCP NUT client connection.
|
||||||
pub struct TcpConnection {
|
pub struct TcpConnection {
|
||||||
config: Config,
|
config: Config,
|
||||||
pipeline: ConnectionPipeline,
|
pipeline: ConnectionStream,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TcpConnection {
|
impl TcpConnection {
|
||||||
|
@ -71,11 +71,11 @@ impl TcpConnection {
|
||||||
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
|
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
|
||||||
let mut connection = Self {
|
let mut connection = Self {
|
||||||
config,
|
config,
|
||||||
pipeline: ConnectionPipeline::Tcp(tcp_stream),
|
pipeline: ConnectionStream::Plain(tcp_stream),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize SSL connection
|
// Initialize SSL connection
|
||||||
connection.enable_ssl()?;
|
connection = connection.enable_ssl()?;
|
||||||
|
|
||||||
// Attempt login using `config.auth`
|
// Attempt login using `config.auth`
|
||||||
connection.login()?;
|
connection.login()?;
|
||||||
|
@ -84,7 +84,7 @@ impl TcpConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "ssl")]
|
#[cfg(feature = "ssl")]
|
||||||
fn enable_ssl(&mut self) -> crate::Result<()> {
|
fn enable_ssl(mut self) -> crate::Result<Self> {
|
||||||
if self.config.ssl {
|
if self.config.ssl {
|
||||||
// Send TLS request and check for 'OK'
|
// Send TLS request and check for 'OK'
|
||||||
self.write_cmd(Command::StartTLS)?;
|
self.write_cmd(Command::StartTLS)?;
|
||||||
|
@ -110,17 +110,12 @@ impl TcpConnection {
|
||||||
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
|
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
|
||||||
|
|
||||||
// Wrap and override the TCP stream
|
// Wrap and override the TCP stream
|
||||||
let tcp = self
|
self.pipeline = self.pipeline.upgrade_ssl(sess)?;
|
||||||
.pipeline
|
|
||||||
.tcp()
|
|
||||||
.ok_or_else(|| ClientError::from(NutError::SslNotSupported))?;
|
|
||||||
let tls = rustls::StreamOwned::new(sess, tcp);
|
|
||||||
self.pipeline = ConnectionPipeline::Ssl(tls);
|
|
||||||
|
|
||||||
// Send a test command
|
// Send a test command
|
||||||
self.get_network_version()?;
|
self.get_network_version()?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "ssl"))]
|
#[cfg(not(feature = "ssl"))]
|
||||||
|
@ -190,7 +185,7 @@ impl TcpConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_line(
|
fn parse_line(
|
||||||
reader: &mut BufReader<&mut ConnectionPipeline>,
|
reader: &mut BufReader<&mut ConnectionStream>,
|
||||||
debug: bool,
|
debug: bool,
|
||||||
) -> crate::Result<Vec<String>> {
|
) -> crate::Result<Vec<String>> {
|
||||||
let mut raw = String::new();
|
let mut raw = String::new();
|
||||||
|
|
50
nut-client/src/blocking/stream.rs
Normal file
50
nut-client/src/blocking/stream.rs
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
use std::net::TcpStream;
|
||||||
|
|
||||||
|
/// A wrapper for various synchronous stream types.
|
||||||
|
pub enum ConnectionStream {
|
||||||
|
/// A plain TCP stream.
|
||||||
|
Plain(TcpStream),
|
||||||
|
|
||||||
|
/// A stream wrapped with SSL using `rustls`.
|
||||||
|
#[cfg(feature = "ssl")]
|
||||||
|
Ssl(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionStream {
|
||||||
|
/// Wraps the current stream with SSL using `rustls`.
|
||||||
|
#[cfg(feature = "ssl")]
|
||||||
|
pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result<ConnectionStream> {
|
||||||
|
Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new(
|
||||||
|
session, self,
|
||||||
|
))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Read for ConnectionStream {
|
||||||
|
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||||
|
match self {
|
||||||
|
Self::Plain(stream) => stream.read(buf),
|
||||||
|
#[cfg(feature = "ssl")]
|
||||||
|
Self::Ssl(stream) => stream.read(buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Write for ConnectionStream {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||||
|
match self {
|
||||||
|
Self::Plain(stream) => stream.write(buf),
|
||||||
|
#[cfg(feature = "ssl")]
|
||||||
|
Self::Ssl(stream) => stream.write(buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> std::io::Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Plain(stream) => stream.flush(),
|
||||||
|
#[cfg(feature = "ssl")]
|
||||||
|
Self::Ssl(stream) => stream.flush(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue