Bring Tokio side up to date

This commit is contained in:
Aram 🍐 2021-08-05 01:15:42 -04:00
parent 162c015680
commit 3b3f9278de
9 changed files with 461 additions and 62 deletions

55
Cargo.lock generated
View file

@ -52,22 +52,6 @@ version = "3.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c59e7af012c713f529e7a3ee57ce9b31ddd858d4b512923602f74608b009631"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "bytes"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c"
dependencies = [
"byteorder",
"iovec",
]
[[package]]
name = "bytes"
version = "1.0.1"
@ -101,12 +85,6 @@ dependencies = [
"vec_map",
]
[[package]]
name = "futures"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678"
[[package]]
name = "hermit-abi"
version = "0.1.19"
@ -116,15 +94,6 @@ dependencies = [
"libc",
]
[[package]]
name = "iovec"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e"
dependencies = [
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.51"
@ -261,7 +230,6 @@ dependencies = [
"rustls",
"shell-words",
"tokio",
"tokio-mockstream",
"tokio-rustls",
"webpki",
"webpki-roots",
@ -344,7 +312,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b7b349f11a7047e6d1276853e612d152f5e8a352c61917887cc2169e2366b4c"
dependencies = [
"autocfg",
"bytes 1.0.1",
"bytes",
"libc",
"memchr",
"mio",
@ -354,17 +322,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "tokio-io"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57fc868aae093479e3131e3d165c93b1c7474109d13c90ec0dda2a1bbfff0674"
dependencies = [
"bytes 0.4.12",
"futures",
"log",
]
[[package]]
name = "tokio-macros"
version = "1.3.0"
@ -376,16 +333,6 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-mockstream"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41bfc436ef8b7f60c19adf3df086330ae9992385e4d8c53b17a323cad288e155"
dependencies = [
"futures",
"tokio-io",
]
[[package]]
name = "tokio-rustls"
version = "0.22.0"

View file

@ -23,7 +23,6 @@ tokio-rustls = { version = "0.22", optional = true }
[dev-dependencies]
mockstream = "0.0.3"
tokio-mockstream = "1.1.0"
[features]
default = []

View file

@ -1,5 +1,4 @@
use crate::blocking::stream::ConnectionStream;
use crate::proto::{ClientSentences, Sentence, ServerSentences};
use crate::{Config, Error, Host, NutError, TcpHost};
use std::net::TcpStream;

View file

@ -10,6 +10,7 @@ mod stream;
pub use client::Client;
// TODO: Remove me
/// A blocking NUT client connection.
pub enum Connection {
/// A TCP connection.

View file

@ -721,6 +721,25 @@ macro_rules! implement_client_action_commands {
}
)*
}
#[cfg(feature = "async")]
impl crate::tokio::Client {
$(
$(#[$attr])*
#[allow(dead_code)]
$vis async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<()> {
use crate::proto::{Sentence, ClientSentences, ClientSentences::*, ServerSentences::*};
self.stream
.write_sentence(&$cmd)
.await?;
self.stream
.read_sentence::<ClientSentences>()
.await?
.exactly($ret)
.map(|_| ())
}
)*
}
};
}

110
rups/src/tokio/client.rs Normal file
View file

@ -0,0 +1,110 @@
use crate::tokio::stream::ConnectionStream;
use crate::{Config, Error, Host, NutError, TcpHost};
use tokio::net::TcpStream;
/// An asynchronous NUT client, using Tokio.
pub struct Client {
/// The client configuration.
config: Config,
/// The client connection.
pub(crate) stream: ConnectionStream,
}
impl Client {
/// Connects to a remote NUT server using a blocking connection.
pub async fn new(config: &Config) -> crate::Result<Self> {
match &config.host {
Host::Tcp(host) => Self::new_tcp(config, host).await,
}
// TODO: Support Unix domain sockets
}
/// Connects to a remote NUT server using a blocking TCP connection.
async fn new_tcp(config: &Config, host: &TcpHost) -> crate::Result<Self> {
let tcp_stream = TcpStream::connect(&host.addr).await?;
let mut client = Client {
config: config.clone(),
stream: ConnectionStream::Tcp(tcp_stream).buffered(),
};
client = client.enable_ssl().await?;
Ok(client)
}
/// Authenticates to the given UPS device with the username and password set in the config.
pub async fn login(&mut self, ups_name: String) -> crate::Result<()> {
if let Some(auth) = self.config.auth.clone() {
// Pass username and check for 'OK'
self.set_username(auth.username).await?;
// Pass password and check for 'OK'
if let Some(password) = auth.password {
self.set_password(password).await?;
}
// Submit login
self.exec_login(ups_name).await
} else {
Ok(())
}
}
#[cfg(feature = "async-ssl")]
async fn enable_ssl(mut self) -> crate::Result<Self> {
if self.config.ssl {
self.exec_start_tls().await?;
// Initialize SSL configurations
let mut ssl_config = rustls::ClientConfig::new();
let dns_name: webpki::DNSName;
if self.config.ssl_insecure {
ssl_config
.dangerous()
.set_certificate_verifier(std::sync::Arc::new(
crate::ssl::InsecureCertificateValidator::new(&self.config),
));
dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com")
.unwrap()
.to_owned();
} else {
// Try to get hostname as given (e.g. localhost can be used for strict SSL, but not 127.0.0.1)
let hostname = self
.config
.host
.hostname()
.ok_or(Error::Nut(NutError::SslInvalidHostname))?;
dns_name = webpki::DNSNameRef::try_from_ascii_str(&hostname)
.map_err(|_| Error::Nut(NutError::SslInvalidHostname))?
.to_owned();
ssl_config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
};
let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config));
// Un-buffer to get back underlying stream
self.stream = self.stream.unbuffered();
// Upgrade to SSL
self.stream = self
.stream
.upgrade_ssl_client(config, dns_name.as_ref())
.await?;
// Re-buffer
self.stream = self.stream.buffered();
}
Ok(self)
}
#[cfg(not(feature = "async-ssl"))]
async fn enable_ssl(self) -> crate::Result<Self> {
Ok(self)
}
}

View file

@ -0,0 +1,82 @@
use std::fmt;
use std::io::{Error, Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// Async stream for unit testing.
#[derive(Clone, Default)]
pub struct AsyncMockStream(mockstream::SyncMockStream);
impl AsyncMockStream {
/// Create empty stream
pub fn new() -> AsyncMockStream {
AsyncMockStream::default()
}
/// Extract all bytes written by Write trait calls.
pub fn push_bytes_to_read(&mut self, bytes: &[u8]) {
self.0.push_bytes_to_read(bytes)
}
}
impl fmt::Debug for AsyncMockStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncMockStream").finish()
}
}
impl Read for AsyncMockStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.0.read(buf)
}
}
impl Write for AsyncMockStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.0.flush()
}
}
impl AsyncRead for AsyncMockStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let mut vec = Vec::new();
match self.get_mut().read_to_end(&mut vec) {
Ok(_) => {
let slice = vec.as_slice();
buf.put_slice(slice);
Poll::Ready(Ok(()))
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for AsyncMockStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
let len = buf.len();
self.get_mut().push_bytes_to_read(buf);
Poll::Ready(Ok(len))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(self.get_mut().flush())
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}

View file

@ -6,8 +6,14 @@ use crate::{Config, Error, Host, NutError};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
mod client;
#[cfg(test)]
mod mockstream;
mod stream;
pub use client::Client;
// TODO: Remove me
/// An async NUT client connection.
pub enum Connection {
/// A TCP connection.

View file

@ -1,10 +1,15 @@
use crate::Error;
use crate::proto::util::{join_sentence, split_sentence};
use crate::proto::Sentence;
use crate::{Error, NutError};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, ReadBuf};
use tokio::io::{
AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadBuf,
};
use tokio::net::TcpStream;
/// A wrapper for various Tokio stream types.
#[derive(Debug)]
pub enum ConnectionStream {
/// A plain TCP stream.
Tcp(TcpStream),
@ -22,6 +27,10 @@ pub enum ConnectionStream {
/// A server stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")]
SslServer(Box<tokio_rustls::server::TlsStream<ConnectionStream>>),
/// A mock stream, used for testing.
#[cfg(test)]
Mock(crate::tokio::mockstream::AsyncMockStream),
}
impl ConnectionStream {
@ -47,6 +56,101 @@ impl ConnectionStream {
acceptor.accept(self).await.map_err(Error::Io)?,
)))
}
/// Writes a sentence on the current stream.
pub async fn write_sentence<T: Sentence>(&mut self, sentence: &T) -> crate::Result<()> {
let encoded = sentence.encode();
let joined = join_sentence(encoded);
self.write_literal(&joined).await?;
self.flush().await.map_err(crate::Error::Io)
}
/// Writes a collection of sentences on the current stream.
pub async fn write_sentences<T: Sentence>(&mut self, sentences: &[T]) -> crate::Result<()> {
for sentence in sentences {
let encoded = sentence.encode();
let joined = join_sentence(encoded);
self.write_literal(&joined).await?;
}
self.flush().await.map_err(crate::Error::Io)
}
/// Writes a literal string to the current stream.
/// Note: the literal string MUST end with a new-line character (`\n`).
///
/// Note: does not automatically flush.
pub async fn write_literal(&mut self, literal: &str) -> crate::Result<()> {
assert!(literal.ends_with('\n'));
self.write_all(literal.as_bytes()).await?;
Ok(())
}
/// Reads a literal string from the current stream.
///
/// Note: the literal string will ends with a new-line character (`\n`)
/// Note: requires stream to be buffered.
pub async fn read_literal(&mut self) -> crate::Result<String> {
let mut raw = String::new();
self.read_line(&mut raw).await?;
Ok(raw)
}
/// Reads a sentence from the current stream.
///
/// Note: requires stream to be buffered.
pub async fn read_sentence<T: Sentence>(&mut self) -> crate::Result<T> {
dbg!(&self);
let raw = self.read_literal().await?;
if raw.is_empty() {
return Err(Error::eof());
}
let split = split_sentence(raw).ok_or(crate::NutError::NotProcessable)?;
T::decode(split)
.ok_or(Error::Nut(NutError::InvalidArgument))?
.into()
}
/// Reads all sentences in the stream until the given `matcher` function evaluates to `true`.
///
/// The final sentence is excluded.
///
/// Note: requires stream to be buffered.
pub async fn read_sentences_until<T: Sentence, F: Fn(&T) -> bool>(
&mut self,
matcher: F,
) -> crate::Result<Vec<T>> {
let mut result = Vec::new();
let max_iter = 1000; // Exit after 1000 lines to prevent overflow.
for _ in 0..max_iter {
let sentence: T = self.read_sentence().await?;
if matcher(&sentence) {
return Ok(result);
} else {
result.push(sentence);
}
}
Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"Reached maximum read capacity.",
)))
}
/// Wraps the current stream with a `BufReader`.
pub fn buffered(self) -> ConnectionStream {
Self::Buffered(Box::new(BufReader::new(self)))
}
/// Unwraps the underlying stream from the current `BufReader`.
/// If the current stream is not buffered, it returns itself (no-op).
///
/// Note that, if the stream is buffered, any un-consumed data will be discarded.
pub fn unbuffered(self) -> ConnectionStream {
if let Self::Buffered(buf) = self {
buf.into_inner()
} else {
self
}
}
}
impl AsyncRead for ConnectionStream {
@ -74,27 +178,33 @@ impl AsyncRead for ConnectionStream {
let pinned = Pin::new(stream);
pinned.poll_read(cx, buf)
}
#[cfg(test)]
Self::Mock(stream) => {
let pinned = Pin::new(stream);
pinned.poll_read(cx, buf)
}
}
}
}
impl AsyncBufRead for ConnectionStream {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
dbg!(&self);
match self.get_mut() {
Self::Buffered(reader) => {
let pinned = Pin::new(reader.get_mut());
let pinned = Pin::new(reader);
pinned.poll_fill_buf(cx)
}
_ => core::task::Poll::Ready(Err(std::io::Error::new(
s => core::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Stream is not buffered",
format!("Stream is not buffered: {:?}", s),
))),
}
}
fn consume(self: Pin<&mut Self>, amt: usize) {
if let Self::Buffered(reader) = self.get_mut() {
let pinned = Pin::new(reader.get_mut());
let pinned = Pin::new(reader);
pinned.consume(amt)
}
}
@ -125,6 +235,11 @@ impl AsyncWrite for ConnectionStream {
let pinned = Pin::new(stream);
pinned.poll_write(cx, buf)
}
#[cfg(test)]
Self::Mock(stream) => {
let pinned = Pin::new(stream);
pinned.poll_write(cx, buf)
}
}
}
@ -148,6 +263,11 @@ impl AsyncWrite for ConnectionStream {
let pinned = Pin::new(stream);
pinned.poll_flush(cx)
}
#[cfg(test)]
Self::Mock(stream) => {
let pinned = Pin::new(stream);
pinned.poll_flush(cx)
}
}
}
@ -171,6 +291,122 @@ impl AsyncWrite for ConnectionStream {
let pinned = Pin::new(stream);
pinned.poll_shutdown(cx)
}
#[cfg(test)]
Self::Mock(stream) => {
let pinned = Pin::new(stream);
pinned.poll_shutdown(cx)
}
}
}
}
#[cfg(test)]
mod tests {
use super::ConnectionStream;
use crate::proto::{ClientSentences, Sentence, ServerSentences};
#[tokio::test]
async fn read_write_sentence() {
let client_stream = crate::tokio::mockstream::AsyncMockStream::new();
let server_stream = client_stream.clone();
let mut client_stream = ConnectionStream::Mock(client_stream).buffered();
let mut server_stream = ConnectionStream::Mock(server_stream).buffered();
// Client requests list of UPS devices
client_stream
.write_sentence(&ServerSentences::QueryListUps {})
.await
.expect("Failed to write LIST UPS");
dbg!(&client_stream);
dbg!(&server_stream);
// Server reads query for list of UPS devices
let sentence = server_stream
.read_sentence::<ServerSentences>()
.await
.expect("Failed to read LIST UPS");
assert_eq!(sentence, ServerSentences::QueryListUps {});
// Server sends list of UPS devices.
server_stream
.write_sentences(&[
ClientSentences::BeginListUps {},
ClientSentences::RespondUps {
ups_name: "nutdev0".into(),
description: "A NUT device.".into(),
},
ClientSentences::RespondUps {
ups_name: "nutdev1".into(),
description: "Another NUT device.".into(),
},
ClientSentences::EndListUps {},
])
.await
.expect("Failed to write UPS LIST");
// Client reads list of UPS devices.
client_stream
.read_sentence::<ClientSentences>()
.await
.expect("Failed to read BEGIN LIST UPS")
.exactly(ClientSentences::BeginListUps {})
.unwrap();
let sentences: Vec<ClientSentences> = client_stream
.read_sentences_until(|s| matches!(s, ClientSentences::EndListUps {}))
.await
.expect("Failed to read UPS items");
assert_eq!(
sentences,
vec![
ClientSentences::RespondUps {
ups_name: "nutdev0".into(),
description: "A NUT device.".into(),
},
ClientSentences::RespondUps {
ups_name: "nutdev1".into(),
description: "Another NUT device.".into(),
},
]
);
// Client sends login
client_stream
.write_sentence(&ServerSentences::ExecLogin {
ups_name: "nutdev0".into(),
})
.await
.expect("Failed to write LOGIN nutdev0");
// Server receives login
server_stream
.read_sentence::<ServerSentences>()
.await
.expect("Failed to read LOGIN nutdev0")
.exactly(ServerSentences::ExecLogin {
ups_name: "nutdev0".into(),
})
.unwrap();
// Server rejects login
server_stream
.write_sentence(&ClientSentences::RespondErr {
message: "USERNAME-REQUIRED".into(),
extras: vec![],
})
.await
.expect("Failed to write ERR USERNAME-REQUIRED");
// Client expects error
let error: crate::Error = client_stream
.read_sentence::<ClientSentences>()
.await
.expect_err("Failed to read ERR");
assert!(matches!(
error,
crate::Error::Nut(crate::NutError::UsernameRequired)
));
}
}