diff --git a/Cargo.lock b/Cargo.lock index 82d62ec..83b74a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,22 +52,6 @@ version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c59e7af012c713f529e7a3ee57ce9b31ddd858d4b512923602f74608b009631" -[[package]] -name = "byteorder" -version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" - -[[package]] -name = "bytes" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" -dependencies = [ - "byteorder", - "iovec", -] - [[package]] name = "bytes" version = "1.0.1" @@ -101,12 +85,6 @@ dependencies = [ "vec_map", ] -[[package]] -name = "futures" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678" - [[package]] name = "hermit-abi" version = "0.1.19" @@ -116,15 +94,6 @@ dependencies = [ "libc", ] -[[package]] -name = "iovec" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e" -dependencies = [ - "libc", -] - [[package]] name = "js-sys" version = "0.3.51" @@ -261,7 +230,6 @@ dependencies = [ "rustls", "shell-words", "tokio", - "tokio-mockstream", "tokio-rustls", "webpki", "webpki-roots", @@ -344,7 +312,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b7b349f11a7047e6d1276853e612d152f5e8a352c61917887cc2169e2366b4c" dependencies = [ "autocfg", - "bytes 1.0.1", + "bytes", "libc", "memchr", "mio", @@ -354,17 +322,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "tokio-io" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57fc868aae093479e3131e3d165c93b1c7474109d13c90ec0dda2a1bbfff0674" -dependencies = [ - "bytes 0.4.12", - "futures", - "log", -] - [[package]] name = "tokio-macros" version = "1.3.0" @@ -376,16 +333,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio-mockstream" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41bfc436ef8b7f60c19adf3df086330ae9992385e4d8c53b17a323cad288e155" -dependencies = [ - "futures", - "tokio-io", -] - [[package]] name = "tokio-rustls" version = "0.22.0" diff --git a/rups/Cargo.toml b/rups/Cargo.toml index 781c602..29ba795 100644 --- a/rups/Cargo.toml +++ b/rups/Cargo.toml @@ -23,7 +23,6 @@ tokio-rustls = { version = "0.22", optional = true } [dev-dependencies] mockstream = "0.0.3" -tokio-mockstream = "1.1.0" [features] default = [] diff --git a/rups/src/blocking/client.rs b/rups/src/blocking/client.rs index 8ca041e..9fe00b9 100644 --- a/rups/src/blocking/client.rs +++ b/rups/src/blocking/client.rs @@ -1,5 +1,4 @@ use crate::blocking::stream::ConnectionStream; -use crate::proto::{ClientSentences, Sentence, ServerSentences}; use crate::{Config, Error, Host, NutError, TcpHost}; use std::net::TcpStream; diff --git a/rups/src/blocking/mod.rs b/rups/src/blocking/mod.rs index 0ebb78e..1e40a68 100644 --- a/rups/src/blocking/mod.rs +++ b/rups/src/blocking/mod.rs @@ -10,6 +10,7 @@ mod stream; pub use client::Client; +// TODO: Remove me /// A blocking NUT client connection. pub enum Connection { /// A TCP connection. diff --git a/rups/src/cmd.rs b/rups/src/cmd.rs index 3f1b701..ad16273 100644 --- a/rups/src/cmd.rs +++ b/rups/src/cmd.rs @@ -721,6 +721,25 @@ macro_rules! implement_client_action_commands { } )* } + + #[cfg(feature = "async")] + impl crate::tokio::Client { + $( + $(#[$attr])* + #[allow(dead_code)] + $vis async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<()> { + use crate::proto::{Sentence, ClientSentences, ClientSentences::*, ServerSentences::*}; + self.stream + .write_sentence(&$cmd) + .await?; + self.stream + .read_sentence::() + .await? + .exactly($ret) + .map(|_| ()) + } + )* + } }; } diff --git a/rups/src/tokio/client.rs b/rups/src/tokio/client.rs new file mode 100644 index 0000000..7de241f --- /dev/null +++ b/rups/src/tokio/client.rs @@ -0,0 +1,110 @@ +use crate::tokio::stream::ConnectionStream; +use crate::{Config, Error, Host, NutError, TcpHost}; +use tokio::net::TcpStream; + +/// An asynchronous NUT client, using Tokio. +pub struct Client { + /// The client configuration. + config: Config, + /// The client connection. + pub(crate) stream: ConnectionStream, +} + +impl Client { + /// Connects to a remote NUT server using a blocking connection. + pub async fn new(config: &Config) -> crate::Result { + match &config.host { + Host::Tcp(host) => Self::new_tcp(config, host).await, + } + // TODO: Support Unix domain sockets + } + + /// Connects to a remote NUT server using a blocking TCP connection. + async fn new_tcp(config: &Config, host: &TcpHost) -> crate::Result { + let tcp_stream = TcpStream::connect(&host.addr).await?; + let mut client = Client { + config: config.clone(), + stream: ConnectionStream::Tcp(tcp_stream).buffered(), + }; + + client = client.enable_ssl().await?; + + Ok(client) + } + + /// Authenticates to the given UPS device with the username and password set in the config. + pub async fn login(&mut self, ups_name: String) -> crate::Result<()> { + if let Some(auth) = self.config.auth.clone() { + // Pass username and check for 'OK' + self.set_username(auth.username).await?; + + // Pass password and check for 'OK' + if let Some(password) = auth.password { + self.set_password(password).await?; + } + + // Submit login + self.exec_login(ups_name).await + } else { + Ok(()) + } + } + + #[cfg(feature = "async-ssl")] + async fn enable_ssl(mut self) -> crate::Result { + if self.config.ssl { + self.exec_start_tls().await?; + + // Initialize SSL configurations + let mut ssl_config = rustls::ClientConfig::new(); + let dns_name: webpki::DNSName; + + if self.config.ssl_insecure { + ssl_config + .dangerous() + .set_certificate_verifier(std::sync::Arc::new( + crate::ssl::InsecureCertificateValidator::new(&self.config), + )); + + dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com") + .unwrap() + .to_owned(); + } else { + // Try to get hostname as given (e.g. localhost can be used for strict SSL, but not 127.0.0.1) + let hostname = self + .config + .host + .hostname() + .ok_or(Error::Nut(NutError::SslInvalidHostname))?; + + dns_name = webpki::DNSNameRef::try_from_ascii_str(&hostname) + .map_err(|_| Error::Nut(NutError::SslInvalidHostname))? + .to_owned(); + + ssl_config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + }; + + let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config)); + + // Un-buffer to get back underlying stream + self.stream = self.stream.unbuffered(); + + // Upgrade to SSL + self.stream = self + .stream + .upgrade_ssl_client(config, dns_name.as_ref()) + .await?; + + // Re-buffer + self.stream = self.stream.buffered(); + } + Ok(self) + } + + #[cfg(not(feature = "async-ssl"))] + async fn enable_ssl(self) -> crate::Result { + Ok(self) + } +} diff --git a/rups/src/tokio/mockstream.rs b/rups/src/tokio/mockstream.rs new file mode 100644 index 0000000..fb3cb61 --- /dev/null +++ b/rups/src/tokio/mockstream.rs @@ -0,0 +1,82 @@ +use std::fmt; +use std::io::{Error, Read, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// Async stream for unit testing. +#[derive(Clone, Default)] +pub struct AsyncMockStream(mockstream::SyncMockStream); + +impl AsyncMockStream { + /// Create empty stream + pub fn new() -> AsyncMockStream { + AsyncMockStream::default() + } + + /// Extract all bytes written by Write trait calls. + pub fn push_bytes_to_read(&mut self, bytes: &[u8]) { + self.0.push_bytes_to_read(bytes) + } +} + +impl fmt::Debug for AsyncMockStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AsyncMockStream").finish() + } +} + +impl Read for AsyncMockStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.0.read(buf) + } +} + +impl Write for AsyncMockStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.0.flush() + } +} + +impl AsyncRead for AsyncMockStream { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut vec = Vec::new(); + match self.get_mut().read_to_end(&mut vec) { + Ok(_) => { + let slice = vec.as_slice(); + buf.put_slice(slice); + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), + } + } +} + +impl AsyncWrite for AsyncMockStream { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let len = buf.len(); + self.get_mut().push_bytes_to_read(buf); + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.get_mut().flush()) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} diff --git a/rups/src/tokio/mod.rs b/rups/src/tokio/mod.rs index f8d8e45..4e4dfbf 100644 --- a/rups/src/tokio/mod.rs +++ b/rups/src/tokio/mod.rs @@ -6,8 +6,14 @@ use crate::{Config, Error, Host, NutError}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; +mod client; +#[cfg(test)] +mod mockstream; mod stream; +pub use client::Client; + +// TODO: Remove me /// An async NUT client connection. pub enum Connection { /// A TCP connection. diff --git a/rups/src/tokio/stream.rs b/rups/src/tokio/stream.rs index 282a709..41a32cd 100644 --- a/rups/src/tokio/stream.rs +++ b/rups/src/tokio/stream.rs @@ -1,10 +1,15 @@ -use crate::Error; +use crate::proto::util::{join_sentence, split_sentence}; +use crate::proto::Sentence; +use crate::{Error, NutError}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, ReadBuf}; +use tokio::io::{ + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadBuf, +}; use tokio::net::TcpStream; /// A wrapper for various Tokio stream types. +#[derive(Debug)] pub enum ConnectionStream { /// A plain TCP stream. Tcp(TcpStream), @@ -22,6 +27,10 @@ pub enum ConnectionStream { /// A server stream wrapped with SSL using `rustls`. #[cfg(feature = "async-ssl")] SslServer(Box>), + + /// A mock stream, used for testing. + #[cfg(test)] + Mock(crate::tokio::mockstream::AsyncMockStream), } impl ConnectionStream { @@ -47,6 +56,101 @@ impl ConnectionStream { acceptor.accept(self).await.map_err(Error::Io)?, ))) } + + /// Writes a sentence on the current stream. + pub async fn write_sentence(&mut self, sentence: &T) -> crate::Result<()> { + let encoded = sentence.encode(); + let joined = join_sentence(encoded); + self.write_literal(&joined).await?; + self.flush().await.map_err(crate::Error::Io) + } + + /// Writes a collection of sentences on the current stream. + pub async fn write_sentences(&mut self, sentences: &[T]) -> crate::Result<()> { + for sentence in sentences { + let encoded = sentence.encode(); + let joined = join_sentence(encoded); + self.write_literal(&joined).await?; + } + self.flush().await.map_err(crate::Error::Io) + } + + /// Writes a literal string to the current stream. + /// Note: the literal string MUST end with a new-line character (`\n`). + /// + /// Note: does not automatically flush. + pub async fn write_literal(&mut self, literal: &str) -> crate::Result<()> { + assert!(literal.ends_with('\n')); + self.write_all(literal.as_bytes()).await?; + Ok(()) + } + + /// Reads a literal string from the current stream. + /// + /// Note: the literal string will ends with a new-line character (`\n`) + /// Note: requires stream to be buffered. + pub async fn read_literal(&mut self) -> crate::Result { + let mut raw = String::new(); + self.read_line(&mut raw).await?; + Ok(raw) + } + + /// Reads a sentence from the current stream. + /// + /// Note: requires stream to be buffered. + pub async fn read_sentence(&mut self) -> crate::Result { + dbg!(&self); + let raw = self.read_literal().await?; + if raw.is_empty() { + return Err(Error::eof()); + } + let split = split_sentence(raw).ok_or(crate::NutError::NotProcessable)?; + T::decode(split) + .ok_or(Error::Nut(NutError::InvalidArgument))? + .into() + } + + /// Reads all sentences in the stream until the given `matcher` function evaluates to `true`. + /// + /// The final sentence is excluded. + /// + /// Note: requires stream to be buffered. + pub async fn read_sentences_until bool>( + &mut self, + matcher: F, + ) -> crate::Result> { + let mut result = Vec::new(); + let max_iter = 1000; // Exit after 1000 lines to prevent overflow. + for _ in 0..max_iter { + let sentence: T = self.read_sentence().await?; + if matcher(&sentence) { + return Ok(result); + } else { + result.push(sentence); + } + } + Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::Interrupted, + "Reached maximum read capacity.", + ))) + } + + /// Wraps the current stream with a `BufReader`. + pub fn buffered(self) -> ConnectionStream { + Self::Buffered(Box::new(BufReader::new(self))) + } + + /// Unwraps the underlying stream from the current `BufReader`. + /// If the current stream is not buffered, it returns itself (no-op). + /// + /// Note that, if the stream is buffered, any un-consumed data will be discarded. + pub fn unbuffered(self) -> ConnectionStream { + if let Self::Buffered(buf) = self { + buf.into_inner() + } else { + self + } + } } impl AsyncRead for ConnectionStream { @@ -74,27 +178,33 @@ impl AsyncRead for ConnectionStream { let pinned = Pin::new(stream); pinned.poll_read(cx, buf) } + #[cfg(test)] + Self::Mock(stream) => { + let pinned = Pin::new(stream); + pinned.poll_read(cx, buf) + } } } } impl AsyncBufRead for ConnectionStream { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + dbg!(&self); match self.get_mut() { Self::Buffered(reader) => { - let pinned = Pin::new(reader.get_mut()); + let pinned = Pin::new(reader); pinned.poll_fill_buf(cx) } - _ => core::task::Poll::Ready(Err(std::io::Error::new( + s => core::task::Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::Unsupported, - "Stream is not buffered", + format!("Stream is not buffered: {:?}", s), ))), } } fn consume(self: Pin<&mut Self>, amt: usize) { if let Self::Buffered(reader) = self.get_mut() { - let pinned = Pin::new(reader.get_mut()); + let pinned = Pin::new(reader); pinned.consume(amt) } } @@ -125,6 +235,11 @@ impl AsyncWrite for ConnectionStream { let pinned = Pin::new(stream); pinned.poll_write(cx, buf) } + #[cfg(test)] + Self::Mock(stream) => { + let pinned = Pin::new(stream); + pinned.poll_write(cx, buf) + } } } @@ -148,6 +263,11 @@ impl AsyncWrite for ConnectionStream { let pinned = Pin::new(stream); pinned.poll_flush(cx) } + #[cfg(test)] + Self::Mock(stream) => { + let pinned = Pin::new(stream); + pinned.poll_flush(cx) + } } } @@ -171,6 +291,122 @@ impl AsyncWrite for ConnectionStream { let pinned = Pin::new(stream); pinned.poll_shutdown(cx) } + #[cfg(test)] + Self::Mock(stream) => { + let pinned = Pin::new(stream); + pinned.poll_shutdown(cx) + } } } } +#[cfg(test)] +mod tests { + use super::ConnectionStream; + use crate::proto::{ClientSentences, Sentence, ServerSentences}; + + #[tokio::test] + async fn read_write_sentence() { + let client_stream = crate::tokio::mockstream::AsyncMockStream::new(); + let server_stream = client_stream.clone(); + + let mut client_stream = ConnectionStream::Mock(client_stream).buffered(); + let mut server_stream = ConnectionStream::Mock(server_stream).buffered(); + + // Client requests list of UPS devices + client_stream + .write_sentence(&ServerSentences::QueryListUps {}) + .await + .expect("Failed to write LIST UPS"); + + dbg!(&client_stream); + dbg!(&server_stream); + + // Server reads query for list of UPS devices + let sentence = server_stream + .read_sentence::() + .await + .expect("Failed to read LIST UPS"); + assert_eq!(sentence, ServerSentences::QueryListUps {}); + + // Server sends list of UPS devices. + server_stream + .write_sentences(&[ + ClientSentences::BeginListUps {}, + ClientSentences::RespondUps { + ups_name: "nutdev0".into(), + description: "A NUT device.".into(), + }, + ClientSentences::RespondUps { + ups_name: "nutdev1".into(), + description: "Another NUT device.".into(), + }, + ClientSentences::EndListUps {}, + ]) + .await + .expect("Failed to write UPS LIST"); + + // Client reads list of UPS devices. + client_stream + .read_sentence::() + .await + .expect("Failed to read BEGIN LIST UPS") + .exactly(ClientSentences::BeginListUps {}) + .unwrap(); + + let sentences: Vec = client_stream + .read_sentences_until(|s| matches!(s, ClientSentences::EndListUps {})) + .await + .expect("Failed to read UPS items"); + + assert_eq!( + sentences, + vec![ + ClientSentences::RespondUps { + ups_name: "nutdev0".into(), + description: "A NUT device.".into(), + }, + ClientSentences::RespondUps { + ups_name: "nutdev1".into(), + description: "Another NUT device.".into(), + }, + ] + ); + + // Client sends login + client_stream + .write_sentence(&ServerSentences::ExecLogin { + ups_name: "nutdev0".into(), + }) + .await + .expect("Failed to write LOGIN nutdev0"); + + // Server receives login + server_stream + .read_sentence::() + .await + .expect("Failed to read LOGIN nutdev0") + .exactly(ServerSentences::ExecLogin { + ups_name: "nutdev0".into(), + }) + .unwrap(); + + // Server rejects login + server_stream + .write_sentence(&ClientSentences::RespondErr { + message: "USERNAME-REQUIRED".into(), + extras: vec![], + }) + .await + .expect("Failed to write ERR USERNAME-REQUIRED"); + + // Client expects error + let error: crate::Error = client_stream + .read_sentence::() + .await + .expect_err("Failed to read ERR"); + assert!(matches!( + error, + crate::Error::Nut(crate::NutError::UsernameRequired) + )); + } +}