mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 06:58:31 -04:00
Merge 543a01d69f
into 3ca5ae2181
This commit is contained in:
commit
84d3cfce83
2 changed files with 115 additions and 17 deletions
|
@ -74,7 +74,9 @@ pub async fn start_tunnels(config: Config, bus: Bus) -> anyhow::Result<()> {
|
|||
|
||||
// Start TCP Virtual Interface
|
||||
let port_forwards = config.port_forwards.clone();
|
||||
let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip);
|
||||
let remote_port_forwards = config.remote_port_forwards.clone();
|
||||
let iface =
|
||||
TcpVirtualInterface::new(port_forwards, remote_port_forwards, bus, config.source_peer_ip);
|
||||
tokio::spawn(async move { iface.poll_loop(device).await });
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,12 @@ use std::{
|
|||
collections::{HashMap, HashSet, VecDeque},
|
||||
net::IpAddr,
|
||||
time::Duration,
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
net::TcpStream,
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
const MAX_PACKET: usize = 65536;
|
||||
|
@ -25,6 +31,7 @@ const MAX_PACKET: usize = 65536;
|
|||
pub struct TcpVirtualInterface {
|
||||
source_peer_ip: IpAddr,
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
remote_port_forwards: Vec<PortForwardConfig>,
|
||||
bus: Bus,
|
||||
sockets: SocketSet<'static>,
|
||||
}
|
||||
|
@ -32,30 +39,38 @@ pub struct TcpVirtualInterface {
|
|||
impl TcpVirtualInterface {
|
||||
/// Initialize the parameters for a new virtual interface.
|
||||
/// Use the `poll_loop()` future to start the virtual interface poll loop.
|
||||
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
|
||||
pub fn new(
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
remote_port_forwards: Vec<PortForwardConfig>,
|
||||
bus: Bus,
|
||||
source_peer_ip: IpAddr
|
||||
) -> Self {
|
||||
Self {
|
||||
port_forwards: port_forwards
|
||||
.into_iter()
|
||||
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
|
||||
.collect(),
|
||||
remote_port_forwards: remote_port_forwards
|
||||
.into_iter()
|
||||
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
|
||||
.collect(),
|
||||
source_peer_ip,
|
||||
bus,
|
||||
sockets: SocketSet::new([]),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<tcp::Socket<'static>> {
|
||||
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
|
||||
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
|
||||
|
||||
let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
|
||||
let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
|
||||
fn new_server_socket(remote_port_forward: PortForwardConfig) -> anyhow::Result<tcp::Socket<'static>> {
|
||||
let rx_data = vec![0u8; MAX_PACKET];
|
||||
let tx_data = vec![0u8; MAX_PACKET];
|
||||
let tcp_rx_buffer = tcp::SocketBuffer::new(rx_data);
|
||||
let tcp_tx_buffer = tcp::SocketBuffer::new(tx_data);
|
||||
let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer);
|
||||
|
||||
socket
|
||||
.listen((
|
||||
IpAddress::from(port_forward.destination.ip()),
|
||||
port_forward.destination.port(),
|
||||
IpAddress::from(remote_port_forward.source.ip()),
|
||||
remote_port_forward.source.port(),
|
||||
))
|
||||
.context("Virtual server socket failed to listen")?;
|
||||
|
||||
|
@ -101,10 +116,16 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
|
|||
});
|
||||
});
|
||||
|
||||
// Create virtual server for each port forward
|
||||
for port_forward in self.port_forwards.iter() {
|
||||
let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?;
|
||||
self.sockets.add(server_socket);
|
||||
let mut server_handle_map: HashMap<SocketHandle, (Arc<Mutex<TcpStream>>, PortForwardConfig)> =
|
||||
HashMap::new();
|
||||
|
||||
// Create virtual server for each remote port forward
|
||||
for remote_port_forward in self.remote_port_forwards.iter() {
|
||||
let (handle, stream) =
|
||||
connect_to_local_service(&mut self.sockets, remote_port_forward).await?;
|
||||
|
||||
server_handle_map
|
||||
.insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward.clone()));
|
||||
}
|
||||
|
||||
// The next time to poll the interface. Can be None for instant poll.
|
||||
|
@ -121,9 +142,9 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
|
|||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = match (next_poll, port_client_handle_map.len()) {
|
||||
(None, 0) => tokio::time::sleep(Duration::MAX),
|
||||
(None, _) => tokio::time::sleep(Duration::ZERO),
|
||||
_ = match (next_poll, port_client_handle_map.is_empty() && server_handle_map.is_empty()) {
|
||||
(None, true) => tokio::time::sleep(Duration::MAX),
|
||||
(None, false) => tokio::time::sleep(Duration::ZERO),
|
||||
(Some(until), _) => tokio::time::sleep_until(until),
|
||||
} => {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
|
@ -142,6 +163,25 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
|
|||
}
|
||||
});
|
||||
|
||||
let mut updated_server_handle_map = HashMap::new();
|
||||
|
||||
for (server_handle, (stream, remote_port_forward)) in server_handle_map.drain() {
|
||||
let server_socket = self.sockets.get_mut::<tcp::Socket>(server_handle);
|
||||
if server_socket.state() == tcp::State::Closed {
|
||||
self.sockets.remove(server_handle);
|
||||
|
||||
// Recreate listening socket and TCP stream
|
||||
let (handle, stream) =
|
||||
connect_to_local_service(&mut self.sockets, &remote_port_forward).await?;
|
||||
updated_server_handle_map
|
||||
.insert(handle, (Arc::new(Mutex::new(stream)), remote_port_forward));
|
||||
} else {
|
||||
updated_server_handle_map
|
||||
.insert(server_handle, (stream, remote_port_forward));
|
||||
}
|
||||
}
|
||||
server_handle_map = updated_server_handle_map;
|
||||
|
||||
if iface.poll(loop_start, &mut device, &mut self.sockets) == PollResult::SocketStateChanged {
|
||||
log::trace!("TCP virtual interface polled some packets to be processed");
|
||||
}
|
||||
|
@ -189,6 +229,47 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
|
|||
}
|
||||
}
|
||||
|
||||
for (server_handle, (stream, _)) in server_handle_map.iter() {
|
||||
let server_socket = self.sockets.get_mut::<tcp::Socket>(*server_handle);
|
||||
|
||||
if server_socket.state() == tcp::State::CloseWait {
|
||||
server_socket.close();
|
||||
}
|
||||
|
||||
if server_socket.can_recv() {
|
||||
let data = server_socket.recv(|buffer| {
|
||||
(buffer.len(), Bytes::copy_from_slice(buffer))
|
||||
})?;
|
||||
|
||||
if !data.is_empty() {
|
||||
let stream = Arc::clone(stream);
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream.lock().await;
|
||||
log::trace!("Forwarding remote data to listening local service");
|
||||
if let Err(e) = stream.write(&data).await {
|
||||
error!("Failed to forward to local service: {}", e);
|
||||
}
|
||||
if let Err(e) = stream.flush().await {
|
||||
error!("Failed to flush to local stream: {}", e);
|
||||
}
|
||||
if let Err(e) = stream.shutdown().await {
|
||||
error!("Failed to shutdown local stream: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if server_socket.can_send() {
|
||||
let stream = stream.lock().await;
|
||||
let mut buf = [0u8; 1024];
|
||||
if let Ok(n) = stream.try_read(&mut buf) {
|
||||
if n > 0 {
|
||||
let _ = server_socket.send_slice(&buf[..n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
|
||||
next_poll = match iface.poll_delay(loop_start, &self.sockets) {
|
||||
Some(smoltcp::time::Duration::ZERO) => None,
|
||||
|
@ -255,3 +336,18 @@ const fn addr_length(addr: &IpAddress) -> u8 {
|
|||
IpVersion::Ipv6 => 128,
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a local service via tcp using the provided port forward config.
|
||||
/// Returns both the socket handle and the TcpStream.
|
||||
async fn connect_to_local_service(
|
||||
sockets: &mut SocketSet<'static>,
|
||||
remote_port_forward: &PortForwardConfig,
|
||||
) -> anyhow::Result<(SocketHandle, TcpStream)> {
|
||||
let server_socket = TcpVirtualInterface::new_server_socket(*remote_port_forward)?;
|
||||
let handle = sockets.add(server_socket);
|
||||
let stream = TcpStream::connect(remote_port_forward.destination)
|
||||
.await
|
||||
.context("Failed to connect to listening locally listening server.")?;
|
||||
|
||||
Ok((handle, stream))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue