From bf489900e653bab9b9f752c21036698561e17446 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Thu, 14 Oct 2021 19:31:29 -0400 Subject: [PATCH] Use tokio select for polling client socket --- src/main.rs | 127 ++++++++++++++++++++++------------------------- src/port_pool.rs | 4 +- 2 files changed, 60 insertions(+), 71 deletions(-) diff --git a/src/main.rs b/src/main.rs index 02ff45e..4432113 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use anyhow::Context; use smoltcp::iface::InterfaceBuilder; use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; use smoltcp::wire::{IpAddress, IpCidr}; -use tokio::io::Interest; +use tokio::io::{AsyncReadExt, Interest}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::error::TryRecvError; @@ -158,85 +158,76 @@ async fn handle_tcp_proxy_connection( } loop { - let ready = socket - .ready(Interest::READABLE | Interest::WRITABLE) - .await - .with_context(|| "Failed to wait for TCP proxy socket readiness")?; - if abort.load(Ordering::Relaxed) { break; } - if ready.is_readable() { - let mut buffer = [0u8; MAX_PACKET]; - - match socket.try_read(&mut buffer) { - Ok(size) if size > 0 => { - let data = &buffer[..size]; - debug!( - "[{}] Read {} bytes of TCP data from real client", - virtual_port, size - ); - match data_to_real_server_tx.send(data.to_vec()).await { + tokio::select! { + readable_result = socket.readable() => { + match readable_result { + Ok(_) => { + let mut buffer = vec![]; + match socket.try_read_buf(&mut buffer) { + Ok(size) if size > 0 => { + let data = &buffer[..size]; + debug!( + "[{}] Read {} bytes of TCP data from real client", + virtual_port, size + ); + match data_to_real_server_tx.send(data.to_vec()).await { + Err(e) => { + error!( + "[{}] Failed to dispatch data to virtual interface: {:?}", + virtual_port, e + ); + } + _ => {} + } + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + error!( + "[{}] Failed to read from client TCP socket: {:?}", + virtual_port, e + ); + break; + } + _ => { + break; + } + } + } + Err(e) => { + error!("[{}] Failed to check if readable: {:?}", virtual_port, e); + break; + } + } + } + data_recv_result = data_to_real_client_rx.recv() => { + match data_recv_result { + Some(data) => match socket.try_write(&data) { + Ok(size) => { + debug!( + "[{}] Wrote {} bytes of TCP data to real client", + virtual_port, size + ); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } Err(e) => { error!( - "[{}] Failed to dispatch data to virtual interface: {:?}", + "[{}] Failed to write to client TCP socket: {:?}", virtual_port, e ); } - _ => {} - } + }, + None => continue, } - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - continue; - } - Err(e) => { - error!( - "[{}] Failed to read from client TCP socket: {:?}", - virtual_port, e - ); - break; - } - _ => {} } } - - if ready.is_writable() { - // Flush the data_to_real_client_rx channel - match data_to_real_client_rx.try_recv() { - Ok(data) => match socket.try_write(&data) { - Ok(size) => { - debug!( - "[{}] Wrote {} bytes of TCP data to real client", - virtual_port, size - ); - } - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - continue; - } - Err(e) => { - error!( - "[{}] Failed to write to client TCP socket: {:?}", - virtual_port, e - ); - } - }, - Err(e) => match e { - TryRecvError::Empty => { - // Nothing else to consume in the data channel. - } - TryRecvError::Disconnected => { - // Channel is broken, probably terminated. - } - }, - } - } - - if ready.is_read_closed() || ready.is_write_closed() { - break; - } - - tokio::time::sleep(Duration::from_millis(1)).await; } trace!("[{}] TCP socket handler task terminated", virtual_port); diff --git a/src/port_pool.rs b/src/port_pool.rs index b52dc84..0538b9c 100644 --- a/src/port_pool.rs +++ b/src/port_pool.rs @@ -14,9 +14,7 @@ impl PortPool { pub fn new() -> Self { let inner = lockfree::queue::Queue::default(); PORT_RANGE.for_each(|p| inner.push(p) as ()); - Self { - inner, - } + Self { inner } } pub fn next(&self) -> anyhow::Result {