From 543a01d69fdd23da76f90961e1b0f4ac87584046 Mon Sep 17 00:00:00 2001 From: Louis Travaux Date: Mon, 19 May 2025 16:01:01 +0200 Subject: [PATCH] [ADD] virtual_iface/tcp: remote port forwarding logic Implements remote TCP port forwarding using previously added config parsing (#73671a4). Resolves #6 by listening on a specified virtual IP/port and forwarding incoming connections to a local service. Example: `--source-peer-ip 10.10.10.10 --remote 8080:127.0.0.1:49369` forwards `10.10.10.10:8080` to `127.0.0.1:49369`. --- src/lib.rs | 4 +- src/virtual_iface/tcp.rs | 128 ++++++++++++++++++++++++++++++++++----- 2 files changed, 115 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a76fa18..d103835 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,9 @@ pub async fn start_tunnels(config: Config, bus: Bus) -> anyhow::Result<()> { // Start TCP Virtual Interface let port_forwards = config.port_forwards.clone(); - let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip); + let remote_port_forwards = config.remote_port_forwards.clone(); + let iface = + TcpVirtualInterface::new(port_forwards, remote_port_forwards, bus, config.source_peer_ip); tokio::spawn(async move { iface.poll_loop(device).await }); } diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 3a3fd8d..1393c09 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -17,6 +17,12 @@ use std::{ collections::{HashMap, HashSet, VecDeque}, net::IpAddr, time::Duration, + sync::Arc, +}; +use tokio::{ + io::AsyncWriteExt, + net::TcpStream, + sync::Mutex, }; const MAX_PACKET: usize = 65536; @@ -25,6 +31,7 @@ const MAX_PACKET: usize = 65536; pub struct TcpVirtualInterface { source_peer_ip: IpAddr, port_forwards: Vec, + remote_port_forwards: Vec, bus: Bus, sockets: SocketSet<'static>, } @@ -32,30 +39,38 @@ pub struct TcpVirtualInterface { impl TcpVirtualInterface { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. - pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { + pub fn new( + port_forwards: Vec, + remote_port_forwards: Vec, + bus: Bus, + source_peer_ip: IpAddr + ) -> Self { Self { port_forwards: port_forwards .into_iter() .filter(|f| matches!(f.protocol, PortProtocol::Tcp)) .collect(), + remote_port_forwards: remote_port_forwards + .into_iter() + .filter(|f| matches!(f.protocol, PortProtocol::Tcp)) + .collect(), source_peer_ip, bus, sockets: SocketSet::new([]), } } - fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { - static mut TCP_SERVER_RX_DATA: [u8; 0] = []; - static mut TCP_SERVER_TX_DATA: [u8; 0] = []; - - let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + fn new_server_socket(remote_port_forward: PortForwardConfig) -> anyhow::Result> { + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let tcp_rx_buffer = tcp::SocketBuffer::new(rx_data); + let tcp_tx_buffer = tcp::SocketBuffer::new(tx_data); let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); socket .listen(( - IpAddress::from(port_forward.destination.ip()), - port_forward.destination.port(), + IpAddress::from(remote_port_forward.source.ip()), + remote_port_forward.source.port(), )) .context("Virtual server socket failed to listen")?; @@ -101,10 +116,16 @@ impl VirtualInterfacePoll for TcpVirtualInterface { }); }); - // Create virtual server for each port forward - for port_forward in self.port_forwards.iter() { - let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?; - self.sockets.add(server_socket); + let mut server_handle_map: HashMap>, PortForwardConfig)> = + HashMap::new(); + + // Create virtual server for each remote port forward + for remote_port_forward in self.remote_port_forwards.iter() { + let (handle, stream) = + connect_to_local_service(&mut self.sockets, remote_port_forward).await?; + + server_handle_map + .insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward.clone())); } // The next time to poll the interface. Can be None for instant poll. @@ -121,9 +142,9 @@ impl VirtualInterfacePoll for TcpVirtualInterface { loop { tokio::select! { - _ = match (next_poll, port_client_handle_map.len()) { - (None, 0) => tokio::time::sleep(Duration::MAX), - (None, _) => tokio::time::sleep(Duration::ZERO), + _ = match (next_poll, port_client_handle_map.is_empty() && server_handle_map.is_empty()) { + (None, true) => tokio::time::sleep(Duration::MAX), + (None, false) => tokio::time::sleep(Duration::ZERO), (Some(until), _) => tokio::time::sleep_until(until), } => { let loop_start = smoltcp::time::Instant::now(); @@ -142,6 +163,25 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } }); + let mut updated_server_handle_map = HashMap::new(); + + for (server_handle, (stream, remote_port_forward)) in server_handle_map.drain() { + let server_socket = self.sockets.get_mut::(server_handle); + if server_socket.state() == tcp::State::Closed { + self.sockets.remove(server_handle); + + // Recreate listening socket and TCP stream + let (handle, stream) = + connect_to_local_service(&mut self.sockets, &remote_port_forward).await?; + updated_server_handle_map + .insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward)); + } else { + updated_server_handle_map + .insert(server_handle, (stream, remote_port_forward)); + } + } + server_handle_map = updated_server_handle_map; + if iface.poll(loop_start, &mut device, &mut self.sockets) == PollResult::SocketStateChanged { log::trace!("TCP virtual interface polled some packets to be processed"); } @@ -189,6 +229,47 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } } + for (server_handle, (stream, _)) in server_handle_map.iter() { + let server_socket = self.sockets.get_mut::(*server_handle); + + if server_socket.state() == tcp::State::CloseWait { + server_socket.close(); + } + + if server_socket.can_recv() { + let data = server_socket.recv(|buffer| { + (buffer.len(), Bytes::copy_from_slice(buffer)) + })?; + + if !data.is_empty() { + let stream = Arc::clone(stream); + tokio::spawn(async move { + let mut stream = stream.lock().await; + log::trace!("Forwarding remote data to listening local service"); + if let Err(e) = stream.write(&data).await { + error!("Failed to forward to local service: {}", e); + } + if let Err(e) = stream.flush().await { + error!("Failed to flush to local stream: {}", e); + } + if let Err(e) = stream.shutdown().await { + error!("Failed to shutdown local stream: {}", e); + } + }); + } + } + + if server_socket.can_send() { + let stream = stream.lock().await; + let mut buf = [0u8; 1024]; + if let Ok(n) = stream.try_read(&mut buf) { + if n > 0 { + let _ = server_socket.send_slice(&buf[..n]); + } + } + } + } + // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) next_poll = match iface.poll_delay(loop_start, &self.sockets) { Some(smoltcp::time::Duration::ZERO) => None, @@ -255,3 +336,18 @@ const fn addr_length(addr: &IpAddress) -> u8 { IpVersion::Ipv6 => 128, } } + +/// Connect to a local service via tcp using the provided port forward config. +/// Returns both the socket handle and the TcpStream. +async fn connect_to_local_service( + sockets: &mut SocketSet<'static>, + remote_port_forward: &PortForwardConfig, +) -> anyhow::Result<(SocketHandle, TcpStream)> { + let server_socket = TcpVirtualInterface::new_server_socket(*remote_port_forward)?; + let handle = sockets.add(server_socket); + let stream = TcpStream::connect(remote_port_forward.destination) + .await + .context("Failed to connect to listening locally listening server.")?; + + Ok((handle, stream)) +}