mirror of
https://github.com/aramperes/nut-rs.git
synced 2025-09-09 05:28:31 -04:00
Bring Tokio side up to date
This commit is contained in:
parent
162c015680
commit
3b3f9278de
9 changed files with 461 additions and 62 deletions
|
@ -23,7 +23,6 @@ tokio-rustls = { version = "0.22", optional = true }
|
|||
|
||||
[dev-dependencies]
|
||||
mockstream = "0.0.3"
|
||||
tokio-mockstream = "1.1.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::blocking::stream::ConnectionStream;
|
||||
use crate::proto::{ClientSentences, Sentence, ServerSentences};
|
||||
use crate::{Config, Error, Host, NutError, TcpHost};
|
||||
use std::net::TcpStream;
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ mod stream;
|
|||
|
||||
pub use client::Client;
|
||||
|
||||
// TODO: Remove me
|
||||
/// A blocking NUT client connection.
|
||||
pub enum Connection {
|
||||
/// A TCP connection.
|
||||
|
|
|
@ -721,6 +721,25 @@ macro_rules! implement_client_action_commands {
|
|||
}
|
||||
)*
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
impl crate::tokio::Client {
|
||||
$(
|
||||
$(#[$attr])*
|
||||
#[allow(dead_code)]
|
||||
$vis async fn $name(&mut self$(, $argname: $argty)*) -> crate::Result<()> {
|
||||
use crate::proto::{Sentence, ClientSentences, ClientSentences::*, ServerSentences::*};
|
||||
self.stream
|
||||
.write_sentence(&$cmd)
|
||||
.await?;
|
||||
self.stream
|
||||
.read_sentence::<ClientSentences>()
|
||||
.await?
|
||||
.exactly($ret)
|
||||
.map(|_| ())
|
||||
}
|
||||
)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
110
rups/src/tokio/client.rs
Normal file
110
rups/src/tokio/client.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
use crate::tokio::stream::ConnectionStream;
|
||||
use crate::{Config, Error, Host, NutError, TcpHost};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// An asynchronous NUT client, using Tokio.
|
||||
pub struct Client {
|
||||
/// The client configuration.
|
||||
config: Config,
|
||||
/// The client connection.
|
||||
pub(crate) stream: ConnectionStream,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Connects to a remote NUT server using a blocking connection.
|
||||
pub async fn new(config: &Config) -> crate::Result<Self> {
|
||||
match &config.host {
|
||||
Host::Tcp(host) => Self::new_tcp(config, host).await,
|
||||
}
|
||||
// TODO: Support Unix domain sockets
|
||||
}
|
||||
|
||||
/// Connects to a remote NUT server using a blocking TCP connection.
|
||||
async fn new_tcp(config: &Config, host: &TcpHost) -> crate::Result<Self> {
|
||||
let tcp_stream = TcpStream::connect(&host.addr).await?;
|
||||
let mut client = Client {
|
||||
config: config.clone(),
|
||||
stream: ConnectionStream::Tcp(tcp_stream).buffered(),
|
||||
};
|
||||
|
||||
client = client.enable_ssl().await?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Authenticates to the given UPS device with the username and password set in the config.
|
||||
pub async fn login(&mut self, ups_name: String) -> crate::Result<()> {
|
||||
if let Some(auth) = self.config.auth.clone() {
|
||||
// Pass username and check for 'OK'
|
||||
self.set_username(auth.username).await?;
|
||||
|
||||
// Pass password and check for 'OK'
|
||||
if let Some(password) = auth.password {
|
||||
self.set_password(password).await?;
|
||||
}
|
||||
|
||||
// Submit login
|
||||
self.exec_login(ups_name).await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "async-ssl")]
|
||||
async fn enable_ssl(mut self) -> crate::Result<Self> {
|
||||
if self.config.ssl {
|
||||
self.exec_start_tls().await?;
|
||||
|
||||
// Initialize SSL configurations
|
||||
let mut ssl_config = rustls::ClientConfig::new();
|
||||
let dns_name: webpki::DNSName;
|
||||
|
||||
if self.config.ssl_insecure {
|
||||
ssl_config
|
||||
.dangerous()
|
||||
.set_certificate_verifier(std::sync::Arc::new(
|
||||
crate::ssl::InsecureCertificateValidator::new(&self.config),
|
||||
));
|
||||
|
||||
dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com")
|
||||
.unwrap()
|
||||
.to_owned();
|
||||
} else {
|
||||
// Try to get hostname as given (e.g. localhost can be used for strict SSL, but not 127.0.0.1)
|
||||
let hostname = self
|
||||
.config
|
||||
.host
|
||||
.hostname()
|
||||
.ok_or(Error::Nut(NutError::SslInvalidHostname))?;
|
||||
|
||||
dns_name = webpki::DNSNameRef::try_from_ascii_str(&hostname)
|
||||
.map_err(|_| Error::Nut(NutError::SslInvalidHostname))?
|
||||
.to_owned();
|
||||
|
||||
ssl_config
|
||||
.root_store
|
||||
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
||||
};
|
||||
|
||||
let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config));
|
||||
|
||||
// Un-buffer to get back underlying stream
|
||||
self.stream = self.stream.unbuffered();
|
||||
|
||||
// Upgrade to SSL
|
||||
self.stream = self
|
||||
.stream
|
||||
.upgrade_ssl_client(config, dns_name.as_ref())
|
||||
.await?;
|
||||
|
||||
// Re-buffer
|
||||
self.stream = self.stream.buffered();
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "async-ssl"))]
|
||||
async fn enable_ssl(self) -> crate::Result<Self> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
82
rups/src/tokio/mockstream.rs
Normal file
82
rups/src/tokio/mockstream.rs
Normal file
|
@ -0,0 +1,82 @@
|
|||
use std::fmt;
|
||||
use std::io::{Error, Read, Write};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
/// Async stream for unit testing.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct AsyncMockStream(mockstream::SyncMockStream);
|
||||
|
||||
impl AsyncMockStream {
|
||||
/// Create empty stream
|
||||
pub fn new() -> AsyncMockStream {
|
||||
AsyncMockStream::default()
|
||||
}
|
||||
|
||||
/// Extract all bytes written by Write trait calls.
|
||||
pub fn push_bytes_to_read(&mut self, bytes: &[u8]) {
|
||||
self.0.push_bytes_to_read(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for AsyncMockStream {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("AsyncMockStream").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for AsyncMockStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for AsyncMockStream {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.0.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.0.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for AsyncMockStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let mut vec = Vec::new();
|
||||
match self.get_mut().read_to_end(&mut vec) {
|
||||
Ok(_) => {
|
||||
let slice = vec.as_slice();
|
||||
buf.put_slice(slice);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Err(e) => Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for AsyncMockStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
let len = buf.len();
|
||||
self.get_mut().push_bytes_to_read(buf);
|
||||
Poll::Ready(Ok(len))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(self.get_mut().flush())
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
|
@ -6,8 +6,14 @@ use crate::{Config, Error, Host, NutError};
|
|||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
mod client;
|
||||
#[cfg(test)]
|
||||
mod mockstream;
|
||||
mod stream;
|
||||
|
||||
pub use client::Client;
|
||||
|
||||
// TODO: Remove me
|
||||
/// An async NUT client connection.
|
||||
pub enum Connection {
|
||||
/// A TCP connection.
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
use crate::Error;
|
||||
use crate::proto::util::{join_sentence, split_sentence};
|
||||
use crate::proto::Sentence;
|
||||
use crate::{Error, NutError};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, ReadBuf};
|
||||
use tokio::io::{
|
||||
AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadBuf,
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// A wrapper for various Tokio stream types.
|
||||
#[derive(Debug)]
|
||||
pub enum ConnectionStream {
|
||||
/// A plain TCP stream.
|
||||
Tcp(TcpStream),
|
||||
|
@ -22,6 +27,10 @@ pub enum ConnectionStream {
|
|||
/// A server stream wrapped with SSL using `rustls`.
|
||||
#[cfg(feature = "async-ssl")]
|
||||
SslServer(Box<tokio_rustls::server::TlsStream<ConnectionStream>>),
|
||||
|
||||
/// A mock stream, used for testing.
|
||||
#[cfg(test)]
|
||||
Mock(crate::tokio::mockstream::AsyncMockStream),
|
||||
}
|
||||
|
||||
impl ConnectionStream {
|
||||
|
@ -47,6 +56,101 @@ impl ConnectionStream {
|
|||
acceptor.accept(self).await.map_err(Error::Io)?,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Writes a sentence on the current stream.
|
||||
pub async fn write_sentence<T: Sentence>(&mut self, sentence: &T) -> crate::Result<()> {
|
||||
let encoded = sentence.encode();
|
||||
let joined = join_sentence(encoded);
|
||||
self.write_literal(&joined).await?;
|
||||
self.flush().await.map_err(crate::Error::Io)
|
||||
}
|
||||
|
||||
/// Writes a collection of sentences on the current stream.
|
||||
pub async fn write_sentences<T: Sentence>(&mut self, sentences: &[T]) -> crate::Result<()> {
|
||||
for sentence in sentences {
|
||||
let encoded = sentence.encode();
|
||||
let joined = join_sentence(encoded);
|
||||
self.write_literal(&joined).await?;
|
||||
}
|
||||
self.flush().await.map_err(crate::Error::Io)
|
||||
}
|
||||
|
||||
/// Writes a literal string to the current stream.
|
||||
/// Note: the literal string MUST end with a new-line character (`\n`).
|
||||
///
|
||||
/// Note: does not automatically flush.
|
||||
pub async fn write_literal(&mut self, literal: &str) -> crate::Result<()> {
|
||||
assert!(literal.ends_with('\n'));
|
||||
self.write_all(literal.as_bytes()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reads a literal string from the current stream.
|
||||
///
|
||||
/// Note: the literal string will ends with a new-line character (`\n`)
|
||||
/// Note: requires stream to be buffered.
|
||||
pub async fn read_literal(&mut self) -> crate::Result<String> {
|
||||
let mut raw = String::new();
|
||||
self.read_line(&mut raw).await?;
|
||||
Ok(raw)
|
||||
}
|
||||
|
||||
/// Reads a sentence from the current stream.
|
||||
///
|
||||
/// Note: requires stream to be buffered.
|
||||
pub async fn read_sentence<T: Sentence>(&mut self) -> crate::Result<T> {
|
||||
dbg!(&self);
|
||||
let raw = self.read_literal().await?;
|
||||
if raw.is_empty() {
|
||||
return Err(Error::eof());
|
||||
}
|
||||
let split = split_sentence(raw).ok_or(crate::NutError::NotProcessable)?;
|
||||
T::decode(split)
|
||||
.ok_or(Error::Nut(NutError::InvalidArgument))?
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Reads all sentences in the stream until the given `matcher` function evaluates to `true`.
|
||||
///
|
||||
/// The final sentence is excluded.
|
||||
///
|
||||
/// Note: requires stream to be buffered.
|
||||
pub async fn read_sentences_until<T: Sentence, F: Fn(&T) -> bool>(
|
||||
&mut self,
|
||||
matcher: F,
|
||||
) -> crate::Result<Vec<T>> {
|
||||
let mut result = Vec::new();
|
||||
let max_iter = 1000; // Exit after 1000 lines to prevent overflow.
|
||||
for _ in 0..max_iter {
|
||||
let sentence: T = self.read_sentence().await?;
|
||||
if matcher(&sentence) {
|
||||
return Ok(result);
|
||||
} else {
|
||||
result.push(sentence);
|
||||
}
|
||||
}
|
||||
Err(Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::Interrupted,
|
||||
"Reached maximum read capacity.",
|
||||
)))
|
||||
}
|
||||
|
||||
/// Wraps the current stream with a `BufReader`.
|
||||
pub fn buffered(self) -> ConnectionStream {
|
||||
Self::Buffered(Box::new(BufReader::new(self)))
|
||||
}
|
||||
|
||||
/// Unwraps the underlying stream from the current `BufReader`.
|
||||
/// If the current stream is not buffered, it returns itself (no-op).
|
||||
///
|
||||
/// Note that, if the stream is buffered, any un-consumed data will be discarded.
|
||||
pub fn unbuffered(self) -> ConnectionStream {
|
||||
if let Self::Buffered(buf) = self {
|
||||
buf.into_inner()
|
||||
} else {
|
||||
self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for ConnectionStream {
|
||||
|
@ -74,27 +178,33 @@ impl AsyncRead for ConnectionStream {
|
|||
let pinned = Pin::new(stream);
|
||||
pinned.poll_read(cx, buf)
|
||||
}
|
||||
#[cfg(test)]
|
||||
Self::Mock(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncBufRead for ConnectionStream {
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
|
||||
dbg!(&self);
|
||||
match self.get_mut() {
|
||||
Self::Buffered(reader) => {
|
||||
let pinned = Pin::new(reader.get_mut());
|
||||
let pinned = Pin::new(reader);
|
||||
pinned.poll_fill_buf(cx)
|
||||
}
|
||||
_ => core::task::Poll::Ready(Err(std::io::Error::new(
|
||||
s => core::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Unsupported,
|
||||
"Stream is not buffered",
|
||||
format!("Stream is not buffered: {:?}", s),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||
if let Self::Buffered(reader) = self.get_mut() {
|
||||
let pinned = Pin::new(reader.get_mut());
|
||||
let pinned = Pin::new(reader);
|
||||
pinned.consume(amt)
|
||||
}
|
||||
}
|
||||
|
@ -125,6 +235,11 @@ impl AsyncWrite for ConnectionStream {
|
|||
let pinned = Pin::new(stream);
|
||||
pinned.poll_write(cx, buf)
|
||||
}
|
||||
#[cfg(test)]
|
||||
Self::Mock(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_write(cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,6 +263,11 @@ impl AsyncWrite for ConnectionStream {
|
|||
let pinned = Pin::new(stream);
|
||||
pinned.poll_flush(cx)
|
||||
}
|
||||
#[cfg(test)]
|
||||
Self::Mock(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,6 +291,122 @@ impl AsyncWrite for ConnectionStream {
|
|||
let pinned = Pin::new(stream);
|
||||
pinned.poll_shutdown(cx)
|
||||
}
|
||||
#[cfg(test)]
|
||||
Self::Mock(stream) => {
|
||||
let pinned = Pin::new(stream);
|
||||
pinned.poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ConnectionStream;
|
||||
use crate::proto::{ClientSentences, Sentence, ServerSentences};
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_write_sentence() {
|
||||
let client_stream = crate::tokio::mockstream::AsyncMockStream::new();
|
||||
let server_stream = client_stream.clone();
|
||||
|
||||
let mut client_stream = ConnectionStream::Mock(client_stream).buffered();
|
||||
let mut server_stream = ConnectionStream::Mock(server_stream).buffered();
|
||||
|
||||
// Client requests list of UPS devices
|
||||
client_stream
|
||||
.write_sentence(&ServerSentences::QueryListUps {})
|
||||
.await
|
||||
.expect("Failed to write LIST UPS");
|
||||
|
||||
dbg!(&client_stream);
|
||||
dbg!(&server_stream);
|
||||
|
||||
// Server reads query for list of UPS devices
|
||||
let sentence = server_stream
|
||||
.read_sentence::<ServerSentences>()
|
||||
.await
|
||||
.expect("Failed to read LIST UPS");
|
||||
assert_eq!(sentence, ServerSentences::QueryListUps {});
|
||||
|
||||
// Server sends list of UPS devices.
|
||||
server_stream
|
||||
.write_sentences(&[
|
||||
ClientSentences::BeginListUps {},
|
||||
ClientSentences::RespondUps {
|
||||
ups_name: "nutdev0".into(),
|
||||
description: "A NUT device.".into(),
|
||||
},
|
||||
ClientSentences::RespondUps {
|
||||
ups_name: "nutdev1".into(),
|
||||
description: "Another NUT device.".into(),
|
||||
},
|
||||
ClientSentences::EndListUps {},
|
||||
])
|
||||
.await
|
||||
.expect("Failed to write UPS LIST");
|
||||
|
||||
// Client reads list of UPS devices.
|
||||
client_stream
|
||||
.read_sentence::<ClientSentences>()
|
||||
.await
|
||||
.expect("Failed to read BEGIN LIST UPS")
|
||||
.exactly(ClientSentences::BeginListUps {})
|
||||
.unwrap();
|
||||
|
||||
let sentences: Vec<ClientSentences> = client_stream
|
||||
.read_sentences_until(|s| matches!(s, ClientSentences::EndListUps {}))
|
||||
.await
|
||||
.expect("Failed to read UPS items");
|
||||
|
||||
assert_eq!(
|
||||
sentences,
|
||||
vec![
|
||||
ClientSentences::RespondUps {
|
||||
ups_name: "nutdev0".into(),
|
||||
description: "A NUT device.".into(),
|
||||
},
|
||||
ClientSentences::RespondUps {
|
||||
ups_name: "nutdev1".into(),
|
||||
description: "Another NUT device.".into(),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
// Client sends login
|
||||
client_stream
|
||||
.write_sentence(&ServerSentences::ExecLogin {
|
||||
ups_name: "nutdev0".into(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to write LOGIN nutdev0");
|
||||
|
||||
// Server receives login
|
||||
server_stream
|
||||
.read_sentence::<ServerSentences>()
|
||||
.await
|
||||
.expect("Failed to read LOGIN nutdev0")
|
||||
.exactly(ServerSentences::ExecLogin {
|
||||
ups_name: "nutdev0".into(),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Server rejects login
|
||||
server_stream
|
||||
.write_sentence(&ClientSentences::RespondErr {
|
||||
message: "USERNAME-REQUIRED".into(),
|
||||
extras: vec![],
|
||||
})
|
||||
.await
|
||||
.expect("Failed to write ERR USERNAME-REQUIRED");
|
||||
|
||||
// Client expects error
|
||||
let error: crate::Error = client_stream
|
||||
.read_sentence::<ClientSentences>()
|
||||
.await
|
||||
.expect_err("Failed to read ERR");
|
||||
assert!(matches!(
|
||||
error,
|
||||
crate::Error::Nut(crate::NutError::UsernameRequired)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue