diff --git a/rups/src/blocking/stream.rs b/rups/src/blocking/stream.rs index e2ccc5c..03fc1f1 100644 --- a/rups/src/blocking/stream.rs +++ b/rups/src/blocking/stream.rs @@ -9,6 +9,12 @@ pub enum ConnectionStream { /// A plain TCP stream. Tcp(TcpStream), + /// A stream wrapped with `BufReader`. + /// + /// Use `.buffered()` to wrap any stream with `BufReader`. + /// It can then be un-wrapped with `.unbuffered()`. + Buffered(Box>), + /// A client stream wrapped with SSL using `rustls`. #[cfg(feature = "ssl")] SslClient(Box>), @@ -17,12 +23,6 @@ pub enum ConnectionStream { #[cfg(feature = "ssl")] SslServer(Box>), - /// A stream wrapped with `BufReader`. - /// - /// Use `.buffered()` to wrap any stream with `BufReader`. - /// It can then be un-wrapped with `.unbuffered()`. - Buffered(Box>), - /// A mock stream, used for testing. #[cfg(test)] Mock(mockstream::SharedMockStream), diff --git a/rups/src/tokio/mod.rs b/rups/src/tokio/mod.rs index 374a0e3..f8d8e45 100644 --- a/rups/src/tokio/mod.rs +++ b/rups/src/tokio/mod.rs @@ -60,7 +60,7 @@ impl TcpConnection { let tcp_stream = TcpStream::connect(socket_addr).await?; let mut connection = Self { config, - stream: ConnectionStream::Plain(tcp_stream), + stream: ConnectionStream::Tcp(tcp_stream), }; connection = connection.enable_ssl().await?; Ok(connection) diff --git a/rups/src/tokio/stream.rs b/rups/src/tokio/stream.rs index 7f57e22..282a709 100644 --- a/rups/src/tokio/stream.rs +++ b/rups/src/tokio/stream.rs @@ -1,13 +1,19 @@ use crate::Error; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, ReadBuf}; use tokio::net::TcpStream; /// A wrapper for various Tokio stream types. pub enum ConnectionStream { /// A plain TCP stream. - Plain(TcpStream), + Tcp(TcpStream), + + /// A stream wrapped with `BufReader`. + /// + /// Use `.buffered()` to wrap any stream with `BufReader`. + /// It can then be un-wrapped with `.unbuffered()`. + Buffered(Box>), /// A client stream wrapped with SSL using `rustls`. #[cfg(feature = "async-ssl")] @@ -50,10 +56,14 @@ impl AsyncRead for ConnectionStream { buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { - Self::Plain(stream) => { + Self::Tcp(stream) => { let pinned = Pin::new(stream); pinned.poll_read(cx, buf) } + Self::Buffered(reader) => { + let pinned = Pin::new(reader.get_mut()); + pinned.poll_read(cx, buf) + } #[cfg(feature = "async-ssl")] Self::SslClient(stream) => { let pinned = Pin::new(stream); @@ -68,6 +78,28 @@ impl AsyncRead for ConnectionStream { } } +impl AsyncBufRead for ConnectionStream { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Buffered(reader) => { + let pinned = Pin::new(reader.get_mut()); + pinned.poll_fill_buf(cx) + } + _ => core::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Stream is not buffered", + ))), + } + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + if let Self::Buffered(reader) = self.get_mut() { + let pinned = Pin::new(reader.get_mut()); + pinned.consume(amt) + } + } +} + impl AsyncWrite for ConnectionStream { fn poll_write( self: Pin<&mut Self>, @@ -75,10 +107,14 @@ impl AsyncWrite for ConnectionStream { buf: &[u8], ) -> Poll> { match self.get_mut() { - Self::Plain(stream) => { + Self::Tcp(stream) => { let pinned = Pin::new(stream); pinned.poll_write(cx, buf) } + Self::Buffered(reader) => { + let pinned = Pin::new(reader.get_mut()); + pinned.poll_write(cx, buf) + } #[cfg(feature = "async-ssl")] Self::SslClient(stream) => { let pinned = Pin::new(stream); @@ -94,10 +130,14 @@ impl AsyncWrite for ConnectionStream { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Self::Plain(stream) => { + Self::Tcp(stream) => { let pinned = Pin::new(stream); pinned.poll_flush(cx) } + Self::Buffered(reader) => { + let pinned = Pin::new(reader.get_mut()); + pinned.poll_flush(cx) + } #[cfg(feature = "async-ssl")] Self::SslClient(stream) => { let pinned = Pin::new(stream); @@ -113,10 +153,14 @@ impl AsyncWrite for ConnectionStream { fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Self::Plain(stream) => { + Self::Tcp(stream) => { let pinned = Pin::new(stream); pinned.poll_shutdown(cx) } + Self::Buffered(reader) => { + let pinned = Pin::new(reader.get_mut()); + pinned.poll_shutdown(cx) + } #[cfg(feature = "async-ssl")] Self::SslClient(stream) => { let pinned = Pin::new(stream);