Buffered connection stream wrapper

This commit is contained in:
Aram 🍐 2021-08-04 18:37:33 -04:00
parent a92500e67b
commit ff03f27b49
4 changed files with 158 additions and 47 deletions

View file

@ -1,5 +1,7 @@
use crate::blocking::stream::ConnectionStream; use crate::blocking::stream::ConnectionStream;
use crate::Config; use crate::proto::{ClientSentences, Sentence, ServerSentences};
use crate::{Config, Host, TcpHost};
use std::net::TcpStream;
/// A synchronous NUT client. /// A synchronous NUT client.
pub struct Client { pub struct Client {
@ -9,4 +11,53 @@ pub struct Client {
stream: ConnectionStream, stream: ConnectionStream,
} }
impl Client {} impl Client {
/// Connects to a remote NUT server using a blocking connection.
pub fn new(config: &Config) -> crate::Result<Self> {
match &config.host {
Host::Tcp(host) => Self::new_tcp(config, host),
}
// TODO: Support Unix domain sockets
}
/// Connects to a remote NUT server using a blocking TCP connection.
fn new_tcp(config: &Config, host: &TcpHost) -> crate::Result<Self> {
let tcp_stream = TcpStream::connect_timeout(&host.addr, config.timeout)?;
let mut client = Client {
config: config.clone(),
stream: ConnectionStream::Tcp(tcp_stream).buffered(),
};
client = client.enable_ssl()?;
// TODO: Enable SSL
// TODO: Login
Ok(client)
}
#[cfg(feature = "ssl")]
fn enable_ssl(mut self) -> crate::Result<Self> {
if self.config.ssl {
// Send STARTTLS
self.stream
.write_sentence(&ServerSentences::ExecStartTLS {})?;
// Expect the OK
self.stream
.read_sentence::<ClientSentences>()?
.as_exactly(ClientSentences::StartTLSOk {})?;
// Un-buffer to get back underlying stream
self.stream = self.stream.unbuffered();
// TODO: Un-buffer
// TODO: Do the upgrade
}
Ok(self)
}
#[cfg(not(feature = "ssl"))]
fn enable_ssl(self) -> crate::Result<Self> {
Ok(self)
}
}

View file

@ -60,7 +60,7 @@ impl TcpConnection {
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?; let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
let mut connection = Self { let mut connection = Self {
config, config,
stream: ConnectionStream::Plain(tcp_stream), stream: ConnectionStream::Tcp(tcp_stream),
}; };
connection = connection.enable_ssl()?; connection = connection.enable_ssl()?;
Ok(connection) Ok(connection)

View file

@ -7,7 +7,7 @@ use std::net::TcpStream;
/// A wrapper for various synchronous stream types. /// A wrapper for various synchronous stream types.
pub enum ConnectionStream { pub enum ConnectionStream {
/// A plain TCP stream. /// A plain TCP stream.
Plain(TcpStream), Tcp(TcpStream),
/// A client stream wrapped with SSL using `rustls`. /// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
@ -17,6 +17,12 @@ pub enum ConnectionStream {
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>), SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
/// A stream wrapped with `BufReader`.
///
/// Use `.buffered()` to wrap any stream with `BufReader`.
/// It can then be un-wrapped with `.unbuffered()`.
Buffered(Box<BufReader<ConnectionStream>>),
/// A mock stream, used for testing. /// A mock stream, used for testing.
#[cfg(test)] #[cfg(test)]
Mock(mockstream::SharedMockStream), Mock(mockstream::SharedMockStream),
@ -74,16 +80,20 @@ impl ConnectionStream {
} }
/// Reads a literal string from the current stream. /// Reads a literal string from the current stream.
/// Note: the literal string will ends with a new-line character (`\n`). ///
pub fn read_literal(reader: &mut BufReader<&mut Self>) -> crate::Result<String> { /// Note: the literal string will ends with a new-line character (`\n`)
/// Note: requires stream to be buffered.
pub fn read_literal(&mut self) -> crate::Result<String> {
let mut raw = String::new(); let mut raw = String::new();
reader.read_line(&mut raw)?; self.read_line(&mut raw)?;
Ok(raw) Ok(raw)
} }
/// Reads a sentence from the given `BufReader`. /// Reads a sentence from the current stream.
pub fn read_sentence<T: Sentence>(reader: &mut BufReader<&mut Self>) -> crate::Result<T> { ///
let raw = Self::read_literal(reader)?; /// Note: requires stream to be buffered.
pub fn read_sentence<T: Sentence>(&mut self) -> crate::Result<T> {
let raw = self.read_literal()?;
if raw.is_empty() { if raw.is_empty() {
return Err(Error::eof()); return Err(Error::eof());
} }
@ -93,17 +103,19 @@ impl ConnectionStream {
.into() .into()
} }
/// Reads all sentences in the buffer until the given `matcher` function evaluates to `true`. /// Reads all sentences in the stream until the given `matcher` function evaluates to `true`.
/// ///
/// The final sentence is excluded. /// The final sentence is excluded.
///
/// Note: requires stream to be buffered.
pub fn read_sentences_until<T: Sentence, F: Fn(&T) -> bool>( pub fn read_sentences_until<T: Sentence, F: Fn(&T) -> bool>(
reader: &mut BufReader<&mut Self>, &mut self,
matcher: F, matcher: F,
) -> crate::Result<Vec<T>> { ) -> crate::Result<Vec<T>> {
let mut result = Vec::new(); let mut result = Vec::new();
let max_iter = 1000; // Exit after 1000 lines to prevent overflow. let max_iter = 1000; // Exit after 1000 lines to prevent overflow.
for _ in 0..max_iter { for _ in 0..max_iter {
let sentence: T = Self::read_sentence(reader)?; let sentence: T = self.read_sentence()?;
if matcher(&sentence) { if matcher(&sentence) {
return Ok(result); return Ok(result);
} else { } else {
@ -116,16 +128,28 @@ impl ConnectionStream {
))) )))
} }
/// Initializes a new `BufReader` for the current stream. /// Wraps the current stream with a `BufReader`.
pub fn buffer(&mut self) -> BufReader<&mut Self> { pub fn buffered(mut self) -> ConnectionStream {
BufReader::new(self) Self::Buffered(Box::new(BufReader::new(self)))
}
/// Unwraps the underlying stream from the current `BufReader`.
/// If the current stream is not buffered, it returns itself (no-op).
///
/// Note that, if the stream is buffered, any un-consumed data will be discarded.
pub fn unbuffered(mut self) -> ConnectionStream {
match self {
Self::Buffered(buf) => buf.into_inner(),
_ => self,
}
} }
} }
impl Read for ConnectionStream { impl Read for ConnectionStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self { match self {
Self::Plain(stream) => stream.read(buf), Self::Tcp(stream) => stream.read(buf),
Self::Buffered(reader) => reader.read(buf),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
Self::SslClient(stream) => stream.read(buf), Self::SslClient(stream) => stream.read(buf),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
@ -136,10 +160,30 @@ impl Read for ConnectionStream {
} }
} }
impl BufRead for ConnectionStream {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
match self {
Self::Buffered(reader) => reader.fill_buf(),
_ => Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Stream is not buffered",
)),
}
}
fn consume(&mut self, amt: usize) {
match self {
Self::Buffered(reader) => reader.consume(amt),
_ => (),
}
}
}
impl Write for ConnectionStream { impl Write for ConnectionStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self { match self {
Self::Plain(stream) => stream.write(buf), Self::Tcp(stream) => stream.write(buf),
Self::Buffered(reader) => reader.get_mut().write(buf),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
Self::SslClient(stream) => stream.write(buf), Self::SslClient(stream) => stream.write(buf),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
@ -155,7 +199,8 @@ impl Write for ConnectionStream {
fn flush(&mut self) -> std::io::Result<()> { fn flush(&mut self) -> std::io::Result<()> {
match self { match self {
Self::Plain(stream) => stream.flush(), Self::Tcp(stream) => stream.flush(),
Self::Buffered(reader) => reader.get_mut().flush(),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
Self::SslClient(stream) => stream.flush(), Self::SslClient(stream) => stream.flush(),
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
@ -169,7 +214,7 @@ impl Write for ConnectionStream {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::ConnectionStream; use super::ConnectionStream;
use crate::proto::{ClientSentences, ServerSentences}; use crate::proto::{ClientSentences, Sentence, ServerSentences};
use std::io::{Read, Write}; use std::io::{Read, Write};
#[test] #[test]
@ -177,8 +222,8 @@ mod tests {
let mut client_stream = mockstream::SharedMockStream::new(); let mut client_stream = mockstream::SharedMockStream::new();
let mut server_stream = client_stream.clone(); let mut server_stream = client_stream.clone();
let mut client_stream = ConnectionStream::Mock(client_stream); let mut client_stream = ConnectionStream::Mock(client_stream).buffered();
let mut server_stream = ConnectionStream::Mock(server_stream); let mut server_stream = ConnectionStream::Mock(server_stream).buffered();
// Client requests list of UPS devices // Client requests list of UPS devices
client_stream client_stream
@ -186,9 +231,9 @@ mod tests {
.expect("Failed to write LIST UPS"); .expect("Failed to write LIST UPS");
// Server reads query for list of UPS devices // Server reads query for list of UPS devices
let mut server_buffer = server_stream.buffer(); let sentence = server_stream
let sentence: ServerSentences = .read_sentence::<ServerSentences>()
ConnectionStream::read_sentence(&mut server_buffer).expect("Failed to read LIST UPS"); .expect("Failed to read LIST UPS");
assert_eq!(sentence, ServerSentences::QueryListUps {}); assert_eq!(sentence, ServerSentences::QueryListUps {});
// Server sends list of UPS devices. // Server sends list of UPS devices.
@ -208,15 +253,14 @@ mod tests {
.expect("Failed to write UPS LIST"); .expect("Failed to write UPS LIST");
// Client reads list of UPS devices. // Client reads list of UPS devices.
let mut client_buffer = client_stream.buffer(); let sentence = client_stream
let sentence: ClientSentences = ConnectionStream::read_sentence(&mut client_buffer) .read_sentence::<ClientSentences>()
.expect("Failed to read BEGIN LIST UPS"); .expect("Failed to read BEGIN LIST UPS")
assert_eq!(sentence, ClientSentences::BeginListUps {}); .as_exactly(ClientSentences::BeginListUps {})
.unwrap();
let sentences: Vec<ClientSentences> = let sentences: Vec<ClientSentences> = client_stream
ConnectionStream::read_sentences_until(&mut client_buffer, |s| { .read_sentences_until(|s| matches!(s, ClientSentences::EndListUps {}))
matches!(s, ClientSentences::EndListUps {})
})
.expect("Failed to read UPS items"); .expect("Failed to read UPS items");
assert_eq!( assert_eq!(
@ -241,15 +285,13 @@ mod tests {
.expect("Failed to write LOGIN nutdev0"); .expect("Failed to write LOGIN nutdev0");
// Server receives login // Server receives login
let mut server_buffer = server_stream.buffer(); let sentence = server_stream
let sentence: ServerSentences = ConnectionStream::read_sentence(&mut server_buffer) .read_sentence::<ServerSentences>()
.expect("Failed to read LOGIN nutdev0"); .expect("Failed to read LOGIN nutdev0")
assert_eq!( .as_exactly(ServerSentences::ExecLogin {
sentence, ups_name: "nutdev0".into(),
ServerSentences::ExecLogin { })
ups_name: "nutdev0".into() .unwrap();
}
);
// Server rejects login // Server rejects login
server_stream server_stream
@ -260,10 +302,9 @@ mod tests {
.expect("Failed to write ERR USERNAME-REQUIRED"); .expect("Failed to write ERR USERNAME-REQUIRED");
// Client expects error // Client expects error
let mut client_buffer = client_stream.buffer(); let error: crate::Error = client_stream
let error: crate::Error = .read_sentence::<ClientSentences>()
ConnectionStream::read_sentence::<ClientSentences>(&mut client_buffer) .expect_err("Failed to read ERR");
.expect_err("Failed to read ERR");
assert!(matches!( assert!(matches!(
error, error,
crate::Error::Nut(crate::NutError::UsernameRequired) crate::Error::Nut(crate::NutError::UsernameRequired)

View file

@ -105,12 +105,30 @@ macro_rules! impl_words {
} }
/// A NUT protocol sentence that can be encoded and decoded from a Vector of strings. /// A NUT protocol sentence that can be encoded and decoded from a Vector of strings.
pub trait Sentence: Sized + Into<crate::Result<Self>> { pub trait Sentence: Eq + Sized + Into<crate::Result<Self>> {
/// Decodes a sentence. Returns `None` if the pattern cannot be recognized. /// Decodes a sentence. Returns `None` if the pattern cannot be recognized.
fn decode(raw: Vec<String>) -> Option<Self>; fn decode(raw: Vec<String>) -> Option<Self>;
/// Encodes the sentence. /// Encodes the sentence.
fn encode(&self) -> Vec<&str>; fn encode(&self) -> Vec<&str>;
/// Returns an error if the sentence does not match what was expected.
fn as_matching<F: FnOnce(&Self) -> bool>(self, matcher: F) -> crate::Result<Self> {
if matcher(&self) {
Ok(self)
} else {
Err(Error::Nut(NutError::UnexpectedResponse))
}
}
/// Returns an error if the sentence is not equal to what was expected.
fn as_exactly(self, expected: Self) -> crate::Result<Self> {
if expected == self {
Ok(self)
} else {
Err(Error::Nut(NutError::UnexpectedResponse))
}
}
} }
/// Implements the list of sentences, which are combinations /// Implements the list of sentences, which are combinations
@ -313,6 +331,7 @@ impl_words! {
Version("VERSION"), Version("VERSION"),
} }
use crate::{Error, NutError};
pub(crate) use impl_sentences; pub(crate) use impl_sentences;
#[cfg(test)] #[cfg(test)]
pub(crate) use test_encode_decode; pub(crate) use test_encode_decode;