diff --git a/src/main.rs b/src/main.rs index 3c2fc8d..0975f45 100644 --- a/src/main.rs +++ b/src/main.rs @@ -132,10 +132,11 @@ async fn handle_tcp_proxy_connection( // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back // to the real client. - let (data_to_real_client_tx, mut data_to_real_client_rx) = - tokio::sync::mpsc::channel(1_000_000); + let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000); - let (data_to_real_server_tx, data_to_real_server_rx) = tokio::sync::mpsc::channel(1_000_000); + // data_to_real_server_(tx/rx): This task sends the data received from the real client to the + // virtual interface (virtual server socket). + let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000); // Spawn virtual interface { @@ -148,7 +149,7 @@ async fn handle_tcp_proxy_connection( wg, abort, data_to_real_client_tx, - data_to_real_server_rx, + data_to_virtual_server_rx, ) .await }); @@ -171,14 +172,11 @@ async fn handle_tcp_proxy_connection( "[{}] Read {} bytes of TCP data from real client", virtual_port, size ); - match data_to_real_server_tx.send(data.to_vec()).await { - Err(e) => { - error!( - "[{}] Failed to dispatch data to virtual interface: {:?}", - virtual_port, e - ); - } - _ => {} + if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await { + error!( + "[{}] Failed to dispatch data to virtual interface: {:?}", + virtual_port, e + ); } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { @@ -239,7 +237,7 @@ async fn virtual_tcp_interface( wg: Arc, abort: Arc, data_to_real_client_tx: tokio::sync::mpsc::Sender>, - mut data_to_real_server_rx: tokio::sync::mpsc::Receiver>, + mut data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, ) -> anyhow::Result<()> { // Create a device and interface to simulate IP packets // In essence: @@ -326,11 +324,8 @@ async fn virtual_tcp_interface( match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { Ok(data) => { // Send it to the real client - match data_to_real_client_tx.send(data).await { - Err(e) => { - error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e); - } - _ => {} + if let Err(e) = data_to_real_client_tx.send(data).await { + error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e); } } Err(e) => { @@ -343,17 +338,13 @@ async fn virtual_tcp_interface( } if client_socket.can_send() { // Check if there is anything to send - match data_to_real_server_rx.try_recv() { - Ok(data) => match client_socket.send_slice(&data) { - Err(e) => { - error!( - "[{}] Failed to send slice via virtual client socket: {:?}", - virtual_port, e - ); - } - _ => {} - }, - Err(_) => {} + if let Ok(data) = data_to_virtual_server_rx.try_recv() { + if let Err(e) = client_socket.send_slice(&data) { + error!( + "[{}] Failed to send slice via virtual client socket: {:?}", + virtual_port, e + ); + } } } } diff --git a/src/wg.rs b/src/wg.rs index fd71ce6..7a15555 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -1,4 +1,4 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::time::Duration; @@ -17,7 +17,7 @@ use crate::port_pool::PortPool; use crate::MAX_PACKET; /// The capacity of the broadcast channel for received IP packets. -const BROADCAST_CAPACITY: usize = 1_000_000; +const BROADCAST_CAPACITY: usize = 1_000; pub struct WireGuardTunnel { source_peer_ip: IpAddr, @@ -189,7 +189,7 @@ impl WireGuardTunnel { // For debugging purposes: parse packet trace_ip_packet("Received IP packet", packet); - match self.route(packet) { + match self.route_ip_packet(packet) { RouteResult::Broadcast => { // Broadcast IP packet if self.ip_broadcast_tx.receiver_count() > 1 { @@ -238,41 +238,78 @@ impl WireGuardTunnel { .with_context(|| "Failed to initialize boringtun Tunn") } - fn route(&self, packet: &[u8]) -> RouteResult { + /// Makes a decision on the handling of an incoming IP packet. + fn route_ip_packet(&self, packet: &[u8]) -> RouteResult { match IpVersion::of_packet(packet) { Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet) .ok() // Only care if the packet is destined for this tunnel .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.protocol() { - IpProtocol::Tcp => TcpPacket::new_checked(packet.payload()).ok().map(|tcp| { - if self.port_pool.is_in_use(tcp.dst_port()) { - RouteResult::Broadcast - } else if tcp.rst() { - RouteResult::Drop - } else { - // Port is not in use, but it's a TCP packet so we'll craft a RST. - RouteResult::TcpReset(craft_tcp_rst_reply( - IpVersion::Ipv4, - packet.src_addr().into(), - tcp.src_port(), - packet.dst_addr().into(), - tcp.dst_port(), - tcp.ack_number(), - )) - } - }), + IpProtocol::Tcp => Some( + self.route_tcp_segment( + packet.src_addr().into(), + packet.dst_addr().into(), + packet.payload(), + ) + // Note: Ipv4 drops invalid TCP packets when the specified protocol says that it should be TCP + .unwrap_or(RouteResult::Drop), + ), // Unrecognized protocol, so we'll allow it. _ => Some(RouteResult::Broadcast), }) .flatten() .unwrap_or(RouteResult::Drop), - // TODO: IPv6 + Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet) + .ok() + // Only care if the packet is destined for this tunnel + .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) + .map(|packet| { + self.route_tcp_segment( + packet.src_addr().into(), + packet.dst_addr().into(), + packet.payload(), + ) + // Note: Since Ipv6 doesn't inform us of the protocol at this layer, + // we should broadcast unrecognized packets. + .unwrap_or(RouteResult::Broadcast) + }) + .unwrap_or(RouteResult::Drop), _ => RouteResult::Drop, } } + + /// Makes a decision on the handling of an incoming TCP segment. + /// When the given segment is an invalid TCP packet, it returns `None`. + fn route_tcp_segment( + &self, + src_addr: IpAddress, + dst_addr: IpAddress, + segment: &[u8], + ) -> Option { + TcpPacket::new_checked(segment).ok().map(|tcp| { + if self.port_pool.is_in_use(tcp.dst_port()) { + RouteResult::Broadcast + } else if tcp.rst() { + RouteResult::Drop + } else { + // Port is not in use, but it's a TCP packet so we'll craft a RST. + RouteResult::TcpReset(craft_tcp_rst_reply( + IpVersion::Ipv4, + src_addr, + tcp.src_port(), + dst_addr, + tcp.dst_port(), + tcp.ack_number(), + )) + } + }) + } } +/// Craft an IP packet containing a TCP RST segment, given an IP version, +/// source address (the one to reply to), destination address (the one the reply comes from), +/// and the ACK number received in the initiating TCP segment. fn craft_tcp_rst_reply( ip_version: IpVersion, source_addr: IpAddress,