diff --git a/src/main.rs b/src/main.rs index b2f1f75..d20138d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use anyhow::Context; use crate::config::Config; use crate::tunnel::tcp::TcpPortPool; +use crate::tunnel::udp::UdpPortPool; use crate::wg::WireGuardTunnel; pub mod config; @@ -23,7 +24,7 @@ async fn main() -> anyhow::Result<()> { // Initialize the port pool for each protocol let tcp_port_pool = TcpPortPool::new(); - // TODO: udp_port_pool + let udp_port_pool = UdpPortPool::new(); let wg = WireGuardTunnel::new(&config) .await @@ -54,12 +55,12 @@ async fn main() -> anyhow::Result<()> { port_forwards .into_iter() - .map(|pf| (pf, wg.clone(), tcp_port_pool.clone())) - .for_each(move |(pf, wg, tcp_port_pool)| { + .map(|pf| (pf, wg.clone(), tcp_port_pool.clone(), udp_port_pool.clone())) + .for_each(move |(pf, wg, tcp_port_pool, udp_port_pool)| { std::thread::spawn(move || { let cpu_pool = tokio::runtime::Runtime::new().unwrap(); cpu_pool.block_on(async move { - tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, wg) + tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg) .await .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }); diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index a4042c9..c7f2a67 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::config::{PortForwardConfig, PortProtocol}; use crate::tunnel::tcp::TcpPortPool; +use crate::tunnel::udp::UdpPortPool; use crate::wg::WireGuardTunnel; pub mod tcp; @@ -13,6 +14,7 @@ pub async fn port_forward( port_forward: PortForwardConfig, source_peer_ip: IpAddr, tcp_port_pool: TcpPortPool, + udp_port_pool: UdpPortPool, wg: Arc, ) -> anyhow::Result<()> { info!( @@ -26,6 +28,6 @@ pub async fn port_forward( match port_forward.protocol { PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await, - PortProtocol::Udp => udp::udp_proxy_server(port_forward, /* udp_port_pool, */ wg).await, + PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, wg).await, } } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index fbbdbc2..f49aa7c 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -3,7 +3,7 @@ use crate::virtual_iface::tcp::TcpVirtualInterface; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; -use std::collections::{HashSet, VecDeque}; +use std::collections::VecDeque; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; @@ -235,8 +235,7 @@ impl TcpPortPool { let port = inner .queue .pop_front() - .with_context(|| "Virtual port pool is exhausted")?; - inner.taken.insert(port); + .with_context(|| "TCP virtual port pool is exhausted")?; Ok(port) } @@ -244,15 +243,12 @@ impl TcpPortPool { pub async fn release(&self, port: u16) { let mut inner = self.inner.write().await; inner.queue.push_back(port); - inner.taken.remove(&port); } } /// Non thread-safe inner logic for TCP port pool. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] struct TcpPortPoolInner { /// Remaining ports in the pool. queue: VecDeque, - /// Ports taken out of the pool. - taken: HashSet, } diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 2326351..a92bd5e 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -1,12 +1,22 @@ +use std::collections::{HashMap, VecDeque}; +use std::net::{IpAddr, SocketAddr}; +use std::ops::Range; use std::sync::Arc; +use std::time::Instant; use anyhow::Context; +use rand::seq::SliceRandom; +use rand::thread_rng; use tokio::net::UdpSocket; -use crate::config::PortForwardConfig; +use crate::config::{PortForwardConfig, PortProtocol}; +use crate::virtual_iface::VirtualPort; use crate::wg::WireGuardTunnel; const MAX_PACKET: usize = 65536; +const MIN_PORT: u16 = 1000; +const MAX_PORT: u16 = 60999; +const PORT_RANGE: Range = MIN_PORT..MAX_PORT; /// How long to keep the UDP peer address assigned to its virtual specified port, in seconds. /// TODO: Make this configurable by the CLI @@ -18,6 +28,7 @@ const PORTS_PER_IP: usize = 100; pub async fn udp_proxy_server( port_forward: PortForwardConfig, + port_pool: UdpPortPool, wg: Arc, ) -> anyhow::Result<()> { let socket = UdpSocket::bind(port_forward.source) @@ -33,14 +44,79 @@ pub async fn udp_proxy_server( let _wg = wg.clone(); let _data = &buffer[..size].to_vec(); - debug!("Received datagram of {} bytes from {}", size, peer_addr); // Assign a 'virtual port': this is a unique port number used to route IP packets // received from the WireGuard tunnel. It is the port number that the virtual client will // listen on. - // Since UDP is connection-less, the port is assigned to the source SocketAddr for up to `UDP_TIMEOUT_SECONDS`; - // every datagram resets the timer for that SocketAddr. Each IP address also has a limit of active connections, - // discarding the LRU ports. - // TODO: UDP Port Pool + let port = match port_pool.next(peer_addr).await { + Ok(port) => port, + Err(e) => { + error!( + "Failed to assign virtual port number for UDP datagram from [{}]: {:?}", + peer_addr, e + ); + continue; + } + }; + + let port = VirtualPort(port, PortProtocol::Udp); + debug!( + "[{}] Received datagram of {} bytes from {}", + port, size, peer_addr + ); } } + +/// A pool of virtual ports available for TCP connections. +#[derive(Clone)] +pub struct UdpPortPool { + inner: Arc>, +} + +impl Default for UdpPortPool { + fn default() -> Self { + Self::new() + } +} + +impl UdpPortPool { + /// Initializes a new pool of virtual ports. + pub fn new() -> Self { + let mut inner = UdpPortPoolInner::default(); + let mut ports: Vec = PORT_RANGE.collect(); + ports.shuffle(&mut thread_rng()); + ports + .into_iter() + .for_each(|p| inner.queue.push_back(p) as ()); + Self { + inner: Arc::new(tokio::sync::RwLock::new(inner)), + } + } + + /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). + pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result { + { + let inner = self.inner.read().await; + if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) { + return Ok(*port); + } + } + + let mut inner = self.inner.write().await; + let port = inner + .queue + .pop_front() + .with_context(|| "UDP virtual port pool is exhausted")?; + inner.port_by_peer_addr.insert(peer_addr, port); + Ok(port) + } +} + +/// Non thread-safe inner logic for UDP port pool. +#[derive(Debug, Default)] +struct UdpPortPoolInner { + /// Remaining ports in the pool. + queue: VecDeque, + /// The port assigned by peer IP/port. + port_by_peer_addr: HashMap, +}