mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 23:18:31 -04:00
281 lines
10 KiB
Rust
281 lines
10 KiB
Rust
use std::collections::{BTreeMap, HashMap, VecDeque};
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::ops::Range;
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use anyhow::Context;
|
|
use priority_queue::double_priority_queue::DoublePriorityQueue;
|
|
use priority_queue::priority_queue::PriorityQueue;
|
|
use rand::seq::SliceRandom;
|
|
use rand::thread_rng;
|
|
use tokio::net::UdpSocket;
|
|
|
|
use crate::config::{PortForwardConfig, PortProtocol};
|
|
use crate::virtual_iface::udp::UdpVirtualInterface;
|
|
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
|
use crate::wg::WireGuardTunnel;
|
|
|
|
const MAX_PACKET: usize = 65536;
|
|
const MIN_PORT: u16 = 1000;
|
|
const MAX_PORT: u16 = 60999;
|
|
const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
|
|
|
|
/// How long to keep the UDP peer address assigned to its virtual specified port, in seconds.
|
|
/// TODO: Make this configurable by the CLI
|
|
const UDP_TIMEOUT_SECONDS: u64 = 60;
|
|
|
|
/// To prevent port-flooding, we set a limit on the amount of open ports per IP address.
|
|
/// TODO: Make this configurable by the CLI
|
|
const PORTS_PER_IP: usize = 100;
|
|
|
|
pub async fn udp_proxy_server(
|
|
port_forward: PortForwardConfig,
|
|
port_pool: UdpPortPool,
|
|
wg: Arc<WireGuardTunnel>,
|
|
) -> anyhow::Result<()> {
|
|
// Abort signal
|
|
let abort = Arc::new(AtomicBool::new(false));
|
|
|
|
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
|
|
// to the real client.
|
|
let (data_to_real_client_tx, mut data_to_real_client_rx) =
|
|
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
|
|
|
|
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
|
|
// virtual interface (virtual server socket).
|
|
let (data_to_virtual_server_tx, data_to_virtual_server_rx) =
|
|
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
|
|
|
|
{
|
|
// Spawn virtual interface
|
|
// Note: contrary to TCP, there is only one UDP virtual interface
|
|
let virtual_interface = UdpVirtualInterface::new(
|
|
port_forward,
|
|
wg,
|
|
data_to_real_client_tx,
|
|
data_to_virtual_server_rx,
|
|
);
|
|
let abort = abort.clone();
|
|
tokio::spawn(async move {
|
|
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
|
|
error!("Virtual interface poll loop failed unexpectedly: {}", e);
|
|
abort.store(true, Ordering::Relaxed);
|
|
});
|
|
});
|
|
}
|
|
|
|
let socket = UdpSocket::bind(port_forward.source)
|
|
.await
|
|
.with_context(|| "Failed to bind on UDP proxy address")?;
|
|
|
|
let mut buffer = [0u8; MAX_PACKET];
|
|
loop {
|
|
if abort.load(Ordering::Relaxed) {
|
|
break;
|
|
}
|
|
tokio::select! {
|
|
to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => {
|
|
match to_send_result {
|
|
Ok(Some((port, data))) => {
|
|
data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| {
|
|
error!(
|
|
"Failed to dispatch data to UDP virtual interface: {:?}",
|
|
e
|
|
);
|
|
});
|
|
}
|
|
Ok(None) => {
|
|
continue;
|
|
}
|
|
Err(e) => {
|
|
error!(
|
|
"Failed to read from client UDP socket: {:?}",
|
|
e
|
|
);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
data_recv_result = data_to_real_client_rx.recv() => {
|
|
if let Some((port, data)) = data_recv_result {
|
|
if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await {
|
|
if let Err(e) = socket.send_to(&data, peer_addr).await {
|
|
error!(
|
|
"[{}] Failed to send UDP datagram to real client ({}): {:?}",
|
|
port,
|
|
peer_addr,
|
|
e,
|
|
);
|
|
}
|
|
port_pool.update_last_transmit(port.0).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn next_udp_datagram(
|
|
socket: &UdpSocket,
|
|
buffer: &mut [u8],
|
|
port_pool: UdpPortPool,
|
|
) -> anyhow::Result<Option<(VirtualPort, Vec<u8>)>> {
|
|
let (size, peer_addr) = socket
|
|
.recv_from(buffer)
|
|
.await
|
|
.with_context(|| "Failed to accept incoming UDP datagram")?;
|
|
|
|
// Assign a 'virtual port': this is a unique port number used to route IP packets
|
|
// received from the WireGuard tunnel. It is the port number that the virtual client will
|
|
// listen on.
|
|
let port = match port_pool.next(peer_addr).await {
|
|
Ok(port) => port,
|
|
Err(e) => {
|
|
error!(
|
|
"Failed to assign virtual port number for UDP datagram from [{}]: {:?}",
|
|
peer_addr, e
|
|
);
|
|
return Ok(None);
|
|
}
|
|
};
|
|
let port = VirtualPort(port, PortProtocol::Udp);
|
|
|
|
debug!(
|
|
"[{}] Received datagram of {} bytes from {}",
|
|
port, size, peer_addr
|
|
);
|
|
|
|
port_pool.update_last_transmit(port.0).await;
|
|
|
|
let data = buffer[..size].to_vec();
|
|
Ok(Some((port, data)))
|
|
}
|
|
|
|
/// A pool of virtual ports available for TCP connections.
|
|
#[derive(Clone)]
|
|
pub struct UdpPortPool {
|
|
inner: Arc<tokio::sync::RwLock<UdpPortPoolInner>>,
|
|
}
|
|
|
|
impl Default for UdpPortPool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl UdpPortPool {
|
|
/// Initializes a new pool of virtual ports.
|
|
pub fn new() -> Self {
|
|
let mut inner = UdpPortPoolInner::default();
|
|
let mut ports: Vec<u16> = PORT_RANGE.collect();
|
|
ports.shuffle(&mut thread_rng());
|
|
ports
|
|
.into_iter()
|
|
.for_each(|p| inner.queue.push_back(p) as ());
|
|
Self {
|
|
inner: Arc::new(tokio::sync::RwLock::new(inner)),
|
|
}
|
|
}
|
|
|
|
/// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity).
|
|
pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result<u16> {
|
|
// A port found to be reused. This is outside of the block because the read lock cannot be upgraded to a write lock.
|
|
let mut port_reuse: Option<u16> = None;
|
|
|
|
{
|
|
let inner = self.inner.read().await;
|
|
if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) {
|
|
return Ok(*port);
|
|
}
|
|
|
|
// Count how many ports are being used by the peer IP
|
|
let peer_ip = peer_addr.ip();
|
|
let peer_port_count = inner
|
|
.peer_port_usage
|
|
.get(&peer_ip)
|
|
.map(|v| v.len())
|
|
.unwrap_or_default();
|
|
|
|
if peer_port_count >= PORTS_PER_IP {
|
|
// Return least recently used port in this IP's pool
|
|
port_reuse = Some(
|
|
*(inner
|
|
.peer_port_usage
|
|
.get(&peer_ip)
|
|
.unwrap()
|
|
.peek_min()
|
|
.unwrap()
|
|
.0),
|
|
);
|
|
warn!(
|
|
"Peer [{}] is re-using active virtual port {} due to self-exhaustion.",
|
|
peer_addr,
|
|
port_reuse.unwrap()
|
|
);
|
|
}
|
|
}
|
|
|
|
let mut inner = self.inner.write().await;
|
|
|
|
let port = port_reuse
|
|
.or_else(|| inner.queue.pop_front())
|
|
.or_else(|| {
|
|
// If there is no port to reuse, and the port pool is exhausted, take the last recently used port overall,
|
|
// as long as the last transmission exceeds the deadline
|
|
let last: (&u16, &Instant) = inner.port_usage.peek_min().unwrap();
|
|
if Instant::now().duration_since(*last.1).as_secs() > UDP_TIMEOUT_SECONDS {
|
|
warn!(
|
|
"Peer [{}] is re-using inactive virtual port {} due to global exhaustion.",
|
|
peer_addr, last.0
|
|
);
|
|
Some(*last.0)
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.with_context(|| "virtual port pool is exhausted")?;
|
|
|
|
inner.port_by_peer_addr.insert(peer_addr, port);
|
|
inner.peer_addr_by_port.insert(port, peer_addr);
|
|
Ok(port)
|
|
}
|
|
|
|
/// Notify that the given virtual port has received or transmitted a UDP datagram.
|
|
pub async fn update_last_transmit(&self, port: u16) {
|
|
let mut inner = self.inner.write().await;
|
|
if let Some(peer) = inner.peer_addr_by_port.get(&port).copied() {
|
|
let mut pq: &mut DoublePriorityQueue<u16, Instant> = inner
|
|
.peer_port_usage
|
|
.entry(peer.ip())
|
|
.or_insert_with(Default::default);
|
|
pq.push(port, Instant::now());
|
|
}
|
|
let mut pq: &mut DoublePriorityQueue<u16, Instant> = &mut inner.port_usage;
|
|
pq.push(port, Instant::now());
|
|
}
|
|
|
|
pub async fn get_peer_addr(&self, port: u16) -> Option<SocketAddr> {
|
|
let inner = self.inner.read().await;
|
|
inner.peer_addr_by_port.get(&port).copied()
|
|
}
|
|
}
|
|
|
|
/// Non thread-safe inner logic for UDP port pool.
|
|
#[derive(Debug, Default)]
|
|
struct UdpPortPoolInner {
|
|
/// Remaining ports in the pool.
|
|
queue: VecDeque<u16>,
|
|
/// The port assigned by peer IP/port. This is used to lookup an existing virtual port
|
|
/// for an incoming UDP datagram.
|
|
port_by_peer_addr: HashMap<SocketAddr, u16>,
|
|
/// The socket address assigned to a peer IP/port. This is used to send a UDP datagram to
|
|
/// the real peer address, given the virtual port.
|
|
peer_addr_by_port: HashMap<u16, SocketAddr>,
|
|
/// Keeps an ordered map of the most recently used virtual ports by a peer (client) IP.
|
|
peer_port_usage: HashMap<IpAddr, DoublePriorityQueue<u16, Instant>>,
|
|
/// Keeps an ordered map of the most recently used virtual ports in general.
|
|
port_usage: DoublePriorityQueue<u16, Instant>,
|
|
}
|