From abd9df6be42f6d1fbc219250921dcb6c2e87a517 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Sat, 8 Jan 2022 02:25:56 -0500 Subject: [PATCH] Implement event-based UDP interface --- src/events.rs | 2 +- src/main.rs | 8 +- src/tunnel/tcp.rs | 2 +- src/tunnel/udp.rs | 8 +- src/virtual_iface/mod.rs | 3 +- src/virtual_iface/tcp.rs | 47 +++++----- src/virtual_iface/udp.rs | 198 ++++++++++++++++++++++++++++++++++++--- 7 files changed, 218 insertions(+), 50 deletions(-) diff --git a/src/events.rs b/src/events.rs index 1437bc2..33f8ab9 100644 --- a/src/events.rs +++ b/src/events.rs @@ -15,7 +15,7 @@ pub enum Event { /// A connection was dropped from the pool and should be closed in all interfaces. ClientConnectionDropped(VirtualPort), /// Data received by the local server that should be sent to the virtual server. - LocalData(VirtualPort, Vec), + LocalData(PortForwardConfig, VirtualPort, Vec), /// Data received by the remote server that should be sent to the local client. RemoteData(VirtualPort, Vec), /// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device. diff --git a/src/main.rs b/src/main.rs index 1a01b51..2377245 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,8 +72,8 @@ async fn main() -> anyhow::Result<()> { // Start TCP Virtual Interface let port_forwards = config.port_forwards.clone(); - let iface = TcpVirtualInterface::new(port_forwards, bus, device, config.source_peer_ip); - tokio::spawn(async move { iface.poll_loop().await }); + let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip); + tokio::spawn(async move { iface.poll_loop(device).await }); } if config @@ -88,8 +88,8 @@ async fn main() -> anyhow::Result<()> { // Start UDP Virtual Interface let port_forwards = config.port_forwards.clone(); - let iface = UdpVirtualInterface::new(port_forwards, bus, device, config.source_peer_ip); - tokio::spawn(async move { iface.poll_loop().await }); + let iface = UdpVirtualInterface::new(port_forwards, bus, config.source_peer_ip); + tokio::spawn(async move { iface.poll_loop(device).await }); } { diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index 718cf57..4633e78 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -90,7 +90,7 @@ async fn handle_tcp_proxy_connection( match socket.try_read_buf(&mut buffer) { Ok(size) if size > 0 => { let data = Vec::from(&buffer[..size]); - endpoint.send(Event::LocalData(virtual_port, data)); + endpoint.send(Event::LocalData(port_forward, virtual_port, data)); // Reset buffer buffer.clear(); } diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 34fb26d..d9bd43d 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -48,7 +48,7 @@ pub async fn udp_proxy_server( to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => { match to_send_result { Ok(Some((port, data))) => { - endpoint.send(Event::LocalData(port, data)); + endpoint.send(Event::LocalData(port_forward, port, data)); } Ok(None) => { continue; @@ -64,12 +64,12 @@ pub async fn udp_proxy_server( } event = endpoint.recv() => { if let Event::RemoteData(port, data) = event { - if let Some(peer_addr) = port_pool.get_peer_addr(port).await { - if let Err(e) = socket.send_to(&data, peer_addr).await { + if let Some(peer) = port_pool.get_peer_addr(port).await { + if let Err(e) = socket.send_to(&data, peer).await { error!( "[{}] Failed to send UDP datagram to real client ({}): {:?}", port, - peer_addr, + peer, e, ); } diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs index ea3cd74..4fd6f47 100644 --- a/src/virtual_iface/mod.rs +++ b/src/virtual_iface/mod.rs @@ -2,6 +2,7 @@ pub mod tcp; pub mod udp; use crate::config::PortProtocol; +use crate::VirtualIpDevice; use async_trait::async_trait; use std::fmt::{Display, Formatter}; @@ -9,7 +10,7 @@ use std::fmt::{Display, Formatter}; pub trait VirtualInterfacePoll { /// Initializes the virtual interface and processes incoming data to be dispatched /// to the WireGuard tunnel and to the real client. - async fn poll_loop(mut self) -> anyhow::Result<()>; + async fn poll_loop(mut self, device: VirtualIpDevice) -> anyhow::Result<()>; } /// Virtual port. diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 6c283f1..190af1f 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -18,25 +18,18 @@ const MAX_PACKET: usize = 65536; pub struct TcpVirtualInterface { source_peer_ip: IpAddr, port_forwards: Vec, - device: VirtualIpDevice, bus: Bus, } impl TcpVirtualInterface { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. - pub fn new( - port_forwards: Vec, - bus: Bus, - device: VirtualIpDevice, - source_peer_ip: IpAddr, - ) -> Self { + pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { Self { port_forwards: port_forwards .into_iter() .filter(|f| matches!(f.protocol, PortProtocol::Tcp)) .collect(), - device, source_peer_ip, bus, } @@ -68,25 +61,28 @@ impl TcpVirtualInterface { let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); Ok(socket) } + + fn addresses(&self) -> Vec { + let mut addresses = HashSet::new(); + addresses.insert(IpAddress::from(self.source_peer_ip)); + for config in self.port_forwards.iter() { + addresses.insert(IpAddress::from(config.destination.ip())); + } + addresses + .into_iter() + .map(|addr| IpCidr::new(addr, 32)) + .collect() + } } #[async_trait] impl VirtualInterfacePoll for TcpVirtualInterface { - async fn poll_loop(self) -> anyhow::Result<()> { + async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { // Create CIDR block for source peer IP + each port forward IP - let addresses: Vec = { - let mut addresses = HashSet::new(); - addresses.insert(IpAddress::from(self.source_peer_ip)); - for config in self.port_forwards.iter() { - addresses.insert(IpAddress::from(config.destination.ip())); - } - addresses - .into_iter() - .map(|addr| IpCidr::new(addr, 32)) - .collect() - }; + let addresses = self.addresses(); - let mut iface = InterfaceBuilder::new(self.device, vec![]) + // Create virtual interface (contains smoltcp state machine) + let mut iface = InterfaceBuilder::new(device, vec![]) .ip_addrs(addresses) .finalize(); @@ -102,6 +98,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Bus endpoint to read events let mut endpoint = self.bus.new_endpoint(); + // Maps virtual port to its client socket handle let mut port_client_handle_map: HashMap = HashMap::new(); // Data packets to send from a virtual client @@ -146,10 +143,11 @@ impl VirtualInterfacePoll for TcpVirtualInterface { ); } } + break; } else if client_socket.state() == TcpState::CloseWait { client_socket.close(); + break; } - break; } } } @@ -163,8 +161,6 @@ impl VirtualInterfacePoll for TcpVirtualInterface { if !data.is_empty() { endpoint.send(Event::RemoteData(*virtual_port, data)); break; - } else { - continue; } } Err(e) => { @@ -182,6 +178,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { if client_socket.state() == TcpState::Closed { endpoint.send(Event::ClientConnectionDropped(*virtual_port)); send_queue.remove(virtual_port); + iface.remove_socket(*client_handle); false } else { // Not closed, retain @@ -235,7 +232,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { next_poll = None; } } - Event::LocalData(virtual_port, data) if send_queue.contains_key(&virtual_port) => { + Event::LocalData(_, virtual_port, data) if send_queue.contains_key(&virtual_port) => { if let Some(send_queue) = send_queue.get_mut(&virtual_port) { send_queue.push_back(data); next_poll = None; diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 022a2c2..f1725e7 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,49 +1,219 @@ #![allow(dead_code)] + +use anyhow::Context; +use std::collections::{HashMap, HashSet, VecDeque}; use std::net::IpAddr; +use crate::events::Event; use crate::{Bus, PortProtocol}; use async_trait::async_trait; +use smoltcp::iface::{InterfaceBuilder, SocketHandle}; +use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; +use smoltcp::wire::{IpAddress, IpCidr}; +use std::time::Duration; use crate::config::PortForwardConfig; use crate::virtual_device::VirtualIpDevice; -use crate::virtual_iface::VirtualInterfacePoll; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; const MAX_PACKET: usize = 65536; pub struct UdpVirtualInterface { source_peer_ip: IpAddr, port_forwards: Vec, - device: VirtualIpDevice, bus: Bus, } impl UdpVirtualInterface { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. - pub fn new( - port_forwards: Vec, - bus: Bus, - device: VirtualIpDevice, - source_peer_ip: IpAddr, - ) -> Self { + pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { Self { port_forwards: port_forwards .into_iter() .filter(|f| matches!(f.protocol, PortProtocol::Udp)) .collect(), - device, source_peer_ip, bus, } } + + fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { + static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_RX_DATA: [u8; 0] = []; + static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_TX_DATA: [u8; 0] = []; + let udp_rx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { + &mut UDP_SERVER_RX_DATA[..] + }); + let udp_tx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { + &mut UDP_SERVER_TX_DATA[..] + }); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + socket + .bind(( + IpAddress::from(port_forward.destination.ip()), + port_forward.destination.port(), + )) + .with_context(|| "UDP virtual server socket failed to bind")?; + Ok(socket) + } + + fn new_client_socket( + source_peer_ip: IpAddr, + client_port: VirtualPort, + ) -> anyhow::Result> { + let rx_meta = vec![UdpPacketMetadata::EMPTY; 10]; + let tx_meta = vec![UdpPacketMetadata::EMPTY; 10]; + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + socket + .bind((IpAddress::from(source_peer_ip), client_port.num())) + .with_context(|| "UDP virtual client failed to bind")?; + Ok(socket) + } + + fn addresses(&self) -> Vec { + let mut addresses = HashSet::new(); + addresses.insert(IpAddress::from(self.source_peer_ip)); + for config in self.port_forwards.iter() { + addresses.insert(IpAddress::from(config.destination.ip())); + } + addresses + .into_iter() + .map(|addr| IpCidr::new(addr, 32)) + .collect() + } } #[async_trait] impl VirtualInterfacePoll for UdpVirtualInterface { - async fn poll_loop(self) -> anyhow::Result<()> { - // TODO: Create smoltcp virtual device and interface - // TODO: Create smoltcp virtual servers for `port_forwards` - // TODO: listen on events - futures::future::pending().await + async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { + // Create CIDR block for source peer IP + each port forward IP + let addresses = self.addresses(); + + // Create virtual interface (contains smoltcp state machine) + let mut iface = InterfaceBuilder::new(device, vec![]) + .ip_addrs(addresses) + .finalize(); + + // Create virtual server for each port forward + for port_forward in self.port_forwards.iter() { + let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?; + iface.add_socket(server_socket); + } + + // The next time to poll the interface. Can be None for instant poll. + let mut next_poll: Option = None; + + // Bus endpoint to read events + let mut endpoint = self.bus.new_endpoint(); + + // Maps virtual port to its client socket handle + let mut port_client_handle_map: HashMap = HashMap::new(); + + // Data packets to send from a virtual client + let mut send_queue: HashMap)>> = + HashMap::new(); + + loop { + tokio::select! { + _ = match (next_poll, port_client_handle_map.len()) { + (None, 0) => tokio::time::sleep(Duration::MAX), + (None, _) => tokio::time::sleep(Duration::ZERO), + (Some(until), _) => tokio::time::sleep_until(until), + } => { + let loop_start = smoltcp::time::Instant::now(); + + match iface.poll(loop_start) { + Ok(processed) if processed => { + trace!("UDP virtual interface polled some packets to be processed"); + } + Err(e) => error!("UDP virtual interface poll error: {:?}", e), + _ => {} + } + + // Find client socket send data to + for (virtual_port, client_handle) in port_client_handle_map.iter() { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.can_send() { + if let Some(send_queue) = send_queue.get_mut(virtual_port) { + let to_transfer = send_queue.pop_front(); + if let Some((port_forward, data)) = to_transfer { + client_socket + .send_slice( + &data, + (IpAddress::from(port_forward.destination.ip()), port_forward.destination.port()).into(), + ) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to send data to virtual server: {:?}", + virtual_port, e + ); + }); + break; + } + } + } + } + + // Find client socket recv data from + for (virtual_port, client_handle) in port_client_handle_map.iter() { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.can_recv() { + match client_socket.recv() { + Ok((data, _peer)) => { + if !data.is_empty() { + endpoint.send(Event::RemoteData(*virtual_port, data.to_vec())); + break; + } + } + Err(e) => { + error!( + "Failed to read from virtual client socket: {:?}", e + ); + } + } + } + } + + // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) + next_poll = match iface.poll_delay(loop_start) { + Some(smoltcp::time::Duration::ZERO) => None, + Some(delay) => { + trace!("UDP Virtual interface delayed next poll by {}", delay); + Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis())) + }, + None => None, + }; + } + event = endpoint.recv() => { + match event { + Event::LocalData(port_forward, virtual_port, data) => { + if let Some(send_queue) = send_queue.get_mut(&virtual_port) { + // Client socket already exists + send_queue.push_back((port_forward, data)); + } else { + // Client socket does not exist + let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?; + let client_handle = iface.add_socket(client_socket); + + // Add handle to map + port_client_handle_map.insert(virtual_port, client_handle); + send_queue.insert(virtual_port, VecDeque::from(vec![(port_forward, data)])); + } + next_poll = None; + } + Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Udp => { + next_poll = None; + } + _ => {} + } + } + } + } } }