Strict SSL verification (#9)

Fixes #8
This commit is contained in:
Aram Peres 2021-07-31 11:12:45 -04:00 committed by GitHub
parent f22867d2d2
commit 3002b4de53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 171 additions and 75 deletions

View file

@ -17,6 +17,7 @@ license = "MIT"
shell-words = "1.0.0"
rustls = { version = "0.19", optional = true, features = ["dangerous_configuration"] }
webpki = { version = "0.21", optional = true }
webpki-roots = { version = "0.21", optional = true }
[features]
ssl = ["rustls", "webpki"]
ssl = ["rustls", "webpki", "webpki-roots"]

View file

@ -1,23 +1,23 @@
use std::env;
use std::net::ToSocketAddrs;
use nut_client::blocking::Connection;
use nut_client::{Auth, ConfigBuilder, Host};
use nut_client::{Auth, ConfigBuilder};
use std::convert::TryInto;
fn main() -> nut_client::Result<()> {
let addr = env::var("NUT_ADDR")
.unwrap_or_else(|_| "localhost:3493".into())
.to_socket_addrs()
.unwrap()
.next()
.unwrap();
let host = env::var("NUT_HOST").unwrap_or_else(|_| "localhost".into());
let port = env::var("NUT_PORT")
.ok()
.map(|s| s.parse::<u16>().ok())
.flatten()
.unwrap_or(3493);
let username = env::var("NUT_USER").ok();
let password = env::var("NUT_PASSWORD").ok();
let auth = username.map(|username| Auth::new(username, password));
let config = ConfigBuilder::new()
.with_host(Host::Tcp(addr))
.with_host((host, port).try_into().unwrap_or_default())
.with_auth(auth)
.with_debug(false) // Turn this on for debugging network chatter
.build();

View file

@ -17,9 +17,7 @@ impl Connection {
/// Initializes a connection to a NUT server (upsd).
pub fn new(config: &Config) -> crate::Result<Self> {
match &config.host {
Host::Tcp(socket_addr) => {
Ok(Self::Tcp(TcpConnection::new(config.clone(), socket_addr)?))
}
Host::Tcp(host) => Ok(Self::Tcp(TcpConnection::new(config.clone(), &host.addr)?)),
}
}
@ -62,7 +60,7 @@ impl Connection {
/// A blocking TCP NUT client connection.
pub struct TcpConnection {
config: Config,
pipeline: ConnectionStream,
stream: ConnectionStream,
}
impl TcpConnection {
@ -71,7 +69,7 @@ impl TcpConnection {
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
let mut connection = Self {
config,
pipeline: ConnectionStream::Plain(tcp_stream),
stream: ConnectionStream::Plain(tcp_stream),
};
// Initialize SSL connection
@ -98,19 +96,37 @@ impl TcpConnection {
})?
.expect_ok()?;
let mut config = rustls::ClientConfig::new();
config
.dangerous()
.set_certificate_verifier(std::sync::Arc::new(
crate::ssl::NutCertificateValidator::new(&self.config),
));
let mut ssl_config = rustls::ClientConfig::new();
let sess = if self.config.ssl_insecure {
ssl_config
.dangerous()
.set_certificate_verifier(std::sync::Arc::new(
crate::ssl::InsecureCertificateValidator::new(&self.config),
));
// todo: this DNS name is temporary; should get from connection hostname? (#8)
let dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com").unwrap();
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);
let dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com").unwrap();
rustls::ClientSession::new(&std::sync::Arc::new(ssl_config), dns_name)
} else {
// Try to get hostname as given (e.g. localhost can be used for strict SSL, but not 127.0.0.1)
let hostname = self
.config
.host
.hostname()
.ok_or(ClientError::Nut(NutError::SslInvalidHostname))?;
let dns_name = webpki::DNSNameRef::try_from_ascii_str(&hostname)
.map_err(|_| ClientError::Nut(NutError::SslInvalidHostname))?;
ssl_config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
rustls::ClientSession::new(&std::sync::Arc::new(ssl_config), dns_name)
};
// Wrap and override the TCP stream
self.pipeline = self.pipeline.upgrade_ssl(sess)?;
self.stream = self.stream.upgrade_ssl(sess)?;
// Send a test command
self.get_network_version()?;
@ -179,8 +195,8 @@ impl TcpConnection {
if self.config.debug {
eprint!("DEBUG -> {}", line);
}
self.pipeline.write_all(line.as_bytes())?;
self.pipeline.flush()?;
self.stream.write_all(line.as_bytes())?;
self.stream.flush()?;
Ok(())
}
@ -203,19 +219,19 @@ impl TcpConnection {
}
fn read_response(&mut self) -> crate::Result<Response> {
let mut reader = BufReader::new(&mut self.pipeline);
let mut reader = BufReader::new(&mut self.stream);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Response::from_args(args)
}
fn read_plain_response(&mut self) -> crate::Result<String> {
let mut reader = BufReader::new(&mut self.pipeline);
let mut reader = BufReader::new(&mut self.stream);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Ok(args.join(" "))
}
fn read_list(&mut self, query: &[&str]) -> crate::Result<Vec<Response>> {
let mut reader = BufReader::new(&mut self.pipeline);
let mut reader = BufReader::new(&mut self.stream);
let args = Self::parse_line(&mut reader, self.config.debug)?;
Response::from_args(args)?.expect_begin_list(query)?;

View file

@ -1,29 +1,66 @@
use core::fmt;
use std::convert::{TryFrom, TryInto};
use std::net::{SocketAddr, ToSocketAddrs};
use std::time::Duration;
use crate::ClientError;
/// A host specification.
#[derive(Clone, Debug)]
pub enum Host {
/// A TCP hostname and port.
Tcp(SocketAddr),
/// A TCP hostname, and address (IP + port).
Tcp(TcpHost),
// TODO: Support Unix socket streams.
}
impl Host {
/// Returns the hostname as given, if any.
pub fn hostname(&self) -> Option<String> {
match self {
Host::Tcp(host) => Some(host.hostname.to_owned()),
// _ => None,
}
}
}
impl Default for Host {
fn default() -> Self {
let addr = (String::from("127.0.0.1"), 3493)
.to_socket_addrs()
.expect("Failed to create local UPS socket address. This is a bug.")
.next()
.expect("Failed to create local UPS socket address. This is a bug.");
Self::Tcp(addr)
(String::from("localhost"), 3493)
.try_into()
.expect("Failed to parse local hostname; this is a bug.")
}
}
impl From<SocketAddr> for Host {
fn from(addr: SocketAddr) -> Self {
Self::Tcp(addr)
let hostname = addr.ip().to_string();
Self::Tcp(TcpHost { hostname, addr })
}
}
/// A TCP address, preserving the original DNS hostname if any.
#[derive(Clone, Debug)]
pub struct TcpHost {
pub(crate) hostname: String,
pub(crate) addr: SocketAddr,
}
impl TryFrom<(String, u16)> for Host {
type Error = ClientError;
fn try_from(hostname_port: (String, u16)) -> Result<Self, Self::Error> {
let (hostname, _) = hostname_port.clone();
let addr = hostname_port
.to_socket_addrs()
.map_err(ClientError::Io)?
.next()
.ok_or_else(|| {
ClientError::Io(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no address given",
))
})?;
Ok(Host::Tcp(TcpHost { hostname, addr }))
}
}
@ -59,17 +96,26 @@ pub struct Config {
pub(crate) auth: Option<Auth>,
pub(crate) timeout: Duration,
pub(crate) ssl: bool,
pub(crate) ssl_insecure: bool,
pub(crate) debug: bool,
}
impl Config {
/// Creates a connection configuration.
pub fn new(host: Host, auth: Option<Auth>, timeout: Duration, ssl: bool, debug: bool) -> Self {
pub fn new(
host: Host,
auth: Option<Auth>,
timeout: Duration,
ssl: bool,
ssl_insecure: bool,
debug: bool,
) -> Self {
Config {
host,
auth,
timeout,
ssl,
ssl_insecure,
debug,
}
}
@ -82,6 +128,7 @@ pub struct ConfigBuilder {
auth: Option<Auth>,
timeout: Option<Duration>,
ssl: Option<bool>,
ssl_insecure: Option<bool>,
debug: Option<bool>,
}
@ -111,12 +158,24 @@ impl ConfigBuilder {
}
/// Enables SSL on the connection.
///
/// This will enable strict SSL verification (including hostname),
/// unless `.with_insecure_ssl` is also set to `true`.
#[cfg(feature = "ssl")]
pub fn with_ssl(mut self, ssl: bool) -> Self {
self.ssl = Some(ssl);
self
}
/// Turns off SSL verification.
///
/// Note: you must still use `.with_ssl(true)` to turn on SSL.
#[cfg(feature = "ssl")]
pub fn with_insecure_ssl(mut self, ssl_insecure: bool) -> Self {
self.ssl_insecure = Some(ssl_insecure);
self
}
/// Enables debugging network calls by printing to stderr.
pub fn with_debug(mut self, debug: bool) -> Self {
self.debug = Some(debug);
@ -130,6 +189,7 @@ impl ConfigBuilder {
self.auth,
self.timeout.unwrap_or_else(|| Duration::from_secs(5)),
self.ssl.unwrap_or(false),
self.ssl_insecure.unwrap_or(false),
self.debug.unwrap_or(false),
)
}

View file

@ -15,6 +15,8 @@ pub enum NutError {
/// Occurs when attempting to use SSL in a transport that doesn't support it, or
/// if the server is not configured for it.
SslNotSupported,
/// Occurs when trying to initialize a strict SSL connection with an invalid hostname.
SslInvalidHostname,
/// Occurs when the client used a feature that is disabled by the server.
FeatureNotConfigured,
/// Generic (usually internal) client error.
@ -29,6 +31,10 @@ impl fmt::Display for NutError {
Self::UnexpectedResponse => write!(f, "Unexpected server response content"),
Self::UnknownResponseType(ty) => write!(f, "Unknown response type: {}", ty),
Self::SslNotSupported => write!(f, "SSL not supported by server or transport"),
Self::SslInvalidHostname => write!(
f,
"Given hostname cannot be used for a strict SSL connection"
),
Self::FeatureNotConfigured => write!(f, "Feature not configured by server"),
Self::Generic(msg) => write!(f, "Internal client error: {}", msg),
}

View file

@ -1,40 +1,30 @@
use crate::Config;
/// The certificate validation mechanism for NUT.
pub struct NutCertificateValidator {
/// The certificate validation mechanism that allows any certificate.
pub struct InsecureCertificateValidator {
debug: bool,
}
impl NutCertificateValidator {
impl InsecureCertificateValidator {
/// Initialize a new instance.
pub fn new(config: &Config) -> Self {
NutCertificateValidator {
InsecureCertificateValidator {
debug: config.debug,
}
}
}
impl rustls::ServerCertVerifier for NutCertificateValidator {
impl rustls::ServerCertVerifier for InsecureCertificateValidator {
fn verify_server_cert(
&self,
_roots: &rustls::RootCertStore,
presented_certs: &[rustls::Certificate],
_presented_certs: &[rustls::Certificate],
_dns_name: webpki::DNSNameRef<'_>,
_ocsp: &[u8],
) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
// todo: verify certificates, but not hostnames
if self.debug {
let parsed = webpki::EndEntityCert::from(presented_certs[0].0.as_slice()).ok();
if let Some(_parsed) = parsed {
eprintln!("DEBUG <- Certificate received and parsed");
// todo: reading values here... https://github.com/briansmith/webpki/pull/103
} else {
eprintln!("DEBUG <- Certificate not-parseable");
}
eprintln!("DEBUG <- (!) Certificate received, but not verified");
}
// trust everything for now
Ok(rustls::ServerCertVerified::assertion())
}
}