Buffered connection stream wrapper

This commit is contained in:
Aram 🍐 2021-08-04 18:37:33 -04:00
parent a92500e67b
commit ff03f27b49
4 changed files with 158 additions and 47 deletions

View file

@ -1,5 +1,7 @@
use crate::blocking::stream::ConnectionStream;
use crate::Config;
use crate::proto::{ClientSentences, Sentence, ServerSentences};
use crate::{Config, Host, TcpHost};
use std::net::TcpStream;
/// A synchronous NUT client.
pub struct Client {
@ -9,4 +11,53 @@ pub struct Client {
stream: ConnectionStream,
}
impl Client {}
impl Client {
/// Connects to a remote NUT server using a blocking connection.
pub fn new(config: &Config) -> crate::Result<Self> {
match &config.host {
Host::Tcp(host) => Self::new_tcp(config, host),
}
// TODO: Support Unix domain sockets
}
/// Connects to a remote NUT server using a blocking TCP connection.
fn new_tcp(config: &Config, host: &TcpHost) -> crate::Result<Self> {
let tcp_stream = TcpStream::connect_timeout(&host.addr, config.timeout)?;
let mut client = Client {
config: config.clone(),
stream: ConnectionStream::Tcp(tcp_stream).buffered(),
};
client = client.enable_ssl()?;
// TODO: Enable SSL
// TODO: Login
Ok(client)
}
#[cfg(feature = "ssl")]
fn enable_ssl(mut self) -> crate::Result<Self> {
if self.config.ssl {
// Send STARTTLS
self.stream
.write_sentence(&ServerSentences::ExecStartTLS {})?;
// Expect the OK
self.stream
.read_sentence::<ClientSentences>()?
.as_exactly(ClientSentences::StartTLSOk {})?;
// Un-buffer to get back underlying stream
self.stream = self.stream.unbuffered();
// TODO: Un-buffer
// TODO: Do the upgrade
}
Ok(self)
}
#[cfg(not(feature = "ssl"))]
fn enable_ssl(self) -> crate::Result<Self> {
Ok(self)
}
}

View file

@ -60,7 +60,7 @@ impl TcpConnection {
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
let mut connection = Self {
config,
stream: ConnectionStream::Plain(tcp_stream),
stream: ConnectionStream::Tcp(tcp_stream),
};
connection = connection.enable_ssl()?;
Ok(connection)

View file

@ -7,7 +7,7 @@ use std::net::TcpStream;
/// A wrapper for various synchronous stream types.
pub enum ConnectionStream {
/// A plain TCP stream.
Plain(TcpStream),
Tcp(TcpStream),
/// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")]
@ -17,6 +17,12 @@ pub enum ConnectionStream {
#[cfg(feature = "ssl")]
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
/// A stream wrapped with `BufReader`.
///
/// Use `.buffered()` to wrap any stream with `BufReader`.
/// It can then be un-wrapped with `.unbuffered()`.
Buffered(Box<BufReader<ConnectionStream>>),
/// A mock stream, used for testing.
#[cfg(test)]
Mock(mockstream::SharedMockStream),
@ -74,16 +80,20 @@ impl ConnectionStream {
}
/// Reads a literal string from the current stream.
/// Note: the literal string will ends with a new-line character (`\n`).
pub fn read_literal(reader: &mut BufReader<&mut Self>) -> crate::Result<String> {
///
/// Note: the literal string will ends with a new-line character (`\n`)
/// Note: requires stream to be buffered.
pub fn read_literal(&mut self) -> crate::Result<String> {
let mut raw = String::new();
reader.read_line(&mut raw)?;
self.read_line(&mut raw)?;
Ok(raw)
}
/// Reads a sentence from the given `BufReader`.
pub fn read_sentence<T: Sentence>(reader: &mut BufReader<&mut Self>) -> crate::Result<T> {
let raw = Self::read_literal(reader)?;
/// Reads a sentence from the current stream.
///
/// Note: requires stream to be buffered.
pub fn read_sentence<T: Sentence>(&mut self) -> crate::Result<T> {
let raw = self.read_literal()?;
if raw.is_empty() {
return Err(Error::eof());
}
@ -93,17 +103,19 @@ impl ConnectionStream {
.into()
}
/// Reads all sentences in the buffer until the given `matcher` function evaluates to `true`.
/// Reads all sentences in the stream until the given `matcher` function evaluates to `true`.
///
/// The final sentence is excluded.
///
/// Note: requires stream to be buffered.
pub fn read_sentences_until<T: Sentence, F: Fn(&T) -> bool>(
reader: &mut BufReader<&mut Self>,
&mut self,
matcher: F,
) -> crate::Result<Vec<T>> {
let mut result = Vec::new();
let max_iter = 1000; // Exit after 1000 lines to prevent overflow.
for _ in 0..max_iter {
let sentence: T = Self::read_sentence(reader)?;
let sentence: T = self.read_sentence()?;
if matcher(&sentence) {
return Ok(result);
} else {
@ -116,16 +128,28 @@ impl ConnectionStream {
)))
}
/// Initializes a new `BufReader` for the current stream.
pub fn buffer(&mut self) -> BufReader<&mut Self> {
BufReader::new(self)
/// Wraps the current stream with a `BufReader`.
pub fn buffered(mut self) -> ConnectionStream {
Self::Buffered(Box::new(BufReader::new(self)))
}
/// Unwraps the underlying stream from the current `BufReader`.
/// If the current stream is not buffered, it returns itself (no-op).
///
/// Note that, if the stream is buffered, any un-consumed data will be discarded.
pub fn unbuffered(mut self) -> ConnectionStream {
match self {
Self::Buffered(buf) => buf.into_inner(),
_ => self,
}
}
}
impl Read for ConnectionStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Plain(stream) => stream.read(buf),
Self::Tcp(stream) => stream.read(buf),
Self::Buffered(reader) => reader.read(buf),
#[cfg(feature = "ssl")]
Self::SslClient(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
@ -136,10 +160,30 @@ impl Read for ConnectionStream {
}
}
impl BufRead for ConnectionStream {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
match self {
Self::Buffered(reader) => reader.fill_buf(),
_ => Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Stream is not buffered",
)),
}
}
fn consume(&mut self, amt: usize) {
match self {
Self::Buffered(reader) => reader.consume(amt),
_ => (),
}
}
}
impl Write for ConnectionStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
Self::Plain(stream) => stream.write(buf),
Self::Tcp(stream) => stream.write(buf),
Self::Buffered(reader) => reader.get_mut().write(buf),
#[cfg(feature = "ssl")]
Self::SslClient(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
@ -155,7 +199,8 @@ impl Write for ConnectionStream {
fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(stream) => stream.flush(),
Self::Tcp(stream) => stream.flush(),
Self::Buffered(reader) => reader.get_mut().flush(),
#[cfg(feature = "ssl")]
Self::SslClient(stream) => stream.flush(),
#[cfg(feature = "ssl")]
@ -169,7 +214,7 @@ impl Write for ConnectionStream {
#[cfg(test)]
mod tests {
use super::ConnectionStream;
use crate::proto::{ClientSentences, ServerSentences};
use crate::proto::{ClientSentences, Sentence, ServerSentences};
use std::io::{Read, Write};
#[test]
@ -177,8 +222,8 @@ mod tests {
let mut client_stream = mockstream::SharedMockStream::new();
let mut server_stream = client_stream.clone();
let mut client_stream = ConnectionStream::Mock(client_stream);
let mut server_stream = ConnectionStream::Mock(server_stream);
let mut client_stream = ConnectionStream::Mock(client_stream).buffered();
let mut server_stream = ConnectionStream::Mock(server_stream).buffered();
// Client requests list of UPS devices
client_stream
@ -186,9 +231,9 @@ mod tests {
.expect("Failed to write LIST UPS");
// Server reads query for list of UPS devices
let mut server_buffer = server_stream.buffer();
let sentence: ServerSentences =
ConnectionStream::read_sentence(&mut server_buffer).expect("Failed to read LIST UPS");
let sentence = server_stream
.read_sentence::<ServerSentences>()
.expect("Failed to read LIST UPS");
assert_eq!(sentence, ServerSentences::QueryListUps {});
// Server sends list of UPS devices.
@ -208,15 +253,14 @@ mod tests {
.expect("Failed to write UPS LIST");
// Client reads list of UPS devices.
let mut client_buffer = client_stream.buffer();
let sentence: ClientSentences = ConnectionStream::read_sentence(&mut client_buffer)
.expect("Failed to read BEGIN LIST UPS");
assert_eq!(sentence, ClientSentences::BeginListUps {});
let sentence = client_stream
.read_sentence::<ClientSentences>()
.expect("Failed to read BEGIN LIST UPS")
.as_exactly(ClientSentences::BeginListUps {})
.unwrap();
let sentences: Vec<ClientSentences> =
ConnectionStream::read_sentences_until(&mut client_buffer, |s| {
matches!(s, ClientSentences::EndListUps {})
})
let sentences: Vec<ClientSentences> = client_stream
.read_sentences_until(|s| matches!(s, ClientSentences::EndListUps {}))
.expect("Failed to read UPS items");
assert_eq!(
@ -241,15 +285,13 @@ mod tests {
.expect("Failed to write LOGIN nutdev0");
// Server receives login
let mut server_buffer = server_stream.buffer();
let sentence: ServerSentences = ConnectionStream::read_sentence(&mut server_buffer)
.expect("Failed to read LOGIN nutdev0");
assert_eq!(
sentence,
ServerSentences::ExecLogin {
ups_name: "nutdev0".into()
}
);
let sentence = server_stream
.read_sentence::<ServerSentences>()
.expect("Failed to read LOGIN nutdev0")
.as_exactly(ServerSentences::ExecLogin {
ups_name: "nutdev0".into(),
})
.unwrap();
// Server rejects login
server_stream
@ -260,10 +302,9 @@ mod tests {
.expect("Failed to write ERR USERNAME-REQUIRED");
// Client expects error
let mut client_buffer = client_stream.buffer();
let error: crate::Error =
ConnectionStream::read_sentence::<ClientSentences>(&mut client_buffer)
.expect_err("Failed to read ERR");
let error: crate::Error = client_stream
.read_sentence::<ClientSentences>()
.expect_err("Failed to read ERR");
assert!(matches!(
error,
crate::Error::Nut(crate::NutError::UsernameRequired)

View file

@ -105,12 +105,30 @@ macro_rules! impl_words {
}
/// A NUT protocol sentence that can be encoded and decoded from a Vector of strings.
pub trait Sentence: Sized + Into<crate::Result<Self>> {
pub trait Sentence: Eq + Sized + Into<crate::Result<Self>> {
/// Decodes a sentence. Returns `None` if the pattern cannot be recognized.
fn decode(raw: Vec<String>) -> Option<Self>;
/// Encodes the sentence.
fn encode(&self) -> Vec<&str>;
/// Returns an error if the sentence does not match what was expected.
fn as_matching<F: FnOnce(&Self) -> bool>(self, matcher: F) -> crate::Result<Self> {
if matcher(&self) {
Ok(self)
} else {
Err(Error::Nut(NutError::UnexpectedResponse))
}
}
/// Returns an error if the sentence is not equal to what was expected.
fn as_exactly(self, expected: Self) -> crate::Result<Self> {
if expected == self {
Ok(self)
} else {
Err(Error::Nut(NutError::UnexpectedResponse))
}
}
}
/// Implements the list of sentences, which are combinations
@ -313,6 +331,7 @@ impl_words! {
Version("VERSION"),
}
use crate::{Error, NutError};
pub(crate) use impl_sentences;
#[cfg(test)]
pub(crate) use test_encode_decode;