[ADD] virtual_iface/tcp: remote port forwarding logic

Implements remote TCP port forwarding using previously added config parsing (#73671a4).

Resolves #6 by listening on a specified virtual IP/port and forwarding incoming connections
to a local service.

Example: `--source-peer-ip 10.10.10.10 --remote 8080:127.0.0.1:49369` forwards
`10.10.10.10:8080` to `127.0.0.1:49369`.
This commit is contained in:
Louis Travaux 2025-05-19 16:01:01 +02:00
parent 3ca5ae2181
commit 543a01d69f
2 changed files with 115 additions and 17 deletions

View file

@ -74,7 +74,9 @@ pub async fn start_tunnels(config: Config, bus: Bus) -> anyhow::Result<()> {
// Start TCP Virtual Interface
let port_forwards = config.port_forwards.clone();
let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip);
let remote_port_forwards = config.remote_port_forwards.clone();
let iface =
TcpVirtualInterface::new(port_forwards, remote_port_forwards, bus, config.source_peer_ip);
tokio::spawn(async move { iface.poll_loop(device).await });
}

View file

@ -17,6 +17,12 @@ use std::{
collections::{HashMap, HashSet, VecDeque},
net::IpAddr,
time::Duration,
sync::Arc,
};
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
sync::Mutex,
};
const MAX_PACKET: usize = 65536;
@ -25,6 +31,7 @@ const MAX_PACKET: usize = 65536;
pub struct TcpVirtualInterface {
source_peer_ip: IpAddr,
port_forwards: Vec<PortForwardConfig>,
remote_port_forwards: Vec<PortForwardConfig>,
bus: Bus,
sockets: SocketSet<'static>,
}
@ -32,30 +39,38 @@ pub struct TcpVirtualInterface {
impl TcpVirtualInterface {
/// Initialize the parameters for a new virtual interface.
/// Use the `poll_loop()` future to start the virtual interface poll loop.
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
pub fn new(
port_forwards: Vec<PortForwardConfig>,
remote_port_forwards: Vec<PortForwardConfig>,
bus: Bus,
source_peer_ip: IpAddr
) -> Self {
Self {
port_forwards: port_forwards
.into_iter()
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
.collect(),
remote_port_forwards: remote_port_forwards
.into_iter()
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
.collect(),
source_peer_ip,
bus,
sockets: SocketSet::new([]),
}
}
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<tcp::Socket<'static>> {
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
fn new_server_socket(remote_port_forward: PortForwardConfig) -> anyhow::Result<tcp::Socket<'static>> {
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let tcp_rx_buffer = tcp::SocketBuffer::new(rx_data);
let tcp_tx_buffer = tcp::SocketBuffer::new(tx_data);
let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.listen((
IpAddress::from(port_forward.destination.ip()),
port_forward.destination.port(),
IpAddress::from(remote_port_forward.source.ip()),
remote_port_forward.source.port(),
))
.context("Virtual server socket failed to listen")?;
@ -101,10 +116,16 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
});
});
// Create virtual server for each port forward
for port_forward in self.port_forwards.iter() {
let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?;
self.sockets.add(server_socket);
let mut server_handle_map: HashMap<SocketHandle, (Arc<Mutex<TcpStream>>, PortForwardConfig)> =
HashMap::new();
// Create virtual server for each remote port forward
for remote_port_forward in self.remote_port_forwards.iter() {
let (handle, stream) =
connect_to_local_service(&mut self.sockets, remote_port_forward).await?;
server_handle_map
.insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward.clone()));
}
// The next time to poll the interface. Can be None for instant poll.
@ -121,9 +142,9 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
loop {
tokio::select! {
_ = match (next_poll, port_client_handle_map.len()) {
(None, 0) => tokio::time::sleep(Duration::MAX),
(None, _) => tokio::time::sleep(Duration::ZERO),
_ = match (next_poll, port_client_handle_map.is_empty() && server_handle_map.is_empty()) {
(None, true) => tokio::time::sleep(Duration::MAX),
(None, false) => tokio::time::sleep(Duration::ZERO),
(Some(until), _) => tokio::time::sleep_until(until),
} => {
let loop_start = smoltcp::time::Instant::now();
@ -142,6 +163,25 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
}
});
let mut updated_server_handle_map = HashMap::new();
for (server_handle, (stream, remote_port_forward)) in server_handle_map.drain() {
let server_socket = self.sockets.get_mut::<tcp::Socket>(server_handle);
if server_socket.state() == tcp::State::Closed {
self.sockets.remove(server_handle);
// Recreate listening socket and TCP stream
let (handle, stream) =
connect_to_local_service(&mut self.sockets, &remote_port_forward).await?;
updated_server_handle_map
.insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward));
} else {
updated_server_handle_map
.insert(server_handle, (stream, remote_port_forward));
}
}
server_handle_map = updated_server_handle_map;
if iface.poll(loop_start, &mut device, &mut self.sockets) == PollResult::SocketStateChanged {
log::trace!("TCP virtual interface polled some packets to be processed");
}
@ -189,6 +229,47 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
}
}
for (server_handle, (stream, _)) in server_handle_map.iter() {
let server_socket = self.sockets.get_mut::<tcp::Socket>(*server_handle);
if server_socket.state() == tcp::State::CloseWait {
server_socket.close();
}
if server_socket.can_recv() {
let data = server_socket.recv(|buffer| {
(buffer.len(), Bytes::copy_from_slice(buffer))
})?;
if !data.is_empty() {
let stream = Arc::clone(stream);
tokio::spawn(async move {
let mut stream = stream.lock().await;
log::trace!("Forwarding remote data to listening local service");
if let Err(e) = stream.write(&data).await {
error!("Failed to forward to local service: {}", e);
}
if let Err(e) = stream.flush().await {
error!("Failed to flush to local stream: {}", e);
}
if let Err(e) = stream.shutdown().await {
error!("Failed to shutdown local stream: {}", e);
}
});
}
}
if server_socket.can_send() {
let stream = stream.lock().await;
let mut buf = [0u8; 1024];
if let Ok(n) = stream.try_read(&mut buf) {
if n > 0 {
let _ = server_socket.send_slice(&buf[..n]);
}
}
}
}
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
next_poll = match iface.poll_delay(loop_start, &self.sockets) {
Some(smoltcp::time::Duration::ZERO) => None,
@ -255,3 +336,18 @@ const fn addr_length(addr: &IpAddress) -> u8 {
IpVersion::Ipv6 => 128,
}
}
/// Connect to a local service via tcp using the provided port forward config.
/// Returns both the socket handle and the TcpStream.
async fn connect_to_local_service(
sockets: &mut SocketSet<'static>,
remote_port_forward: &PortForwardConfig,
) -> anyhow::Result<(SocketHandle, TcpStream)> {
let server_socket = TcpVirtualInterface::new_server_socket(*remote_port_forward)?;
let handle = sockets.add(server_socket);
let stream = TcpStream::connect(remote_port_forward.destination)
.await
.context("Failed to connect to listening locally listening server.")?;
Ok((handle, stream))
}