From d975efefaf787934bfe7f6f817ce3230f246c589 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Mon, 25 Oct 2021 19:05:40 -0400 Subject: [PATCH] End-to-end UDP implementation Port re-use still needs to be implemented to prevent exhaustion over time, and flooding. --- src/tunnel/udp.rs | 5 +- src/virtual_iface/udp.rs | 102 ++++++++++++++++++++++++++++++--------- src/wg.rs | 35 +++++++++----- 3 files changed, 106 insertions(+), 36 deletions(-) diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index db232cf..6eb7e02 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -175,7 +175,7 @@ impl UdpPortPool { } } - /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). + /// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity). pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result { { let inner = self.inner.read().await; @@ -184,6 +184,9 @@ impl UdpPortPool { } } + // TODO: When the port pool is exhausted, it should re-queue the least recently used port. + // TODO: Limit number of ports in use by peer IP + let mut inner = self.inner.write().await; let port = inner .queue diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 139610d..382d7c3 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -4,9 +4,8 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use dashmap::DashMap; use smoltcp::iface::InterfaceBuilder; -use smoltcp::socket::{SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; +use smoltcp::socket::{SocketHandle, SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; use smoltcp::wire::{IpAddress, IpCidr}; use crate::config::PortForwardConfig; @@ -14,6 +13,8 @@ use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; +const MAX_PACKET: usize = 65536; + pub struct UdpVirtualInterface { port_forward: PortForwardConfig, wg: Arc, @@ -89,31 +90,88 @@ impl VirtualInterfacePoll for UdpVirtualInterface { let mut socket_set = SocketSet::new(vec![]); let _server_handle = socket_set.add(server_socket?); + // A map of virtual port to client socket. + let mut client_sockets: HashMap = HashMap::new(); + loop { - let _loop_start = smoltcp::time::Instant::now(); + let loop_start = smoltcp::time::Instant::now(); let wg = self.wg.clone(); - // TODO: smoltcp UDP + + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + trace!("UDP virtual interface polled some packets to be processed"); + } + Err(e) => error!("UDP virtual interface poll error: {:?}", e), + _ => {} + } + + // Loop through each client socket and check if there is any data to send back + // to the real client. + for (virtual_port, client_socket_handle) in client_sockets.iter() { + let mut client_socket = socket_set.get::(*client_socket_handle); + match client_socket.recv() { + Ok((data, _peer)) => { + // Send the data back to the real client using MPSC channel + self.data_to_real_client_tx + .send((*virtual_port, data.to_vec())) + .await + .unwrap_or_else(|e| { + error!( + "[{}] Failed to dispatch data from virtual client to real client: {:?}", + virtual_port, e + ); + }); + } + Err(smoltcp::Error::Exhausted) => {} + Err(e) => { + error!( + "[{}] Failed to read from virtual client socket: {:?}", + virtual_port, e + ); + } + } + } if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { - // Register the socket in WireGuard Tunnel if not already - if !wg.is_registered(client_port) { - wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to register UDP socket in WireGuard tunnel", - client_port - ); - }); - } - - // TODO: Find the matching client socket and send - // Echo for now - self.data_to_real_client_tx - .send((client_port, data)) - .await + // Register the socket in WireGuard Tunnel (overrides any previous registration as well) + wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) .unwrap_or_else(|e| { error!( - "[{}] Failed to dispatch data from virtual client to real client: {:?}", + "[{}] Failed to register UDP socket in WireGuard tunnel: {:?}", + client_port, e + ); + }); + + let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { + let rx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let tx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + + socket + .bind((IpAddress::from(wg.source_peer_ip), client_port.0)) + .unwrap_or_else(|e| { + error!( + "[{}] UDP virtual client socket failed to bind: {:?}", + client_port, e + ); + }); + + socket_set.add(socket) + }); + + let mut client_socket = socket_set.get::(*client_socket_handle); + client_socket + .send_slice( + &data, + (IpAddress::from(destination.ip()), destination.port()).into(), + ) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to send data to virtual server: {:?}", client_port, e ); }); @@ -121,7 +179,5 @@ impl VirtualInterfacePoll for UdpVirtualInterface { tokio::time::sleep(Duration::from_millis(1)).await; } - - // Ok(()) } } diff --git a/src/wg.rs b/src/wg.rs index 80c7272..2dc20d3 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -4,7 +4,7 @@ use std::time::Duration; use anyhow::Context; use boringtun::noise::{Tunn, TunnResult}; use log::Level; -use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket}; +use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket}; use tokio::net::UdpSocket; use tokio::sync::RwLock; @@ -90,17 +90,8 @@ impl WireGuardTunnel { virtual_port: VirtualPort, sender: tokio::sync::mpsc::Sender>, ) -> anyhow::Result<()> { - let existing = self.is_registered(virtual_port); - if existing { - Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) - } else { - self.virtual_port_ip_tx.insert(virtual_port, sender); - Ok(()) - } - } - - pub fn is_registered(&self, virtual_port: VirtualPort) -> bool { - self.virtual_port_ip_tx.contains_key(&virtual_port) + self.virtual_port_ip_tx.insert(virtual_port, sender); + Ok(()) } /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. @@ -276,6 +267,7 @@ impl WireGuardTunnel { .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.protocol() { IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), + IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())), // Unrecognized protocol, so we cannot determine where to route _ => Some(RouteResult::Drop), }) @@ -287,6 +279,7 @@ impl WireGuardTunnel { .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.next_header() { IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), + IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())), // Unrecognized protocol, so we cannot determine where to route _ => Some(RouteResult::Drop), }) @@ -316,6 +309,24 @@ impl WireGuardTunnel { .unwrap_or(RouteResult::Drop) } + /// Makes a decision on the handling of an incoming UDP datagram. + fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult { + UdpPacket::new_checked(datagram) + .ok() + .map(|udp| { + if self + .virtual_port_ip_tx + .get(&VirtualPort(udp.dst_port(), PortProtocol::Udp)) + .is_some() + { + RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp)) + } else { + RouteResult::Drop + } + }) + .unwrap_or(RouteResult::Drop) + } + /// Route a packet to the IP sink interface. async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> { let ip_sink_tx = self.sink_ip_tx.read().await;