mirror of
https://github.com/arampoire/nut-rs.git
synced 2025-12-01 00:30:23 -05:00
Better SSL error handling
This commit is contained in:
parent
279b10b148
commit
2d0658a7c5
4 changed files with 24 additions and 101 deletions
|
|
@ -3,10 +3,9 @@ use std::net::{SocketAddr, TcpStream};
|
||||||
|
|
||||||
use crate::blocking::filter::ConnectionPipeline;
|
use crate::blocking::filter::ConnectionPipeline;
|
||||||
use crate::cmd::{Command, Response};
|
use crate::cmd::{Command, Response};
|
||||||
use crate::{Config, Host, NutError, Variable};
|
use crate::{ClientError, Config, Host, NutError, Variable};
|
||||||
|
|
||||||
mod filter;
|
mod filter;
|
||||||
mod reader;
|
|
||||||
|
|
||||||
/// A blocking NUT client connection.
|
/// A blocking NUT client connection.
|
||||||
pub enum Connection {
|
pub enum Connection {
|
||||||
|
|
@ -89,7 +88,15 @@ impl TcpConnection {
|
||||||
if self.config.ssl {
|
if self.config.ssl {
|
||||||
// Send TLS request and check for 'OK'
|
// Send TLS request and check for 'OK'
|
||||||
self.write_cmd(Command::StartTLS)?;
|
self.write_cmd(Command::StartTLS)?;
|
||||||
self.read_response()?.expect_ok()?;
|
self.read_response()
|
||||||
|
.map_err(|e| {
|
||||||
|
if let ClientError::Nut(NutError::FeatureNotConfigured) = e {
|
||||||
|
ClientError::Nut(NutError::SslNotSupported)
|
||||||
|
} else {
|
||||||
|
e
|
||||||
|
}
|
||||||
|
})?
|
||||||
|
.expect_ok()?;
|
||||||
|
|
||||||
let mut config = rustls::ClientConfig::new();
|
let mut config = rustls::ClientConfig::new();
|
||||||
config
|
config
|
||||||
|
|
@ -98,16 +105,20 @@ impl TcpConnection {
|
||||||
crate::ssl::NutCertificateValidator::new(&self.config),
|
crate::ssl::NutCertificateValidator::new(&self.config),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// todo: this DNS name is temporary; should get from connection hostname? (#8)
|
||||||
let dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com").unwrap();
|
let dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com").unwrap();
|
||||||
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
|
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
|
||||||
|
|
||||||
// Wrap and override the TCP stream
|
// Wrap and override the TCP stream
|
||||||
let tcp = self.pipeline.tcp().unwrap();
|
let tcp = self
|
||||||
|
.pipeline
|
||||||
|
.tcp()
|
||||||
|
.ok_or_else(|| ClientError::from(NutError::SslNotSupported))?;
|
||||||
let tls = rustls::StreamOwned::new(sess, tcp);
|
let tls = rustls::StreamOwned::new(sess, tcp);
|
||||||
self.pipeline = ConnectionPipeline::Ssl(tls);
|
self.pipeline = ConnectionPipeline::Ssl(tls);
|
||||||
|
|
||||||
// Send a test command
|
// Send a test command
|
||||||
self.get_version()?;
|
self.get_network_version()?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -163,7 +174,7 @@ impl TcpConnection {
|
||||||
self.read_response()?.expect_var()
|
self.read_response()?.expect_var()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_version(&mut self) -> crate::Result<String> {
|
fn get_network_version(&mut self) -> crate::Result<String> {
|
||||||
self.write_cmd(Command::NetworkVersion)?;
|
self.write_cmd(Command::NetworkVersion)?;
|
||||||
self.read_plain_response()
|
self.read_plain_response()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
// use std::io::Read;
|
|
||||||
// use std::io::Write;
|
|
||||||
// use std::net::TcpStream;
|
|
||||||
//
|
|
||||||
// /// A Read implementation with optional SSL support.
|
|
||||||
// pub struct SslOptionalReader<'a> {
|
|
||||||
// pub(crate) inner_stream: TcpStream,
|
|
||||||
// #[cfg(feature = "ssl")]
|
|
||||||
// ssl_stream: Option<rustls::Stream<'a, rustls::ClientSession, TcpStream>>,
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// impl<'a> SslOptionalReader<'a> {
|
|
||||||
// #[cfg(feature = "ssl")]
|
|
||||||
// pub fn new(inner: TcpStream) -> Self {
|
|
||||||
// Self {
|
|
||||||
// inner_stream: inner,
|
|
||||||
// ssl_stream: None,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[cfg(not(feature = "ssl"))]
|
|
||||||
// pub fn new(inner: TcpStream) -> Self {
|
|
||||||
// Self {
|
|
||||||
// inner_stream: inner,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[cfg(feature = "ssl")]
|
|
||||||
// pub fn set_ssl_stream(
|
|
||||||
// &mut self,
|
|
||||||
// ssl_stream: rustls::Stream<'a, rustls::ClientSession, TcpStream>,
|
|
||||||
// ) {
|
|
||||||
// self.ssl_stream = Some(ssl_stream)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// impl<'a> Read for SslOptionalReader<'a> {
|
|
||||||
// #[cfg(feature = "ssl")]
|
|
||||||
// fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
|
||||||
// if let Some(ssl_stream) = &mut self.ssl_stream {
|
|
||||||
// ssl_stream.read(buf)
|
|
||||||
// } else {
|
|
||||||
// self.inner_stream.read(buf)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[cfg(not(feature = "ssl"))]
|
|
||||||
// fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
|
||||||
// self.inner_stream.read(buf)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// impl<'a> Write for SslOptionalReader<'a> {
|
|
||||||
// #[cfg(feature = "ssl")]
|
|
||||||
// fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
|
||||||
// if let Some(ssl_stream) = &mut self.ssl_stream {
|
|
||||||
// ssl_stream.write(buf)
|
|
||||||
// } else {
|
|
||||||
// self.inner_stream.write(buf)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[cfg(not(feature = "ssl"))]
|
|
||||||
// fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
|
||||||
// self.inner_stream.write(buf)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[cfg(feature = "ssl")]
|
|
||||||
// fn flush(&mut self) -> std::io::Result<()> {
|
|
||||||
// if let Some(ssl_stream) = &mut self.ssl_stream {
|
|
||||||
// ssl_stream.flush()
|
|
||||||
// } else {
|
|
||||||
// self.inner_stream.flush()
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[cfg(not(feature = "ssl"))]
|
|
||||||
// fn flush(&mut self) -> std::io::Result<()> {
|
|
||||||
// self.inner_stream.flush()
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// impl <'a> std::fmt::Debug for SslOptionalReader <'a> {
|
|
||||||
// #[cfg(feature = "ssl")]
|
|
||||||
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
// write!(f, "SslOptionalReader[ssl={}]", self.ssl_stream.is_some())
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[cfg(not(feature = "ssl"))]
|
|
||||||
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
// write!(f, "SslOptionalReader[ssl={}]", false)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
@ -90,6 +90,7 @@ impl Response {
|
||||||
match err_type.as_str() {
|
match err_type.as_str() {
|
||||||
"ACCESS-DENIED" => Err(NutError::AccessDenied.into()),
|
"ACCESS-DENIED" => Err(NutError::AccessDenied.into()),
|
||||||
"UNKNOWN-UPS" => Err(NutError::UnknownUps.into()),
|
"UNKNOWN-UPS" => Err(NutError::UnknownUps.into()),
|
||||||
|
"FEATURE-NOT-CONFIGURED" => Err(NutError::FeatureNotConfigured.into()),
|
||||||
_ => Err(NutError::Generic(format!(
|
_ => Err(NutError::Generic(format!(
|
||||||
"Server error: {} {}",
|
"Server error: {} {}",
|
||||||
err_type,
|
err_type,
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,11 @@ pub enum NutError {
|
||||||
UnexpectedResponse,
|
UnexpectedResponse,
|
||||||
/// Occurs when the response type is not recognized by the client.
|
/// Occurs when the response type is not recognized by the client.
|
||||||
UnknownResponseType(String),
|
UnknownResponseType(String),
|
||||||
/// Occurs when attempting to use SSL in a transport that doesn't support it.
|
/// Occurs when attempting to use SSL in a transport that doesn't support it, or
|
||||||
|
/// if the server is not configured for it.
|
||||||
SslNotSupported,
|
SslNotSupported,
|
||||||
|
/// Occurs when the client used a feature that is disabled by the server.
|
||||||
|
FeatureNotConfigured,
|
||||||
/// Generic (usually internal) client error.
|
/// Generic (usually internal) client error.
|
||||||
Generic(String),
|
Generic(String),
|
||||||
}
|
}
|
||||||
|
|
@ -25,7 +28,8 @@ impl fmt::Display for NutError {
|
||||||
Self::UnknownUps => write!(f, "Unknown UPS device name"),
|
Self::UnknownUps => write!(f, "Unknown UPS device name"),
|
||||||
Self::UnexpectedResponse => write!(f, "Unexpected server response content"),
|
Self::UnexpectedResponse => write!(f, "Unexpected server response content"),
|
||||||
Self::UnknownResponseType(ty) => write!(f, "Unknown response type: {}", ty),
|
Self::UnknownResponseType(ty) => write!(f, "Unknown response type: {}", ty),
|
||||||
Self::SslNotSupported => write!(f, "SSL not supported by transport"),
|
Self::SslNotSupported => write!(f, "SSL not supported by server or transport"),
|
||||||
|
Self::FeatureNotConfigured => write!(f, "Feature not configured by server"),
|
||||||
Self::Generic(msg) => write!(f, "Internal client error: {}", msg),
|
Self::Generic(msg) => write!(f, "Internal client error: {}", msg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue