Add SSL support (#7)

Fixes #1
This commit is contained in:
Aram Peres 2021-07-31 08:43:26 -04:00 committed by GitHub
parent d78fd8c141
commit d36999db6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 494 additions and 104 deletions

View file

@ -1,9 +1,11 @@
use std::io;
use std::io::{BufRead, BufReader, Write};
use std::net::{SocketAddr, TcpStream};
use crate::blocking::filter::ConnectionPipeline;
use crate::cmd::{Command, Response};
use crate::{Config, Host, NutError, Variable};
use crate::{ClientError, Config, Host, NutError, Variable};
mod filter;
/// A blocking NUT client connection.
pub enum Connection {
@ -13,7 +15,7 @@ pub enum Connection {
impl Connection {
/// Initializes a connection to a NUT server (upsd).
pub fn new(config: Config) -> crate::Result<Self> {
pub fn new(config: &Config) -> crate::Result<Self> {
match &config.host {
Host::Tcp(socket_addr) => {
Ok(Self::Tcp(TcpConnection::new(config.clone(), socket_addr)?))
@ -58,17 +60,22 @@ impl Connection {
}
/// A blocking TCP NUT client connection.
#[derive(Debug)]
pub struct TcpConnection {
config: Config,
tcp_stream: TcpStream,
pipeline: ConnectionPipeline,
}
impl TcpConnection {
fn new(config: Config, socket_addr: &SocketAddr) -> crate::Result<Self> {
// Create the TCP connection
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
let mut connection = Self { config, tcp_stream };
let mut connection = Self {
config,
pipeline: ConnectionPipeline::Tcp(tcp_stream),
};
// Initialize SSL connection
connection.enable_ssl()?;
// Attempt login using `config.auth`
connection.login()?;
@ -76,84 +83,114 @@ impl TcpConnection {
Ok(connection)
}
#[cfg(feature = "ssl")]
fn enable_ssl(&mut self) -> crate::Result<()> {
if self.config.ssl {
// Send TLS request and check for 'OK'
self.write_cmd(Command::StartTLS)?;
self.read_response()
.map_err(|e| {
if let ClientError::Nut(NutError::FeatureNotConfigured) = e {
ClientError::Nut(NutError::SslNotSupported)
} else {
e
}
})?
.expect_ok()?;
let mut config = rustls::ClientConfig::new();
config
.dangerous()
.set_certificate_verifier(std::sync::Arc::new(
crate::ssl::NutCertificateValidator::new(&self.config),
));
// todo: this DNS name is temporary; should get from connection hostname? (#8)
let dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com").unwrap();
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
// Wrap and override the TCP stream
let tcp = self
.pipeline
.tcp()
.ok_or_else(|| ClientError::from(NutError::SslNotSupported))?;
let tls = rustls::StreamOwned::new(sess, tcp);
self.pipeline = ConnectionPipeline::Ssl(tls);
// Send a test command
self.get_network_version()?;
}
Ok(())
}
#[cfg(not(feature = "ssl"))]
fn enable_ssl(&mut self) -> crate::Result<()> {
Ok(())
}
fn login(&mut self) -> crate::Result<()> {
if let Some(auth) = &self.config.auth {
if let Some(auth) = self.config.auth.clone() {
// Pass username and check for 'OK'
Self::write_cmd(
&mut self.tcp_stream,
Command::SetUsername(&auth.username),
self.config.debug,
)?;
Self::read_response(&mut self.tcp_stream, self.config.debug)?.expect_ok()?;
self.write_cmd(Command::SetUsername(&auth.username))?;
self.read_response()?.expect_ok()?;
// Pass password and check for 'OK'
if let Some(password) = &auth.password {
Self::write_cmd(
&mut self.tcp_stream,
Command::SetPassword(password),
self.config.debug,
)?;
Self::read_response(&mut self.tcp_stream, self.config.debug)?.expect_ok()?;
self.write_cmd(Command::SetPassword(password))?;
self.read_response()?.expect_ok()?;
}
}
Ok(())
}
fn list_ups(&mut self) -> crate::Result<Vec<(String, String)>> {
Self::write_cmd(
&mut self.tcp_stream,
Command::List(&["UPS"]),
self.config.debug,
)?;
let list = Self::read_list(&mut self.tcp_stream, &["UPS"], self.config.debug)?;
let query = &["UPS"];
self.write_cmd(Command::List(query))?;
let list = self.read_list(query)?;
list.into_iter().map(|row| row.expect_ups()).collect()
}
fn list_clients(&mut self, ups_name: &str) -> crate::Result<Vec<String>> {
let query = &["CLIENT", ups_name];
Self::write_cmd(
&mut self.tcp_stream,
Command::List(query),
self.config.debug,
)?;
let list = Self::read_list(&mut self.tcp_stream, query, self.config.debug)?;
self.write_cmd(Command::List(query))?;
let list = self.read_list(query)?;
list.into_iter().map(|row| row.expect_client()).collect()
}
fn list_vars(&mut self, ups_name: &str) -> crate::Result<Vec<(String, String)>> {
let query = &["VAR", ups_name];
Self::write_cmd(
&mut self.tcp_stream,
Command::List(query),
self.config.debug,
)?;
let list = Self::read_list(&mut self.tcp_stream, query, self.config.debug)?;
self.write_cmd(Command::List(query))?;
let list = self.read_list(query)?;
list.into_iter().map(|row| row.expect_var()).collect()
}
fn get_var(&mut self, ups_name: &str, variable: &str) -> crate::Result<(String, String)> {
let query = &["VAR", ups_name, variable];
Self::write_cmd(&mut self.tcp_stream, Command::Get(query), self.config.debug)?;
self.write_cmd(Command::Get(query))?;
let resp = Self::read_response(&mut self.tcp_stream, self.config.debug)?;
resp.expect_var()
self.read_response()?.expect_var()
}
fn write_cmd(stream: &mut TcpStream, line: Command, debug: bool) -> crate::Result<()> {
fn get_network_version(&mut self) -> crate::Result<String> {
self.write_cmd(Command::NetworkVersion)?;
self.read_plain_response()
}
fn write_cmd(&mut self, line: Command) -> crate::Result<()> {
let line = format!("{}\n", line);
if debug {
if self.config.debug {
eprint!("DEBUG -> {}", line);
}
stream.write_all(line.as_bytes())?;
stream.flush()?;
self.pipeline.write_all(line.as_bytes())?;
self.pipeline.flush()?;
Ok(())
}
fn parse_line(
reader: &mut BufReader<&mut TcpStream>,
reader: &mut BufReader<&mut ConnectionPipeline>,
debug: bool,
) -> crate::Result<Vec<String>> {
let mut raw = String::new();
@ -170,25 +207,27 @@ impl TcpConnection {
Ok(args)
}
fn read_response(stream: &mut TcpStream, debug: bool) -> crate::Result<Response> {
let mut reader = io::BufReader::new(stream);
let args = Self::parse_line(&mut reader, debug)?;
fn read_response(&mut self) -> crate::Result<Response> {
let mut reader = BufReader::new(&mut self.pipeline);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Response::from_args(args)
}
fn read_list(
stream: &mut TcpStream,
query: &[&str],
debug: bool,
) -> crate::Result<Vec<Response>> {
let mut reader = io::BufReader::new(stream);
let args = Self::parse_line(&mut reader, debug)?;
fn read_plain_response(&mut self) -> crate::Result<String> {
let mut reader = BufReader::new(&mut self.pipeline);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Ok(args.join(" "))
}
fn read_list(&mut self, query: &[&str]) -> crate::Result<Vec<Response>> {
let mut reader = BufReader::new(&mut self.pipeline);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Response::from_args(args)?.expect_begin_list(query)?;
let mut lines: Vec<Response> = Vec::new();
loop {
let args = Self::parse_line(&mut reader, debug)?;
let args = Self::parse_line(&mut reader, self.config.debug)?;
let resp = Response::from_args(args)?;
match resp {