diff --git a/Cargo.lock b/Cargo.lock index 82146e6..0df50c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,6 +28,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + [[package]] name = "base64" version = "0.13.0" @@ -46,6 +52,12 @@ version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c59e7af012c713f529e7a3ee57ce9b31ddd858d4b512923602f74608b009631" +[[package]] +name = "bytes" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" + [[package]] name = "cc" version = "1.0.69" @@ -112,12 +124,61 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "memchr" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" + +[[package]] +name = "mio" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16" +dependencies = [ + "libc", + "log", + "miow", + "ntapi", + "winapi", +] + +[[package]] +name = "miow" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" +dependencies = [ + "winapi", +] + +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi", +] + +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "nut-client" version = "0.2.2" dependencies = [ "rustls", "shell-words", + "tokio", + "tokio-rustls", "webpki", "webpki-roots", ] @@ -128,6 +189,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56" +[[package]] +name = "pin-project-lite" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" + [[package]] name = "proc-macro2" version = "1.0.28" @@ -231,6 +298,45 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "tokio" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b7b349f11a7047e6d1276853e612d152f5e8a352c61917887cc2169e2366b4c" +dependencies = [ + "autocfg", + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "pin-project-lite", + "tokio-macros", + "winapi", +] + +[[package]] +name = "tokio-macros" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls", + "tokio", + "webpki", +] + [[package]] name = "unicode-width" version = "0.1.8" diff --git a/README.md b/README.md index bdb9121..a8b9fa4 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ A [Network UPS Tools](https://github.com/networkupstools/nut) (NUT) client libra - List UPS devices - List variables for a UPS device - Connect securely with SSL (optional feature) +- Supports blocking and async (Tokio) ## Getting Started @@ -36,7 +37,12 @@ built-in [upsc](https://networkupstools.org/docs/man/upsc.html) tool. Below is a sample program using this library (`cargo run --example blocking`). +You can also run the async version of this code using +`cargo run --example async --features async-rt` (source: `nut-client/example/async.rs`). + ```rust +// nut-client/example/blocking.rs + use std::env; use nut_client::blocking::Connection; @@ -91,3 +97,10 @@ that the connection hostname is a valid DNS name (e.g. `localhost`, not `127.0.0 If the server is using a self-signed certificate, and you'd like to ignore the strict validation, you can add `.with_insecure_ssl(true)` along with `.with_ssl(true)`. + +## Async (Tokio) + +The `nut-client` library supports async network requests. This requires the `async` feature, which uses Tokio v1 under +the hood. + +For SSL support, you must use the `async-ssl` feature as well. diff --git a/nut-client/Cargo.toml b/nut-client/Cargo.toml index 84e2e53..8bed77c 100644 --- a/nut-client/Cargo.toml +++ b/nut-client/Cargo.toml @@ -15,9 +15,21 @@ license = "MIT" [dependencies] shell-words = "1.0.0" -rustls = { version = "0.19", optional = true, features = ["dangerous_configuration"] } +rustls = { version = "0.19", optional = true } webpki = { version = "0.21", optional = true } webpki-roots = { version = "0.21", optional = true } +tokio = { version = "1", optional = true, features = ["net", "io-util", "rt"] } +tokio-rustls = { version = "0.22", optional = true } [features] -ssl = ["rustls", "webpki", "webpki-roots"] +default = [] +ssl = ["rustls", "rustls/dangerous_configuration", "webpki", "webpki-roots"] +async = ["tokio"] +async-ssl = ["async", "tokio-rustls", "ssl"] + +# a feature gate for examples +async-rt = ["tokio/rt-multi-thread", "tokio/macros"] + +[[example]] +name = "async" +required-features = ["async-rt"] diff --git a/nut-client/examples/async.rs b/nut-client/examples/async.rs new file mode 100644 index 0000000..a8fe5a6 --- /dev/null +++ b/nut-client/examples/async.rs @@ -0,0 +1,42 @@ +use std::env; + +use nut_client::tokio::Connection; +use nut_client::{Auth, ConfigBuilder}; +use std::convert::TryInto; + +#[tokio::main] +async fn main() -> nut_client::Result<()> { + let host = env::var("NUT_HOST").unwrap_or_else(|_| "localhost".into()); + let port = env::var("NUT_PORT") + .ok() + .map(|s| s.parse::().ok()) + .flatten() + .unwrap_or(3493); + + let username = env::var("NUT_USER").ok(); + let password = env::var("NUT_PASSWORD").ok(); + let auth = username.map(|username| Auth::new(username, password)); + + let config = ConfigBuilder::new() + .with_host((host, port).try_into().unwrap_or_default()) + .with_auth(auth) + .with_debug(false) // Turn this on for debugging network chatter + .build(); + + let mut conn = Connection::new(&config).await?; + + // Print a list of all UPS devices + println!("Connected UPS devices:"); + for (name, description) in conn.list_ups().await? { + println!("\t- Name: {}", name); + println!("\t Description: {}", description); + + // List UPS variables (key = val) + println!("\t Variables:"); + for var in conn.list_vars(&name).await? { + println!("\t\t- {}", var); + } + } + + Ok(()) +} diff --git a/nut-client/src/blocking/mod.rs b/nut-client/src/blocking/mod.rs index 4081177..9041edc 100644 --- a/nut-client/src/blocking/mod.rs +++ b/nut-client/src/blocking/mod.rs @@ -135,7 +135,7 @@ impl TcpConnection { } #[cfg(not(feature = "ssl"))] - fn enable_ssl(mut self) -> crate::Result { + fn enable_ssl(self) -> crate::Result { Ok(self) } @@ -185,6 +185,7 @@ impl TcpConnection { self.read_response()?.expect_var() } + #[allow(dead_code)] fn get_network_version(&mut self) -> crate::Result { self.write_cmd(Command::NetworkVersion)?; self.read_plain_response() diff --git a/nut-client/src/lib.rs b/nut-client/src/lib.rs index 3c3ff02..594cd41 100644 --- a/nut-client/src/lib.rs +++ b/nut-client/src/lib.rs @@ -11,6 +11,9 @@ pub use var::*; /// Blocking client implementation for NUT. pub mod blocking; +/// Async client implementation for NUT, using Tokio. +#[cfg(feature = "tokio")] +pub mod tokio; mod cmd; mod config; diff --git a/nut-client/src/tokio/mod.rs b/nut-client/src/tokio/mod.rs new file mode 100644 index 0000000..dc550b9 --- /dev/null +++ b/nut-client/src/tokio/mod.rs @@ -0,0 +1,267 @@ +use std::net::SocketAddr; + +use crate::cmd::{Command, Response}; +use crate::tokio::stream::ConnectionStream; +use crate::{Config, Host, NutError, Variable}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; + +mod stream; + +/// An async NUT client connection. +pub enum Connection { + /// A TCP connection. + Tcp(TcpConnection), +} + +impl Connection { + /// Initializes a connection to a NUT server (upsd). + pub async fn new(config: &Config) -> crate::Result { + match &config.host { + Host::Tcp(host) => Ok(Self::Tcp( + TcpConnection::new(config.clone(), &host.addr).await?, + )), + } + } + + /// Queries a list of UPS devices. + pub async fn list_ups(&mut self) -> crate::Result> { + match self { + Self::Tcp(conn) => conn.list_ups().await, + } + } + + /// Queries a list of client IP addresses connected to the given device. + pub async fn list_clients(&mut self, ups_name: &str) -> crate::Result> { + match self { + Self::Tcp(conn) => conn.list_clients(ups_name).await, + } + } + + /// Queries the list of variables for a UPS device. + pub async fn list_vars(&mut self, ups_name: &str) -> crate::Result> { + match self { + Self::Tcp(conn) => Ok(conn + .list_vars(ups_name) + .await? + .into_iter() + .map(|(key, val)| Variable::parse(key.as_str(), val)) + .collect()), + } + } + + /// Queries one variable for a UPS device. + pub async fn get_var(&mut self, ups_name: &str, variable: &str) -> crate::Result { + match self { + Self::Tcp(conn) => { + let var = conn.get_var(ups_name, variable).await?; + Ok(Variable::parse(var.0.as_str(), var.1)) + } + } + } +} + +/// A blocking TCP NUT client connection. +pub struct TcpConnection { + config: Config, + stream: ConnectionStream, +} + +impl TcpConnection { + async fn new(config: Config, socket_addr: &SocketAddr) -> crate::Result { + // Create the TCP connection + let tcp_stream = TcpStream::connect(socket_addr).await?; + let mut connection = Self { + config, + stream: ConnectionStream::Plain(tcp_stream), + }; + + // Initialize SSL connection + connection = connection.enable_ssl().await?; + + // Attempt login using `config.auth` + connection.login().await?; + + Ok(connection) + } + + #[cfg(feature = "async-ssl")] + async fn enable_ssl(mut self) -> crate::Result { + if self.config.ssl { + // Send TLS request and check for 'OK' + self.write_cmd(Command::StartTLS).await?; + self.read_response() + .await + .map_err(|e| { + if let crate::ClientError::Nut(NutError::FeatureNotConfigured) = e { + crate::ClientError::Nut(NutError::SslNotSupported) + } else { + e + } + })? + .expect_ok()?; + + let mut ssl_config = rustls::ClientConfig::new(); + let dns_name: webpki::DNSName; + + if self.config.ssl_insecure { + ssl_config + .dangerous() + .set_certificate_verifier(std::sync::Arc::new( + crate::ssl::InsecureCertificateValidator::new(&self.config), + )); + + dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com") + .unwrap() + .to_owned(); + } else { + // Try to get hostname as given (e.g. localhost can be used for strict SSL, but not 127.0.0.1) + let hostname = self + .config + .host + .hostname() + .ok_or(crate::ClientError::Nut(NutError::SslInvalidHostname))?; + + dns_name = webpki::DNSNameRef::try_from_ascii_str(&hostname) + .map_err(|_| crate::ClientError::Nut(NutError::SslInvalidHostname))? + .to_owned(); + + ssl_config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + }; + + let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config)); + + // Wrap and override the TCP stream + self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?; + + // Send a test command + self.get_network_version().await?; + } + Ok(self) + } + + #[cfg(not(feature = "async-ssl"))] + async fn enable_ssl(self) -> crate::Result { + Ok(self) + } + + async fn login(&mut self) -> crate::Result<()> { + if let Some(auth) = self.config.auth.clone() { + // Pass username and check for 'OK' + self.write_cmd(Command::SetUsername(&auth.username)).await?; + self.read_response().await?.expect_ok()?; + + // Pass password and check for 'OK' + if let Some(password) = &auth.password { + self.write_cmd(Command::SetPassword(password)).await?; + self.read_response().await?.expect_ok()?; + } + } + Ok(()) + } + + async fn list_ups(&mut self) -> crate::Result> { + let query = &["UPS"]; + self.write_cmd(Command::List(query)).await?; + + let list = self.read_list(query).await?; + list.into_iter().map(|row| row.expect_ups()).collect() + } + + async fn list_clients(&mut self, ups_name: &str) -> crate::Result> { + let query = &["CLIENT", ups_name]; + self.write_cmd(Command::List(query)).await?; + + let list = self.read_list(query).await?; + list.into_iter().map(|row| row.expect_client()).collect() + } + + async fn list_vars(&mut self, ups_name: &str) -> crate::Result> { + let query = &["VAR", ups_name]; + self.write_cmd(Command::List(query)).await?; + + let list = self.read_list(query).await?; + list.into_iter().map(|row| row.expect_var()).collect() + } + + async fn get_var<'a>( + &mut self, + ups_name: &'a str, + variable: &'a str, + ) -> crate::Result<(String, String)> { + let query = &["VAR", ups_name, variable]; + self.write_cmd(Command::Get(query)).await?; + + self.read_response().await?.expect_var() + } + + #[allow(dead_code)] + async fn get_network_version(&mut self) -> crate::Result { + self.write_cmd(Command::NetworkVersion).await?; + self.read_plain_response().await + } + + async fn write_cmd(&mut self, line: Command<'_>) -> crate::Result<()> { + let line = format!("{}\n", line); + if self.config.debug { + eprint!("DEBUG -> {}", line); + } + self.stream.write_all(line.as_bytes()).await?; + self.stream.flush().await?; + Ok(()) + } + + async fn parse_line( + reader: &mut BufReader<&mut ConnectionStream>, + debug: bool, + ) -> crate::Result> { + let mut raw = String::new(); + reader.read_line(&mut raw).await?; + if debug { + eprint!("DEBUG <- {}", raw); + } + raw = raw[..raw.len() - 1].to_string(); // Strip off \n + + // Parse args by splitting whitespace, minding quotes for args with multiple words + let args = shell_words::split(&raw) + .map_err(|e| NutError::Generic(format!("Parsing server response failed: {}", e)))?; + + Ok(args) + } + + async fn read_response(&mut self) -> crate::Result { + let mut reader = BufReader::new(&mut self.stream); + let args = Self::parse_line(&mut reader, self.config.debug).await?; + Response::from_args(args) + } + + async fn read_plain_response(&mut self) -> crate::Result { + let mut reader = BufReader::new(&mut self.stream); + let args = Self::parse_line(&mut reader, self.config.debug).await?; + Ok(args.join(" ")) + } + + async fn read_list(&mut self, query: &[&str]) -> crate::Result> { + let mut reader = BufReader::new(&mut self.stream); + let args = Self::parse_line(&mut reader, self.config.debug).await?; + + Response::from_args(args)?.expect_begin_list(query)?; + let mut lines: Vec = Vec::new(); + + loop { + let args = Self::parse_line(&mut reader, self.config.debug).await?; + let resp = Response::from_args(args)?; + + match resp { + Response::EndList(_) => { + break; + } + _ => lines.push(resp), + } + } + + Ok(lines) + } +} diff --git a/nut-client/src/tokio/stream.rs b/nut-client/src/tokio/stream.rs new file mode 100644 index 0000000..03c0008 --- /dev/null +++ b/nut-client/src/tokio/stream.rs @@ -0,0 +1,100 @@ +use std::io::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::TcpStream; + +/// A wrapper for various Tokio stream types. +pub enum ConnectionStream { + /// A plain TCP stream. + Plain(TcpStream), + + /// A stream wrapped with SSL using `rustls`. + #[cfg(feature = "async-ssl")] + Ssl(Box>), +} + +impl ConnectionStream { + /// Wraps the current stream with SSL using `rustls`. + #[cfg(feature = "async-ssl")] + pub async fn upgrade_ssl( + self, + config: tokio_rustls::TlsConnector, + dns_name: webpki::DNSNameRef<'_>, + ) -> crate::Result { + Ok(ConnectionStream::Ssl(Box::new( + config + .connect(dns_name, self) + .await + .map_err(crate::ClientError::Io)?, + ))) + } +} + +impl AsyncRead for ConnectionStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Plain(stream) => { + let pinned = Pin::new(stream); + pinned.poll_read(cx, buf) + } + #[cfg(feature = "async-ssl")] + Self::Ssl(stream) => { + let pinned = Pin::new(stream); + pinned.poll_read(cx, buf) + } + } + } +} + +impl AsyncWrite for ConnectionStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Plain(stream) => { + let pinned = Pin::new(stream); + pinned.poll_write(cx, buf) + } + #[cfg(feature = "async-ssl")] + Self::Ssl(stream) => { + let pinned = Pin::new(stream); + pinned.poll_write(cx, buf) + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(stream) => { + let pinned = Pin::new(stream); + pinned.poll_flush(cx) + } + #[cfg(feature = "async-ssl")] + Self::Ssl(stream) => { + let pinned = Pin::new(stream); + pinned.poll_flush(cx) + } + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(stream) => { + let pinned = Pin::new(stream); + pinned.poll_shutdown(cx) + } + #[cfg(feature = "async-ssl")] + Self::Ssl(stream) => { + let pinned = Pin::new(stream); + pinned.poll_shutdown(cx) + } + } + } +}