This commit is contained in:
Louis Travaux 2025-05-20 14:01:16 +00:00 committed by GitHub
commit 84d3cfce83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 115 additions and 17 deletions

View file

@ -74,7 +74,9 @@ pub async fn start_tunnels(config: Config, bus: Bus) -> anyhow::Result<()> {
// Start TCP Virtual Interface // Start TCP Virtual Interface
let port_forwards = config.port_forwards.clone(); let port_forwards = config.port_forwards.clone();
let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip); let remote_port_forwards = config.remote_port_forwards.clone();
let iface =
TcpVirtualInterface::new(port_forwards, remote_port_forwards, bus, config.source_peer_ip);
tokio::spawn(async move { iface.poll_loop(device).await }); tokio::spawn(async move { iface.poll_loop(device).await });
} }

View file

@ -17,6 +17,12 @@ use std::{
collections::{HashMap, HashSet, VecDeque}, collections::{HashMap, HashSet, VecDeque},
net::IpAddr, net::IpAddr,
time::Duration, time::Duration,
sync::Arc,
};
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
sync::Mutex,
}; };
const MAX_PACKET: usize = 65536; const MAX_PACKET: usize = 65536;
@ -25,6 +31,7 @@ const MAX_PACKET: usize = 65536;
pub struct TcpVirtualInterface { pub struct TcpVirtualInterface {
source_peer_ip: IpAddr, source_peer_ip: IpAddr,
port_forwards: Vec<PortForwardConfig>, port_forwards: Vec<PortForwardConfig>,
remote_port_forwards: Vec<PortForwardConfig>,
bus: Bus, bus: Bus,
sockets: SocketSet<'static>, sockets: SocketSet<'static>,
} }
@ -32,30 +39,38 @@ pub struct TcpVirtualInterface {
impl TcpVirtualInterface { impl TcpVirtualInterface {
/// Initialize the parameters for a new virtual interface. /// Initialize the parameters for a new virtual interface.
/// Use the `poll_loop()` future to start the virtual interface poll loop. /// Use the `poll_loop()` future to start the virtual interface poll loop.
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self { pub fn new(
port_forwards: Vec<PortForwardConfig>,
remote_port_forwards: Vec<PortForwardConfig>,
bus: Bus,
source_peer_ip: IpAddr
) -> Self {
Self { Self {
port_forwards: port_forwards port_forwards: port_forwards
.into_iter() .into_iter()
.filter(|f| matches!(f.protocol, PortProtocol::Tcp)) .filter(|f| matches!(f.protocol, PortProtocol::Tcp))
.collect(), .collect(),
remote_port_forwards: remote_port_forwards
.into_iter()
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
.collect(),
source_peer_ip, source_peer_ip,
bus, bus,
sockets: SocketSet::new([]), sockets: SocketSet::new([]),
} }
} }
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<tcp::Socket<'static>> { fn new_server_socket(remote_port_forward: PortForwardConfig) -> anyhow::Result<tcp::Socket<'static>> {
static mut TCP_SERVER_RX_DATA: [u8; 0] = []; let rx_data = vec![0u8; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; 0] = []; let tx_data = vec![0u8; MAX_PACKET];
let tcp_rx_buffer = tcp::SocketBuffer::new(rx_data);
let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); let tcp_tx_buffer = tcp::SocketBuffer::new(tx_data);
let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer);
socket socket
.listen(( .listen((
IpAddress::from(port_forward.destination.ip()), IpAddress::from(remote_port_forward.source.ip()),
port_forward.destination.port(), remote_port_forward.source.port(),
)) ))
.context("Virtual server socket failed to listen")?; .context("Virtual server socket failed to listen")?;
@ -101,10 +116,16 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
}); });
}); });
// Create virtual server for each port forward let mut server_handle_map: HashMap<SocketHandle, (Arc<Mutex<TcpStream>>, PortForwardConfig)> =
for port_forward in self.port_forwards.iter() { HashMap::new();
let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?;
self.sockets.add(server_socket); // Create virtual server for each remote port forward
for remote_port_forward in self.remote_port_forwards.iter() {
let (handle, stream) =
connect_to_local_service(&mut self.sockets, remote_port_forward).await?;
server_handle_map
.insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward.clone()));
} }
// The next time to poll the interface. Can be None for instant poll. // The next time to poll the interface. Can be None for instant poll.
@ -121,9 +142,9 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
loop { loop {
tokio::select! { tokio::select! {
_ = match (next_poll, port_client_handle_map.len()) { _ = match (next_poll, port_client_handle_map.is_empty() && server_handle_map.is_empty()) {
(None, 0) => tokio::time::sleep(Duration::MAX), (None, true) => tokio::time::sleep(Duration::MAX),
(None, _) => tokio::time::sleep(Duration::ZERO), (None, false) => tokio::time::sleep(Duration::ZERO),
(Some(until), _) => tokio::time::sleep_until(until), (Some(until), _) => tokio::time::sleep_until(until),
} => { } => {
let loop_start = smoltcp::time::Instant::now(); let loop_start = smoltcp::time::Instant::now();
@ -142,6 +163,25 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
} }
}); });
let mut updated_server_handle_map = HashMap::new();
for (server_handle, (stream, remote_port_forward)) in server_handle_map.drain() {
let server_socket = self.sockets.get_mut::<tcp::Socket>(server_handle);
if server_socket.state() == tcp::State::Closed {
self.sockets.remove(server_handle);
// Recreate listening socket and TCP stream
let (handle, stream) =
connect_to_local_service(&mut self.sockets, &remote_port_forward).await?;
updated_server_handle_map
.insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward));
} else {
updated_server_handle_map
.insert(server_handle, (stream, remote_port_forward));
}
}
server_handle_map = updated_server_handle_map;
if iface.poll(loop_start, &mut device, &mut self.sockets) == PollResult::SocketStateChanged { if iface.poll(loop_start, &mut device, &mut self.sockets) == PollResult::SocketStateChanged {
log::trace!("TCP virtual interface polled some packets to be processed"); log::trace!("TCP virtual interface polled some packets to be processed");
} }
@ -189,6 +229,47 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
} }
} }
for (server_handle, (stream, _)) in server_handle_map.iter() {
let server_socket = self.sockets.get_mut::<tcp::Socket>(*server_handle);
if server_socket.state() == tcp::State::CloseWait {
server_socket.close();
}
if server_socket.can_recv() {
let data = server_socket.recv(|buffer| {
(buffer.len(), Bytes::copy_from_slice(buffer))
})?;
if !data.is_empty() {
let stream = Arc::clone(stream);
tokio::spawn(async move {
let mut stream = stream.lock().await;
log::trace!("Forwarding remote data to listening local service");
if let Err(e) = stream.write(&data).await {
error!("Failed to forward to local service: {}", e);
}
if let Err(e) = stream.flush().await {
error!("Failed to flush to local stream: {}", e);
}
if let Err(e) = stream.shutdown().await {
error!("Failed to shutdown local stream: {}", e);
}
});
}
}
if server_socket.can_send() {
let stream = stream.lock().await;
let mut buf = [0u8; 1024];
if let Ok(n) = stream.try_read(&mut buf) {
if n > 0 {
let _ = server_socket.send_slice(&buf[..n]);
}
}
}
}
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls) // The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
next_poll = match iface.poll_delay(loop_start, &self.sockets) { next_poll = match iface.poll_delay(loop_start, &self.sockets) {
Some(smoltcp::time::Duration::ZERO) => None, Some(smoltcp::time::Duration::ZERO) => None,
@ -255,3 +336,18 @@ const fn addr_length(addr: &IpAddress) -> u8 {
IpVersion::Ipv6 => 128, IpVersion::Ipv6 => 128,
} }
} }
/// Connect to a local service via tcp using the provided port forward config.
/// Returns both the socket handle and the TcpStream.
async fn connect_to_local_service(
sockets: &mut SocketSet<'static>,
remote_port_forward: &PortForwardConfig,
) -> anyhow::Result<(SocketHandle, TcpStream)> {
let server_socket = TcpVirtualInterface::new_server_socket(*remote_port_forward)?;
let handle = sockets.add(server_socket);
let stream = TcpStream::connect(remote_port_forward.destination)
.await
.context("Failed to connect to listening locally listening server.")?;
Ok((handle, stream))
}