diff --git a/Cargo.lock b/Cargo.lock index 495b8db..bc6362d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -272,6 +272,100 @@ dependencies = [ "version_check", ] +[[package]] +name = "futures" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12aa0eb539080d55c3f2d45a67c3b58b6b0773c1a3ca2dfec66d58c97fd66ca" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" + +[[package]] +name = "futures-executor" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "522de2a0fe3e380f1bc577ba0474108faf3f6b18321dbf60b3b9c39a75073377" + +[[package]] +name = "futures-macro" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e4a4b95cea4b4ccbcf1c5675ca7c4ee4e9e75eb79944d07defde18068f79bb" +dependencies = [ + "autocfg", + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11" + +[[package]] +name = "futures-task" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99" + +[[package]] +name = "futures-util" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481" +dependencies = [ + "autocfg", + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "proc-macro-hack", + "proc-macro-nested", + "slab", +] + [[package]] name = "getrandom" version = "0.2.3" @@ -518,6 +612,7 @@ dependencies = [ "clap", "crossbeam-channel", "dashmap", + "futures", "lockfree", "log", "pretty_env_logger", @@ -562,6 +657,12 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pretty_env_logger" version = "0.3.1" @@ -573,6 +674,18 @@ dependencies = [ "log", ] +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + +[[package]] +name = "proc-macro-nested" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" + [[package]] name = "proc-macro2" version = "1.0.29" @@ -684,6 +797,12 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" + [[package]] name = "slog" version = "2.7.0" diff --git a/Cargo.toml b/Cargo.toml index 6edccb9..8b9c266 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" } dashmap = "4.0.2" tokio = { version = "1", features = ["full"] } lockfree = "0.5.1" +futures = "0.3.17" diff --git a/src/main.rs b/src/main.rs index f2509ee..9bbac59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,22 @@ #[macro_use] extern crate log; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; use anyhow::Context; +use smoltcp::iface::InterfaceBuilder; +use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; +use smoltcp::wire::{IpAddress, IpCidr}; use tokio::io::Interest; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::error::TryRecvError; use crate::config::Config; use crate::port_pool::PortPool; +use crate::virtual_device::VirtualIpDevice; use crate::wg::WireGuardTunnel; pub mod client; @@ -51,12 +55,21 @@ async fn main() -> anyhow::Result<()> { &config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip ); - tcp_proxy_server(config.source_addr.clone(), port_pool.clone(), wg).await + tcp_proxy_server( + config.source_addr, + config.source_peer_ip, + config.dest_addr, + port_pool.clone(), + wg, + ) + .await } /// Starts the server that listens on TCP connections. async fn tcp_proxy_server( listen_addr: SocketAddr, + source_peer_ip: IpAddr, + dest_addr: SocketAddr, port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { @@ -90,7 +103,9 @@ async fn tcp_proxy_server( tokio::spawn(async move { let port_pool = Arc::clone(&port_pool); - let result = handle_tcp_proxy_connection(socket, virtual_port, wg).await; + let result = + handle_tcp_proxy_connection(socket, virtual_port, source_peer_ip, dest_addr, wg) + .await; if let Err(e) = result { error!( @@ -111,6 +126,8 @@ async fn tcp_proxy_server( async fn handle_tcp_proxy_connection( socket: TcpStream, virtual_port: u16, + source_peer_ip: IpAddr, + dest_addr: SocketAddr, wg: Arc, ) -> anyhow::Result<()> { // Abort signal for stopping the Virtual Interface @@ -121,11 +138,22 @@ async fn handle_tcp_proxy_connection( let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000_000); + let (data_to_real_server_tx, data_to_real_server_rx) = tokio::sync::mpsc::channel(1_000_000); + // Spawn virtual interface { let abort = abort.clone(); tokio::spawn(async move { - virtual_tcp_interface(virtual_port, wg, abort, data_to_real_client_tx).await + virtual_tcp_interface( + virtual_port, + source_peer_ip, + dest_addr, + wg, + abort, + data_to_real_client_tx, + data_to_real_server_rx, + ) + .await }); } @@ -149,7 +177,15 @@ async fn handle_tcp_proxy_connection( "[{}] Read {} bytes of TCP data from real client", virtual_port, size ); - trace!("[{}] Read: {:?}", virtual_port, data); + match data_to_real_server_tx.send(data.to_vec()).await { + Err(e) => { + error!( + "[{}] Failed to dispatch data to virtual interface: {:?}", + virtual_port, e + ); + } + _ => {} + } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { continue; @@ -210,37 +246,135 @@ async fn handle_tcp_proxy_connection( async fn virtual_tcp_interface( virtual_port: u16, + source_peer_ip: IpAddr, + dest_addr: SocketAddr, wg: Arc, abort: Arc, data_to_real_client_tx: tokio::sync::mpsc::Sender>, + mut data_to_real_server_rx: tokio::sync::mpsc::Receiver>, ) -> anyhow::Result<()> { // Create a device and interface to simulate IP packets // In essence: - // * TCP packets received from the 'real' client are 'sent' via the 'virtual client' + // * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client' // * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel // * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface' // * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded) // * The TCP data read by the 'virtual client' is sent to the 'real' TCP client + + // Consumer for IP packets to send through the virtual interface + // Initialize the interface + let device = VirtualIpDevice::new(wg); + let mut virtual_interface = InterfaceBuilder::new(device) + .ip_addrs([ + // Interface handles IP packets for the sender and recipient + IpCidr::new(IpAddress::from(source_peer_ip), 32), + IpCidr::new(IpAddress::from(dest_addr.ip()), 32), + ]) + .any_ip(true) + .finalize(); + + // Server socket: this is a placeholder for the interface to route new connections to. + // TODO: Determine if we even need buffers here. + let server_socket: anyhow::Result = { + static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .listen((IpAddress::from(dest_addr.ip()), dest_addr.port())) + .with_context(|| "Virtual server socket failed to listen")?; + + Ok(socket) + }; + + let client_socket: anyhow::Result = { + static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .connect( + (IpAddress::from(dest_addr.ip()), dest_addr.port()), + (IpAddress::from(source_peer_ip), virtual_port), + ) + .with_context(|| "Virtual server socket failed to listen")?; + + Ok(socket) + }; + + // Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server. + let mut socket_set_entries: [_; 2] = Default::default(); + let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); + let _server_handle = socket_set.add(server_socket?); + let client_handle = socket_set.add(client_socket?); + loop { + let loop_start = smoltcp::time::Instant::now(); + if abort.load(Ordering::Relaxed) { break; } - // Test START - tokio::time::sleep(Duration::from_millis(1000)).await; - match data_to_real_client_tx.send(b"pong".to_vec()).await { - Ok(_) => { - trace!("Wrote stuff in the data_to_real_client_tx") - } - Err(e) => { + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { trace!( - "[{}] Virtual interface failed to dispatch data to parent task: {:?}", - virtual_port, - e + "[{}] Virtual interface polled some packets to be processed", + virtual_port ); } + Err(e) => { + error!("[{}] Virtual interface poll error: {:?}", virtual_port, e); + } + _ => {} } - // Test END + + { + let mut client_socket = socket_set.get::(client_handle); + if client_socket.can_recv() { + match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { + Ok(data) => { + // Send it to the real client + match data_to_real_client_tx.send(data).await { + Err(e) => { + error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e); + } + _ => {} + } + } + Err(e) => { + error!( + "[{}] Failed to read from virtual client socket: {:?}", + virtual_port, e + ); + } + } + } + if client_socket.can_send() { + // Check if there is anything to send + match data_to_real_server_rx.try_recv() { + Ok(data) => match client_socket.send_slice(&data) { + Err(e) => { + error!( + "[{}] Failed to send slice via virtual client socket: {:?}", + virtual_port, e + ); + } + _ => {} + }, + Err(_) => {} + } + } + } + + match virtual_interface.poll_delay(&socket_set, loop_start) { + None => tokio::time::sleep(Duration::from_millis(1)).await, + Some(smoltcp::time::Duration::ZERO) => {} + Some(delay) => tokio::time::sleep(Duration::from_millis(delay.millis())).await, + }; } trace!("[{}] Virtual interface task terminated", virtual_port); Ok(()) diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 24ddfe4..ec43128 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,19 +1,17 @@ +use crate::wg::WireGuardTunnel; use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; +use std::sync::Arc; #[derive(Clone)] pub struct VirtualIpDevice { - /// Channel for packets sent by the interface. - ip_tx: crossbeam_channel::Sender>, - ip_rx: crossbeam_channel::Receiver>, + /// Tunnel to send IP packets to. + wg: Arc, } impl VirtualIpDevice { - pub fn new( - ip_tx: crossbeam_channel::Sender>, - ip_rx: crossbeam_channel::Receiver>, - ) -> Self { - Self { ip_tx, ip_rx } + pub fn new(wg: Arc) -> Self { + Self { wg } } } @@ -22,22 +20,21 @@ impl<'a> Device<'a> for VirtualIpDevice { type TxToken = TxToken; fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - if !self.ip_rx.is_empty() { - let buffer = self.ip_rx.recv().expect("failed to read ip_rx"); - Some(( - RxToken { buffer }, - TxToken { - ip_tx: self.ip_tx.clone(), + let mut consumer = self.wg.subscribe(); + match consumer.try_recv() { + Ok(buffer) => Some(( + Self::RxToken { buffer }, + Self::TxToken { + wg: self.wg.clone(), }, - )) - } else { - None + )), + Err(_) => None, } } fn transmit(&'a mut self) -> Option { Some(TxToken { - ip_tx: self.ip_tx.clone(), + wg: self.wg.clone(), }) } @@ -65,10 +62,10 @@ impl smoltcp::phy::RxToken for RxToken { #[doc(hidden)] pub struct TxToken { - ip_tx: crossbeam_channel::Sender>, + wg: Arc, } -impl<'a> smoltcp::phy::TxToken for TxToken { +impl smoltcp::phy::TxToken for TxToken { fn consume(self, _timestamp: Instant, len: usize, f: F) -> smoltcp::Result where F: FnOnce(&mut [u8]) -> smoltcp::Result, @@ -76,9 +73,12 @@ impl<'a> smoltcp::phy::TxToken for TxToken { let mut buffer = Vec::new(); buffer.resize(len, 0); let result = f(&mut buffer); - self.ip_tx - .send(buffer.clone()) - .expect("failed to send to ip_tx"); + match futures::executor::block_on(self.wg.send_ip_packet(&buffer)) { + Ok(_) => {} + Err(e) => { + error!("Failed to send IP packet to WireGuard endpoint: {:?}", e); + } + } result } } diff --git a/src/wg.rs b/src/wg.rs index 0622558..ed92a77 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -21,6 +21,7 @@ pub struct WireGuardTunnel { endpoint: SocketAddr, /// Broadcast sender for received IP packets. ip_broadcast_tx: tokio::sync::broadcast::Sender>, + ip_broadcast_rx: tokio::sync::broadcast::Receiver>, } impl WireGuardTunnel { @@ -31,13 +32,15 @@ impl WireGuardTunnel { .await .with_context(|| "Failed to create UDP socket for WireGuard connection")?; let endpoint = config.endpoint_addr; - let (ip_broadcast_tx, _) = tokio::sync::broadcast::channel(BROADCAST_CAPACITY); + let (ip_broadcast_tx, ip_broadcast_rx) = + tokio::sync::broadcast::channel(BROADCAST_CAPACITY); Ok(Self { peer, udp, endpoint, ip_broadcast_tx, + ip_broadcast_rx, }) } @@ -178,7 +181,7 @@ impl WireGuardTunnel { } Err(e) => { error!( - "Failed to broadcast received IP packet to recipients: {:?}", + "Failed to broadcast received IP packet to recipients: {}", e ); }