Add tokio buffered stream

This commit is contained in:
Aram 🍐 2021-08-04 22:06:15 -04:00
parent 07034d2cec
commit 97e3731df2
3 changed files with 57 additions and 13 deletions

View file

@ -9,6 +9,12 @@ pub enum ConnectionStream {
/// A plain TCP stream. /// A plain TCP stream.
Tcp(TcpStream), Tcp(TcpStream),
/// A stream wrapped with `BufReader`.
///
/// Use `.buffered()` to wrap any stream with `BufReader`.
/// It can then be un-wrapped with `.unbuffered()`.
Buffered(Box<BufReader<ConnectionStream>>),
/// A client stream wrapped with SSL using `rustls`. /// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
SslClient(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>), SslClient(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
@ -17,12 +23,6 @@ pub enum ConnectionStream {
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>), SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
/// A stream wrapped with `BufReader`.
///
/// Use `.buffered()` to wrap any stream with `BufReader`.
/// It can then be un-wrapped with `.unbuffered()`.
Buffered(Box<BufReader<ConnectionStream>>),
/// A mock stream, used for testing. /// A mock stream, used for testing.
#[cfg(test)] #[cfg(test)]
Mock(mockstream::SharedMockStream), Mock(mockstream::SharedMockStream),

View file

@ -60,7 +60,7 @@ impl TcpConnection {
let tcp_stream = TcpStream::connect(socket_addr).await?; let tcp_stream = TcpStream::connect(socket_addr).await?;
let mut connection = Self { let mut connection = Self {
config, config,
stream: ConnectionStream::Plain(tcp_stream), stream: ConnectionStream::Tcp(tcp_stream),
}; };
connection = connection.enable_ssl().await?; connection = connection.enable_ssl().await?;
Ok(connection) Ok(connection)

View file

@ -1,13 +1,19 @@
use crate::Error; use crate::Error;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, ReadBuf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
/// A wrapper for various Tokio stream types. /// A wrapper for various Tokio stream types.
pub enum ConnectionStream { pub enum ConnectionStream {
/// A plain TCP stream. /// A plain TCP stream.
Plain(TcpStream), Tcp(TcpStream),
/// A stream wrapped with `BufReader`.
///
/// Use `.buffered()` to wrap any stream with `BufReader`.
/// It can then be un-wrapped with `.unbuffered()`.
Buffered(Box<BufReader<ConnectionStream>>),
/// A client stream wrapped with SSL using `rustls`. /// A client stream wrapped with SSL using `rustls`.
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
@ -50,10 +56,14 @@ impl AsyncRead for ConnectionStream {
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> { ) -> Poll<std::io::Result<()>> {
match self.get_mut() { match self.get_mut() {
Self::Plain(stream) => { Self::Tcp(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_read(cx, buf) pinned.poll_read(cx, buf)
} }
Self::Buffered(reader) => {
let pinned = Pin::new(reader.get_mut());
pinned.poll_read(cx, buf)
}
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::SslClient(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
@ -68,6 +78,28 @@ impl AsyncRead for ConnectionStream {
} }
} }
impl AsyncBufRead for ConnectionStream {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
match self.get_mut() {
Self::Buffered(reader) => {
let pinned = Pin::new(reader.get_mut());
pinned.poll_fill_buf(cx)
}
_ => core::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Stream is not buffered",
))),
}
}
fn consume(self: Pin<&mut Self>, amt: usize) {
if let Self::Buffered(reader) = self.get_mut() {
let pinned = Pin::new(reader.get_mut());
pinned.consume(amt)
}
}
}
impl AsyncWrite for ConnectionStream { impl AsyncWrite for ConnectionStream {
fn poll_write( fn poll_write(
self: Pin<&mut Self>, self: Pin<&mut Self>,
@ -75,10 +107,14 @@ impl AsyncWrite for ConnectionStream {
buf: &[u8], buf: &[u8],
) -> Poll<std::io::Result<usize>> { ) -> Poll<std::io::Result<usize>> {
match self.get_mut() { match self.get_mut() {
Self::Plain(stream) => { Self::Tcp(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_write(cx, buf) pinned.poll_write(cx, buf)
} }
Self::Buffered(reader) => {
let pinned = Pin::new(reader.get_mut());
pinned.poll_write(cx, buf)
}
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::SslClient(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
@ -94,10 +130,14 @@ impl AsyncWrite for ConnectionStream {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() { match self.get_mut() {
Self::Plain(stream) => { Self::Tcp(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_flush(cx) pinned.poll_flush(cx)
} }
Self::Buffered(reader) => {
let pinned = Pin::new(reader.get_mut());
pinned.poll_flush(cx)
}
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::SslClient(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
@ -113,10 +153,14 @@ impl AsyncWrite for ConnectionStream {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> { fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() { match self.get_mut() {
Self::Plain(stream) => { Self::Tcp(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);
pinned.poll_shutdown(cx) pinned.poll_shutdown(cx)
} }
Self::Buffered(reader) => {
let pinned = Pin::new(reader.get_mut());
pinned.poll_shutdown(cx)
}
#[cfg(feature = "async-ssl")] #[cfg(feature = "async-ssl")]
Self::SslClient(stream) => { Self::SslClient(stream) => {
let pinned = Pin::new(stream); let pinned = Pin::new(stream);