Add SSL support

This commit is contained in:
Aram 🍐 2021-07-31 07:00:46 -04:00
parent d78fd8c141
commit 64811b7421
15 changed files with 550 additions and 102 deletions

View file

@ -15,7 +15,8 @@ license = "MIT"
[dependencies]
shell-words = "1.0.0"
dotenv = { version = "0.15.0", optional = true }
rustls = { version = "0.19", optional = true, features = ["dangerous_configuration"] }
webpki = { version = "0.21", optional = true }
[features]
env-file = ["dotenv"]
ssl = ["rustls", "webpki"]

View file

@ -22,7 +22,7 @@ fn main() -> nut_client::Result<()> {
.with_debug(false) // Turn this on for debugging network chatter
.build();
let mut conn = Connection::new(config)?;
let mut conn = Connection::new(&config)?;
// Print a list of all UPS devices
println!("Connected UPS devices:");

View file

@ -0,0 +1,47 @@
use std::io::{Read, Write};
use std::net::TcpStream;
#[allow(clippy::large_enum_variant)]
pub enum ConnectionPipeline {
Tcp(TcpStream),
#[cfg(feature = "ssl")]
Ssl(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
}
impl ConnectionPipeline {
pub fn tcp(&self) -> Option<TcpStream> {
match self {
Self::Tcp(stream) => Some(stream.try_clone().ok()).flatten(),
_ => None,
}
}
}
impl Read for ConnectionPipeline {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.read(buf),
}
}
}
impl Write for ConnectionPipeline {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.flush(),
}
}
}

View file

@ -1,10 +1,13 @@
use std::io;
use std::io::{BufRead, BufReader, Write};
use std::net::{SocketAddr, TcpStream};
use crate::blocking::filter::ConnectionPipeline;
use crate::cmd::{Command, Response};
use crate::{Config, Host, NutError, Variable};
mod filter;
mod reader;
/// A blocking NUT client connection.
pub enum Connection {
/// A TCP connection.
@ -13,7 +16,7 @@ pub enum Connection {
impl Connection {
/// Initializes a connection to a NUT server (upsd).
pub fn new(config: Config) -> crate::Result<Self> {
pub fn new(config: &Config) -> crate::Result<Self> {
match &config.host {
Host::Tcp(socket_addr) => {
Ok(Self::Tcp(TcpConnection::new(config.clone(), socket_addr)?))
@ -58,17 +61,22 @@ impl Connection {
}
/// A blocking TCP NUT client connection.
#[derive(Debug)]
pub struct TcpConnection {
config: Config,
tcp_stream: TcpStream,
pipeline: ConnectionPipeline,
}
impl TcpConnection {
fn new(config: Config, socket_addr: &SocketAddr) -> crate::Result<Self> {
// Create the TCP connection
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
let mut connection = Self { config, tcp_stream };
let mut connection = Self {
config,
pipeline: ConnectionPipeline::Tcp(tcp_stream),
};
// Initialize SSL connection
connection.enable_ssl()?;
// Attempt login using `config.auth`
connection.login()?;
@ -76,84 +84,102 @@ impl TcpConnection {
Ok(connection)
}
#[cfg(feature = "ssl")]
fn enable_ssl(&mut self) -> crate::Result<()> {
if self.config.ssl {
// Send TLS request and check for 'OK'
self.write_cmd(Command::StartTLS)?;
self.read_response()?.expect_ok()?;
let mut config = rustls::ClientConfig::new();
config
.dangerous()
.set_certificate_verifier(std::sync::Arc::new(
crate::ssl::NutCertificateValidator::new(&self.config),
));
let dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com").unwrap();
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
// Wrap and override the TCP stream
let tcp = self.pipeline.tcp().unwrap();
let tls = rustls::StreamOwned::new(sess, tcp);
self.pipeline = ConnectionPipeline::Ssl(tls);
// Send a test command
self.get_version()?;
}
Ok(())
}
#[cfg(not(feature = "ssl"))]
fn enable_ssl(&mut self) -> crate::Result<()> {
Ok(())
}
fn login(&mut self) -> crate::Result<()> {
if let Some(auth) = &self.config.auth {
if let Some(auth) = self.config.auth.clone() {
// Pass username and check for 'OK'
Self::write_cmd(
&mut self.tcp_stream,
Command::SetUsername(&auth.username),
self.config.debug,
)?;
Self::read_response(&mut self.tcp_stream, self.config.debug)?.expect_ok()?;
self.write_cmd(Command::SetUsername(&auth.username))?;
self.read_response()?.expect_ok()?;
// Pass password and check for 'OK'
if let Some(password) = &auth.password {
Self::write_cmd(
&mut self.tcp_stream,
Command::SetPassword(password),
self.config.debug,
)?;
Self::read_response(&mut self.tcp_stream, self.config.debug)?.expect_ok()?;
self.write_cmd(Command::SetPassword(password))?;
self.read_response()?.expect_ok()?;
}
}
Ok(())
}
fn list_ups(&mut self) -> crate::Result<Vec<(String, String)>> {
Self::write_cmd(
&mut self.tcp_stream,
Command::List(&["UPS"]),
self.config.debug,
)?;
let list = Self::read_list(&mut self.tcp_stream, &["UPS"], self.config.debug)?;
let query = &["UPS"];
self.write_cmd(Command::List(query))?;
let list = self.read_list(query)?;
list.into_iter().map(|row| row.expect_ups()).collect()
}
fn list_clients(&mut self, ups_name: &str) -> crate::Result<Vec<String>> {
let query = &["CLIENT", ups_name];
Self::write_cmd(
&mut self.tcp_stream,
Command::List(query),
self.config.debug,
)?;
let list = Self::read_list(&mut self.tcp_stream, query, self.config.debug)?;
self.write_cmd(Command::List(query))?;
let list = self.read_list(query)?;
list.into_iter().map(|row| row.expect_client()).collect()
}
fn list_vars(&mut self, ups_name: &str) -> crate::Result<Vec<(String, String)>> {
let query = &["VAR", ups_name];
Self::write_cmd(
&mut self.tcp_stream,
Command::List(query),
self.config.debug,
)?;
let list = Self::read_list(&mut self.tcp_stream, query, self.config.debug)?;
self.write_cmd(Command::List(query))?;
let list = self.read_list(query)?;
list.into_iter().map(|row| row.expect_var()).collect()
}
fn get_var(&mut self, ups_name: &str, variable: &str) -> crate::Result<(String, String)> {
let query = &["VAR", ups_name, variable];
Self::write_cmd(&mut self.tcp_stream, Command::Get(query), self.config.debug)?;
self.write_cmd(Command::Get(query))?;
let resp = Self::read_response(&mut self.tcp_stream, self.config.debug)?;
resp.expect_var()
self.read_response()?.expect_var()
}
fn write_cmd(stream: &mut TcpStream, line: Command, debug: bool) -> crate::Result<()> {
fn get_version(&mut self) -> crate::Result<String> {
self.write_cmd(Command::NetworkVersion)?;
self.read_plain_response()
}
fn write_cmd(&mut self, line: Command) -> crate::Result<()> {
let line = format!("{}\n", line);
if debug {
if self.config.debug {
eprint!("DEBUG -> {}", line);
}
stream.write_all(line.as_bytes())?;
stream.flush()?;
self.pipeline.write_all(line.as_bytes())?;
self.pipeline.flush()?;
Ok(())
}
fn parse_line(
reader: &mut BufReader<&mut TcpStream>,
reader: &mut BufReader<&mut ConnectionPipeline>,
debug: bool,
) -> crate::Result<Vec<String>> {
let mut raw = String::new();
@ -170,25 +196,27 @@ impl TcpConnection {
Ok(args)
}
fn read_response(stream: &mut TcpStream, debug: bool) -> crate::Result<Response> {
let mut reader = io::BufReader::new(stream);
let args = Self::parse_line(&mut reader, debug)?;
fn read_response(&mut self) -> crate::Result<Response> {
let mut reader = BufReader::new(&mut self.pipeline);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Response::from_args(args)
}
fn read_list(
stream: &mut TcpStream,
query: &[&str],
debug: bool,
) -> crate::Result<Vec<Response>> {
let mut reader = io::BufReader::new(stream);
let args = Self::parse_line(&mut reader, debug)?;
fn read_plain_response(&mut self) -> crate::Result<String> {
let mut reader = BufReader::new(&mut self.pipeline);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Ok(args.join(" "))
}
fn read_list(&mut self, query: &[&str]) -> crate::Result<Vec<Response>> {
let mut reader = BufReader::new(&mut self.pipeline);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Response::from_args(args)?.expect_begin_list(query)?;
let mut lines: Vec<Response> = Vec::new();
loop {
let args = Self::parse_line(&mut reader, debug)?;
let args = Self::parse_line(&mut reader, self.config.debug)?;
let resp = Response::from_args(args)?;
match resp {

View file

@ -0,0 +1,93 @@
// use std::io::Read;
// use std::io::Write;
// use std::net::TcpStream;
//
// /// A Read implementation with optional SSL support.
// pub struct SslOptionalReader<'a> {
// pub(crate) inner_stream: TcpStream,
// #[cfg(feature = "ssl")]
// ssl_stream: Option<rustls::Stream<'a, rustls::ClientSession, TcpStream>>,
// }
//
// impl<'a> SslOptionalReader<'a> {
// #[cfg(feature = "ssl")]
// pub fn new(inner: TcpStream) -> Self {
// Self {
// inner_stream: inner,
// ssl_stream: None,
// }
// }
//
// #[cfg(not(feature = "ssl"))]
// pub fn new(inner: TcpStream) -> Self {
// Self {
// inner_stream: inner,
// }
// }
//
// #[cfg(feature = "ssl")]
// pub fn set_ssl_stream(
// &mut self,
// ssl_stream: rustls::Stream<'a, rustls::ClientSession, TcpStream>,
// ) {
// self.ssl_stream = Some(ssl_stream)
// }
// }
//
// impl<'a> Read for SslOptionalReader<'a> {
// #[cfg(feature = "ssl")]
// fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// if let Some(ssl_stream) = &mut self.ssl_stream {
// ssl_stream.read(buf)
// } else {
// self.inner_stream.read(buf)
// }
// }
//
// #[cfg(not(feature = "ssl"))]
// fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// self.inner_stream.read(buf)
// }
// }
//
// impl<'a> Write for SslOptionalReader<'a> {
// #[cfg(feature = "ssl")]
// fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
// if let Some(ssl_stream) = &mut self.ssl_stream {
// ssl_stream.write(buf)
// } else {
// self.inner_stream.write(buf)
// }
// }
//
// #[cfg(not(feature = "ssl"))]
// fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
// self.inner_stream.write(buf)
// }
//
// #[cfg(feature = "ssl")]
// fn flush(&mut self) -> std::io::Result<()> {
// if let Some(ssl_stream) = &mut self.ssl_stream {
// ssl_stream.flush()
// } else {
// self.inner_stream.flush()
// }
// }
//
// #[cfg(not(feature = "ssl"))]
// fn flush(&mut self) -> std::io::Result<()> {
// self.inner_stream.flush()
// }
// }
//
// impl <'a> std::fmt::Debug for SslOptionalReader <'a> {
// #[cfg(feature = "ssl")]
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// write!(f, "SslOptionalReader[ssl={}]", self.ssl_stream.is_some())
// }
//
// #[cfg(not(feature = "ssl"))]
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// write!(f, "SslOptionalReader[ssl={}]", false)
// }
// }

View file

@ -11,6 +11,10 @@ pub enum Command<'a> {
SetPassword(&'a str),
/// Queries for a list. Allows for any number of arguments, which forms a single query.
List(&'a [&'a str]),
/// Tells upsd to switch to TLS, so all future communications will be encrypted.
StartTLS,
/// Queries the network version.
NetworkVersion,
}
impl<'a> Command<'a> {
@ -21,6 +25,8 @@ impl<'a> Command<'a> {
Self::SetUsername(_) => "USERNAME",
Self::SetPassword(_) => "PASSWORD",
Self::List(_) => "LIST",
Self::StartTLS => "STARTTLS",
Self::NetworkVersion => "NETVER",
}
}
@ -31,6 +37,7 @@ impl<'a> Command<'a> {
Self::SetUsername(username) => vec![username],
Self::SetPassword(password) => vec![password],
Self::List(query) => query.to_vec(),
_ => Vec::new(),
}
}
}

View file

@ -58,16 +58,18 @@ pub struct Config {
pub(crate) host: Host,
pub(crate) auth: Option<Auth>,
pub(crate) timeout: Duration,
pub(crate) ssl: bool,
pub(crate) debug: bool,
}
impl Config {
/// Creates a connection configuration.
pub fn new(host: Host, auth: Option<Auth>, timeout: Duration, debug: bool) -> Self {
pub fn new(host: Host, auth: Option<Auth>, timeout: Duration, ssl: bool, debug: bool) -> Self {
Config {
host,
auth,
timeout,
ssl,
debug,
}
}
@ -79,6 +81,7 @@ pub struct ConfigBuilder {
host: Option<Host>,
auth: Option<Auth>,
timeout: Option<Duration>,
ssl: Option<bool>,
debug: Option<bool>,
}
@ -107,6 +110,13 @@ impl ConfigBuilder {
self
}
/// Enables SSL on the connection.
#[cfg(feature = "ssl")]
pub fn with_ssl(mut self, ssl: bool) -> Self {
self.ssl = Some(ssl);
self
}
/// Enables debugging network calls by printing to stderr.
pub fn with_debug(mut self, debug: bool) -> Self {
self.debug = Some(debug);
@ -119,6 +129,7 @@ impl ConfigBuilder {
self.host.unwrap_or_default(),
self.auth,
self.timeout.unwrap_or_else(|| Duration::from_secs(5)),
self.ssl.unwrap_or(false),
self.debug.unwrap_or(false),
)
}

View file

@ -12,6 +12,8 @@ pub enum NutError {
UnexpectedResponse,
/// Occurs when the response type is not recognized by the client.
UnknownResponseType(String),
/// Occurs when attempting to use SSL in a transport that doesn't support it.
SslNotSupported,
/// Generic (usually internal) client error.
Generic(String),
}
@ -23,6 +25,7 @@ impl fmt::Display for NutError {
Self::UnknownUps => write!(f, "Unknown UPS device name"),
Self::UnexpectedResponse => write!(f, "Unexpected server response content"),
Self::UnknownResponseType(ty) => write!(f, "Unknown response type: {}", ty),
Self::SslNotSupported => write!(f, "SSL not supported by transport"),
Self::Generic(msg) => write!(f, "Internal client error: {}", msg),
}
}

View file

@ -15,4 +15,6 @@ pub mod blocking;
mod cmd;
mod config;
mod error;
#[cfg(feature = "ssl")]
mod ssl;
mod var;

40
nut-client/src/ssl/mod.rs Normal file
View file

@ -0,0 +1,40 @@
use crate::Config;
/// The certificate validation mechanism for NUT.
pub struct NutCertificateValidator {
debug: bool,
}
impl NutCertificateValidator {
/// Initialize a new instance.
pub fn new(config: &Config) -> Self {
NutCertificateValidator {
debug: config.debug,
}
}
}
impl rustls::ServerCertVerifier for NutCertificateValidator {
fn verify_server_cert(
&self,
_roots: &rustls::RootCertStore,
presented_certs: &[rustls::Certificate],
_dns_name: webpki::DNSNameRef<'_>,
_ocsp: &[u8],
) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
// todo: verify certificates, but not hostnames
if self.debug {
let parsed = webpki::EndEntityCert::from(presented_certs[0].0.as_slice()).ok();
if let Some(_parsed) = parsed {
eprintln!("DEBUG <- Certificate received and parsed");
// todo: reading values here... https://github.com/briansmith/webpki/pull/103
} else {
eprintln!("DEBUG <- Certificate not-parseable");
}
}
// trust everything for now
Ok(rustls::ServerCertVerified::assertion())
}
}