From 51788c95577109082b2f73cda3587bd910a2b040 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Sat, 8 Jan 2022 01:05:51 -0500 Subject: [PATCH 1/2] Improve reliability using event-based synchronization --- src/config.rs | 5 +- src/events.rs | 150 ++++++++++++++ src/ip_sink.rs | 35 ---- src/main.rs | 65 +++++- src/tunnel/mod.rs | 6 +- src/tunnel/tcp.rs | 139 ++++--------- src/tunnel/udp.rs | 75 ++----- src/virtual_device.rs | 98 ++++----- src/virtual_iface/mod.rs | 45 +++- src/virtual_iface/tcp.rs | 433 ++++++++++++++++++--------------------- src/virtual_iface/udp.rs | 201 +++--------------- src/wg.rs | 181 +++------------- 12 files changed, 628 insertions(+), 805 deletions(-) create mode 100644 src/events.rs delete mode 100644 src/ip_sink.rs diff --git a/src/config.rs b/src/config.rs index ed6757c..8df9511 100644 --- a/src/config.rs +++ b/src/config.rs @@ -235,7 +235,7 @@ fn is_file_insecurely_readable(path: &str) -> Option<(bool, bool)> { } #[cfg(not(unix))] -fn is_file_insecurely_readable(path: &str) -> Option<(bool, bool)> { +fn is_file_insecurely_readable(_path: &str) -> Option<(bool, bool)> { // No good way to determine permissions on non-Unix target None } @@ -399,9 +399,12 @@ impl Display for PortForwardConfig { } } +/// Layer 7 protocols for ports. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)] pub enum PortProtocol { + /// TCP Tcp, + /// UDP Udp, } diff --git a/src/events.rs b/src/events.rs new file mode 100644 index 0000000..1437bc2 --- /dev/null +++ b/src/events.rs @@ -0,0 +1,150 @@ +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Arc; + +use crate::config::PortForwardConfig; +use crate::virtual_iface::VirtualPort; +use crate::PortProtocol; + +/// Events that go on the bus between the local server, smoltcp, and WireGuard. +#[derive(Debug, Clone)] +pub enum Event { + /// Dumb event with no data. + Dumb, + /// A new connection with the local server was initiated, and the given virtual port was assigned. + ClientConnectionInitiated(PortForwardConfig, VirtualPort), + /// A connection was dropped from the pool and should be closed in all interfaces. + ClientConnectionDropped(VirtualPort), + /// Data received by the local server that should be sent to the virtual server. + LocalData(VirtualPort, Vec), + /// Data received by the remote server that should be sent to the local client. + RemoteData(VirtualPort, Vec), + /// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device. + InboundInternetPacket(PortProtocol, Vec), + /// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device. + OutboundInternetPacket(Vec), + /// Notifies that a virtual device read an IP packet. + VirtualDeviceFed(PortProtocol), +} + +#[derive(Clone)] +pub struct Bus { + counter: Arc, + bus: Arc>, +} + +impl Bus { + /// Creates a new event bus. + pub fn new() -> Self { + let (bus, _) = tokio::sync::broadcast::channel(1000); + let bus = Arc::new(bus); + let counter = Arc::new(AtomicU32::default()); + Self { bus, counter } + } + + /// Creates a new endpoint on the event bus. + pub fn new_endpoint(&self) -> BusEndpoint { + let id = self.counter.fetch_add(1, Ordering::Relaxed); + let tx = (*self.bus).clone(); + let rx = self.bus.subscribe(); + + let tx = BusSender { id, tx }; + BusEndpoint { id, tx, rx } + } +} + +impl Default for Bus { + fn default() -> Self { + Self::new() + } +} + +pub struct BusEndpoint { + id: u32, + tx: BusSender, + rx: tokio::sync::broadcast::Receiver<(u32, Event)>, +} + +impl BusEndpoint { + /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. + pub fn send(&self, event: Event) { + self.tx.send(event) + } + + /// Returns the unique sequential ID of this endpoint. + pub fn id(&self) -> u32 { + self.id + } + + /// Awaits the next `Event` on the bus to be read. + pub async fn recv(&mut self) -> Event { + loop { + match self.rx.recv().await { + Ok((id, event)) => { + if id == self.id { + // If the event was sent by this endpoint, it is skipped + continue; + } else { + trace!("#{} <- {:?}", self.id, event); + return event; + } + } + Err(_) => { + error!("Failed to read event bus from endpoint #{}", self.id); + return futures::future::pending().await; + } + } + } + } + + /// Creates a new sender for this endpoint that can be cloned. + pub fn sender(&self) -> BusSender { + self.tx.clone() + } +} + +#[derive(Clone)] +pub struct BusSender { + id: u32, + tx: tokio::sync::broadcast::Sender<(u32, Event)>, +} + +impl BusSender { + /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. + pub fn send(&self, event: Event) { + trace!("#{} -> {:?}", self.id, event); + match self.tx.send((self.id, event)) { + Ok(_) => {} + Err(_) => error!("Failed to send event to bus from endpoint #{}", self.id), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_bus() { + let bus = Bus::new(); + + let mut endpoint_1 = bus.new_endpoint(); + let mut endpoint_2 = bus.new_endpoint(); + let mut endpoint_3 = bus.new_endpoint(); + + assert_eq!(endpoint_1.id(), 0); + assert_eq!(endpoint_2.id(), 1); + assert_eq!(endpoint_3.id(), 2); + + endpoint_1.send(Event::Dumb); + let recv_2 = endpoint_2.recv().await; + let recv_3 = endpoint_3.recv().await; + assert!(matches!(recv_2, Event::Dumb)); + assert!(matches!(recv_3, Event::Dumb)); + + endpoint_2.send(Event::Dumb); + let recv_1 = endpoint_1.recv().await; + let recv_3 = endpoint_3.recv().await; + assert!(matches!(recv_1, Event::Dumb)); + assert!(matches!(recv_3, Event::Dumb)); + } +} diff --git a/src/ip_sink.rs b/src/ip_sink.rs deleted file mode 100644 index b9226f7..0000000 --- a/src/ip_sink.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::virtual_device::VirtualIpDevice; -use crate::wg::WireGuardTunnel; -use smoltcp::iface::InterfaceBuilder; -use std::sync::Arc; -use tokio::time::Duration; - -/// A repeating task that processes unroutable IP packets. -pub async fn run_ip_sink_interface(wg: Arc) -> ! { - // Initialize interface - let device = VirtualIpDevice::new_sink(wg) - .await - .expect("Failed to initialize VirtualIpDevice for sink interface"); - - // No sockets on sink interface - let mut sockets: [_; 0] = Default::default(); - let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..]) - .ip_addrs([]) - .finalize(); - - loop { - let loop_start = smoltcp::time::Instant::now(); - match virtual_interface.poll(loop_start) { - Ok(processed) if processed => { - trace!("[SINK] Virtual interface polled some packets to be processed",); - tokio::time::sleep(Duration::from_millis(1)).await; - } - Err(e) => { - error!("[SINK] Virtual interface poll error: {:?}", e); - } - _ => { - tokio::time::sleep(Duration::from_millis(5)).await; - } - } - } -} diff --git a/src/main.rs b/src/main.rs index d14fd06..1a01b51 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,13 +5,18 @@ use std::sync::Arc; use anyhow::Context; -use crate::config::Config; +use crate::config::{Config, PortProtocol}; +use crate::events::Bus; use crate::tunnel::tcp::TcpPortPool; use crate::tunnel::udp::UdpPortPool; +use crate::virtual_device::VirtualIpDevice; +use crate::virtual_iface::tcp::TcpVirtualInterface; +use crate::virtual_iface::udp::UdpVirtualInterface; +use crate::virtual_iface::VirtualInterfacePoll; use crate::wg::WireGuardTunnel; pub mod config; -pub mod ip_sink; +pub mod events; pub mod tunnel; pub mod virtual_device; pub mod virtual_iface; @@ -30,7 +35,9 @@ async fn main() -> anyhow::Result<()> { let tcp_port_pool = TcpPortPool::new(); let udp_port_pool = UdpPortPool::new(); - let wg = WireGuardTunnel::new(&config) + let bus = Bus::default(); + + let wg = WireGuardTunnel::new(&config, bus.clone()) .await .with_context(|| "Failed to initialize WireGuard tunnel")?; let wg = Arc::new(wg); @@ -48,9 +55,41 @@ async fn main() -> anyhow::Result<()> { } { - // Start IP sink task for incoming IP packets + // Start production task for WireGuard let wg = wg.clone(); - tokio::spawn(async move { ip_sink::run_ip_sink_interface(wg).await }); + tokio::spawn(async move { wg.produce_task().await }); + } + + if config + .port_forwards + .iter() + .any(|pf| pf.protocol == PortProtocol::Tcp) + { + // TCP device + let bus = bus.clone(); + let device = + VirtualIpDevice::new(PortProtocol::Tcp, bus.clone(), config.max_transmission_unit); + + // Start TCP Virtual Interface + let port_forwards = config.port_forwards.clone(); + let iface = TcpVirtualInterface::new(port_forwards, bus, device, config.source_peer_ip); + tokio::spawn(async move { iface.poll_loop().await }); + } + + if config + .port_forwards + .iter() + .any(|pf| pf.protocol == PortProtocol::Udp) + { + // UDP device + let bus = bus.clone(); + let device = + VirtualIpDevice::new(PortProtocol::Udp, bus.clone(), config.max_transmission_unit); + + // Start UDP Virtual Interface + let port_forwards = config.port_forwards.clone(); + let iface = UdpVirtualInterface::new(port_forwards, bus, device, config.source_peer_ip); + tokio::spawn(async move { iface.poll_loop().await }); } { @@ -59,10 +98,18 @@ async fn main() -> anyhow::Result<()> { port_forwards .into_iter() - .map(|pf| (pf, wg.clone(), tcp_port_pool.clone(), udp_port_pool.clone())) - .for_each(move |(pf, wg, tcp_port_pool, udp_port_pool)| { + .map(|pf| { + ( + pf, + wg.clone(), + tcp_port_pool.clone(), + udp_port_pool.clone(), + bus.clone(), + ) + }) + .for_each(move |(pf, wg, tcp_port_pool, udp_port_pool, bus)| { tokio::spawn(async move { - tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg) + tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg, bus) .await .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }); @@ -73,7 +120,7 @@ async fn main() -> anyhow::Result<()> { } fn init_logger(config: &Config) -> anyhow::Result<()> { - let mut builder = pretty_env_logger::formatted_builder(); + let mut builder = pretty_env_logger::formatted_timed_builder(); builder.parse_filters(&config.log); builder .try_init() diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index c7f2a67..1676d7a 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -2,6 +2,7 @@ use std::net::IpAddr; use std::sync::Arc; use crate::config::{PortForwardConfig, PortProtocol}; +use crate::events::Bus; use crate::tunnel::tcp::TcpPortPool; use crate::tunnel::udp::UdpPortPool; use crate::wg::WireGuardTunnel; @@ -16,6 +17,7 @@ pub async fn port_forward( tcp_port_pool: TcpPortPool, udp_port_pool: UdpPortPool, wg: Arc, + bus: Bus, ) -> anyhow::Result<()> { info!( "Tunneling {} [{}]->[{}] (via [{}] as peer {})", @@ -27,7 +29,7 @@ pub async fn port_forward( ); match port_forward.protocol { - PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await, - PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, wg).await, + PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, bus).await, + PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, bus).await, } } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index f49aa7c..718cf57 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -1,17 +1,17 @@ use crate::config::{PortForwardConfig, PortProtocol}; -use crate::virtual_iface::tcp::TcpVirtualInterface; -use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; -use crate::wg::WireGuardTunnel; +use crate::virtual_iface::VirtualPort; use anyhow::Context; use std::collections::VecDeque; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use std::ops::Range; +use std::time::Duration; +use crate::events::{Bus, Event}; use rand::seq::SliceRandom; use rand::thread_rng; +use tokio::io::AsyncWriteExt; const MAX_PACKET: usize = 65536; const MIN_PORT: u16 = 1000; @@ -22,14 +22,13 @@ const PORT_RANGE: Range = MIN_PORT..MAX_PORT; pub async fn tcp_proxy_server( port_forward: PortForwardConfig, port_pool: TcpPortPool, - wg: Arc, + bus: Bus, ) -> anyhow::Result<()> { let listener = TcpListener::bind(port_forward.source) .await .with_context(|| "Failed to listen on TCP proxy server")?; loop { - let wg = wg.clone(); let port_pool = port_pool.clone(); let (socket, peer_addr) = listener .accept() @@ -52,10 +51,10 @@ pub async fn tcp_proxy_server( info!("[{}] Incoming connection from {}", virtual_port, peer_addr); + let bus = bus.clone(); tokio::spawn(async move { let port_pool = port_pool.clone(); - let result = - handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await; + let result = handle_tcp_proxy_connection(socket, virtual_port, port_forward, bus).await; if let Err(e) = result { error!( @@ -66,8 +65,7 @@ pub async fn tcp_proxy_server( info!("[{}] Connection closed by client", virtual_port); } - // Release port when connection drops - wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp)); + tokio::time::sleep(Duration::from_millis(100)).await; // Make sure the other tasks have time to process the event port_pool.release(virtual_port).await; }); } @@ -75,72 +73,26 @@ pub async fn tcp_proxy_server( /// Handles a new TCP connection with its assigned virtual port. async fn handle_tcp_proxy_connection( - socket: TcpStream, - virtual_port: u16, + mut socket: TcpStream, + virtual_port: VirtualPort, port_forward: PortForwardConfig, - wg: Arc, + bus: Bus, ) -> anyhow::Result<()> { - // Abort signal for stopping the Virtual Interface - let abort = Arc::new(AtomicBool::new(false)); - - // Signals that the Virtual Client is ready to send data - let (virtual_client_ready_tx, virtual_client_ready_rx) = tokio::sync::oneshot::channel::<()>(); - - // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back - // to the real client. - let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000); - - // data_to_real_server_(tx/rx): This task sends the data received from the real client to the - // virtual interface (virtual server socket). - let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000); - - // Spawn virtual interface - { - let abort = abort.clone(); - let virtual_interface = TcpVirtualInterface::new( - virtual_port, - port_forward, - wg, - abort.clone(), - data_to_real_client_tx, - data_to_virtual_server_rx, - virtual_client_ready_tx, - ); - - tokio::spawn(async move { - virtual_interface.poll_loop().await.unwrap_or_else(|e| { - error!("Virtual interface poll loop failed unexpectedly: {}", e); - abort.store(true, Ordering::Relaxed); - }) - }); - } - - // Wait for virtual client to be ready. - virtual_client_ready_rx - .await - .with_context(|| "Virtual client dropped before being ready.")?; - trace!("[{}] Virtual client is ready to send data", virtual_port); + let mut endpoint = bus.new_endpoint(); + endpoint.send(Event::ClientConnectionInitiated(port_forward, virtual_port)); + let mut buffer = Vec::with_capacity(MAX_PACKET); loop { tokio::select! { readable_result = socket.readable() => { match readable_result { Ok(_) => { - // Buffer for the individual TCP segment. - let mut buffer = Vec::with_capacity(MAX_PACKET); match socket.try_read_buf(&mut buffer) { Ok(size) if size > 0 => { - let data = &buffer[..size]; - debug!( - "[{}] Read {} bytes of TCP data from real client", - virtual_port, size - ); - if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await { - error!( - "[{}] Failed to dispatch data to virtual interface: {:?}", - virtual_port, e - ); - } + let data = Vec::from(&buffer[..size]); + endpoint.send(Event::LocalData(virtual_port, data)); + // Reset buffer + buffer.clear(); } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { continue; @@ -163,43 +115,32 @@ async fn handle_tcp_proxy_connection( } } } - data_recv_result = data_to_real_client_rx.recv() => { - match data_recv_result { - Some(data) => match socket.try_write(&data) { - Ok(size) => { - debug!( - "[{}] Wrote {} bytes of TCP data to real client", - virtual_port, size - ); - } - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - if abort.load(Ordering::Relaxed) { + event = endpoint.recv() => { + match event { + Event::ClientConnectionDropped(e_vp) if e_vp == virtual_port => { + // This connection is supposed to be closed, stop the task. + break; + } + Event::RemoteData(e_vp, data) if e_vp == virtual_port => { + // Have remote data to send to the local client + let size = data.len(); + match socket.write(&data).await { + Ok(size) => debug!("[{}] Sent {} bytes to local client", virtual_port, size), + Err(e) => { + error!("[{}] Failed to send {} bytes to local client: {:?}", virtual_port, size, e); break; - } else { - continue; } } - Err(e) => { - error!( - "[{}] Failed to write to client TCP socket: {:?}", - virtual_port, e - ); - } - }, - None => { - if abort.load(Ordering::Relaxed) { - break; - } else { - continue; - } - }, + } + _ => {} } } } } - trace!("[{}] TCP socket handler task terminated", virtual_port); - abort.store(true, Ordering::Relaxed); + // Notify other endpoints that this task has closed and no more data is to be sent to the local client + endpoint.send(Event::ClientConnectionDropped(virtual_port)); + Ok(()) } @@ -230,19 +171,19 @@ impl TcpPortPool { } /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). - pub async fn next(&self) -> anyhow::Result { + pub async fn next(&self) -> anyhow::Result { let mut inner = self.inner.write().await; let port = inner .queue .pop_front() .with_context(|| "TCP virtual port pool is exhausted")?; - Ok(port) + Ok(VirtualPort::new(port, PortProtocol::Tcp)) } /// Releases a port back into the pool. - pub async fn release(&self, port: u16) { + pub async fn release(&self, port: VirtualPort) { let mut inner = self.inner.write().await; - inner.queue.push_back(port); + inner.queue.push_back(port.num()); } } diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index bedcdf7..34fb26d 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Instant; +use crate::events::{Bus, Event}; use anyhow::Context; use priority_queue::double_priority_queue::DoublePriorityQueue; use priority_queue::priority_queue::PriorityQueue; @@ -30,61 +31,24 @@ const UDP_TIMEOUT_SECONDS: u64 = 60; /// TODO: Make this configurable by the CLI const PORTS_PER_IP: usize = 100; +/// Starts the server that listens on UDP datagrams. pub async fn udp_proxy_server( port_forward: PortForwardConfig, port_pool: UdpPortPool, - wg: Arc, + bus: Bus, ) -> anyhow::Result<()> { - // Abort signal - let abort = Arc::new(AtomicBool::new(false)); - - // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back - // to the real client. - let (data_to_real_client_tx, mut data_to_real_client_rx) = - tokio::sync::mpsc::channel::<(VirtualPort, Vec)>(1_000); - - // data_to_real_server_(tx/rx): This task sends the data received from the real client to the - // virtual interface (virtual server socket). - let (data_to_virtual_server_tx, data_to_virtual_server_rx) = - tokio::sync::mpsc::channel::<(VirtualPort, Vec)>(1_000); - - { - // Spawn virtual interface - // Note: contrary to TCP, there is only one UDP virtual interface - let virtual_interface = UdpVirtualInterface::new( - port_forward, - wg, - data_to_real_client_tx, - data_to_virtual_server_rx, - ); - let abort = abort.clone(); - tokio::spawn(async move { - virtual_interface.poll_loop().await.unwrap_or_else(|e| { - error!("Virtual interface poll loop failed unexpectedly: {}", e); - abort.store(true, Ordering::Relaxed); - }); - }); - } - + let mut endpoint = bus.new_endpoint(); let socket = UdpSocket::bind(port_forward.source) .await .with_context(|| "Failed to bind on UDP proxy address")?; let mut buffer = [0u8; MAX_PACKET]; loop { - if abort.load(Ordering::Relaxed) { - break; - } tokio::select! { to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => { match to_send_result { Ok(Some((port, data))) => { - data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| { - error!( - "Failed to dispatch data to UDP virtual interface: {:?}", - e - ); - }); + endpoint.send(Event::LocalData(port, data)); } Ok(None) => { continue; @@ -98,9 +62,9 @@ pub async fn udp_proxy_server( } } } - data_recv_result = data_to_real_client_rx.recv() => { - if let Some((port, data)) = data_recv_result { - if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await { + event = endpoint.recv() => { + if let Event::RemoteData(port, data) = event { + if let Some(peer_addr) = port_pool.get_peer_addr(port).await { if let Err(e) = socket.send_to(&data, peer_addr).await { error!( "[{}] Failed to send UDP datagram to real client ({}): {:?}", @@ -109,7 +73,7 @@ pub async fn udp_proxy_server( e, ); } - port_pool.update_last_transmit(port.0).await; + port_pool.update_last_transmit(port).await; } } } @@ -141,14 +105,13 @@ async fn next_udp_datagram( return Ok(None); } }; - let port = VirtualPort(port, PortProtocol::Udp); debug!( "[{}] Received datagram of {} bytes from {}", port, size, peer_addr ); - port_pool.update_last_transmit(port.0).await; + port_pool.update_last_transmit(port).await; let data = buffer[..size].to_vec(); Ok(Some((port, data))) @@ -181,14 +144,14 @@ impl UdpPortPool { } /// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity). - pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result { + pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result { // A port found to be reused. This is outside of the block because the read lock cannot be upgraded to a write lock. let mut port_reuse: Option = None; { let inner = self.inner.read().await; if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) { - return Ok(*port); + return Ok(VirtualPort::new(*port, PortProtocol::Udp)); } // Count how many ports are being used by the peer IP @@ -240,26 +203,26 @@ impl UdpPortPool { inner.port_by_peer_addr.insert(peer_addr, port); inner.peer_addr_by_port.insert(port, peer_addr); - Ok(port) + Ok(VirtualPort::new(port, PortProtocol::Udp)) } /// Notify that the given virtual port has received or transmitted a UDP datagram. - pub async fn update_last_transmit(&self, port: u16) { + pub async fn update_last_transmit(&self, port: VirtualPort) { let mut inner = self.inner.write().await; - if let Some(peer) = inner.peer_addr_by_port.get(&port).copied() { + if let Some(peer) = inner.peer_addr_by_port.get(&port.num()).copied() { let mut pq: &mut DoublePriorityQueue = inner .peer_port_usage .entry(peer.ip()) .or_insert_with(Default::default); - pq.push(port, Instant::now()); + pq.push(port.num(), Instant::now()); } let mut pq: &mut DoublePriorityQueue = &mut inner.port_usage; - pq.push(port, Instant::now()); + pq.push(port.num(), Instant::now()); } - pub async fn get_peer_addr(&self, port: u16) -> Option { + pub async fn get_peer_addr(&self, port: VirtualPort) -> Option { let inner = self.inner.read().await; - inner.peer_addr_by_port.get(&port).copied() + inner.peer_addr_by_port.get(&port.num()).copied() } } diff --git a/src/virtual_device.rs b/src/virtual_device.rs index d06d7fa..e0e7e4d 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,45 +1,51 @@ -use crate::virtual_iface::VirtualPort; -use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; -use anyhow::Context; +use crate::config::PortProtocol; +use crate::events::{BusSender, Event}; +use crate::Bus; use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; -use std::sync::Arc; +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; -/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint -/// are made available to this device using a channel receiver. IP packets sent from this device -/// are asynchronously sent out to the WireGuard tunnel. +/// A virtual device that processes IP packets through smoltcp and WireGuard. pub struct VirtualIpDevice { - /// Tunnel to send IP packets to. - wg: Arc, + /// Max transmission unit (bytes) + max_transmission_unit: usize, /// Channel receiver for received IP packets. - ip_dispatch_rx: tokio::sync::mpsc::Receiver>, + bus_sender: BusSender, + /// Local queue for packets received from the bus that need to go through the smoltcp interface. + process_queue: Arc>>>, } impl VirtualIpDevice { /// Initializes a new virtual IP device. - pub fn new( - wg: Arc, - ip_dispatch_rx: tokio::sync::mpsc::Receiver>, - ) -> Self { - Self { wg, ip_dispatch_rx } - } + pub fn new(protocol: PortProtocol, bus: Bus, max_transmission_unit: usize) -> Self { + let mut bus_endpoint = bus.new_endpoint(); + let bus_sender = bus_endpoint.sender(); + let process_queue = Arc::new(Mutex::new(VecDeque::new())); - /// Registers a virtual IP device for a single virtual client. - pub fn new_direct(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { - let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); + { + let process_queue = process_queue.clone(); + tokio::spawn(async move { + loop { + match bus_endpoint.recv().await { + Event::InboundInternetPacket(ip_proto, data) if ip_proto == protocol => { + let mut queue = process_queue + .lock() + .expect("Failed to acquire process queue lock"); + queue.push_back(data); + bus_endpoint.send(Event::VirtualDeviceFed(ip_proto)); + } + _ => {} + } + } + }); + } - wg.register_virtual_interface(virtual_port, ip_dispatch_tx) - .with_context(|| "Failed to register IP dispatch for virtual interface")?; - - Ok(Self { wg, ip_dispatch_rx }) - } - - pub async fn new_sink(wg: Arc) -> anyhow::Result { - let ip_dispatch_rx = wg - .register_sink_interface() - .await - .with_context(|| "Failed to register IP dispatch for sink virtual interface")?; - Ok(Self { wg, ip_dispatch_rx }) + Self { + bus_sender, + process_queue, + max_transmission_unit, + } } } @@ -48,27 +54,34 @@ impl<'a> Device<'a> for VirtualIpDevice { type TxToken = TxToken; fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - match self.ip_dispatch_rx.try_recv() { - Ok(buffer) => Some(( + let next = { + let mut queue = self + .process_queue + .lock() + .expect("Failed to acquire process queue lock"); + queue.pop_front() + }; + match next { + Some(buffer) => Some(( Self::RxToken { buffer }, Self::TxToken { - wg: self.wg.clone(), + sender: self.bus_sender.clone(), }, )), - Err(_) => None, + None => None, } } fn transmit(&'a mut self) -> Option { Some(TxToken { - wg: self.wg.clone(), + sender: self.bus_sender.clone(), }) } fn capabilities(&self) -> DeviceCapabilities { let mut cap = DeviceCapabilities::default(); cap.medium = Medium::Ip; - cap.max_transmission_unit = self.wg.max_transmission_unit; + cap.max_transmission_unit = self.max_transmission_unit; cap } } @@ -89,7 +102,7 @@ impl smoltcp::phy::RxToken for RxToken { #[doc(hidden)] pub struct TxToken { - wg: Arc, + sender: BusSender, } impl smoltcp::phy::TxToken for TxToken { @@ -100,14 +113,7 @@ impl smoltcp::phy::TxToken for TxToken { let mut buffer = Vec::new(); buffer.resize(len, 0); let result = f(&mut buffer); - tokio::spawn(async move { - match self.wg.send_ip_packet(&buffer).await { - Ok(_) => {} - Err(e) => { - error!("Failed to send IP packet to WireGuard endpoint: {:?}", e); - } - } - }); + self.sender.send(Event::OutboundInternetPacket(buffer)); result } } diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs index f3796bd..ea3cd74 100644 --- a/src/virtual_iface/mod.rs +++ b/src/virtual_iface/mod.rs @@ -14,10 +14,51 @@ pub trait VirtualInterfacePoll { /// Virtual port. #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub struct VirtualPort(pub u16, pub PortProtocol); +pub struct VirtualPort(u16, PortProtocol); + +impl VirtualPort { + /// Create a new `VirtualPort` instance, with the given port number and associated protocol. + pub fn new(port: u16, proto: PortProtocol) -> Self { + VirtualPort(port, proto) + } + + /// The port number + pub fn num(&self) -> u16 { + self.0 + } + + /// The protocol of this port. + pub fn proto(&self) -> PortProtocol { + self.1 + } +} + +impl From for u16 { + fn from(port: VirtualPort) -> Self { + port.num() + } +} + +impl From<&VirtualPort> for u16 { + fn from(port: &VirtualPort) -> Self { + port.num() + } +} + +impl From for PortProtocol { + fn from(port: VirtualPort) -> Self { + port.proto() + } +} + +impl From<&VirtualPort> for PortProtocol { + fn from(port: &VirtualPort) -> Self { + port.proto() + } +} impl Display for VirtualPort { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "[{}:{}]", self.0, self.1) + write!(f, "[{}:{}]", self.num(), self.proto()) } } diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 340242a..6c283f1 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -1,284 +1,253 @@ use crate::config::{PortForwardConfig, PortProtocol}; +use crate::events::Event; use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; -use crate::wg::WireGuardTunnel; +use crate::Bus; use anyhow::Context; use async_trait::async_trait; -use smoltcp::iface::InterfaceBuilder; +use smoltcp::iface::{InterfaceBuilder, SocketHandle}; use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState}; use smoltcp::wire::{IpAddress, IpCidr}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::net::IpAddr; use std::time::Duration; const MAX_PACKET: usize = 65536; /// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa. pub struct TcpVirtualInterface { - /// The virtual port assigned to the virtual client, used to - /// route Layer 4 segments/datagrams to and from the WireGuard tunnel. - virtual_port: u16, - /// The overall port-forward configuration: used for the destination address (on which - /// the virtual server listens) and the protocol in use. - port_forward: PortForwardConfig, - /// The WireGuard tunnel to send IP packets to. - wg: Arc, - /// Abort signal to shutdown the virtual interface and its parent task. - abort: Arc, - /// Channel sender for pushing Layer 7 data back to the real client. - data_to_real_client_tx: tokio::sync::mpsc::Sender>, - /// Channel receiver for processing Layer 7 data through the virtual interface. - data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, - /// One-shot sender to notify the parent task that the virtual client is ready to send Layer 7 data. - virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, + source_peer_ip: IpAddr, + port_forwards: Vec, + device: VirtualIpDevice, + bus: Bus, } impl TcpVirtualInterface { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. pub fn new( - virtual_port: u16, - port_forward: PortForwardConfig, - wg: Arc, - abort: Arc, - data_to_real_client_tx: tokio::sync::mpsc::Sender>, - data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, - virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, + port_forwards: Vec, + bus: Bus, + device: VirtualIpDevice, + source_peer_ip: IpAddr, ) -> Self { Self { - virtual_port, - port_forward, - wg, - abort, - data_to_real_client_tx, - data_to_virtual_server_rx, - virtual_client_ready_tx, + port_forwards: port_forwards + .into_iter() + .filter(|f| matches!(f.protocol, PortProtocol::Tcp)) + .collect(), + device, + source_peer_ip, + bus, } } + + fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { + static mut TCP_SERVER_RX_DATA: [u8; 0] = []; + static mut TCP_SERVER_TX_DATA: [u8; 0] = []; + + let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .listen(( + IpAddress::from(port_forward.destination.ip()), + port_forward.destination.port(), + )) + .with_context(|| "Virtual server socket failed to listen")?; + + Ok(socket) + } + + fn new_client_socket() -> anyhow::Result> { + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); + let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); + let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + Ok(socket) + } } #[async_trait] impl VirtualInterfacePoll for TcpVirtualInterface { async fn poll_loop(self) -> anyhow::Result<()> { - let mut virtual_client_ready_tx = Some(self.virtual_client_ready_tx); - let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; - let source_peer_ip = self.wg.source_peer_ip; + // Create CIDR block for source peer IP + each port forward IP + let addresses: Vec = { + let mut addresses = HashSet::new(); + addresses.insert(IpAddress::from(self.source_peer_ip)); + for config in self.port_forwards.iter() { + addresses.insert(IpAddress::from(config.destination.ip())); + } + addresses + .into_iter() + .map(|addr| IpCidr::new(addr, 32)) + .collect() + }; - // Create a device and interface to simulate IP packets - // In essence: - // * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client' - // * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel - // * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface' - // * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded) - // * The TCP data read by the 'virtual client' is sent to the 'real' TCP client - - // Consumer for IP packets to send through the virtual interface - // Initialize the interface - let device = - VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) - .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; - - // there are always 2 sockets: 1 virtual client and 1 virtual server. - let mut sockets: [_; 2] = Default::default(); - let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..]) - .ip_addrs([ - // Interface handles IP packets for the sender and recipient - IpCidr::new(IpAddress::from(source_peer_ip), 32), - IpCidr::new(IpAddress::from(self.port_forward.destination.ip()), 32), - ]) + let mut iface = InterfaceBuilder::new(self.device, vec![]) + .ip_addrs(addresses) .finalize(); - // Server socket: this is a placeholder for the interface to route new connections to. - let server_socket: anyhow::Result = { - static mut TCP_SERVER_RX_DATA: [u8; 0] = []; - static mut TCP_SERVER_TX_DATA: [u8; 0] = []; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + // Create virtual server for each port forward + for port_forward in self.port_forwards.iter() { + let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?; + iface.add_socket(server_socket); + } - socket - .listen(( - IpAddress::from(self.port_forward.destination.ip()), - self.port_forward.destination.port(), - )) - .with_context(|| "Virtual server socket failed to listen")?; + // The next time to poll the interface. Can be None for instant poll. + let mut next_poll: Option = None; - Ok(socket) - }; + // Bus endpoint to read events + let mut endpoint = self.bus.new_endpoint(); - let client_socket: anyhow::Result = { - let rx_data = vec![0u8; MAX_PACKET]; - let tx_data = vec![0u8; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); - let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); - let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - Ok(socket) - }; + let mut port_client_handle_map: HashMap = HashMap::new(); - let _server_handle = virtual_interface.add_socket(server_socket?); - let client_handle = virtual_interface.add_socket(client_socket?); - - // Any data that wasn't sent because it was over the sending buffer limit - let mut tx_extra = Vec::new(); - - // Counts the connection attempts by the virtual client - let mut connection_attempts = 0; - // Whether the client has successfully connected before. Prevents the case of connecting again. - let mut has_connected = false; + // Data packets to send from a virtual client + let mut send_queue: HashMap>> = HashMap::new(); loop { - let loop_start = smoltcp::time::Instant::now(); + tokio::select! { + _ = match (next_poll, port_client_handle_map.len()) { + (None, 0) => tokio::time::sleep(Duration::MAX), + (None, _) => tokio::time::sleep(Duration::ZERO), + (Some(until), _) => tokio::time::sleep_until(until), + } => { + let loop_start = smoltcp::time::Instant::now(); - // Shutdown occurs when the real client closes the connection, - // or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore. - // One last poll-loop iteration is executed so that the RST segment can be dispatched. - let shutdown = self.abort.load(Ordering::Relaxed); - - if shutdown { - // Shutdown: sends a RST packet. - trace!("[{}] Shutting down virtual interface", self.virtual_port); - let client_socket = virtual_interface.get_socket::(client_handle); - client_socket.abort(); - } - - match virtual_interface.poll(loop_start) { - Ok(processed) if processed => { - trace!( - "[{}] Virtual interface polled some packets to be processed", - self.virtual_port - ); - } - Err(e) => { - error!( - "[{}] Virtual interface poll error: {:?}", - self.virtual_port, e - ); - } - _ => {} - } - - { - let (client_socket, context) = - virtual_interface.get_socket_and_context::(client_handle); - - if !shutdown && client_socket.state() == TcpState::Closed && !has_connected { - // Not shutting down, but the client socket is closed, and the client never successfully connected. - if connection_attempts < 10 { - // Try to connect - client_socket - .connect( - context, - ( - IpAddress::from(self.port_forward.destination.ip()), - self.port_forward.destination.port(), - ), - (IpAddress::from(source_peer_ip), self.virtual_port), - ) - .with_context(|| "Virtual server socket failed to listen")?; - if connection_attempts > 0 { - debug!( - "[{}] Virtual client retrying connection in 500ms", - self.virtual_port - ); - // Not our first connection attempt, wait a little bit. - tokio::time::sleep(Duration::from_millis(500)).await; + match iface.poll(loop_start) { + Ok(processed) if processed => { + trace!("TCP virtual interface polled some packets to be processed"); } - } else { - // Too many connection attempts - self.abort.store(true, Ordering::Relaxed); - } - connection_attempts += 1; - continue; - } - - if client_socket.state() == TcpState::Established { - // Prevent reconnection if the server later closes. - has_connected = true; - } - - if client_socket.can_recv() { - match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { - Ok(data) => { - trace!( - "[{}] Virtual client received {} bytes of data", - self.virtual_port, - data.len() - ); - // Send it to the real client - if let Err(e) = self.data_to_real_client_tx.send(data).await { - error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", self.virtual_port, e); - } - } - Err(e) => { - error!( - "[{}] Failed to read from virtual client socket: {:?}", - self.virtual_port, e - ); - } - } - } - if client_socket.can_send() { - if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() { - virtual_client_ready_tx - .send(()) - .expect("Failed to notify real client that virtual client is ready"); + Err(e) => error!("TCP virtual interface poll error: {:?}", e), + _ => {} } - let mut to_transfer = None; - - if tx_extra.is_empty() { - // The payload segment from the previous loop is complete, - // we can now read the next payload in the queue. - if let Ok(data) = data_to_virtual_server_rx.try_recv() { - to_transfer = Some(data); - } else if client_socket.state() == TcpState::CloseWait { - // No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN), - // the interface is shutdown. - trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", self.virtual_port); - self.abort.store(true, Ordering::Relaxed); - continue; - } - } - - let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice(); - if !to_transfer_slice.is_empty() { - let total = to_transfer_slice.len(); - match client_socket.send_slice(to_transfer_slice) { - Ok(sent) => { - trace!( - "[{}] Sent {}/{} bytes via virtual client socket", - self.virtual_port, - sent, - total, - ); - tx_extra = Vec::from(&to_transfer_slice[sent..total]); - } - Err(e) => { - error!( - "[{}] Failed to send slice via virtual client socket: {:?}", - self.virtual_port, e - ); + // Find client socket send data to + for (virtual_port, client_handle) in port_client_handle_map.iter() { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.can_send() { + if let Some(send_queue) = send_queue.get_mut(virtual_port) { + let to_transfer = send_queue.pop_front(); + if let Some(to_transfer_slice) = to_transfer.as_deref() { + let total = to_transfer_slice.len(); + match client_socket.send_slice(to_transfer_slice) { + Ok(sent) => { + if sent < total { + // Sometimes only a subset is sent, so the rest needs to be sent on the next poll + let tx_extra = Vec::from(&to_transfer_slice[sent..total]); + send_queue.push_front(tx_extra); + } + } + Err(e) => { + error!( + "Failed to send slice via virtual client socket: {:?}", e + ); + } + } + } else if client_socket.state() == TcpState::CloseWait { + client_socket.close(); + } + break; } } } - } - } - if shutdown { - break; - } + // Find client socket recv data from + for (virtual_port, client_handle) in port_client_handle_map.iter() { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.can_recv() { + match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { + Ok(data) => { + if !data.is_empty() { + endpoint.send(Event::RemoteData(*virtual_port, data)); + break; + } else { + continue; + } + } + Err(e) => { + error!( + "Failed to read from virtual client socket: {:?}", e + ); + } + } + } + } - match virtual_interface.poll_delay(loop_start) { - Some(smoltcp::time::Duration::ZERO) => { - continue; + // Find closed sockets + port_client_handle_map.retain(|virtual_port, client_handle| { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.state() == TcpState::Closed { + endpoint.send(Event::ClientConnectionDropped(*virtual_port)); + send_queue.remove(virtual_port); + false + } else { + // Not closed, retain + true + } + }); + + // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) + next_poll = match iface.poll_delay(loop_start) { + Some(smoltcp::time::Duration::ZERO) => None, + Some(delay) => { + trace!("TCP Virtual interface delayed next poll by {}", delay); + Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis())) + }, + None => None, + }; } - _ => { - tokio::time::sleep(Duration::from_millis(1)).await; + event = endpoint.recv() => { + match event { + Event::ClientConnectionInitiated(port_forward, virtual_port) => { + let client_socket = TcpVirtualInterface::new_client_socket()?; + let client_handle = iface.add_socket(client_socket); + + // Add handle to map + port_client_handle_map.insert(virtual_port, client_handle); + send_queue.insert(virtual_port, VecDeque::new()); + + let (client_socket, context) = iface.get_socket_and_context::(client_handle); + + client_socket + .connect( + context, + ( + IpAddress::from(port_forward.destination.ip()), + port_forward.destination.port(), + ), + (IpAddress::from(self.source_peer_ip), virtual_port.num()), + ) + .with_context(|| "Virtual server socket failed to listen")?; + + next_poll = None; + } + Event::ClientConnectionDropped(virtual_port) => { + if let Some(client_handle) = port_client_handle_map.get(&virtual_port) { + let client_handle = *client_handle; + port_client_handle_map.remove(&virtual_port); + send_queue.remove(&virtual_port); + + let client_socket = iface.get_socket::(client_handle); + client_socket.close(); + next_poll = None; + } + } + Event::LocalData(virtual_port, data) if send_queue.contains_key(&virtual_port) => { + if let Some(send_queue) = send_queue.get_mut(&virtual_port) { + send_queue.push_back(data); + next_poll = None; + } + } + Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Tcp => { + next_poll = None; + } + _ => {} + } } } } - trace!("[{}] Virtual interface task terminated", self.virtual_port); - self.abort.store(true, Ordering::Relaxed); - Ok(()) } } diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 1fdbb63..022a2c2 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,39 +1,39 @@ -use anyhow::Context; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; +#![allow(dead_code)] +use std::net::IpAddr; +use crate::{Bus, PortProtocol}; use async_trait::async_trait; -use smoltcp::iface::{InterfaceBuilder, SocketHandle}; -use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; -use smoltcp::wire::{IpAddress, IpCidr}; use crate::config::PortForwardConfig; use crate::virtual_device::VirtualIpDevice; -use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; -use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; +use crate::virtual_iface::VirtualInterfacePoll; const MAX_PACKET: usize = 65536; pub struct UdpVirtualInterface { - port_forward: PortForwardConfig, - wg: Arc, - data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec)>, - data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec)>, + source_peer_ip: IpAddr, + port_forwards: Vec, + device: VirtualIpDevice, + bus: Bus, } impl UdpVirtualInterface { + /// Initialize the parameters for a new virtual interface. + /// Use the `poll_loop()` future to start the virtual interface poll loop. pub fn new( - port_forward: PortForwardConfig, - wg: Arc, - data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec)>, - data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec)>, + port_forwards: Vec, + bus: Bus, + device: VirtualIpDevice, + source_peer_ip: IpAddr, ) -> Self { Self { - port_forward, - wg, - data_to_real_client_tx, - data_to_virtual_server_rx, + port_forwards: port_forwards + .into_iter() + .filter(|f| matches!(f.protocol, PortProtocol::Udp)) + .collect(), + device, + source_peer_ip, + bus, } } } @@ -41,160 +41,9 @@ impl UdpVirtualInterface { #[async_trait] impl VirtualInterfacePoll for UdpVirtualInterface { async fn poll_loop(self) -> anyhow::Result<()> { - // Data receiver to dispatch using virtual client sockets - let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; - - // The IP to bind client sockets to - let source_peer_ip = self.wg.source_peer_ip; - - // The IP/port to bind the server socket to - let destination = self.port_forward.destination; - - // Initialize a channel for IP packets. - // The "base transmitted" is cloned so that each virtual port can register a sender in the tunnel. - // The receiver is given to the device so that the Virtual Interface can process incoming IP packets from the tunnel. - let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); - - let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx); - let mut virtual_interface = InterfaceBuilder::new(device, vec![]) - .ip_addrs([ - // Interface handles IP packets for the sender and recipient - IpCidr::new(source_peer_ip.into(), 32), - IpCidr::new(destination.ip().into(), 32), - ]) - .finalize(); - - // Server socket: this is a placeholder for the interface. - let server_socket: anyhow::Result = { - static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; - static mut UDP_SERVER_RX_DATA: [u8; 0] = []; - static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; - static mut UDP_SERVER_TX_DATA: [u8; 0] = []; - let udp_rx_buffer = - UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { - &mut UDP_SERVER_RX_DATA[..] - }); - let udp_tx_buffer = - UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { - &mut UDP_SERVER_TX_DATA[..] - }); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - - socket - .bind((IpAddress::from(destination.ip()), destination.port())) - .with_context(|| "UDP virtual server socket failed to listen")?; - - Ok(socket) - }; - - let _server_handle = virtual_interface.add_socket(server_socket?); - - // A map of virtual port to client socket. - let mut client_sockets: HashMap = HashMap::new(); - - // The next instant required to poll the virtual interface - // None means "immediate poll required". - let mut next_poll: Option = None; - - loop { - let wg = self.wg.clone(); - tokio::select! { - // Wait the recommended amount of time by smoltcp, and poll again. - _ = match next_poll { - None => tokio::time::sleep(Duration::ZERO), - Some(until) => tokio::time::sleep_until(until) - } => { - let loop_start = smoltcp::time::Instant::now(); - - match virtual_interface.poll(loop_start) { - Ok(processed) if processed => { - trace!("UDP virtual interface polled some packets to be processed"); - } - Err(e) => error!("UDP virtual interface poll error: {:?}", e), - _ => {} - } - - // Loop through each client socket and check if there is any data to send back - // to the real client. - for (virtual_port, client_socket_handle) in client_sockets.iter() { - let client_socket = virtual_interface.get_socket::(*client_socket_handle); - match client_socket.recv() { - Ok((data, _peer)) => { - // Send the data back to the real client using MPSC channel - self.data_to_real_client_tx - .send((*virtual_port, data.to_vec())) - .await - .unwrap_or_else(|e| { - error!( - "[{}] Failed to dispatch data from virtual client to real client: {:?}", - virtual_port, e - ); - }); - } - Err(smoltcp::Error::Exhausted) => {} - Err(e) => { - error!( - "[{}] Failed to read from virtual client socket: {:?}", - virtual_port, e - ); - } - } - } - - next_poll = match virtual_interface.poll_delay(loop_start) { - Some(smoltcp::time::Duration::ZERO) => None, - Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())), - None => None, - } - } - // Wait for data to be received from the real client - data_recv_result = data_to_virtual_server_rx.recv() => { - if let Some((client_port, data)) = data_recv_result { - // Register the socket in WireGuard Tunnel (overrides any previous registration as well) - wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to register UDP socket in WireGuard tunnel: {:?}", - client_port, e - ); - }); - - let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { - let rx_meta = vec![UdpPacketMetadata::EMPTY; 10]; - let tx_meta = vec![UdpPacketMetadata::EMPTY; 10]; - let rx_data = vec![0u8; MAX_PACKET]; - let tx_data = vec![0u8; MAX_PACKET]; - let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); - let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - - socket - .bind((IpAddress::from(wg.source_peer_ip), client_port.0)) - .unwrap_or_else(|e| { - error!( - "[{}] UDP virtual client socket failed to bind: {:?}", - client_port, e - ); - }); - - virtual_interface.add_socket(socket) - }); - - let client_socket = virtual_interface.get_socket::(*client_socket_handle); - client_socket - .send_slice( - &data, - (IpAddress::from(destination.ip()), destination.port()).into(), - ) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to send data to virtual server: {:?}", - client_port, e - ); - }); - } - } - } - } + // TODO: Create smoltcp virtual device and interface + // TODO: Create smoltcp virtual servers for `port_forwards` + // TODO: listen on events + futures::future::pending().await } } diff --git a/src/wg.rs b/src/wg.rs index a7f56e5..60767e3 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -1,15 +1,15 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::Duration; +use crate::Bus; use anyhow::Context; use boringtun::noise::{Tunn, TunnResult}; use log::Level; -use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket}; +use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet}; use tokio::net::UdpSocket; -use tokio::sync::RwLock; use crate::config::{Config, PortProtocol}; -use crate::virtual_iface::VirtualPort; +use crate::events::Event; /// The capacity of the channel for received IP packets. pub const DISPATCH_CAPACITY: usize = 1_000; @@ -26,17 +26,13 @@ pub struct WireGuardTunnel { udp: UdpSocket, /// The address of the public WireGuard endpoint (UDP). pub(crate) endpoint: SocketAddr, - /// Maps virtual ports to the corresponding IP packet dispatcher. - virtual_port_ip_tx: dashmap::DashMap>>, - /// IP packet dispatcher for unroutable packets. `None` if not initialized. - sink_ip_tx: RwLock>>>, - /// The max transmission unit for WireGuard. - pub(crate) max_transmission_unit: usize, + /// Event bus + bus: Bus, } impl WireGuardTunnel { /// Initialize a new WireGuard tunnel. - pub async fn new(config: &Config) -> anyhow::Result { + pub async fn new(config: &Config, bus: Bus) -> anyhow::Result { let source_peer_ip = config.source_peer_ip; let peer = Self::create_tunnel(config)?; let endpoint = config.endpoint_addr; @@ -46,16 +42,13 @@ impl WireGuardTunnel { }) .await .with_context(|| "Failed to create UDP socket for WireGuard connection")?; - let virtual_port_ip_tx = Default::default(); Ok(Self { source_peer_ip, peer, udp, endpoint, - virtual_port_ip_tx, - sink_ip_tx: RwLock::new(None), - max_transmission_unit: config.max_transmission_unit, + bus, }) } @@ -90,31 +83,20 @@ impl WireGuardTunnel { Ok(()) } - /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. - pub fn register_virtual_interface( - &self, - virtual_port: VirtualPort, - sender: tokio::sync::mpsc::Sender>, - ) -> anyhow::Result<()> { - self.virtual_port_ip_tx.insert(virtual_port, sender); - Ok(()) - } + pub async fn produce_task(&self) -> ! { + trace!("Starting WireGuard production task"); + let mut endpoint = self.bus.new_endpoint(); - /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. - pub async fn register_sink_interface( - &self, - ) -> anyhow::Result>> { - let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); - - let mut sink_ip_tx = self.sink_ip_tx.write().await; - *sink_ip_tx = Some(sender); - - Ok(receiver) - } - - /// Releases the virtual interface from IP dispatch. - pub fn release_virtual_interface(&self, virtual_port: VirtualPort) { - self.virtual_port_ip_tx.remove(&virtual_port); + loop { + if let Event::OutboundInternetPacket(data) = endpoint.recv().await { + match self.send_ip_packet(&data).await { + Ok(_) => {} + Err(e) => { + error!("{:?}", e); + } + } + } + } } /// WireGuard Routine task. Handles Handshake, keep-alive, etc. @@ -160,6 +142,7 @@ impl WireGuardTunnel { /// decapsulates them, and dispatches newly received IP packets. pub async fn consume_task(&self) -> ! { trace!("Starting WireGuard consumption task"); + let endpoint = self.bus.new_endpoint(); loop { let mut recv_buf = [0u8; MAX_PACKET]; @@ -212,38 +195,8 @@ impl WireGuardTunnel { // For debugging purposes: parse packet trace_ip_packet("Received IP packet", packet); - match self.route_ip_packet(packet) { - RouteResult::Dispatch(port) => { - let sender = self.virtual_port_ip_tx.get(&port); - if let Some(sender_guard) = sender { - let sender = sender_guard.value(); - match sender.send(packet.to_vec()).await { - Ok(_) => { - trace!( - "Dispatched received IP packet to virtual port {}", - port - ); - } - Err(e) => { - error!( - "Failed to dispatch received IP packet to virtual port {}: {}", - port, e - ); - } - } - } else { - warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port); - } - } - RouteResult::Sink => { - trace!("Sending unroutable IP packet received from WireGuard endpoint to sink interface"); - self.route_ip_sink(packet).await.unwrap_or_else(|e| { - error!("Failed to send unroutable IP packet to sink: {:?}", e) - }); - } - RouteResult::Drop => { - trace!("Dropped unroutable IP packet received from WireGuard endpoint"); - } + if let Some(proto) = self.route_protocol(packet) { + endpoint.send(Event::InboundInternetPacket(proto, packet.into())); } } _ => {} @@ -264,89 +217,32 @@ impl WireGuardTunnel { .with_context(|| "Failed to initialize boringtun Tunn") } - /// Makes a decision on the handling of an incoming IP packet. - fn route_ip_packet(&self, packet: &[u8]) -> RouteResult { + /// Determine the inner protocol of the incoming IP packet (TCP/UDP). + fn route_protocol(&self, packet: &[u8]) -> Option { match IpVersion::of_packet(packet) { Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet) .ok() // Only care if the packet is destined for this tunnel .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.protocol() { - IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), - IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())), + IpProtocol::Tcp => Some(PortProtocol::Tcp), + IpProtocol::Udp => Some(PortProtocol::Udp), // Unrecognized protocol, so we cannot determine where to route - _ => Some(RouteResult::Drop), + _ => None, }) - .flatten() - .unwrap_or(RouteResult::Drop), + .flatten(), Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet) .ok() // Only care if the packet is destined for this tunnel .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.next_header() { - IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), - IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())), + IpProtocol::Tcp => Some(PortProtocol::Tcp), + IpProtocol::Udp => Some(PortProtocol::Udp), // Unrecognized protocol, so we cannot determine where to route - _ => Some(RouteResult::Drop), + _ => None, }) - .flatten() - .unwrap_or(RouteResult::Drop), - _ => RouteResult::Drop, - } - } - - /// Makes a decision on the handling of an incoming TCP segment. - fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult { - TcpPacket::new_checked(segment) - .ok() - .map(|tcp| { - if self - .virtual_port_ip_tx - .get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp)) - .is_some() - { - RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp)) - } else if tcp.rst() { - RouteResult::Drop - } else { - RouteResult::Sink - } - }) - .unwrap_or(RouteResult::Drop) - } - - /// Makes a decision on the handling of an incoming UDP datagram. - fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult { - UdpPacket::new_checked(datagram) - .ok() - .map(|udp| { - if self - .virtual_port_ip_tx - .get(&VirtualPort(udp.dst_port(), PortProtocol::Udp)) - .is_some() - { - RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp)) - } else { - RouteResult::Drop - } - }) - .unwrap_or(RouteResult::Drop) - } - - /// Route a packet to the IP sink interface. - async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> { - let ip_sink_tx = self.sink_ip_tx.read().await; - - if let Some(ip_sink_tx) = &*ip_sink_tx { - ip_sink_tx - .send(packet.to_vec()) - .await - .with_context(|| "Failed to dispatch IP packet to sink interface") - } else { - warn!( - "Could not dispatch unroutable IP packet to sink because interface is not active." - ); - Ok(()) + .flatten(), + _ => None, } } } @@ -370,12 +266,3 @@ fn trace_ip_packet(message: &str, packet: &[u8]) { } } } - -enum RouteResult { - /// Dispatch the packet to the virtual port. - Dispatch(VirtualPort), - /// The packet is not routable, and should be sent to the sink interface. - Sink, - /// The packet is not routable, and can be safely ignored. - Drop, -} From abd9df6be42f6d1fbc219250921dcb6c2e87a517 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Sat, 8 Jan 2022 02:25:56 -0500 Subject: [PATCH 2/2] Implement event-based UDP interface --- src/events.rs | 2 +- src/main.rs | 8 +- src/tunnel/tcp.rs | 2 +- src/tunnel/udp.rs | 8 +- src/virtual_iface/mod.rs | 3 +- src/virtual_iface/tcp.rs | 47 +++++----- src/virtual_iface/udp.rs | 198 ++++++++++++++++++++++++++++++++++++--- 7 files changed, 218 insertions(+), 50 deletions(-) diff --git a/src/events.rs b/src/events.rs index 1437bc2..33f8ab9 100644 --- a/src/events.rs +++ b/src/events.rs @@ -15,7 +15,7 @@ pub enum Event { /// A connection was dropped from the pool and should be closed in all interfaces. ClientConnectionDropped(VirtualPort), /// Data received by the local server that should be sent to the virtual server. - LocalData(VirtualPort, Vec), + LocalData(PortForwardConfig, VirtualPort, Vec), /// Data received by the remote server that should be sent to the local client. RemoteData(VirtualPort, Vec), /// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device. diff --git a/src/main.rs b/src/main.rs index 1a01b51..2377245 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,8 +72,8 @@ async fn main() -> anyhow::Result<()> { // Start TCP Virtual Interface let port_forwards = config.port_forwards.clone(); - let iface = TcpVirtualInterface::new(port_forwards, bus, device, config.source_peer_ip); - tokio::spawn(async move { iface.poll_loop().await }); + let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip); + tokio::spawn(async move { iface.poll_loop(device).await }); } if config @@ -88,8 +88,8 @@ async fn main() -> anyhow::Result<()> { // Start UDP Virtual Interface let port_forwards = config.port_forwards.clone(); - let iface = UdpVirtualInterface::new(port_forwards, bus, device, config.source_peer_ip); - tokio::spawn(async move { iface.poll_loop().await }); + let iface = UdpVirtualInterface::new(port_forwards, bus, config.source_peer_ip); + tokio::spawn(async move { iface.poll_loop(device).await }); } { diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index 718cf57..4633e78 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -90,7 +90,7 @@ async fn handle_tcp_proxy_connection( match socket.try_read_buf(&mut buffer) { Ok(size) if size > 0 => { let data = Vec::from(&buffer[..size]); - endpoint.send(Event::LocalData(virtual_port, data)); + endpoint.send(Event::LocalData(port_forward, virtual_port, data)); // Reset buffer buffer.clear(); } diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 34fb26d..d9bd43d 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -48,7 +48,7 @@ pub async fn udp_proxy_server( to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => { match to_send_result { Ok(Some((port, data))) => { - endpoint.send(Event::LocalData(port, data)); + endpoint.send(Event::LocalData(port_forward, port, data)); } Ok(None) => { continue; @@ -64,12 +64,12 @@ pub async fn udp_proxy_server( } event = endpoint.recv() => { if let Event::RemoteData(port, data) = event { - if let Some(peer_addr) = port_pool.get_peer_addr(port).await { - if let Err(e) = socket.send_to(&data, peer_addr).await { + if let Some(peer) = port_pool.get_peer_addr(port).await { + if let Err(e) = socket.send_to(&data, peer).await { error!( "[{}] Failed to send UDP datagram to real client ({}): {:?}", port, - peer_addr, + peer, e, ); } diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs index ea3cd74..4fd6f47 100644 --- a/src/virtual_iface/mod.rs +++ b/src/virtual_iface/mod.rs @@ -2,6 +2,7 @@ pub mod tcp; pub mod udp; use crate::config::PortProtocol; +use crate::VirtualIpDevice; use async_trait::async_trait; use std::fmt::{Display, Formatter}; @@ -9,7 +10,7 @@ use std::fmt::{Display, Formatter}; pub trait VirtualInterfacePoll { /// Initializes the virtual interface and processes incoming data to be dispatched /// to the WireGuard tunnel and to the real client. - async fn poll_loop(mut self) -> anyhow::Result<()>; + async fn poll_loop(mut self, device: VirtualIpDevice) -> anyhow::Result<()>; } /// Virtual port. diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 6c283f1..190af1f 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -18,25 +18,18 @@ const MAX_PACKET: usize = 65536; pub struct TcpVirtualInterface { source_peer_ip: IpAddr, port_forwards: Vec, - device: VirtualIpDevice, bus: Bus, } impl TcpVirtualInterface { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. - pub fn new( - port_forwards: Vec, - bus: Bus, - device: VirtualIpDevice, - source_peer_ip: IpAddr, - ) -> Self { + pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { Self { port_forwards: port_forwards .into_iter() .filter(|f| matches!(f.protocol, PortProtocol::Tcp)) .collect(), - device, source_peer_ip, bus, } @@ -68,25 +61,28 @@ impl TcpVirtualInterface { let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); Ok(socket) } + + fn addresses(&self) -> Vec { + let mut addresses = HashSet::new(); + addresses.insert(IpAddress::from(self.source_peer_ip)); + for config in self.port_forwards.iter() { + addresses.insert(IpAddress::from(config.destination.ip())); + } + addresses + .into_iter() + .map(|addr| IpCidr::new(addr, 32)) + .collect() + } } #[async_trait] impl VirtualInterfacePoll for TcpVirtualInterface { - async fn poll_loop(self) -> anyhow::Result<()> { + async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { // Create CIDR block for source peer IP + each port forward IP - let addresses: Vec = { - let mut addresses = HashSet::new(); - addresses.insert(IpAddress::from(self.source_peer_ip)); - for config in self.port_forwards.iter() { - addresses.insert(IpAddress::from(config.destination.ip())); - } - addresses - .into_iter() - .map(|addr| IpCidr::new(addr, 32)) - .collect() - }; + let addresses = self.addresses(); - let mut iface = InterfaceBuilder::new(self.device, vec![]) + // Create virtual interface (contains smoltcp state machine) + let mut iface = InterfaceBuilder::new(device, vec![]) .ip_addrs(addresses) .finalize(); @@ -102,6 +98,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Bus endpoint to read events let mut endpoint = self.bus.new_endpoint(); + // Maps virtual port to its client socket handle let mut port_client_handle_map: HashMap = HashMap::new(); // Data packets to send from a virtual client @@ -146,10 +143,11 @@ impl VirtualInterfacePoll for TcpVirtualInterface { ); } } + break; } else if client_socket.state() == TcpState::CloseWait { client_socket.close(); + break; } - break; } } } @@ -163,8 +161,6 @@ impl VirtualInterfacePoll for TcpVirtualInterface { if !data.is_empty() { endpoint.send(Event::RemoteData(*virtual_port, data)); break; - } else { - continue; } } Err(e) => { @@ -182,6 +178,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { if client_socket.state() == TcpState::Closed { endpoint.send(Event::ClientConnectionDropped(*virtual_port)); send_queue.remove(virtual_port); + iface.remove_socket(*client_handle); false } else { // Not closed, retain @@ -235,7 +232,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { next_poll = None; } } - Event::LocalData(virtual_port, data) if send_queue.contains_key(&virtual_port) => { + Event::LocalData(_, virtual_port, data) if send_queue.contains_key(&virtual_port) => { if let Some(send_queue) = send_queue.get_mut(&virtual_port) { send_queue.push_back(data); next_poll = None; diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 022a2c2..f1725e7 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,49 +1,219 @@ #![allow(dead_code)] + +use anyhow::Context; +use std::collections::{HashMap, HashSet, VecDeque}; use std::net::IpAddr; +use crate::events::Event; use crate::{Bus, PortProtocol}; use async_trait::async_trait; +use smoltcp::iface::{InterfaceBuilder, SocketHandle}; +use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; +use smoltcp::wire::{IpAddress, IpCidr}; +use std::time::Duration; use crate::config::PortForwardConfig; use crate::virtual_device::VirtualIpDevice; -use crate::virtual_iface::VirtualInterfacePoll; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; const MAX_PACKET: usize = 65536; pub struct UdpVirtualInterface { source_peer_ip: IpAddr, port_forwards: Vec, - device: VirtualIpDevice, bus: Bus, } impl UdpVirtualInterface { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. - pub fn new( - port_forwards: Vec, - bus: Bus, - device: VirtualIpDevice, - source_peer_ip: IpAddr, - ) -> Self { + pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { Self { port_forwards: port_forwards .into_iter() .filter(|f| matches!(f.protocol, PortProtocol::Udp)) .collect(), - device, source_peer_ip, bus, } } + + fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { + static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_RX_DATA: [u8; 0] = []; + static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_TX_DATA: [u8; 0] = []; + let udp_rx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { + &mut UDP_SERVER_RX_DATA[..] + }); + let udp_tx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { + &mut UDP_SERVER_TX_DATA[..] + }); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + socket + .bind(( + IpAddress::from(port_forward.destination.ip()), + port_forward.destination.port(), + )) + .with_context(|| "UDP virtual server socket failed to bind")?; + Ok(socket) + } + + fn new_client_socket( + source_peer_ip: IpAddr, + client_port: VirtualPort, + ) -> anyhow::Result> { + let rx_meta = vec![UdpPacketMetadata::EMPTY; 10]; + let tx_meta = vec![UdpPacketMetadata::EMPTY; 10]; + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + socket + .bind((IpAddress::from(source_peer_ip), client_port.num())) + .with_context(|| "UDP virtual client failed to bind")?; + Ok(socket) + } + + fn addresses(&self) -> Vec { + let mut addresses = HashSet::new(); + addresses.insert(IpAddress::from(self.source_peer_ip)); + for config in self.port_forwards.iter() { + addresses.insert(IpAddress::from(config.destination.ip())); + } + addresses + .into_iter() + .map(|addr| IpCidr::new(addr, 32)) + .collect() + } } #[async_trait] impl VirtualInterfacePoll for UdpVirtualInterface { - async fn poll_loop(self) -> anyhow::Result<()> { - // TODO: Create smoltcp virtual device and interface - // TODO: Create smoltcp virtual servers for `port_forwards` - // TODO: listen on events - futures::future::pending().await + async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { + // Create CIDR block for source peer IP + each port forward IP + let addresses = self.addresses(); + + // Create virtual interface (contains smoltcp state machine) + let mut iface = InterfaceBuilder::new(device, vec![]) + .ip_addrs(addresses) + .finalize(); + + // Create virtual server for each port forward + for port_forward in self.port_forwards.iter() { + let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?; + iface.add_socket(server_socket); + } + + // The next time to poll the interface. Can be None for instant poll. + let mut next_poll: Option = None; + + // Bus endpoint to read events + let mut endpoint = self.bus.new_endpoint(); + + // Maps virtual port to its client socket handle + let mut port_client_handle_map: HashMap = HashMap::new(); + + // Data packets to send from a virtual client + let mut send_queue: HashMap)>> = + HashMap::new(); + + loop { + tokio::select! { + _ = match (next_poll, port_client_handle_map.len()) { + (None, 0) => tokio::time::sleep(Duration::MAX), + (None, _) => tokio::time::sleep(Duration::ZERO), + (Some(until), _) => tokio::time::sleep_until(until), + } => { + let loop_start = smoltcp::time::Instant::now(); + + match iface.poll(loop_start) { + Ok(processed) if processed => { + trace!("UDP virtual interface polled some packets to be processed"); + } + Err(e) => error!("UDP virtual interface poll error: {:?}", e), + _ => {} + } + + // Find client socket send data to + for (virtual_port, client_handle) in port_client_handle_map.iter() { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.can_send() { + if let Some(send_queue) = send_queue.get_mut(virtual_port) { + let to_transfer = send_queue.pop_front(); + if let Some((port_forward, data)) = to_transfer { + client_socket + .send_slice( + &data, + (IpAddress::from(port_forward.destination.ip()), port_forward.destination.port()).into(), + ) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to send data to virtual server: {:?}", + virtual_port, e + ); + }); + break; + } + } + } + } + + // Find client socket recv data from + for (virtual_port, client_handle) in port_client_handle_map.iter() { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.can_recv() { + match client_socket.recv() { + Ok((data, _peer)) => { + if !data.is_empty() { + endpoint.send(Event::RemoteData(*virtual_port, data.to_vec())); + break; + } + } + Err(e) => { + error!( + "Failed to read from virtual client socket: {:?}", e + ); + } + } + } + } + + // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) + next_poll = match iface.poll_delay(loop_start) { + Some(smoltcp::time::Duration::ZERO) => None, + Some(delay) => { + trace!("UDP Virtual interface delayed next poll by {}", delay); + Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis())) + }, + None => None, + }; + } + event = endpoint.recv() => { + match event { + Event::LocalData(port_forward, virtual_port, data) => { + if let Some(send_queue) = send_queue.get_mut(&virtual_port) { + // Client socket already exists + send_queue.push_back((port_forward, data)); + } else { + // Client socket does not exist + let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?; + let client_handle = iface.add_socket(client_socket); + + // Add handle to map + port_client_handle_map.insert(virtual_port, client_handle); + send_queue.insert(virtual_port, VecDeque::from(vec![(port_forward, data)])); + } + next_poll = None; + } + Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Udp => { + next_poll = None; + } + _ => {} + } + } + } + } } }