From 5cec6d4943557541fa5eff0afef4022e87987b3b Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 19 Oct 2021 01:55:04 -0400 Subject: [PATCH] Index ports with protocol in WG. Start writing UDP tunnel code with plans. --- src/config.rs | 2 +- src/main.rs | 14 ++++---- src/port_pool.rs | 62 ----------------------------------- src/tunnel/mod.rs | 12 ++++--- src/tunnel/tcp.rs | 70 +++++++++++++++++++++++++++++++++++++--- src/tunnel/udp.rs | 44 +++++++++++++++++++++++++ src/virtual_device.rs | 3 +- src/virtual_iface/mod.rs | 12 +++++++ src/virtual_iface/tcp.rs | 9 +++--- src/wg.rs | 19 +++++++---- 10 files changed, 156 insertions(+), 91 deletions(-) delete mode 100644 src/port_pool.rs create mode 100644 src/tunnel/udp.rs diff --git a/src/config.rs b/src/config.rs index 3a46881..0e3a727 100644 --- a/src/config.rs +++ b/src/config.rs @@ -332,7 +332,7 @@ impl Display for PortForwardConfig { } } -#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)] pub enum PortProtocol { Tcp, Udp, diff --git a/src/main.rs b/src/main.rs index b5da3ad..0d301d0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,12 +6,11 @@ use std::sync::Arc; use anyhow::Context; use crate::config::Config; -use crate::port_pool::PortPool; +use crate::tunnel::tcp::TcpPortPool; use crate::wg::WireGuardTunnel; pub mod config; pub mod ip_sink; -pub mod port_pool; pub mod tunnel; pub mod virtual_device; pub mod virtual_iface; @@ -21,7 +20,10 @@ pub mod wg; async fn main() -> anyhow::Result<()> { let config = Config::from_args().with_context(|| "Failed to read config")?; init_logger(&config)?; - let port_pool = Arc::new(PortPool::new()); + + // Initialize the port pool for each protocol + let tcp_port_pool = Arc::new(TcpPortPool::new()); + // TODO: udp_port_pool let wg = WireGuardTunnel::new(&config) .await @@ -52,12 +54,12 @@ async fn main() -> anyhow::Result<()> { port_forwards .into_iter() - .map(|pf| (pf, wg.clone(), port_pool.clone())) - .for_each(move |(pf, wg, port_pool)| { + .map(|pf| (pf, wg.clone(), tcp_port_pool.clone())) + .for_each(move |(pf, wg, tcp_port_pool)| { std::thread::spawn(move || { let cpu_pool = tokio::runtime::Runtime::new().unwrap(); cpu_pool.block_on(async move { - tunnel::port_forward(pf, source_peer_ip, port_pool, wg) + tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, wg) .await .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }); diff --git a/src/port_pool.rs b/src/port_pool.rs deleted file mode 100644 index 7bff712..0000000 --- a/src/port_pool.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::ops::Range; - -use anyhow::Context; -use rand::seq::SliceRandom; -use rand::thread_rng; - -const MIN_PORT: u16 = 32768; -const MAX_PORT: u16 = 60999; -const PORT_RANGE: Range = MIN_PORT..MAX_PORT; - -/// A pool of virtual ports available. -/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`. -pub struct PortPool { - /// Remaining ports - inner: lockfree::queue::Queue, - /// Ports in use, with their associated IP channel sender. - taken: lockfree::set::Set, -} - -impl Default for PortPool { - fn default() -> Self { - Self::new() - } -} - -impl PortPool { - /// Initializes a new pool of virtual ports. - pub fn new() -> Self { - let inner = lockfree::queue::Queue::default(); - let mut ports: Vec = PORT_RANGE.collect(); - ports.shuffle(&mut thread_rng()); - ports.into_iter().for_each(|p| inner.push(p) as ()); - Self { - inner, - taken: lockfree::set::Set::new(), - } - } - - /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). - pub fn next(&self) -> anyhow::Result { - let port = self - .inner - .pop() - .with_context(|| "Virtual port pool is exhausted")?; - self.taken - .insert(port) - .ok() - .with_context(|| "Failed to insert taken")?; - Ok(port) - } - - /// Releases a port back into the pool. - pub fn release(&self, port: u16) { - self.inner.push(port); - self.taken.remove(&port); - } - - /// Whether the given port is in use by a virtual interface. - pub fn is_in_use(&self, port: u16) -> bool { - self.taken.contains(&port) - } -} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 7eab856..00fe4e8 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -2,15 +2,17 @@ use std::net::IpAddr; use std::sync::Arc; use crate::config::{PortForwardConfig, PortProtocol}; -use crate::port_pool::PortPool; +use crate::tunnel::tcp::TcpPortPool; use crate::wg::WireGuardTunnel; -mod tcp; +pub mod tcp; +#[allow(unused)] +pub mod udp; pub async fn port_forward( port_forward: PortForwardConfig, source_peer_ip: IpAddr, - port_pool: Arc, + tcp_port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { info!( @@ -23,7 +25,7 @@ pub async fn port_forward( ); match port_forward.protocol { - PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, port_pool, wg).await, - PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")), + PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await, + PortProtocol::Udp => udp::udp_proxy_server(port_forward, /* udp_port_pool, */ wg).await, } } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index 987aa5e..1f01642 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -1,19 +1,26 @@ -use crate::config::PortForwardConfig; -use crate::port_pool::PortPool; +use crate::config::{PortForwardConfig, PortProtocol}; use crate::virtual_iface::tcp::TcpVirtualInterface; -use crate::virtual_iface::VirtualInterfacePoll; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; +use std::ops::Range; + +use rand::seq::SliceRandom; +use rand::thread_rng; + const MAX_PACKET: usize = 65536; +const MIN_PORT: u16 = 1000; +const MAX_PORT: u16 = 60999; +const PORT_RANGE: Range = MIN_PORT..MAX_PORT; /// Starts the server that listens on TCP connections. pub async fn tcp_proxy_server( port_forward: PortForwardConfig, - port_pool: Arc, + port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { let listener = TcpListener::bind(port_forward.source) @@ -59,7 +66,7 @@ pub async fn tcp_proxy_server( } // Release port when connection drops - wg.release_virtual_interface(virtual_port); + wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp)); port_pool.release(virtual_port); }); } @@ -194,3 +201,56 @@ async fn handle_tcp_proxy_connection( abort.store(true, Ordering::Relaxed); Ok(()) } + +/// A pool of virtual ports available for TCP connections. +/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`. +pub struct TcpPortPool { + /// Remaining ports + inner: lockfree::queue::Queue, + /// Ports in use, with their associated IP channel sender. + taken: lockfree::set::Set, +} + +impl Default for TcpPortPool { + fn default() -> Self { + Self::new() + } +} + +impl TcpPortPool { + /// Initializes a new pool of virtual ports. + pub fn new() -> Self { + let inner = lockfree::queue::Queue::default(); + let mut ports: Vec = PORT_RANGE.collect(); + ports.shuffle(&mut thread_rng()); + ports.into_iter().for_each(|p| inner.push(p) as ()); + Self { + inner, + taken: lockfree::set::Set::new(), + } + } + + /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). + pub fn next(&self) -> anyhow::Result { + let port = self + .inner + .pop() + .with_context(|| "Virtual port pool is exhausted")?; + self.taken + .insert(port) + .ok() + .with_context(|| "Failed to insert taken")?; + Ok(port) + } + + /// Releases a port back into the pool. + pub fn release(&self, port: u16) { + self.inner.push(port); + self.taken.remove(&port); + } + + /// Whether the given port is in use by a virtual interface. + pub fn is_in_use(&self, port: u16) -> bool { + self.taken.contains(&port) + } +} diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs new file mode 100644 index 0000000..7f9fda4 --- /dev/null +++ b/src/tunnel/udp.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use anyhow::Context; +use tokio::net::UdpSocket; + +use crate::config::PortForwardConfig; +use crate::wg::WireGuardTunnel; + +const MAX_PACKET: usize = 65536; + +/// How long to keep the UDP peer address assigned to its virtual specified port, in seconds. +const UDP_TIMEOUT_SECONDS: u64 = 60; + +/// To prevent port-flooding, we set a limit on the amount of open ports per IP address. +const PORTS_PER_IP: usize = 100; + +pub async fn udp_proxy_server( + port_forward: PortForwardConfig, + wg: Arc, +) -> anyhow::Result<()> { + let socket = UdpSocket::bind(port_forward.source) + .await + .with_context(|| "Failed to bind on UDP proxy address")?; + + let mut buffer = [0u8; MAX_PACKET]; + loop { + let (size, peer_addr) = socket + .recv_from(&mut buffer) + .await + .with_context(|| "Failed to accept incoming UDP datagram")?; + + let _wg = wg.clone(); + let _data = &buffer[..size].to_vec(); + debug!("Received datagram of {} bytes from {}", size, peer_addr); + + // Assign a 'virtual port': this is a unique port number used to route IP packets + // received from the WireGuard tunnel. It is the port number that the virtual client will + // listen on. + // Since UDP is connection-less, the port is assigned to the source SocketAddr for up to `UDP_TIMEOUT_SECONDS`; + // every datagram resets the timer for that SocketAddr. Each IP address also has a limit of active connections, + // discarding the LRU ports. + // TODO: UDP Port Pool + } +} diff --git a/src/virtual_device.rs b/src/virtual_device.rs index b8a61ce..992e7d0 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,3 +1,4 @@ +use crate::virtual_iface::VirtualPort; use crate::wg::WireGuardTunnel; use anyhow::Context; use smoltcp::phy::{Device, DeviceCapabilities, Medium}; @@ -15,7 +16,7 @@ pub struct VirtualIpDevice { } impl VirtualIpDevice { - pub fn new(virtual_port: u16, wg: Arc) -> anyhow::Result { + pub fn new(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { let ip_dispatch_rx = wg .register_virtual_interface(virtual_port) .with_context(|| "Failed to register IP dispatch for virtual interface")?; diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs index b9d3354..d2ceb53 100644 --- a/src/virtual_iface/mod.rs +++ b/src/virtual_iface/mod.rs @@ -1,6 +1,8 @@ pub mod tcp; +use crate::config::PortProtocol; use async_trait::async_trait; +use std::fmt::{Display, Formatter}; #[async_trait] pub trait VirtualInterfacePoll { @@ -8,3 +10,13 @@ pub trait VirtualInterfacePoll { /// to the WireGuard tunnel and to the real client. async fn poll_loop(mut self) -> anyhow::Result<()>; } + +/// Virtual port. +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub struct VirtualPort(pub u16, pub PortProtocol); + +impl Display for VirtualPort { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}:{}]", self.0, self.1) + } +} diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index c9fee95..fc8348c 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -1,6 +1,6 @@ -use crate::config::PortForwardConfig; +use crate::config::{PortForwardConfig, PortProtocol}; use crate::virtual_device::VirtualIpDevice; -use crate::virtual_iface::VirtualInterfacePoll; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; use async_trait::async_trait; @@ -74,8 +74,9 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Consumer for IP packets to send through the virtual interface // Initialize the interface - let device = VirtualIpDevice::new(self.virtual_port, self.wg) - .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; + let device = + VirtualIpDevice::new(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) + .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; let mut virtual_interface = InterfaceBuilder::new(device) .ip_addrs([ // Interface handles IP packets for the sender and recipient diff --git a/src/wg.rs b/src/wg.rs index 498fb98..d10ad50 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -8,7 +8,8 @@ use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket}; use tokio::net::UdpSocket; use tokio::sync::RwLock; -use crate::config::Config; +use crate::config::{Config, PortProtocol}; +use crate::virtual_iface::VirtualPort; /// The capacity of the channel for received IP packets. const DISPATCH_CAPACITY: usize = 1_000; @@ -26,7 +27,7 @@ pub struct WireGuardTunnel { /// The address of the public WireGuard endpoint (UDP). pub(crate) endpoint: SocketAddr, /// Maps virtual ports to the corresponding IP packet dispatcher. - virtual_port_ip_tx: lockfree::map::Map>>, + virtual_port_ip_tx: lockfree::map::Map>>, /// IP packet dispatcher for unroutable packets. `None` if not initialized. sink_ip_tx: RwLock>>>, } @@ -86,7 +87,7 @@ impl WireGuardTunnel { /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. pub fn register_virtual_interface( &self, - virtual_port: u16, + virtual_port: VirtualPort, ) -> anyhow::Result>> { let existing = self.virtual_port_ip_tx.get(&virtual_port); if existing.is_some() { @@ -111,7 +112,7 @@ impl WireGuardTunnel { } /// Releases the virtual interface from IP dispatch. - pub fn release_virtual_interface(&self, virtual_port: u16) { + pub fn release_virtual_interface(&self, virtual_port: VirtualPort) { self.virtual_port_ip_tx.remove(&virtual_port); } @@ -296,8 +297,12 @@ impl WireGuardTunnel { TcpPacket::new_checked(segment) .ok() .map(|tcp| { - if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() { - RouteResult::Dispatch(tcp.dst_port()) + if self + .virtual_port_ip_tx + .get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp)) + .is_some() + { + RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp)) } else if tcp.rst() { RouteResult::Drop } else { @@ -347,7 +352,7 @@ fn trace_ip_packet(message: &str, packet: &[u8]) { enum RouteResult { /// Dispatch the packet to the virtual port. - Dispatch(u16), + Dispatch(VirtualPort), /// The packet is not routable, and should be sent to the sink interface. Sink, /// The packet is not routable, and can be safely ignored.