diff --git a/nut-client/src/blocking/mod.rs b/nut-client/src/blocking/mod.rs index 6438bbc..4d377d9 100644 --- a/nut-client/src/blocking/mod.rs +++ b/nut-client/src/blocking/mod.rs @@ -3,10 +3,9 @@ use std::net::{SocketAddr, TcpStream}; use crate::blocking::filter::ConnectionPipeline; use crate::cmd::{Command, Response}; -use crate::{Config, Host, NutError, Variable}; +use crate::{ClientError, Config, Host, NutError, Variable}; mod filter; -mod reader; /// A blocking NUT client connection. pub enum Connection { @@ -89,7 +88,15 @@ impl TcpConnection { if self.config.ssl { // Send TLS request and check for 'OK' self.write_cmd(Command::StartTLS)?; - self.read_response()?.expect_ok()?; + self.read_response() + .map_err(|e| { + if let ClientError::Nut(NutError::FeatureNotConfigured) = e { + ClientError::Nut(NutError::SslNotSupported) + } else { + e + } + })? + .expect_ok()?; let mut config = rustls::ClientConfig::new(); config @@ -98,16 +105,20 @@ impl TcpConnection { crate::ssl::NutCertificateValidator::new(&self.config), )); + // todo: this DNS name is temporary; should get from connection hostname? (#8) let dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com").unwrap(); let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name); // Wrap and override the TCP stream - let tcp = self.pipeline.tcp().unwrap(); + let tcp = self + .pipeline + .tcp() + .ok_or_else(|| ClientError::from(NutError::SslNotSupported))?; let tls = rustls::StreamOwned::new(sess, tcp); self.pipeline = ConnectionPipeline::Ssl(tls); // Send a test command - self.get_version()?; + self.get_network_version()?; } Ok(()) } @@ -163,7 +174,7 @@ impl TcpConnection { self.read_response()?.expect_var() } - fn get_version(&mut self) -> crate::Result { + fn get_network_version(&mut self) -> crate::Result { self.write_cmd(Command::NetworkVersion)?; self.read_plain_response() } diff --git a/nut-client/src/blocking/reader.rs b/nut-client/src/blocking/reader.rs deleted file mode 100644 index d6e1c2b..0000000 --- a/nut-client/src/blocking/reader.rs +++ /dev/null @@ -1,93 +0,0 @@ -// use std::io::Read; -// use std::io::Write; -// use std::net::TcpStream; -// -// /// A Read implementation with optional SSL support. -// pub struct SslOptionalReader<'a> { -// pub(crate) inner_stream: TcpStream, -// #[cfg(feature = "ssl")] -// ssl_stream: Option>, -// } -// -// impl<'a> SslOptionalReader<'a> { -// #[cfg(feature = "ssl")] -// pub fn new(inner: TcpStream) -> Self { -// Self { -// inner_stream: inner, -// ssl_stream: None, -// } -// } -// -// #[cfg(not(feature = "ssl"))] -// pub fn new(inner: TcpStream) -> Self { -// Self { -// inner_stream: inner, -// } -// } -// -// #[cfg(feature = "ssl")] -// pub fn set_ssl_stream( -// &mut self, -// ssl_stream: rustls::Stream<'a, rustls::ClientSession, TcpStream>, -// ) { -// self.ssl_stream = Some(ssl_stream) -// } -// } -// -// impl<'a> Read for SslOptionalReader<'a> { -// #[cfg(feature = "ssl")] -// fn read(&mut self, buf: &mut [u8]) -> std::io::Result { -// if let Some(ssl_stream) = &mut self.ssl_stream { -// ssl_stream.read(buf) -// } else { -// self.inner_stream.read(buf) -// } -// } -// -// #[cfg(not(feature = "ssl"))] -// fn read(&mut self, buf: &mut [u8]) -> std::io::Result { -// self.inner_stream.read(buf) -// } -// } -// -// impl<'a> Write for SslOptionalReader<'a> { -// #[cfg(feature = "ssl")] -// fn write(&mut self, buf: &[u8]) -> std::io::Result { -// if let Some(ssl_stream) = &mut self.ssl_stream { -// ssl_stream.write(buf) -// } else { -// self.inner_stream.write(buf) -// } -// } -// -// #[cfg(not(feature = "ssl"))] -// fn write(&mut self, buf: &[u8]) -> std::io::Result { -// self.inner_stream.write(buf) -// } -// -// #[cfg(feature = "ssl")] -// fn flush(&mut self) -> std::io::Result<()> { -// if let Some(ssl_stream) = &mut self.ssl_stream { -// ssl_stream.flush() -// } else { -// self.inner_stream.flush() -// } -// } -// -// #[cfg(not(feature = "ssl"))] -// fn flush(&mut self) -> std::io::Result<()> { -// self.inner_stream.flush() -// } -// } -// -// impl <'a> std::fmt::Debug for SslOptionalReader <'a> { -// #[cfg(feature = "ssl")] -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// write!(f, "SslOptionalReader[ssl={}]", self.ssl_stream.is_some()) -// } -// -// #[cfg(not(feature = "ssl"))] -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// write!(f, "SslOptionalReader[ssl={}]", false) -// } -// } diff --git a/nut-client/src/cmd.rs b/nut-client/src/cmd.rs index 3a8eab4..28dc2e7 100644 --- a/nut-client/src/cmd.rs +++ b/nut-client/src/cmd.rs @@ -90,6 +90,7 @@ impl Response { match err_type.as_str() { "ACCESS-DENIED" => Err(NutError::AccessDenied.into()), "UNKNOWN-UPS" => Err(NutError::UnknownUps.into()), + "FEATURE-NOT-CONFIGURED" => Err(NutError::FeatureNotConfigured.into()), _ => Err(NutError::Generic(format!( "Server error: {} {}", err_type, diff --git a/nut-client/src/error.rs b/nut-client/src/error.rs index 8b4cc98..2939c71 100644 --- a/nut-client/src/error.rs +++ b/nut-client/src/error.rs @@ -12,8 +12,11 @@ pub enum NutError { UnexpectedResponse, /// Occurs when the response type is not recognized by the client. UnknownResponseType(String), - /// Occurs when attempting to use SSL in a transport that doesn't support it. + /// Occurs when attempting to use SSL in a transport that doesn't support it, or + /// if the server is not configured for it. SslNotSupported, + /// Occurs when the client used a feature that is disabled by the server. + FeatureNotConfigured, /// Generic (usually internal) client error. Generic(String), } @@ -25,7 +28,8 @@ impl fmt::Display for NutError { Self::UnknownUps => write!(f, "Unknown UPS device name"), Self::UnexpectedResponse => write!(f, "Unexpected server response content"), Self::UnknownResponseType(ty) => write!(f, "Unknown response type: {}", ty), - Self::SslNotSupported => write!(f, "SSL not supported by transport"), + Self::SslNotSupported => write!(f, "SSL not supported by server or transport"), + Self::FeatureNotConfigured => write!(f, "Feature not configured by server"), Self::Generic(msg) => write!(f, "Internal client error: {}", msg), } }