From c2d0b9719a2264ae6231dd52c53b030c7a5c1076 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 19 Oct 2021 00:43:59 -0400 Subject: [PATCH] Refactor TCP virtual interface code out of main. Removed unused server socket buffer. --- Cargo.lock | 12 ++ Cargo.toml | 1 + src/main.rs | 278 ++++----------------------------------- src/virtual_iface/mod.rs | 10 ++ src/virtual_iface/tcp.rs | 274 ++++++++++++++++++++++++++++++++++++++ src/wg.rs | 2 +- 6 files changed, 322 insertions(+), 255 deletions(-) create mode 100644 src/virtual_iface/mod.rs create mode 100644 src/virtual_iface/tcp.rs diff --git a/Cargo.lock b/Cargo.lock index e87a015..f3ba00d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,17 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e" +[[package]] +name = "async-trait" +version = "0.1.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44318e776df68115a881de9a8fd1b9e53368d7a4a5ce4cc48517da3393233a5e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atty" version = "0.2.14" @@ -595,6 +606,7 @@ name = "onetun" version = "0.1.11" dependencies = [ "anyhow", + "async-trait", "boringtun", "clap", "futures", diff --git a/Cargo.toml b/Cargo.toml index a6f927f..d552c2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,4 @@ lockfree = "0.5.1" futures = "0.3.17" rand = "0.8.4" nom = "7" +async-trait = "0.1.51" diff --git a/src/main.rs b/src/main.rs index a117a99..28846fa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,26 +1,24 @@ #[macro_use] extern crate log; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::time::Duration; use anyhow::Context; -use smoltcp::iface::InterfaceBuilder; -use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState}; -use smoltcp::wire::{IpAddress, IpCidr}; use tokio::net::{TcpListener, TcpStream}; use crate::config::{Config, PortForwardConfig, PortProtocol}; use crate::port_pool::PortPool; -use crate::virtual_device::VirtualIpDevice; +use crate::virtual_iface::tcp::TcpVirtualInterface; +use crate::virtual_iface::VirtualInterfacePoll; use crate::wg::WireGuardTunnel; pub mod config; pub mod ip_sink; pub mod port_pool; pub mod virtual_device; +pub mod virtual_iface; pub mod wg; pub const MAX_PACKET: usize = 65536; @@ -92,29 +90,18 @@ async fn port_forward( ); match port_forward.protocol { - PortProtocol::Tcp => { - tcp_proxy_server( - port_forward.source, - port_forward.destination, - source_peer_ip, - port_pool, - wg, - ) - .await - } + PortProtocol::Tcp => tcp_proxy_server(port_forward, port_pool, wg).await, PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")), } } /// Starts the server that listens on TCP connections. async fn tcp_proxy_server( - listen_addr: SocketAddr, - dest_addr: SocketAddr, - source_peer_ip: IpAddr, + port_forward: PortForwardConfig, port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { - let listener = TcpListener::bind(listen_addr) + let listener = TcpListener::bind(port_forward.source) .await .with_context(|| "Failed to listen on TCP proxy server")?; @@ -144,14 +131,8 @@ async fn tcp_proxy_server( tokio::spawn(async move { let port_pool = Arc::clone(&port_pool); - let result = handle_tcp_proxy_connection( - socket, - virtual_port, - source_peer_ip, - dest_addr, - wg.clone(), - ) - .await; + let result = + handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await; if let Err(e) = result { error!( @@ -173,8 +154,7 @@ async fn tcp_proxy_server( async fn handle_tcp_proxy_connection( socket: TcpStream, virtual_port: u16, - source_peer_ip: IpAddr, - dest_addr: SocketAddr, + port_forward: PortForwardConfig, wg: Arc, ) -> anyhow::Result<()> { // Abort signal for stopping the Virtual Interface @@ -194,18 +174,21 @@ async fn handle_tcp_proxy_connection( // Spawn virtual interface { let abort = abort.clone(); + let virtual_interface = TcpVirtualInterface::new( + virtual_port, + port_forward, + wg, + abort.clone(), + data_to_real_client_tx, + data_to_virtual_server_rx, + virtual_client_ready_tx, + ); + tokio::spawn(async move { - virtual_tcp_interface( - virtual_port, - source_peer_ip, - dest_addr, - wg, - abort, - data_to_real_client_tx, - data_to_virtual_server_rx, - virtual_client_ready_tx, - ) - .await + virtual_interface.poll_loop().await.unwrap_or_else(|e| { + error!("Virtual interface poll loop failed unexpectedly: {}", e); + abort.store(true, Ordering::Relaxed); + }) }); } @@ -297,219 +280,6 @@ async fn handle_tcp_proxy_connection( Ok(()) } -#[allow(clippy::too_many_arguments)] -async fn virtual_tcp_interface( - virtual_port: u16, - source_peer_ip: IpAddr, - dest_addr: SocketAddr, - wg: Arc, - abort: Arc, - data_to_real_client_tx: tokio::sync::mpsc::Sender>, - mut data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, - virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, -) -> anyhow::Result<()> { - let mut virtual_client_ready_tx = Some(virtual_client_ready_tx); - - // Create a device and interface to simulate IP packets - // In essence: - // * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client' - // * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel - // * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface' - // * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded) - // * The TCP data read by the 'virtual client' is sent to the 'real' TCP client - - // Consumer for IP packets to send through the virtual interface - // Initialize the interface - let device = VirtualIpDevice::new(virtual_port, wg) - .with_context(|| "Failed to initialize VirtualIpDevice")?; - let mut virtual_interface = InterfaceBuilder::new(device) - .ip_addrs([ - // Interface handles IP packets for the sender and recipient - IpCidr::new(IpAddress::from(source_peer_ip), 32), - IpCidr::new(IpAddress::from(dest_addr.ip()), 32), - ]) - .finalize(); - - // Server socket: this is a placeholder for the interface to route new connections to. - // TODO: Determine if we even need buffers here. - let server_socket: anyhow::Result = { - static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - - socket - .listen((IpAddress::from(dest_addr.ip()), dest_addr.port())) - .with_context(|| "Virtual server socket failed to listen")?; - - Ok(socket) - }; - - let client_socket: anyhow::Result = { - let rx_data = vec![0u8; MAX_PACKET]; - let tx_data = vec![0u8; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); - let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); - let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - Ok(socket) - }; - - // Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server. - let mut socket_set_entries: [_; 2] = Default::default(); - let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); - let _server_handle = socket_set.add(server_socket?); - let client_handle = socket_set.add(client_socket?); - - // Any data that wasn't sent because it was over the sending buffer limit - let mut tx_extra = Vec::new(); - - // Counts the connection attempts by the virtual client - let mut connection_attempts = 0; - // Whether the client has successfully connected before. Prevents the case of connecting again. - let mut has_connected = false; - - loop { - let loop_start = smoltcp::time::Instant::now(); - - // Shutdown occurs when the real client closes the connection, - // or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore. - // One last poll-loop iteration is executed so that the RST segment can be dispatched. - let shutdown = abort.load(Ordering::Relaxed); - - if shutdown { - // Shutdown: sends a RST packet. - trace!("[{}] Shutting down virtual interface", virtual_port); - let mut client_socket = socket_set.get::(client_handle); - client_socket.abort(); - } - - match virtual_interface.poll(&mut socket_set, loop_start) { - Ok(processed) if processed => { - trace!( - "[{}] Virtual interface polled some packets to be processed", - virtual_port - ); - } - Err(e) => { - error!("[{}] Virtual interface poll error: {:?}", virtual_port, e); - } - _ => {} - } - - { - let mut client_socket = socket_set.get::(client_handle); - - if !shutdown && client_socket.state() == TcpState::Closed && !has_connected { - // Not shutting down, but the client socket is closed, and the client never successfully connected. - if connection_attempts < 10 { - // Try to connect - client_socket - .connect( - (IpAddress::from(dest_addr.ip()), dest_addr.port()), - (IpAddress::from(source_peer_ip), virtual_port), - ) - .with_context(|| "Virtual server socket failed to listen")?; - if connection_attempts > 0 { - debug!( - "[{}] Virtual client retrying connection in 500ms", - virtual_port - ); - // Not our first connection attempt, wait a little bit. - tokio::time::sleep(Duration::from_millis(500)).await; - } - } else { - // Too many connection attempts - abort.store(true, Ordering::Relaxed); - } - connection_attempts += 1; - continue; - } - - if client_socket.state() == TcpState::Established { - // Prevent reconnection if the server later closes. - has_connected = true; - } - - if client_socket.can_recv() { - match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { - Ok(data) => { - trace!( - "[{}] Virtual client received {} bytes of data", - virtual_port, - data.len() - ); - // Send it to the real client - if let Err(e) = data_to_real_client_tx.send(data).await { - error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e); - } - } - Err(e) => { - error!( - "[{}] Failed to read from virtual client socket: {:?}", - virtual_port, e - ); - } - } - } - if client_socket.can_send() { - if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() { - virtual_client_ready_tx - .send(()) - .expect("Failed to notify real client that virtual client is ready"); - } - - let mut to_transfer = None; - - if tx_extra.is_empty() { - // The payload segment from the previous loop is complete, - // we can now read the next payload in the queue. - if let Ok(data) = data_to_virtual_server_rx.try_recv() { - to_transfer = Some(data); - } else if client_socket.state() == TcpState::CloseWait { - // No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN), - // the interface is shutdown. - trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", virtual_port); - abort.store(true, Ordering::Relaxed); - continue; - } - } - - let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice(); - if !to_transfer_slice.is_empty() { - let total = to_transfer_slice.len(); - match client_socket.send_slice(to_transfer_slice) { - Ok(sent) => { - trace!( - "[{}] Sent {}/{} bytes via virtual client socket", - virtual_port, - sent, - total, - ); - tx_extra = Vec::from(&to_transfer_slice[sent..total]); - } - Err(e) => { - error!( - "[{}] Failed to send slice via virtual client socket: {:?}", - virtual_port, e - ); - } - } - } - } - } - - if shutdown { - break; - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - trace!("[{}] Virtual interface task terminated", virtual_port); - abort.store(true, Ordering::Relaxed); - Ok(()) -} - fn init_logger(config: &Config) -> anyhow::Result<()> { let mut builder = pretty_env_logger::formatted_builder(); builder.parse_filters(&config.log); diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs new file mode 100644 index 0000000..b9d3354 --- /dev/null +++ b/src/virtual_iface/mod.rs @@ -0,0 +1,10 @@ +pub mod tcp; + +use async_trait::async_trait; + +#[async_trait] +pub trait VirtualInterfacePoll { + /// Initializes the virtual interface and processes incoming data to be dispatched + /// to the WireGuard tunnel and to the real client. + async fn poll_loop(mut self) -> anyhow::Result<()>; +} diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs new file mode 100644 index 0000000..c9fee95 --- /dev/null +++ b/src/virtual_iface/tcp.rs @@ -0,0 +1,274 @@ +use crate::config::PortForwardConfig; +use crate::virtual_device::VirtualIpDevice; +use crate::virtual_iface::VirtualInterfacePoll; +use crate::wg::WireGuardTunnel; +use anyhow::Context; +use async_trait::async_trait; +use smoltcp::iface::InterfaceBuilder; +use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState}; +use smoltcp::wire::{IpAddress, IpCidr}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +const MAX_PACKET: usize = 65536; + +/// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa. +pub struct TcpVirtualInterface { + /// The virtual port assigned to the virtual client, used to + /// route Layer 4 segments/datagrams to and from the WireGuard tunnel. + virtual_port: u16, + /// The overall port-forward configuration: used for the destination address (on which + /// the virtual server listens) and the protocol in use. + port_forward: PortForwardConfig, + /// The WireGuard tunnel to send IP packets to. + wg: Arc, + /// Abort signal to shutdown the virtual interface and its parent task. + abort: Arc, + /// Channel sender for pushing Layer 7 data back to the real client. + data_to_real_client_tx: tokio::sync::mpsc::Sender>, + /// Channel receiver for processing Layer 7 data through the virtual interface. + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, + /// One-shot sender to notify the parent task that the virtual client is ready to send Layer 7 data. + virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, +} + +impl TcpVirtualInterface { + /// Initialize the parameters for a new virtual interface. + /// Use the `poll_loop()` future to start the virtual interface poll loop. + pub fn new( + virtual_port: u16, + port_forward: PortForwardConfig, + wg: Arc, + abort: Arc, + data_to_real_client_tx: tokio::sync::mpsc::Sender>, + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, + virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, + ) -> Self { + Self { + virtual_port, + port_forward, + wg, + abort, + data_to_real_client_tx, + data_to_virtual_server_rx, + virtual_client_ready_tx, + } + } +} + +#[async_trait] +impl VirtualInterfacePoll for TcpVirtualInterface { + async fn poll_loop(self) -> anyhow::Result<()> { + let mut virtual_client_ready_tx = Some(self.virtual_client_ready_tx); + let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; + let source_peer_ip = self.wg.source_peer_ip; + + // Create a device and interface to simulate IP packets + // In essence: + // * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client' + // * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel + // * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface' + // * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded) + // * The TCP data read by the 'virtual client' is sent to the 'real' TCP client + + // Consumer for IP packets to send through the virtual interface + // Initialize the interface + let device = VirtualIpDevice::new(self.virtual_port, self.wg) + .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; + let mut virtual_interface = InterfaceBuilder::new(device) + .ip_addrs([ + // Interface handles IP packets for the sender and recipient + IpCidr::new(IpAddress::from(source_peer_ip), 32), + IpCidr::new(IpAddress::from(self.port_forward.destination.ip()), 32), + ]) + .finalize(); + + // Server socket: this is a placeholder for the interface to route new connections to. + let server_socket: anyhow::Result = { + static mut TCP_SERVER_RX_DATA: [u8; 0] = []; + static mut TCP_SERVER_TX_DATA: [u8; 0] = []; + let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .listen(( + IpAddress::from(self.port_forward.destination.ip()), + self.port_forward.destination.port(), + )) + .with_context(|| "Virtual server socket failed to listen")?; + + Ok(socket) + }; + + let client_socket: anyhow::Result = { + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); + let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); + let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + Ok(socket) + }; + + // Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server. + let mut socket_set_entries: [_; 2] = Default::default(); + let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); + let _server_handle = socket_set.add(server_socket?); + let client_handle = socket_set.add(client_socket?); + + // Any data that wasn't sent because it was over the sending buffer limit + let mut tx_extra = Vec::new(); + + // Counts the connection attempts by the virtual client + let mut connection_attempts = 0; + // Whether the client has successfully connected before. Prevents the case of connecting again. + let mut has_connected = false; + + loop { + let loop_start = smoltcp::time::Instant::now(); + + // Shutdown occurs when the real client closes the connection, + // or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore. + // One last poll-loop iteration is executed so that the RST segment can be dispatched. + let shutdown = self.abort.load(Ordering::Relaxed); + + if shutdown { + // Shutdown: sends a RST packet. + trace!("[{}] Shutting down virtual interface", self.virtual_port); + let mut client_socket = socket_set.get::(client_handle); + client_socket.abort(); + } + + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + trace!( + "[{}] Virtual interface polled some packets to be processed", + self.virtual_port + ); + } + Err(e) => { + error!( + "[{}] Virtual interface poll error: {:?}", + self.virtual_port, e + ); + } + _ => {} + } + + { + let mut client_socket = socket_set.get::(client_handle); + + if !shutdown && client_socket.state() == TcpState::Closed && !has_connected { + // Not shutting down, but the client socket is closed, and the client never successfully connected. + if connection_attempts < 10 { + // Try to connect + client_socket + .connect( + ( + IpAddress::from(self.port_forward.destination.ip()), + self.port_forward.destination.port(), + ), + (IpAddress::from(source_peer_ip), self.virtual_port), + ) + .with_context(|| "Virtual server socket failed to listen")?; + if connection_attempts > 0 { + debug!( + "[{}] Virtual client retrying connection in 500ms", + self.virtual_port + ); + // Not our first connection attempt, wait a little bit. + tokio::time::sleep(Duration::from_millis(500)).await; + } + } else { + // Too many connection attempts + self.abort.store(true, Ordering::Relaxed); + } + connection_attempts += 1; + continue; + } + + if client_socket.state() == TcpState::Established { + // Prevent reconnection if the server later closes. + has_connected = true; + } + + if client_socket.can_recv() { + match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { + Ok(data) => { + trace!( + "[{}] Virtual client received {} bytes of data", + self.virtual_port, + data.len() + ); + // Send it to the real client + if let Err(e) = self.data_to_real_client_tx.send(data).await { + error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", self.virtual_port, e); + } + } + Err(e) => { + error!( + "[{}] Failed to read from virtual client socket: {:?}", + self.virtual_port, e + ); + } + } + } + if client_socket.can_send() { + if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() { + virtual_client_ready_tx + .send(()) + .expect("Failed to notify real client that virtual client is ready"); + } + + let mut to_transfer = None; + + if tx_extra.is_empty() { + // The payload segment from the previous loop is complete, + // we can now read the next payload in the queue. + if let Ok(data) = data_to_virtual_server_rx.try_recv() { + to_transfer = Some(data); + } else if client_socket.state() == TcpState::CloseWait { + // No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN), + // the interface is shutdown. + trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", self.virtual_port); + self.abort.store(true, Ordering::Relaxed); + continue; + } + } + + let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice(); + if !to_transfer_slice.is_empty() { + let total = to_transfer_slice.len(); + match client_socket.send_slice(to_transfer_slice) { + Ok(sent) => { + trace!( + "[{}] Sent {}/{} bytes via virtual client socket", + self.virtual_port, + sent, + total, + ); + tx_extra = Vec::from(&to_transfer_slice[sent..total]); + } + Err(e) => { + error!( + "[{}] Failed to send slice via virtual client socket: {:?}", + self.virtual_port, e + ); + } + } + } + } + } + + if shutdown { + break; + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + trace!("[{}] Virtual interface task terminated", self.virtual_port); + self.abort.store(true, Ordering::Relaxed); + Ok(()) + } +} diff --git a/src/wg.rs b/src/wg.rs index e2740e8..3d3cc5f 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -18,7 +18,7 @@ const DISPATCH_CAPACITY: usize = 1_000; /// to be sent to and received from a remote UDP endpoint. /// This tunnel supports at most 1 peer IP at a time, but supports simultaneous ports. pub struct WireGuardTunnel { - source_peer_ip: IpAddr, + pub(crate) source_peer_ip: IpAddr, /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. peer: Box, /// The UDP socket for the public WireGuard endpoint to connect to.