UDP virtual interface skeleton

This commit is contained in:
Aram 🍐 2021-10-20 18:06:35 -04:00
parent cc91cce169
commit fb50ee7113
6 changed files with 203 additions and 35 deletions

View file

@ -1,6 +1,7 @@
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::ops::Range; use std::ops::Range;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -10,7 +11,8 @@ use rand::thread_rng;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::config::{PortForwardConfig, PortProtocol}; use crate::config::{PortForwardConfig, PortProtocol};
use crate::virtual_iface::VirtualPort; use crate::virtual_iface::udp::UdpVirtualInterface;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel; use crate::wg::WireGuardTunnel;
const MAX_PACKET: usize = 65536; const MAX_PACKET: usize = 65536;
@ -31,20 +33,98 @@ pub async fn udp_proxy_server(
port_pool: UdpPortPool, port_pool: UdpPortPool,
wg: Arc<WireGuardTunnel>, wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Abort signal
let abort = Arc::new(AtomicBool::new(false));
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
// to the real client.
let (data_to_real_client_tx, mut data_to_real_client_rx) =
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
// virtual interface (virtual server socket).
let (data_to_virtual_server_tx, data_to_virtual_server_rx) =
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
{
// Spawn virtual interface
// Note: contrary to TCP, there is only one UDP virtual interface
let virtual_interface = UdpVirtualInterface::new(
port_forward,
wg,
data_to_real_client_tx,
data_to_virtual_server_rx,
);
let abort = abort.clone();
tokio::spawn(async move {
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
error!("Virtual interface poll loop failed unexpectedly: {}", e);
abort.store(true, Ordering::Relaxed);
});
});
}
let socket = UdpSocket::bind(port_forward.source) let socket = UdpSocket::bind(port_forward.source)
.await .await
.with_context(|| "Failed to bind on UDP proxy address")?; .with_context(|| "Failed to bind on UDP proxy address")?;
let mut buffer = [0u8; MAX_PACKET]; let mut buffer = [0u8; MAX_PACKET];
loop { loop {
if abort.load(Ordering::Relaxed) {
break;
}
tokio::select! {
to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => {
match to_send_result {
Ok(Some((port, data))) => {
data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| {
error!(
"Failed to dispatch data to UDP virtual interface: {:?}",
e
);
});
}
Ok(None) => {
continue;
}
Err(e) => {
error!(
"Failed to read from client UDP socket: {:?}",
e
);
break;
}
}
}
data_recv_result = data_to_real_client_rx.recv() => {
if let Some((port, data)) = data_recv_result {
if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await {
if let Err(e) = socket.send_to(&data, peer_addr).await {
error!(
"[{}] Failed to send UDP datagram to real client ({}): {:?}",
port,
peer_addr,
e,
);
}
}
}
}
}
}
Ok(())
}
async fn next_udp_datagram(
socket: &UdpSocket,
buffer: &mut [u8],
port_pool: UdpPortPool,
) -> anyhow::Result<Option<(VirtualPort, Vec<u8>)>> {
let (size, peer_addr) = socket let (size, peer_addr) = socket
.recv_from(&mut buffer) .recv_from(buffer)
.await .await
.with_context(|| "Failed to accept incoming UDP datagram")?; .with_context(|| "Failed to accept incoming UDP datagram")?;
let _wg = wg.clone();
let _data = &buffer[..size].to_vec();
// Assign a 'virtual port': this is a unique port number used to route IP packets // Assign a 'virtual port': this is a unique port number used to route IP packets
// received from the WireGuard tunnel. It is the port number that the virtual client will // received from the WireGuard tunnel. It is the port number that the virtual client will
// listen on. // listen on.
@ -55,16 +135,18 @@ pub async fn udp_proxy_server(
"Failed to assign virtual port number for UDP datagram from [{}]: {:?}", "Failed to assign virtual port number for UDP datagram from [{}]: {:?}",
peer_addr, e peer_addr, e
); );
continue; return Ok(None);
} }
}; };
let port = VirtualPort(port, PortProtocol::Udp); let port = VirtualPort(port, PortProtocol::Udp);
debug!( debug!(
"[{}] Received datagram of {} bytes from {}", "[{}] Received datagram of {} bytes from {}",
port, size, peer_addr port, size, peer_addr
); );
}
let data = buffer[..size].to_vec();
Ok(Some((port, data)))
} }
/// A pool of virtual ports available for TCP connections. /// A pool of virtual ports available for TCP connections.
@ -108,8 +190,14 @@ impl UdpPortPool {
.pop_front() .pop_front()
.with_context(|| "UDP virtual port pool is exhausted")?; .with_context(|| "UDP virtual port pool is exhausted")?;
inner.port_by_peer_addr.insert(peer_addr, port); inner.port_by_peer_addr.insert(peer_addr, port);
inner.peer_addr_by_port.insert(port, peer_addr);
Ok(port) Ok(port)
} }
pub async fn get_peer_addr(&self, port: u16) -> Option<SocketAddr> {
let inner = self.inner.read().await;
inner.peer_addr_by_port.get(&port).copied()
}
} }
/// Non thread-safe inner logic for UDP port pool. /// Non thread-safe inner logic for UDP port pool.
@ -119,4 +207,6 @@ struct UdpPortPoolInner {
queue: VecDeque<u16>, queue: VecDeque<u16>,
/// The port assigned by peer IP/port. /// The port assigned by peer IP/port.
port_by_peer_addr: HashMap<SocketAddr, u16>, port_by_peer_addr: HashMap<SocketAddr, u16>,
/// The socket address assigned to a peer IP/port.
peer_addr_by_port: HashMap<u16, SocketAddr>,
} }

View file

@ -1,5 +1,5 @@
use crate::virtual_iface::VirtualPort; use crate::virtual_iface::VirtualPort;
use crate::wg::WireGuardTunnel; use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
use anyhow::Context; use anyhow::Context;
use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant; use smoltcp::time::Instant;
@ -16,9 +16,11 @@ pub struct VirtualIpDevice {
} }
impl VirtualIpDevice { impl VirtualIpDevice {
/// Registers a virtual IP device for a single virtual client.
pub fn new(virtual_port: VirtualPort, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> { pub fn new(virtual_port: VirtualPort, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let ip_dispatch_rx = wg let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
.register_virtual_interface(virtual_port)
wg.register_virtual_interface(virtual_port, ip_dispatch_tx)
.with_context(|| "Failed to register IP dispatch for virtual interface")?; .with_context(|| "Failed to register IP dispatch for virtual interface")?;
Ok(Self { wg, ip_dispatch_rx }) Ok(Self { wg, ip_dispatch_rx })

View file

@ -1,4 +1,5 @@
pub mod tcp; pub mod tcp;
pub mod udp;
use crate::config::PortProtocol; use crate::config::PortProtocol;
use async_trait::async_trait; use async_trait::async_trait;

View file

@ -266,8 +266,15 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
break; break;
} }
match virtual_interface.poll_delay(&socket_set, loop_start) {
Some(smoltcp::time::Duration::ZERO) => {
continue;
}
_ => {
tokio::time::sleep(Duration::from_millis(1)).await; tokio::time::sleep(Duration::from_millis(1)).await;
} }
}
}
trace!("[{}] Virtual interface task terminated", self.virtual_port); trace!("[{}] Virtual interface task terminated", self.virtual_port);
self.abort.store(true, Ordering::Relaxed); self.abort.store(true, Ordering::Relaxed);
Ok(()) Ok(())

68
src/virtual_iface/udp.rs Normal file
View file

@ -0,0 +1,68 @@
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use crate::config::PortForwardConfig;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel;
pub struct UdpVirtualInterface {
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
}
impl UdpVirtualInterface {
pub fn new(
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
) -> Self {
Self {
port_forward,
wg,
data_to_real_client_tx,
data_to_virtual_server_rx,
}
}
}
#[async_trait]
impl VirtualInterfacePoll for UdpVirtualInterface {
async fn poll_loop(self) -> anyhow::Result<()> {
// Data receiver to dispatch using virtual client sockets
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
// The IP to bind client sockets to
let _source_peer_ip = self.wg.source_peer_ip;
// The IP/port to bind the server socket to
let _destination = self.port_forward.destination;
loop {
let _loop_start = smoltcp::time::Instant::now();
// TODO: smoltcp UDP
if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() {
// TODO: Find the matching client socket and send
// Echo for now
self.data_to_real_client_tx
.send((client_port, data))
.await
.unwrap_or_else(|e| {
error!(
"[{}] Failed to dispatch data from virtual client to real client: {:?}",
client_port, e
);
});
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
// Ok(())
}
}

View file

@ -12,7 +12,7 @@ use crate::config::{Config, PortProtocol};
use crate::virtual_iface::VirtualPort; use crate::virtual_iface::VirtualPort;
/// The capacity of the channel for received IP packets. /// The capacity of the channel for received IP packets.
const DISPATCH_CAPACITY: usize = 1_000; pub const DISPATCH_CAPACITY: usize = 1_000;
const MAX_PACKET: usize = 65536; const MAX_PACKET: usize = 65536;
/// A WireGuard tunnel. Encapsulates and decapsulates IP packets /// A WireGuard tunnel. Encapsulates and decapsulates IP packets
@ -88,14 +88,14 @@ impl WireGuardTunnel {
pub fn register_virtual_interface( pub fn register_virtual_interface(
&self, &self,
virtual_port: VirtualPort, virtual_port: VirtualPort,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> { sender: tokio::sync::mpsc::Sender<Vec<u8>>,
) -> anyhow::Result<()> {
let existing = self.virtual_port_ip_tx.contains_key(&virtual_port); let existing = self.virtual_port_ip_tx.contains_key(&virtual_port);
if existing { if existing {
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
} else { } else {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
self.virtual_port_ip_tx.insert(virtual_port, sender); self.virtual_port_ip_tx.insert(virtual_port, sender);
Ok(receiver) Ok(())
} }
} }