diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index a92bd5e..db232cf 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, VecDeque}; use std::net::{IpAddr, SocketAddr}; use std::ops::Range; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Instant; @@ -10,7 +11,8 @@ use rand::thread_rng; use tokio::net::UdpSocket; use crate::config::{PortForwardConfig, PortProtocol}; -use crate::virtual_iface::VirtualPort; +use crate::virtual_iface::udp::UdpVirtualInterface; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; const MAX_PACKET: usize = 65536; @@ -31,40 +33,120 @@ pub async fn udp_proxy_server( port_pool: UdpPortPool, wg: Arc, ) -> anyhow::Result<()> { + // Abort signal + let abort = Arc::new(AtomicBool::new(false)); + + // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back + // to the real client. + let (data_to_real_client_tx, mut data_to_real_client_rx) = + tokio::sync::mpsc::channel::<(VirtualPort, Vec)>(1_000); + + // data_to_real_server_(tx/rx): This task sends the data received from the real client to the + // virtual interface (virtual server socket). + let (data_to_virtual_server_tx, data_to_virtual_server_rx) = + tokio::sync::mpsc::channel::<(VirtualPort, Vec)>(1_000); + + { + // Spawn virtual interface + // Note: contrary to TCP, there is only one UDP virtual interface + let virtual_interface = UdpVirtualInterface::new( + port_forward, + wg, + data_to_real_client_tx, + data_to_virtual_server_rx, + ); + let abort = abort.clone(); + tokio::spawn(async move { + virtual_interface.poll_loop().await.unwrap_or_else(|e| { + error!("Virtual interface poll loop failed unexpectedly: {}", e); + abort.store(true, Ordering::Relaxed); + }); + }); + } + let socket = UdpSocket::bind(port_forward.source) .await .with_context(|| "Failed to bind on UDP proxy address")?; let mut buffer = [0u8; MAX_PACKET]; loop { - let (size, peer_addr) = socket - .recv_from(&mut buffer) - .await - .with_context(|| "Failed to accept incoming UDP datagram")?; - - let _wg = wg.clone(); - let _data = &buffer[..size].to_vec(); - - // Assign a 'virtual port': this is a unique port number used to route IP packets - // received from the WireGuard tunnel. It is the port number that the virtual client will - // listen on. - let port = match port_pool.next(peer_addr).await { - Ok(port) => port, - Err(e) => { - error!( - "Failed to assign virtual port number for UDP datagram from [{}]: {:?}", - peer_addr, e - ); - continue; + if abort.load(Ordering::Relaxed) { + break; + } + tokio::select! { + to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => { + match to_send_result { + Ok(Some((port, data))) => { + data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| { + error!( + "Failed to dispatch data to UDP virtual interface: {:?}", + e + ); + }); + } + Ok(None) => { + continue; + } + Err(e) => { + error!( + "Failed to read from client UDP socket: {:?}", + e + ); + break; + } + } } - }; - - let port = VirtualPort(port, PortProtocol::Udp); - debug!( - "[{}] Received datagram of {} bytes from {}", - port, size, peer_addr - ); + data_recv_result = data_to_real_client_rx.recv() => { + if let Some((port, data)) = data_recv_result { + if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await { + if let Err(e) = socket.send_to(&data, peer_addr).await { + error!( + "[{}] Failed to send UDP datagram to real client ({}): {:?}", + port, + peer_addr, + e, + ); + } + } + } + } + } } + Ok(()) +} + +async fn next_udp_datagram( + socket: &UdpSocket, + buffer: &mut [u8], + port_pool: UdpPortPool, +) -> anyhow::Result)>> { + let (size, peer_addr) = socket + .recv_from(buffer) + .await + .with_context(|| "Failed to accept incoming UDP datagram")?; + + // Assign a 'virtual port': this is a unique port number used to route IP packets + // received from the WireGuard tunnel. It is the port number that the virtual client will + // listen on. + let port = match port_pool.next(peer_addr).await { + Ok(port) => port, + Err(e) => { + error!( + "Failed to assign virtual port number for UDP datagram from [{}]: {:?}", + peer_addr, e + ); + return Ok(None); + } + }; + let port = VirtualPort(port, PortProtocol::Udp); + + debug!( + "[{}] Received datagram of {} bytes from {}", + port, size, peer_addr + ); + + let data = buffer[..size].to_vec(); + Ok(Some((port, data))) } /// A pool of virtual ports available for TCP connections. @@ -108,8 +190,14 @@ impl UdpPortPool { .pop_front() .with_context(|| "UDP virtual port pool is exhausted")?; inner.port_by_peer_addr.insert(peer_addr, port); + inner.peer_addr_by_port.insert(port, peer_addr); Ok(port) } + + pub async fn get_peer_addr(&self, port: u16) -> Option { + let inner = self.inner.read().await; + inner.peer_addr_by_port.get(&port).copied() + } } /// Non thread-safe inner logic for UDP port pool. @@ -119,4 +207,6 @@ struct UdpPortPoolInner { queue: VecDeque, /// The port assigned by peer IP/port. port_by_peer_addr: HashMap, + /// The socket address assigned to a peer IP/port. + peer_addr_by_port: HashMap, } diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 992e7d0..8cd0cdf 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,5 +1,5 @@ use crate::virtual_iface::VirtualPort; -use crate::wg::WireGuardTunnel; +use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; use anyhow::Context; use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; @@ -16,9 +16,11 @@ pub struct VirtualIpDevice { } impl VirtualIpDevice { + /// Registers a virtual IP device for a single virtual client. pub fn new(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { - let ip_dispatch_rx = wg - .register_virtual_interface(virtual_port) + let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); + + wg.register_virtual_interface(virtual_port, ip_dispatch_tx) .with_context(|| "Failed to register IP dispatch for virtual interface")?; Ok(Self { wg, ip_dispatch_rx }) diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs index d2ceb53..f3796bd 100644 --- a/src/virtual_iface/mod.rs +++ b/src/virtual_iface/mod.rs @@ -1,4 +1,5 @@ pub mod tcp; +pub mod udp; use crate::config::PortProtocol; use async_trait::async_trait; diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index fc8348c..baa92a7 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -266,7 +266,14 @@ impl VirtualInterfacePoll for TcpVirtualInterface { break; } - tokio::time::sleep(Duration::from_millis(1)).await; + match virtual_interface.poll_delay(&socket_set, loop_start) { + Some(smoltcp::time::Duration::ZERO) => { + continue; + } + _ => { + tokio::time::sleep(Duration::from_millis(1)).await; + } + } } trace!("[{}] Virtual interface task terminated", self.virtual_port); self.abort.store(true, Ordering::Relaxed); diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs new file mode 100644 index 0000000..88f13b6 --- /dev/null +++ b/src/virtual_iface/udp.rs @@ -0,0 +1,68 @@ +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; + +use crate::config::PortForwardConfig; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; +use crate::wg::WireGuardTunnel; + +pub struct UdpVirtualInterface { + port_forward: PortForwardConfig, + wg: Arc, + data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec)>, + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec)>, +} + +impl UdpVirtualInterface { + pub fn new( + port_forward: PortForwardConfig, + wg: Arc, + data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec)>, + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec)>, + ) -> Self { + Self { + port_forward, + wg, + data_to_real_client_tx, + data_to_virtual_server_rx, + } + } +} + +#[async_trait] +impl VirtualInterfacePoll for UdpVirtualInterface { + async fn poll_loop(self) -> anyhow::Result<()> { + // Data receiver to dispatch using virtual client sockets + let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; + + // The IP to bind client sockets to + let _source_peer_ip = self.wg.source_peer_ip; + + // The IP/port to bind the server socket to + let _destination = self.port_forward.destination; + + loop { + let _loop_start = smoltcp::time::Instant::now(); + // TODO: smoltcp UDP + + if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { + // TODO: Find the matching client socket and send + // Echo for now + self.data_to_real_client_tx + .send((client_port, data)) + .await + .unwrap_or_else(|e| { + error!( + "[{}] Failed to dispatch data from virtual client to real client: {:?}", + client_port, e + ); + }); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + // Ok(()) + } +} diff --git a/src/wg.rs b/src/wg.rs index 653568e..673da1c 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -12,7 +12,7 @@ use crate::config::{Config, PortProtocol}; use crate::virtual_iface::VirtualPort; /// The capacity of the channel for received IP packets. -const DISPATCH_CAPACITY: usize = 1_000; +pub const DISPATCH_CAPACITY: usize = 1_000; const MAX_PACKET: usize = 65536; /// A WireGuard tunnel. Encapsulates and decapsulates IP packets @@ -88,14 +88,14 @@ impl WireGuardTunnel { pub fn register_virtual_interface( &self, virtual_port: VirtualPort, - ) -> anyhow::Result>> { + sender: tokio::sync::mpsc::Sender>, + ) -> anyhow::Result<()> { let existing = self.virtual_port_ip_tx.contains_key(&virtual_port); if existing { Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) } else { - let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); self.virtual_port_ip_tx.insert(virtual_port, sender); - Ok(receiver) + Ok(()) } }