use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::Duration; use anyhow::Context; use boringtun::noise::{Tunn, TunnResult}; use log::Level; use smoltcp::phy::ChecksumCapabilities; use smoltcp::wire::{ IpAddress, IpProtocol, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, TcpControl, TcpPacket, TcpRepr, TcpSeqNumber, }; use tokio::net::UdpSocket; use tokio::sync::RwLock; use crate::config::Config; use crate::MAX_PACKET; /// The capacity of the channel for received IP packets. const DISPATCH_CAPACITY: usize = 1_000; /// A WireGuard tunnel. Encapsulates and decapsulates IP packets /// to be sent to and received from a remote UDP endpoint. /// This tunnel supports at most 1 peer IP at a time, but supports simultaneous ports. pub struct WireGuardTunnel { source_peer_ip: IpAddr, /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. peer: Box, /// The UDP socket for the public WireGuard endpoint to connect to. udp: UdpSocket, /// The address of the public WireGuard endpoint (UDP). endpoint: SocketAddr, /// Maps virtual ports to the corresponding IP packet dispatcher. virtual_port_ip_tx: lockfree::map::Map>>, /// IP packet dispatcher for unroutable packets. `None` if not initialized. sink_ip_tx: RwLock>>>, } impl WireGuardTunnel { /// Initialize a new WireGuard tunnel. pub async fn new(config: &Config) -> anyhow::Result { let source_peer_ip = config.source_peer_ip; let peer = Self::create_tunnel(config)?; let udp = UdpSocket::bind("0.0.0.0:0") .await .with_context(|| "Failed to create UDP socket for WireGuard connection")?; let endpoint = config.endpoint_addr; let virtual_port_ip_tx = lockfree::map::Map::new(); Ok(Self { source_peer_ip, peer, udp, endpoint, virtual_port_ip_tx, sink_ip_tx: RwLock::new(None), }) } /// Encapsulates and sends an IP packet through to the WireGuard endpoint. pub async fn send_ip_packet(&self, packet: &[u8]) -> anyhow::Result<()> { trace_ip_packet("Sending IP packet", packet); let mut send_buf = [0u8; MAX_PACKET]; match self.peer.encapsulate(packet, &mut send_buf) { TunnResult::WriteToNetwork(packet) => { self.udp .send_to(packet, self.endpoint) .await .with_context(|| "Failed to send encrypted IP packet to WireGuard endpoint.")?; debug!( "Sent {} bytes to WireGuard endpoint (encrypted IP packet)", packet.len() ); } TunnResult::Err(e) => { error!("Failed to encapsulate IP packet: {:?}", e); } TunnResult::Done => { // Ignored } other => { error!( "Unexpected WireGuard state during encapsulation: {:?}", other ); } }; Ok(()) } /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. pub fn register_virtual_interface( &self, virtual_port: u16, ) -> anyhow::Result>> { let existing = self.virtual_port_ip_tx.get(&virtual_port); if existing.is_some() { Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) } else { let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); self.virtual_port_ip_tx.insert(virtual_port, sender); Ok(receiver) } } /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. pub async fn register_sink_interface( &self, ) -> anyhow::Result>> { let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); let mut sink_ip_tx = self.sink_ip_tx.write().await; *sink_ip_tx = Some(sender); Ok(receiver) } /// Releases the virtual interface from IP dispatch. pub fn release_virtual_interface(&self, virtual_port: u16) { self.virtual_port_ip_tx.remove(&virtual_port); } /// WireGuard Routine task. Handles Handshake, keep-alive, etc. pub async fn routine_task(&self) -> ! { trace!("Starting WireGuard routine task"); loop { let mut send_buf = [0u8; MAX_PACKET]; match self.peer.update_timers(&mut send_buf) { TunnResult::WriteToNetwork(packet) => { debug!( "Sending routine packet of {} bytes to WireGuard endpoint", packet.len() ); match self.udp.send_to(packet, self.endpoint).await { Ok(_) => {} Err(e) => { error!( "Failed to send routine packet to WireGuard endpoint: {:?}", e ); } }; } TunnResult::Err(e) => { error!( "Failed to prepare routine packet for WireGuard endpoint: {:?}", e ); } TunnResult::Done => { // Sleep for a bit tokio::time::sleep(Duration::from_millis(1)).await; } other => { warn!("Unexpected WireGuard routine task state: {:?}", other); } } } } /// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint, /// decapsulates them, and dispatches newly received IP packets. pub async fn consume_task(&self) -> ! { trace!("Starting WireGuard consumption task"); loop { let mut recv_buf = [0u8; MAX_PACKET]; let mut send_buf = [0u8; MAX_PACKET]; let size = match self.udp.recv(&mut recv_buf).await { Ok(size) => size, Err(e) => { error!("Failed to read from WireGuard endpoint: {:?}", e); // Sleep a little bit and try again tokio::time::sleep(Duration::from_millis(1)).await; continue; } }; let data = &recv_buf[..size]; match self.peer.decapsulate(None, data, &mut send_buf) { TunnResult::WriteToNetwork(packet) => { match self.udp.send_to(packet, self.endpoint).await { Ok(_) => {} Err(e) => { error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); continue; } }; loop { let mut send_buf = [0u8; MAX_PACKET]; match self.peer.decapsulate(None, &[], &mut send_buf) { TunnResult::WriteToNetwork(packet) => { match self.udp.send_to(packet, self.endpoint).await { Ok(_) => {} Err(e) => { error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); break; } }; } _ => { break; } } } } TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => { debug!( "WireGuard endpoint sent IP packet of {} bytes", packet.len() ); // For debugging purposes: parse packet trace_ip_packet("Received IP packet", packet); match self.route_ip_packet(packet) { RouteResult::Dispatch(port) => { let sender = self.virtual_port_ip_tx.get(&port); if let Some(sender_guard) = sender { let sender = sender_guard.val(); match sender.send(packet.to_vec()).await { Ok(_) => { trace!( "Dispatched received IP packet to virtual port {}", port ); } Err(e) => { error!( "Failed to dispatch received IP packet to virtual port {}: {}", port, e ); } } } else { warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port); } } RouteResult::TcpReset => { trace!("Resetting dead TCP connection after packet from WireGuard endpoint"); self.route_ip_sink(packet).await.unwrap_or_else(|e| { error!("Failed to send TCP reset to sink: {:?}", e) }); } RouteResult::Drop => { trace!("Dropped incoming IP packet from WireGuard endpoint"); } } } _ => {} } } } fn create_tunnel(config: &Config) -> anyhow::Result> { Tunn::new( config.private_key.clone(), config.endpoint_public_key.clone(), None, config.keepalive_seconds, 0, None, ) .map_err(|s| anyhow::anyhow!("{}", s)) .with_context(|| "Failed to initialize boringtun Tunn") } /// Makes a decision on the handling of an incoming IP packet. fn route_ip_packet(&self, packet: &[u8]) -> RouteResult { match IpVersion::of_packet(packet) { Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet) .ok() // Only care if the packet is destined for this tunnel .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.protocol() { IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), // Unrecognized protocol, so we cannot determine where to route _ => Some(RouteResult::Drop), }) .flatten() .unwrap_or(RouteResult::Drop), Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet) .ok() // Only care if the packet is destined for this tunnel .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.next_header() { IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), // Unrecognized protocol, so we cannot determine where to route _ => Some(RouteResult::Drop), }) .flatten() .unwrap_or(RouteResult::Drop), _ => RouteResult::Drop, } } /// Makes a decision on the handling of an incoming TCP segment. fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult { TcpPacket::new_checked(segment) .ok() .map(|tcp| { if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() { RouteResult::Dispatch(tcp.dst_port()) } else if tcp.rst() { RouteResult::Drop } else { RouteResult::TcpReset } }) .unwrap_or(RouteResult::Drop) } /// Route a packet to the IP sink interface. async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> { let ip_sink_tx = self.sink_ip_tx.read().await; if let Some(ip_sink_tx) = &*ip_sink_tx { ip_sink_tx .send(packet.to_vec()) .await .with_context(|| "Failed to dispatch IP packet to sink interface") } else { warn!( "Could not dispatch unroutable IP packet to sink because interface is not active." ); Ok(()) } } } /// Craft an IP packet containing a TCP RST segment, given an IP version, /// source address (the one to reply to), destination address (the one the reply comes from), /// and the ACK number received in the initiating TCP segment. fn _craft_tcp_rst_reply( ip_version: IpVersion, source_addr: IpAddress, source_port: u16, dest_addr: IpAddress, dest_port: u16, ack_number: TcpSeqNumber, ) -> Vec { let tcp_repr = TcpRepr { src_port: dest_port, dst_port: source_port, control: TcpControl::Rst, seq_number: ack_number, ack_number: None, window_len: 0, window_scale: None, max_seg_size: None, sack_permitted: false, sack_ranges: [None, None, None], payload: &[], }; let mut tcp_buffer = vec![0u8; 20]; let mut tcp_packet = &mut TcpPacket::new_unchecked(&mut tcp_buffer); tcp_repr.emit( &mut tcp_packet, &dest_addr, &source_addr, &ChecksumCapabilities::default(), ); let mut ip_buffer = vec![0u8; MAX_PACKET]; let (header_len, total_len) = match ip_version { IpVersion::Ipv4 => { let dest_addr = match dest_addr { IpAddress::Ipv4(dest_addr) => dest_addr, _ => panic!(), }; let source_addr = match source_addr { IpAddress::Ipv4(source_addr) => source_addr, _ => panic!(), }; let mut ip_packet = &mut Ipv4Packet::new_unchecked(&mut ip_buffer); let ip_repr = Ipv4Repr { src_addr: dest_addr, dst_addr: source_addr, protocol: IpProtocol::Tcp, payload_len: tcp_buffer.len(), hop_limit: 64, }; ip_repr.emit(&mut ip_packet, &ChecksumCapabilities::default()); ( ip_packet.header_len() as usize, ip_packet.total_len() as usize, ) } IpVersion::Ipv6 => { let dest_addr = match dest_addr { IpAddress::Ipv6(dest_addr) => dest_addr, _ => panic!(), }; let source_addr = match source_addr { IpAddress::Ipv6(source_addr) => source_addr, _ => panic!(), }; let mut ip_packet = &mut Ipv6Packet::new_unchecked(&mut ip_buffer); let ip_repr = Ipv6Repr { src_addr: dest_addr, dst_addr: source_addr, next_header: IpProtocol::Tcp, payload_len: tcp_buffer.len(), hop_limit: 64, }; ip_repr.emit(&mut ip_packet); (ip_packet.header_len(), ip_packet.total_len()) } _ => panic!(), }; ip_buffer[header_len..total_len].copy_from_slice(&tcp_buffer); let packet: &[u8] = &ip_buffer[..total_len]; packet.to_vec() } fn trace_ip_packet(message: &str, packet: &[u8]) { if log_enabled!(Level::Trace) { use smoltcp::wire::*; match IpVersion::of_packet(packet) { Ok(IpVersion::Ipv4) => trace!( "{}: {}", message, PrettyPrinter::>::new("", &packet) ), Ok(IpVersion::Ipv6) => trace!( "{}: {}", message, PrettyPrinter::>::new("", &packet) ), _ => {} } } } enum RouteResult { /// Dispatch the packet to the virtual port. Dispatch(u16), /// The packet is not routable so it may be reset. TcpReset, /// The packet can be safely ignored. Drop, }