mirror of
https://github.com/aramperes/nut-rs.git
synced 2025-09-09 13:38:30 -04:00
Add tokio buffered stream
This commit is contained in:
parent
07034d2cec
commit
97e3731df2
3 changed files with 57 additions and 13 deletions
|
@ -9,6 +9,12 @@ pub enum ConnectionStream {
|
||||||
/// A plain TCP stream.
|
/// A plain TCP stream.
|
||||||
Tcp(TcpStream),
|
Tcp(TcpStream),
|
||||||
|
|
||||||
|
/// A stream wrapped with `BufReader`.
|
||||||
|
///
|
||||||
|
/// Use `.buffered()` to wrap any stream with `BufReader`.
|
||||||
|
/// It can then be un-wrapped with `.unbuffered()`.
|
||||||
|
Buffered(Box<BufReader<ConnectionStream>>),
|
||||||
|
|
||||||
/// A client stream wrapped with SSL using `rustls`.
|
/// A client stream wrapped with SSL using `rustls`.
|
||||||
#[cfg(feature = "ssl")]
|
#[cfg(feature = "ssl")]
|
||||||
SslClient(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
|
SslClient(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
|
||||||
|
@ -17,12 +23,6 @@ pub enum ConnectionStream {
|
||||||
#[cfg(feature = "ssl")]
|
#[cfg(feature = "ssl")]
|
||||||
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
|
SslServer(Box<rustls::StreamOwned<rustls::ServerSession, ConnectionStream>>),
|
||||||
|
|
||||||
/// A stream wrapped with `BufReader`.
|
|
||||||
///
|
|
||||||
/// Use `.buffered()` to wrap any stream with `BufReader`.
|
|
||||||
/// It can then be un-wrapped with `.unbuffered()`.
|
|
||||||
Buffered(Box<BufReader<ConnectionStream>>),
|
|
||||||
|
|
||||||
/// A mock stream, used for testing.
|
/// A mock stream, used for testing.
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
Mock(mockstream::SharedMockStream),
|
Mock(mockstream::SharedMockStream),
|
||||||
|
|
|
@ -60,7 +60,7 @@ impl TcpConnection {
|
||||||
let tcp_stream = TcpStream::connect(socket_addr).await?;
|
let tcp_stream = TcpStream::connect(socket_addr).await?;
|
||||||
let mut connection = Self {
|
let mut connection = Self {
|
||||||
config,
|
config,
|
||||||
stream: ConnectionStream::Plain(tcp_stream),
|
stream: ConnectionStream::Tcp(tcp_stream),
|
||||||
};
|
};
|
||||||
connection = connection.enable_ssl().await?;
|
connection = connection.enable_ssl().await?;
|
||||||
Ok(connection)
|
Ok(connection)
|
||||||
|
|
|
@ -1,13 +1,19 @@
|
||||||
use crate::Error;
|
use crate::Error;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, ReadBuf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
|
||||||
/// A wrapper for various Tokio stream types.
|
/// A wrapper for various Tokio stream types.
|
||||||
pub enum ConnectionStream {
|
pub enum ConnectionStream {
|
||||||
/// A plain TCP stream.
|
/// A plain TCP stream.
|
||||||
Plain(TcpStream),
|
Tcp(TcpStream),
|
||||||
|
|
||||||
|
/// A stream wrapped with `BufReader`.
|
||||||
|
///
|
||||||
|
/// Use `.buffered()` to wrap any stream with `BufReader`.
|
||||||
|
/// It can then be un-wrapped with `.unbuffered()`.
|
||||||
|
Buffered(Box<BufReader<ConnectionStream>>),
|
||||||
|
|
||||||
/// A client stream wrapped with SSL using `rustls`.
|
/// A client stream wrapped with SSL using `rustls`.
|
||||||
#[cfg(feature = "async-ssl")]
|
#[cfg(feature = "async-ssl")]
|
||||||
|
@ -50,10 +56,14 @@ impl AsyncRead for ConnectionStream {
|
||||||
buf: &mut ReadBuf<'_>,
|
buf: &mut ReadBuf<'_>,
|
||||||
) -> Poll<std::io::Result<()>> {
|
) -> Poll<std::io::Result<()>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Plain(stream) => {
|
Self::Tcp(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
pinned.poll_read(cx, buf)
|
pinned.poll_read(cx, buf)
|
||||||
}
|
}
|
||||||
|
Self::Buffered(reader) => {
|
||||||
|
let pinned = Pin::new(reader.get_mut());
|
||||||
|
pinned.poll_read(cx, buf)
|
||||||
|
}
|
||||||
#[cfg(feature = "async-ssl")]
|
#[cfg(feature = "async-ssl")]
|
||||||
Self::SslClient(stream) => {
|
Self::SslClient(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
|
@ -68,6 +78,28 @@ impl AsyncRead for ConnectionStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AsyncBufRead for ConnectionStream {
|
||||||
|
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
Self::Buffered(reader) => {
|
||||||
|
let pinned = Pin::new(reader.get_mut());
|
||||||
|
pinned.poll_fill_buf(cx)
|
||||||
|
}
|
||||||
|
_ => core::task::Poll::Ready(Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Unsupported,
|
||||||
|
"Stream is not buffered",
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||||
|
if let Self::Buffered(reader) = self.get_mut() {
|
||||||
|
let pinned = Pin::new(reader.get_mut());
|
||||||
|
pinned.consume(amt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl AsyncWrite for ConnectionStream {
|
impl AsyncWrite for ConnectionStream {
|
||||||
fn poll_write(
|
fn poll_write(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
|
@ -75,10 +107,14 @@ impl AsyncWrite for ConnectionStream {
|
||||||
buf: &[u8],
|
buf: &[u8],
|
||||||
) -> Poll<std::io::Result<usize>> {
|
) -> Poll<std::io::Result<usize>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Plain(stream) => {
|
Self::Tcp(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
pinned.poll_write(cx, buf)
|
pinned.poll_write(cx, buf)
|
||||||
}
|
}
|
||||||
|
Self::Buffered(reader) => {
|
||||||
|
let pinned = Pin::new(reader.get_mut());
|
||||||
|
pinned.poll_write(cx, buf)
|
||||||
|
}
|
||||||
#[cfg(feature = "async-ssl")]
|
#[cfg(feature = "async-ssl")]
|
||||||
Self::SslClient(stream) => {
|
Self::SslClient(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
|
@ -94,10 +130,14 @@ impl AsyncWrite for ConnectionStream {
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Plain(stream) => {
|
Self::Tcp(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
pinned.poll_flush(cx)
|
pinned.poll_flush(cx)
|
||||||
}
|
}
|
||||||
|
Self::Buffered(reader) => {
|
||||||
|
let pinned = Pin::new(reader.get_mut());
|
||||||
|
pinned.poll_flush(cx)
|
||||||
|
}
|
||||||
#[cfg(feature = "async-ssl")]
|
#[cfg(feature = "async-ssl")]
|
||||||
Self::SslClient(stream) => {
|
Self::SslClient(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
|
@ -113,10 +153,14 @@ impl AsyncWrite for ConnectionStream {
|
||||||
|
|
||||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Plain(stream) => {
|
Self::Tcp(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
pinned.poll_shutdown(cx)
|
pinned.poll_shutdown(cx)
|
||||||
}
|
}
|
||||||
|
Self::Buffered(reader) => {
|
||||||
|
let pinned = Pin::new(reader.get_mut());
|
||||||
|
pinned.poll_shutdown(cx)
|
||||||
|
}
|
||||||
#[cfg(feature = "async-ssl")]
|
#[cfg(feature = "async-ssl")]
|
||||||
Self::SslClient(stream) => {
|
Self::SslClient(stream) => {
|
||||||
let pinned = Pin::new(stream);
|
let pinned = Pin::new(stream);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue