From 282d4f48eb0341d24fb2f5090d8771a0cfec3014 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 20 Oct 2021 19:04:56 -0400 Subject: [PATCH] Checkpoint --- src/virtual_device.rs | 15 ++++++++-- src/virtual_iface/tcp.rs | 2 +- src/virtual_iface/udp.rs | 65 ++++++++++++++++++++++++++++++++++++++-- src/wg.rs | 6 +++- 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 8cd0cdf..02d60f9 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -5,6 +5,9 @@ use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; use std::sync::Arc; +/// The max transmission unit for WireGuard. +const WG_MTU: usize = 1420; + /// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint /// are made available to this device using a channel receiver. IP packets sent from this device /// are asynchronously sent out to the WireGuard tunnel. @@ -16,8 +19,16 @@ pub struct VirtualIpDevice { } impl VirtualIpDevice { + /// Initializes a new virtual IP device. + pub fn new( + wg: Arc, + ip_dispatch_rx: tokio::sync::mpsc::Receiver>, + ) -> Self { + Self { wg, ip_dispatch_rx } + } + /// Registers a virtual IP device for a single virtual client. - pub fn new(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { + pub fn new_direct(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); wg.register_virtual_interface(virtual_port, ip_dispatch_tx) @@ -60,7 +71,7 @@ impl<'a> Device<'a> for VirtualIpDevice { fn capabilities(&self) -> DeviceCapabilities { let mut cap = DeviceCapabilities::default(); cap.medium = Medium::Ip; - cap.max_transmission_unit = 1420; + cap.max_transmission_unit = WG_MTU; cap } } diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index baa92a7..a632792 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -75,7 +75,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Consumer for IP packets to send through the virtual interface // Initialize the interface let device = - VirtualIpDevice::new(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) + VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; let mut virtual_interface = InterfaceBuilder::new(device) .ip_addrs([ diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 88f13b6..139610d 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,11 +1,18 @@ +use anyhow::Context; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; +use dashmap::DashMap; +use smoltcp::iface::InterfaceBuilder; +use smoltcp::socket::{SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; +use smoltcp::wire::{IpAddress, IpCidr}; use crate::config::PortForwardConfig; +use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; -use crate::wg::WireGuardTunnel; +use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; pub struct UdpVirtualInterface { port_forward: PortForwardConfig, @@ -37,16 +44,68 @@ impl VirtualInterfacePoll for UdpVirtualInterface { let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; // The IP to bind client sockets to - let _source_peer_ip = self.wg.source_peer_ip; + let source_peer_ip = self.wg.source_peer_ip; // The IP/port to bind the server socket to - let _destination = self.port_forward.destination; + let destination = self.port_forward.destination; + + // Initialize a channel for IP packets. + // The "base transmitted" is cloned so that each virtual port can register a sender in the tunnel. + // The receiver is given to the device so that the Virtual Interface can process incoming IP packets from the tunnel. + let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); + + let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx); + let mut virtual_interface = InterfaceBuilder::new(device) + .ip_addrs([ + // Interface handles IP packets for the sender and recipient + IpCidr::new(source_peer_ip.into(), 32), + IpCidr::new(destination.ip().into(), 32), + ]) + .finalize(); + + // Server socket: this is a placeholder for the interface. + let server_socket: anyhow::Result = { + static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_RX_DATA: [u8; 0] = []; + static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_TX_DATA: [u8; 0] = []; + let udp_rx_buffer = + UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { + &mut UDP_SERVER_RX_DATA[..] + }); + let udp_tx_buffer = + UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { + &mut UDP_SERVER_TX_DATA[..] + }); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + + socket + .bind((IpAddress::from(destination.ip()), destination.port())) + .with_context(|| "UDP virtual server socket failed to listen")?; + + Ok(socket) + }; + + let mut socket_set = SocketSet::new(vec![]); + let _server_handle = socket_set.add(server_socket?); loop { let _loop_start = smoltcp::time::Instant::now(); + let wg = self.wg.clone(); // TODO: smoltcp UDP if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { + // Register the socket in WireGuard Tunnel if not already + if !wg.is_registered(client_port) { + wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to register UDP socket in WireGuard tunnel", + client_port + ); + }); + } + // TODO: Find the matching client socket and send // Echo for now self.data_to_real_client_tx diff --git a/src/wg.rs b/src/wg.rs index 673da1c..80c7272 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -90,7 +90,7 @@ impl WireGuardTunnel { virtual_port: VirtualPort, sender: tokio::sync::mpsc::Sender>, ) -> anyhow::Result<()> { - let existing = self.virtual_port_ip_tx.contains_key(&virtual_port); + let existing = self.is_registered(virtual_port); if existing { Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) } else { @@ -99,6 +99,10 @@ impl WireGuardTunnel { } } + pub fn is_registered(&self, virtual_port: VirtualPort) -> bool { + self.virtual_port_ip_tx.contains_key(&virtual_port) + } + /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. pub async fn register_sink_interface( &self,