diff --git a/src/main.rs b/src/main.rs index 4432113..7b3a1e9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,7 +33,7 @@ async fn main() -> anyhow::Result<()> { let config = Config::from_args().with_context(|| "Failed to read config")?; let port_pool = Arc::new(PortPool::new()); - let wg = WireGuardTunnel::new(&config) + let wg = WireGuardTunnel::new(&config, port_pool.clone()) .await .with_context(|| "Failed to initialize WireGuard tunnel")?; let wg = Arc::new(wg); diff --git a/src/port_pool.rs b/src/port_pool.rs index 0538b9c..58c3727 100644 --- a/src/port_pool.rs +++ b/src/port_pool.rs @@ -7,23 +7,37 @@ const MAX_PORT: u16 = 60999; const PORT_RANGE: Range = MIN_PORT..MAX_PORT; pub struct PortPool { + /// Remaining ports inner: lockfree::queue::Queue, + /// Ports in use + taken: lockfree::set::Set, } impl PortPool { pub fn new() -> Self { let inner = lockfree::queue::Queue::default(); PORT_RANGE.for_each(|p| inner.push(p) as ()); - Self { inner } + Self { + inner, + taken: lockfree::set::Set::new(), + } } pub fn next(&self) -> anyhow::Result { - self.inner + let port = self + .inner .pop() - .with_context(|| "Virtual port pool is exhausted") + .with_context(|| "Virtual port pool is exhausted")?; + self.taken.insert(port); + Ok(port) } pub fn release(&self, port: u16) { self.inner.push(port); + self.taken.remove(&port); + } + + pub fn is_in_use(&self, port: u16) -> bool { + self.taken.contains(&port) } } diff --git a/src/wg.rs b/src/wg.rs index 56852ab..27b8a50 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -1,18 +1,26 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; use std::time::Duration; use anyhow::Context; use boringtun::noise::{Tunn, TunnResult}; use log::Level; +use smoltcp::phy::ChecksumCapabilities; +use smoltcp::wire::{ + IpAddress, IpProtocol, IpRepr, IpVersion, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address, + Ipv6Packet, Ipv6Repr, TcpControl, TcpPacket, TcpRepr, TcpSeqNumber, +}; use tokio::net::UdpSocket; use crate::config::Config; +use crate::port_pool::PortPool; use crate::MAX_PACKET; /// The capacity of the broadcast channel for received IP packets. const BROADCAST_CAPACITY: usize = 1_000_000; pub struct WireGuardTunnel { + source_peer_ip: IpAddr, /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. peer: Box, /// The UDP socket for the public WireGuard endpoint to connect to. @@ -21,12 +29,16 @@ pub struct WireGuardTunnel { endpoint: SocketAddr, /// Broadcast sender for received IP packets. ip_broadcast_tx: tokio::sync::broadcast::Sender>, + /// Placeholder so that the broadcaster doesn't close. ip_broadcast_rx: tokio::sync::broadcast::Receiver>, + /// Port pool. + port_pool: Arc, } impl WireGuardTunnel { /// Initialize a new WireGuard tunnel. - pub async fn new(config: &Config) -> anyhow::Result { + pub async fn new(config: &Config, port_pool: Arc) -> anyhow::Result { + let source_peer_ip = config.source_peer_ip; let peer = Self::create_tunnel(&config)?; let udp = UdpSocket::bind("0.0.0.0:0") .await @@ -36,11 +48,13 @@ impl WireGuardTunnel { tokio::sync::broadcast::channel(BROADCAST_CAPACITY); Ok(Self { + source_peer_ip, peer, udp, endpoint, ip_broadcast_tx, ip_broadcast_rx, + port_pool, }) } @@ -175,16 +189,34 @@ impl WireGuardTunnel { // For debugging purposes: parse packet trace_ip_packet("Received IP packet", packet); - // Broadcast IP packet - match self.ip_broadcast_tx.send(packet.to_vec()) { - Ok(n) => { - trace!("Broadcasted received IP packet to {} recipients", n); + match self.route(packet) { + RouteResult::Broadcast => { + // Broadcast IP packet + if self.ip_broadcast_tx.receiver_count() > 1 { + match self.ip_broadcast_tx.send(packet.to_vec()) { + Ok(n) => { + trace!( + "Broadcasted received IP packet to {} virtual interfaces", + n - 1 + ); + } + Err(e) => { + error!( + "Failed to broadcast received IP packet to recipients: {}", + e + ); + } + } + } } - Err(e) => { - error!( - "Failed to broadcast received IP packet to recipients: {}", - e - ); + RouteResult::TcpReset(packet) => { + trace!("Resetting dead TCP connection after packet from WireGuard endpoint"); + self.send_ip_packet(&packet) + .await + .unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e)); + } + RouteResult::Drop => { + trace!("Dropped incoming IP packet from WireGuard endpoint"); } } } @@ -205,6 +237,126 @@ impl WireGuardTunnel { .map_err(|s| anyhow::anyhow!("{}", s)) .with_context(|| "Failed to initialize boringtun Tunn") } + + fn route(&self, packet: &[u8]) -> RouteResult { + match IpVersion::of_packet(&packet) { + Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet) + .ok() + // Only care if the packet is destined for this tunnel + .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) + .map(|packet| match packet.protocol() { + IpProtocol::Tcp => TcpPacket::new_checked(packet.payload()).ok().map(|tcp| { + if self.port_pool.is_in_use(tcp.dst_port()) { + RouteResult::Broadcast + } else if tcp.rst() { + RouteResult::Drop + } else { + // Port is not in use, but it's a TCP packet so we'll craft a RST. + RouteResult::TcpReset(craft_tcp_rst_reply( + IpVersion::Ipv4, + packet.src_addr().into(), + tcp.src_port(), + packet.dst_addr().into(), + tcp.dst_port(), + tcp.ack_number(), + )) + } + }), + // Unrecognized protocol, so we'll allow it. + _ => Some(RouteResult::Broadcast), + }) + .flatten() + .unwrap_or(RouteResult::Drop), + // TODO: IPv6 + _ => RouteResult::Drop, + } + } +} + +fn craft_tcp_rst_reply( + ip_version: IpVersion, + source_addr: IpAddress, + source_port: u16, + dest_addr: IpAddress, + dest_port: u16, + ack_number: TcpSeqNumber, +) -> Vec { + let tcp_repr = TcpRepr { + src_port: dest_port, + dst_port: source_port, + control: TcpControl::Rst, + seq_number: ack_number, + ack_number: None, + window_len: 0, + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &[], + }; + + let mut tcp_buffer = vec![0u8; 20]; + let mut tcp_packet = &mut TcpPacket::new_unchecked(&mut tcp_buffer); + tcp_repr.emit( + &mut tcp_packet, + &dest_addr, + &source_addr, + &ChecksumCapabilities::default(), + ); + + let mut ip_buffer = vec![0u8; MAX_PACKET]; + + let (header_len, total_len) = match ip_version { + IpVersion::Ipv4 => { + let dest_addr = match dest_addr { + IpAddress::Ipv4(dest_addr) => dest_addr, + _ => panic!(), + }; + let source_addr = match source_addr { + IpAddress::Ipv4(source_addr) => source_addr, + _ => panic!(), + }; + + let mut ip_packet = &mut Ipv4Packet::new_unchecked(&mut ip_buffer); + let ip_repr = Ipv4Repr { + src_addr: dest_addr, + dst_addr: source_addr, + protocol: IpProtocol::Tcp, + payload_len: tcp_buffer.len(), + hop_limit: 64, + }; + ip_repr.emit(&mut ip_packet, &ChecksumCapabilities::default()); + ( + ip_packet.header_len() as usize, + ip_packet.total_len() as usize, + ) + } + IpVersion::Ipv6 => { + let dest_addr = match dest_addr { + IpAddress::Ipv6(dest_addr) => dest_addr, + _ => panic!(), + }; + let source_addr = match source_addr { + IpAddress::Ipv6(source_addr) => source_addr, + _ => panic!(), + }; + let mut ip_packet = &mut Ipv6Packet::new_unchecked(&mut ip_buffer); + let ip_repr = Ipv6Repr { + src_addr: dest_addr, + dst_addr: source_addr, + next_header: IpProtocol::Tcp, + payload_len: tcp_buffer.len(), + hop_limit: 64, + }; + ip_repr.emit(&mut ip_packet); + (ip_packet.header_len(), ip_packet.total_len()) + } + _ => panic!(), + }; + + ip_buffer[header_len..total_len].copy_from_slice(&tcp_buffer); + let packet: &[u8] = &ip_buffer[..total_len]; + packet.to_vec() } fn trace_ip_packet(message: &str, packet: &[u8]) { @@ -226,3 +378,12 @@ fn trace_ip_packet(message: &str, packet: &[u8]) { } } } + +enum RouteResult { + /// The packet can be broadcasted to the virtual interfaces + Broadcast, + /// The packet is not routable so it may be reset. + TcpReset(Vec), + /// The packet can be safely ignored. + Drop, +}