diff --git a/rups/src/blocking/client.rs b/rups/src/blocking/client.rs index ddb961c..096ec40 100644 --- a/rups/src/blocking/client.rs +++ b/rups/src/blocking/client.rs @@ -1,5 +1,7 @@ use crate::blocking::stream::ConnectionStream; -use crate::Config; +use crate::proto::{ClientSentences, Sentence, ServerSentences}; +use crate::{Config, Host, TcpHost}; +use std::net::TcpStream; /// A synchronous NUT client. pub struct Client { @@ -9,4 +11,53 @@ pub struct Client { stream: ConnectionStream, } -impl Client {} +impl Client { + /// Connects to a remote NUT server using a blocking connection. + pub fn new(config: &Config) -> crate::Result { + match &config.host { + Host::Tcp(host) => Self::new_tcp(config, host), + } + // TODO: Support Unix domain sockets + } + + /// Connects to a remote NUT server using a blocking TCP connection. + fn new_tcp(config: &Config, host: &TcpHost) -> crate::Result { + let tcp_stream = TcpStream::connect_timeout(&host.addr, config.timeout)?; + let mut client = Client { + config: config.clone(), + stream: ConnectionStream::Tcp(tcp_stream).buffered(), + }; + + client = client.enable_ssl()?; + + // TODO: Enable SSL + // TODO: Login + Ok(client) + } + + #[cfg(feature = "ssl")] + fn enable_ssl(mut self) -> crate::Result { + if self.config.ssl { + // Send STARTTLS + self.stream + .write_sentence(&ServerSentences::ExecStartTLS {})?; + + // Expect the OK + self.stream + .read_sentence::()? + .as_exactly(ClientSentences::StartTLSOk {})?; + + // Un-buffer to get back underlying stream + self.stream = self.stream.unbuffered(); + + // TODO: Un-buffer + // TODO: Do the upgrade + } + Ok(self) + } + + #[cfg(not(feature = "ssl"))] + fn enable_ssl(self) -> crate::Result { + Ok(self) + } +} diff --git a/rups/src/blocking/mod.rs b/rups/src/blocking/mod.rs index b3515b6..2019c9e 100644 --- a/rups/src/blocking/mod.rs +++ b/rups/src/blocking/mod.rs @@ -60,7 +60,7 @@ impl TcpConnection { let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?; let mut connection = Self { config, - stream: ConnectionStream::Plain(tcp_stream), + stream: ConnectionStream::Tcp(tcp_stream), }; connection = connection.enable_ssl()?; Ok(connection) diff --git a/rups/src/blocking/stream.rs b/rups/src/blocking/stream.rs index d1ba29a..5801ae2 100644 --- a/rups/src/blocking/stream.rs +++ b/rups/src/blocking/stream.rs @@ -7,7 +7,7 @@ use std::net::TcpStream; /// A wrapper for various synchronous stream types. pub enum ConnectionStream { /// A plain TCP stream. - Plain(TcpStream), + Tcp(TcpStream), /// A client stream wrapped with SSL using `rustls`. #[cfg(feature = "ssl")] @@ -17,6 +17,12 @@ pub enum ConnectionStream { #[cfg(feature = "ssl")] SslServer(Box>), + /// A stream wrapped with `BufReader`. + /// + /// Use `.buffered()` to wrap any stream with `BufReader`. + /// It can then be un-wrapped with `.unbuffered()`. + Buffered(Box>), + /// A mock stream, used for testing. #[cfg(test)] Mock(mockstream::SharedMockStream), @@ -74,16 +80,20 @@ impl ConnectionStream { } /// Reads a literal string from the current stream. - /// Note: the literal string will ends with a new-line character (`\n`). - pub fn read_literal(reader: &mut BufReader<&mut Self>) -> crate::Result { + /// + /// Note: the literal string will ends with a new-line character (`\n`) + /// Note: requires stream to be buffered. + pub fn read_literal(&mut self) -> crate::Result { let mut raw = String::new(); - reader.read_line(&mut raw)?; + self.read_line(&mut raw)?; Ok(raw) } - /// Reads a sentence from the given `BufReader`. - pub fn read_sentence(reader: &mut BufReader<&mut Self>) -> crate::Result { - let raw = Self::read_literal(reader)?; + /// Reads a sentence from the current stream. + /// + /// Note: requires stream to be buffered. + pub fn read_sentence(&mut self) -> crate::Result { + let raw = self.read_literal()?; if raw.is_empty() { return Err(Error::eof()); } @@ -93,17 +103,19 @@ impl ConnectionStream { .into() } - /// Reads all sentences in the buffer until the given `matcher` function evaluates to `true`. + /// Reads all sentences in the stream until the given `matcher` function evaluates to `true`. /// /// The final sentence is excluded. + /// + /// Note: requires stream to be buffered. pub fn read_sentences_until bool>( - reader: &mut BufReader<&mut Self>, + &mut self, matcher: F, ) -> crate::Result> { let mut result = Vec::new(); let max_iter = 1000; // Exit after 1000 lines to prevent overflow. for _ in 0..max_iter { - let sentence: T = Self::read_sentence(reader)?; + let sentence: T = self.read_sentence()?; if matcher(&sentence) { return Ok(result); } else { @@ -116,16 +128,28 @@ impl ConnectionStream { ))) } - /// Initializes a new `BufReader` for the current stream. - pub fn buffer(&mut self) -> BufReader<&mut Self> { - BufReader::new(self) + /// Wraps the current stream with a `BufReader`. + pub fn buffered(mut self) -> ConnectionStream { + Self::Buffered(Box::new(BufReader::new(self))) + } + + /// Unwraps the underlying stream from the current `BufReader`. + /// If the current stream is not buffered, it returns itself (no-op). + /// + /// Note that, if the stream is buffered, any un-consumed data will be discarded. + pub fn unbuffered(mut self) -> ConnectionStream { + match self { + Self::Buffered(buf) => buf.into_inner(), + _ => self, + } } } impl Read for ConnectionStream { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { match self { - Self::Plain(stream) => stream.read(buf), + Self::Tcp(stream) => stream.read(buf), + Self::Buffered(reader) => reader.read(buf), #[cfg(feature = "ssl")] Self::SslClient(stream) => stream.read(buf), #[cfg(feature = "ssl")] @@ -136,10 +160,30 @@ impl Read for ConnectionStream { } } +impl BufRead for ConnectionStream { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + match self { + Self::Buffered(reader) => reader.fill_buf(), + _ => Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Stream is not buffered", + )), + } + } + + fn consume(&mut self, amt: usize) { + match self { + Self::Buffered(reader) => reader.consume(amt), + _ => (), + } + } +} + impl Write for ConnectionStream { fn write(&mut self, buf: &[u8]) -> std::io::Result { match self { - Self::Plain(stream) => stream.write(buf), + Self::Tcp(stream) => stream.write(buf), + Self::Buffered(reader) => reader.get_mut().write(buf), #[cfg(feature = "ssl")] Self::SslClient(stream) => stream.write(buf), #[cfg(feature = "ssl")] @@ -155,7 +199,8 @@ impl Write for ConnectionStream { fn flush(&mut self) -> std::io::Result<()> { match self { - Self::Plain(stream) => stream.flush(), + Self::Tcp(stream) => stream.flush(), + Self::Buffered(reader) => reader.get_mut().flush(), #[cfg(feature = "ssl")] Self::SslClient(stream) => stream.flush(), #[cfg(feature = "ssl")] @@ -169,7 +214,7 @@ impl Write for ConnectionStream { #[cfg(test)] mod tests { use super::ConnectionStream; - use crate::proto::{ClientSentences, ServerSentences}; + use crate::proto::{ClientSentences, Sentence, ServerSentences}; use std::io::{Read, Write}; #[test] @@ -177,8 +222,8 @@ mod tests { let mut client_stream = mockstream::SharedMockStream::new(); let mut server_stream = client_stream.clone(); - let mut client_stream = ConnectionStream::Mock(client_stream); - let mut server_stream = ConnectionStream::Mock(server_stream); + let mut client_stream = ConnectionStream::Mock(client_stream).buffered(); + let mut server_stream = ConnectionStream::Mock(server_stream).buffered(); // Client requests list of UPS devices client_stream @@ -186,9 +231,9 @@ mod tests { .expect("Failed to write LIST UPS"); // Server reads query for list of UPS devices - let mut server_buffer = server_stream.buffer(); - let sentence: ServerSentences = - ConnectionStream::read_sentence(&mut server_buffer).expect("Failed to read LIST UPS"); + let sentence = server_stream + .read_sentence::() + .expect("Failed to read LIST UPS"); assert_eq!(sentence, ServerSentences::QueryListUps {}); // Server sends list of UPS devices. @@ -208,15 +253,14 @@ mod tests { .expect("Failed to write UPS LIST"); // Client reads list of UPS devices. - let mut client_buffer = client_stream.buffer(); - let sentence: ClientSentences = ConnectionStream::read_sentence(&mut client_buffer) - .expect("Failed to read BEGIN LIST UPS"); - assert_eq!(sentence, ClientSentences::BeginListUps {}); + let sentence = client_stream + .read_sentence::() + .expect("Failed to read BEGIN LIST UPS") + .as_exactly(ClientSentences::BeginListUps {}) + .unwrap(); - let sentences: Vec = - ConnectionStream::read_sentences_until(&mut client_buffer, |s| { - matches!(s, ClientSentences::EndListUps {}) - }) + let sentences: Vec = client_stream + .read_sentences_until(|s| matches!(s, ClientSentences::EndListUps {})) .expect("Failed to read UPS items"); assert_eq!( @@ -241,15 +285,13 @@ mod tests { .expect("Failed to write LOGIN nutdev0"); // Server receives login - let mut server_buffer = server_stream.buffer(); - let sentence: ServerSentences = ConnectionStream::read_sentence(&mut server_buffer) - .expect("Failed to read LOGIN nutdev0"); - assert_eq!( - sentence, - ServerSentences::ExecLogin { - ups_name: "nutdev0".into() - } - ); + let sentence = server_stream + .read_sentence::() + .expect("Failed to read LOGIN nutdev0") + .as_exactly(ServerSentences::ExecLogin { + ups_name: "nutdev0".into(), + }) + .unwrap(); // Server rejects login server_stream @@ -260,10 +302,9 @@ mod tests { .expect("Failed to write ERR USERNAME-REQUIRED"); // Client expects error - let mut client_buffer = client_stream.buffer(); - let error: crate::Error = - ConnectionStream::read_sentence::(&mut client_buffer) - .expect_err("Failed to read ERR"); + let error: crate::Error = client_stream + .read_sentence::() + .expect_err("Failed to read ERR"); assert!(matches!( error, crate::Error::Nut(crate::NutError::UsernameRequired) diff --git a/rups/src/proto/mod.rs b/rups/src/proto/mod.rs index 01e1589..856cc46 100644 --- a/rups/src/proto/mod.rs +++ b/rups/src/proto/mod.rs @@ -105,12 +105,30 @@ macro_rules! impl_words { } /// A NUT protocol sentence that can be encoded and decoded from a Vector of strings. -pub trait Sentence: Sized + Into> { +pub trait Sentence: Eq + Sized + Into> { /// Decodes a sentence. Returns `None` if the pattern cannot be recognized. fn decode(raw: Vec) -> Option; /// Encodes the sentence. fn encode(&self) -> Vec<&str>; + + /// Returns an error if the sentence does not match what was expected. + fn as_matching bool>(self, matcher: F) -> crate::Result { + if matcher(&self) { + Ok(self) + } else { + Err(Error::Nut(NutError::UnexpectedResponse)) + } + } + + /// Returns an error if the sentence is not equal to what was expected. + fn as_exactly(self, expected: Self) -> crate::Result { + if expected == self { + Ok(self) + } else { + Err(Error::Nut(NutError::UnexpectedResponse)) + } + } } /// Implements the list of sentences, which are combinations @@ -313,6 +331,7 @@ impl_words! { Version("VERSION"), } +use crate::{Error, NutError}; pub(crate) use impl_sentences; #[cfg(test)] pub(crate) use test_encode_decode;