From cfdbdc8f51327e7721e6ceac9630032cd900fecb Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Sat, 16 Oct 2021 19:16:10 -0400 Subject: [PATCH] Remove broadcasting logic to fix simultaneous connection issues. --- README.md | 2 +- src/main.rs | 23 +++--- src/port_pool.rs | 2 +- src/virtual_device.rs | 20 +++--- src/wg.rs | 162 ++++++++++++++++-------------------------- 5 files changed, 86 insertions(+), 123 deletions(-) diff --git a/README.md b/README.md index 20f32e7..8b7681e 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ forward it to the server's port, which handles the TCP segment. The server respo the peer's local WireGuard interface, gets encrypted, forwarded to the WireGuard endpoint, and then finally back to onetun's UDP port. When onetun receives an encrypted packet from the WireGuard endpoint, it decrypts it using boringtun. -The resulting IP packet is broadcasted to all virtual interfaces running inside onetun; once the corresponding +The resulting IP packet is dispatched to the corresponding virtual interface running inside onetun; once the corresponding interface is matched, the IP packet is read and unpacked, and the virtual client's TCP state is updated. Whenever data is sent by the real client, it is simply "sent" by the virtual client, which kicks off the whole IP encapsulation diff --git a/src/main.rs b/src/main.rs index 6158c6f..7ae3e0f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> { init_logger(&config)?; let port_pool = Arc::new(PortPool::new()); - let wg = WireGuardTunnel::new(&config, port_pool.clone()) + let wg = WireGuardTunnel::new(&config) .await .with_context(|| "Failed to initialize WireGuard tunnel")?; let wg = Arc::new(wg); @@ -47,12 +47,6 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { wg.consume_task().await }); } - { - // Start IP broadcast drain task for WireGuard - let wg = wg.clone(); - tokio::spawn(async move { wg.broadcast_drain_task().await }); - } - info!( "Tunnelling [{}]->[{}] (via [{}] as peer {})", &config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip @@ -106,9 +100,14 @@ async fn tcp_proxy_server( tokio::spawn(async move { let port_pool = Arc::clone(&port_pool); - let result = - handle_tcp_proxy_connection(socket, virtual_port, source_peer_ip, dest_addr, wg) - .await; + let result = handle_tcp_proxy_connection( + socket, + virtual_port, + source_peer_ip, + dest_addr, + wg.clone(), + ) + .await; if let Err(e) = result { error!( @@ -120,6 +119,7 @@ async fn tcp_proxy_server( } // Release port when connection drops + wg.release_virtual_interface(virtual_port); port_pool.release(virtual_port); }); } @@ -270,7 +270,8 @@ async fn virtual_tcp_interface( // Consumer for IP packets to send through the virtual interface // Initialize the interface - let device = VirtualIpDevice::new(wg); + let device = VirtualIpDevice::new(virtual_port, wg) + .with_context(|| "Failed to initialize VirtualIpDevice")?; let mut virtual_interface = InterfaceBuilder::new(device) .ip_addrs([ // Interface handles IP packets for the sender and recipient diff --git a/src/port_pool.rs b/src/port_pool.rs index a932b7c..7bff712 100644 --- a/src/port_pool.rs +++ b/src/port_pool.rs @@ -13,7 +13,7 @@ const PORT_RANGE: Range = MIN_PORT..MAX_PORT; pub struct PortPool { /// Remaining ports inner: lockfree::queue::Queue, - /// Ports in use + /// Ports in use, with their associated IP channel sender. taken: lockfree::set::Set, } diff --git a/src/virtual_device.rs b/src/virtual_device.rs index ad60b63..3ca2416 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,26 +1,26 @@ use crate::wg::WireGuardTunnel; +use anyhow::Context; use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; use std::sync::Arc; /// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint -/// are made available to this device using a broadcast channel receiver. IP packets sent from this device +/// are made available to this device using a channel receiver. IP packets sent from this device /// are asynchronously sent out to the WireGuard tunnel. pub struct VirtualIpDevice { /// Tunnel to send IP packets to. wg: Arc, - /// Broadcast channel receiver for received IP packets. - ip_broadcast_rx: tokio::sync::broadcast::Receiver>, + /// Channel receiver for received IP packets. + ip_dispatch_rx: tokio::sync::mpsc::Receiver>, } impl VirtualIpDevice { - pub fn new(wg: Arc) -> Self { - let ip_broadcast_rx = wg.subscribe(); + pub fn new(virtual_port: u16, wg: Arc) -> anyhow::Result { + let ip_dispatch_rx = wg + .register_virtual_interface(virtual_port) + .with_context(|| "Failed to register IP dispatch for virtual interface")?; - Self { - wg, - ip_broadcast_rx, - } + Ok(Self { wg, ip_dispatch_rx }) } } @@ -29,7 +29,7 @@ impl<'a> Device<'a> for VirtualIpDevice { type TxToken = TxToken; fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - match self.ip_broadcast_rx.try_recv() { + match self.ip_dispatch_rx.try_recv() { Ok(buffer) => Some(( Self::RxToken { buffer }, Self::TxToken { diff --git a/src/wg.rs b/src/wg.rs index 451afca..f748794 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -1,10 +1,8 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::sync::Arc; use std::time::Duration; use anyhow::Context; use boringtun::noise::{Tunn, TunnResult}; -use futures::lock::Mutex; use log::Level; use smoltcp::phy::ChecksumCapabilities; use smoltcp::wire::{ @@ -12,14 +10,12 @@ use smoltcp::wire::{ TcpPacket, TcpRepr, TcpSeqNumber, }; use tokio::net::UdpSocket; -use tokio::sync::broadcast::error::RecvError; use crate::config::Config; -use crate::port_pool::PortPool; use crate::MAX_PACKET; -/// The capacity of the broadcast channel for received IP packets. -const BROADCAST_CAPACITY: usize = 1_000; +/// The capacity of the channel for received IP packets. +const DISPATCH_CAPACITY: usize = 1_000; /// A WireGuard tunnel. Encapsulates and decapsulates IP packets /// to be sent to and received from a remote UDP endpoint. @@ -32,34 +28,27 @@ pub struct WireGuardTunnel { udp: UdpSocket, /// The address of the public WireGuard endpoint (UDP). endpoint: SocketAddr, - /// Broadcast sender for received IP packets. - ip_broadcast_tx: tokio::sync::broadcast::Sender>, - /// Sink so that the broadcaster doesn't close. A repeating task should drain this as much as possible. - ip_broadcast_rx_sink: Mutex>>, - /// Port pool. - port_pool: Arc, + /// Maps virtual ports to the corresponding IP packet dispatcher. + virtual_port_ip_tx: lockfree::map::Map>>, } impl WireGuardTunnel { /// Initialize a new WireGuard tunnel. - pub async fn new(config: &Config, port_pool: Arc) -> anyhow::Result { + pub async fn new(config: &Config) -> anyhow::Result { let source_peer_ip = config.source_peer_ip; let peer = Self::create_tunnel(config)?; let udp = UdpSocket::bind("0.0.0.0:0") .await .with_context(|| "Failed to create UDP socket for WireGuard connection")?; let endpoint = config.endpoint_addr; - let (ip_broadcast_tx, ip_broadcast_rx_sink) = - tokio::sync::broadcast::channel(BROADCAST_CAPACITY); + let virtual_port_ip_tx = lockfree::map::Map::new(); Ok(Self { source_peer_ip, peer, udp, endpoint, - ip_broadcast_tx, - ip_broadcast_rx_sink: Mutex::new(ip_broadcast_rx_sink), - port_pool, + virtual_port_ip_tx, }) } @@ -94,9 +83,24 @@ impl WireGuardTunnel { Ok(()) } - /// Create a new receiver for broadcasted IP packets, received from the WireGuard endpoint. - pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver> { - self.ip_broadcast_tx.subscribe() + /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. + pub fn register_virtual_interface( + &self, + virtual_port: u16, + ) -> anyhow::Result>> { + let existing = self.virtual_port_ip_tx.get(&virtual_port); + if existing.is_some() { + Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) + } else { + let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); + self.virtual_port_ip_tx.insert(virtual_port, sender); + Ok(receiver) + } + } + + /// Releases the virtual interface from IP dispatch. + pub fn release_virtual_interface(&self, virtual_port: u16) { + self.virtual_port_ip_tx.remove(&virtual_port); } /// WireGuard Routine task. Handles Handshake, keep-alive, etc. @@ -139,7 +143,7 @@ impl WireGuardTunnel { } /// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint, - /// decapsulates them, and broadcasts newly received IP packets. + /// decapsulates them, and dispatches newly received IP packets. pub async fn consume_task(&self) -> ! { trace!("Starting WireGuard consumption task"); @@ -195,30 +199,33 @@ impl WireGuardTunnel { trace_ip_packet("Received IP packet", packet); match self.route_ip_packet(packet) { - RouteResult::Broadcast => { - // Broadcast IP packet - if self.ip_broadcast_tx.receiver_count() > 1 { - match self.ip_broadcast_tx.send(packet.to_vec()) { - Ok(n) => { + RouteResult::Dispatch(port) => { + let sender = self.virtual_port_ip_tx.get(&port); + if let Some(sender_guard) = sender { + let sender = sender_guard.val(); + match sender.send(packet.to_vec()).await { + Ok(_) => { trace!( - "Broadcasted received IP packet to {} virtual interfaces", - n - 1 + "Dispatched received IP packet to virtual port {}", + port ); } Err(e) => { error!( - "Failed to broadcast received IP packet to recipients: {}", - e + "Failed to dispatch received IP packet to virtual port {}: {}", + port, e ); } } + } else { + warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port); } } - RouteResult::TcpReset(packet) => { + RouteResult::TcpReset => { trace!("Resetting dead TCP connection after packet from WireGuard endpoint"); - self.send_ip_packet(&packet) - .await - .unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e)); + self.route_tcp_sink(packet).await.unwrap_or_else(|e| { + error!("Failed to send TCP reset to sink: {:?}", e) + }); } RouteResult::Drop => { trace!("Dropped incoming IP packet from WireGuard endpoint"); @@ -230,33 +237,6 @@ impl WireGuardTunnel { } } - /// A repeating task that drains the default IP broadcast channel receiver. - /// It is necessary to keep this receiver alive to prevent the overall channel from closing, - /// so draining its backlog regularly is required to avoid memory leaks. - pub async fn broadcast_drain_task(&self) { - trace!("Starting IP broadcast sink drain task"); - - loop { - let mut sink = self.ip_broadcast_rx_sink.lock().await; - match sink.recv().await { - Ok(_) => { - trace!("Drained a packet from IP broadcast sink"); - } - Err(e) => match e { - RecvError::Closed => { - trace!("IP broadcast sink finished draining: channel closed"); - break; - } - RecvError::Lagged(_) => { - warn!("IP broadcast sink is falling behind"); - } - }, - } - } - - trace!("Stopped IP broadcast sink drain"); - } - fn create_tunnel(config: &Config) -> anyhow::Result> { Tunn::new( config.private_key.clone(), @@ -278,14 +258,9 @@ impl WireGuardTunnel { // Only care if the packet is destined for this tunnel .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.protocol() { - IpProtocol::Tcp => Some(self.route_tcp_segment( - IpVersion::Ipv4, - packet.src_addr().into(), - packet.dst_addr().into(), - packet.payload(), - )), - // Unrecognized protocol, so we'll allow it. - _ => Some(RouteResult::Broadcast), + IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), + // Unrecognized protocol, so we cannot determine where to route + _ => Some(RouteResult::Drop), }) .flatten() .unwrap_or(RouteResult::Drop), @@ -294,14 +269,9 @@ impl WireGuardTunnel { // Only care if the packet is destined for this tunnel .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.next_header() { - IpProtocol::Tcp => Some(self.route_tcp_segment( - IpVersion::Ipv6, - packet.src_addr().into(), - packet.dst_addr().into(), - packet.payload(), - )), - // Unrecognized protocol, so we'll allow it. - _ => Some(RouteResult::Broadcast), + IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), + // Unrecognized protocol, so we cannot determine where to route + _ => Some(RouteResult::Drop), }) .flatten() .unwrap_or(RouteResult::Drop), @@ -310,40 +280,32 @@ impl WireGuardTunnel { } /// Makes a decision on the handling of an incoming TCP segment. - fn route_tcp_segment( - &self, - ip_version: IpVersion, - src_addr: IpAddress, - dst_addr: IpAddress, - segment: &[u8], - ) -> RouteResult { + fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult { TcpPacket::new_checked(segment) .ok() .map(|tcp| { - if self.port_pool.is_in_use(tcp.dst_port()) { - RouteResult::Broadcast + if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() { + RouteResult::Dispatch(tcp.dst_port()) } else if tcp.rst() { RouteResult::Drop } else { - // Port is not in use, but it's a TCP packet so we'll craft a RST. - RouteResult::TcpReset(craft_tcp_rst_reply( - ip_version, - src_addr, - tcp.src_port(), - dst_addr, - tcp.dst_port(), - tcp.ack_number(), - )) + RouteResult::TcpReset } }) .unwrap_or(RouteResult::Drop) } + + /// Route a packet to the TCP sink interface. + async fn route_tcp_sink(&self, _packet: &[u8]) -> anyhow::Result<()> { + // TODO + Ok(()) + } } /// Craft an IP packet containing a TCP RST segment, given an IP version, /// source address (the one to reply to), destination address (the one the reply comes from), /// and the ACK number received in the initiating TCP segment. -fn craft_tcp_rst_reply( +fn _craft_tcp_rst_reply( ip_version: IpVersion, source_addr: IpAddress, source_port: u16, @@ -450,10 +412,10 @@ fn trace_ip_packet(message: &str, packet: &[u8]) { } enum RouteResult { - /// The packet can be broadcasted to the virtual interfaces - Broadcast, + /// Dispatch the packet to the virtual port. + Dispatch(u16), /// The packet is not routable so it may be reset. - TcpReset(Vec), + TcpReset, /// The packet can be safely ignored. Drop, }