Merge pull request #28 from aramperes/bus-based

This commit is contained in:
Aram 🍐 2022-01-08 03:42:29 -05:00 committed by GitHub
commit 5b388f2ea3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 771 additions and 780 deletions

View file

@ -235,7 +235,7 @@ fn is_file_insecurely_readable(path: &str) -> Option<(bool, bool)> {
}
#[cfg(not(unix))]
fn is_file_insecurely_readable(path: &str) -> Option<(bool, bool)> {
fn is_file_insecurely_readable(_path: &str) -> Option<(bool, bool)> {
// No good way to determine permissions on non-Unix target
None
}
@ -399,9 +399,12 @@ impl Display for PortForwardConfig {
}
}
/// Layer 7 protocols for ports.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub enum PortProtocol {
/// TCP
Tcp,
/// UDP
Udp,
}

150
src/events.rs Normal file
View file

@ -0,0 +1,150 @@
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use crate::config::PortForwardConfig;
use crate::virtual_iface::VirtualPort;
use crate::PortProtocol;
/// Events that go on the bus between the local server, smoltcp, and WireGuard.
#[derive(Debug, Clone)]
pub enum Event {
/// Dumb event with no data.
Dumb,
/// A new connection with the local server was initiated, and the given virtual port was assigned.
ClientConnectionInitiated(PortForwardConfig, VirtualPort),
/// A connection was dropped from the pool and should be closed in all interfaces.
ClientConnectionDropped(VirtualPort),
/// Data received by the local server that should be sent to the virtual server.
LocalData(PortForwardConfig, VirtualPort, Vec<u8>),
/// Data received by the remote server that should be sent to the local client.
RemoteData(VirtualPort, Vec<u8>),
/// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device.
InboundInternetPacket(PortProtocol, Vec<u8>),
/// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device.
OutboundInternetPacket(Vec<u8>),
/// Notifies that a virtual device read an IP packet.
VirtualDeviceFed(PortProtocol),
}
#[derive(Clone)]
pub struct Bus {
counter: Arc<AtomicU32>,
bus: Arc<tokio::sync::broadcast::Sender<(u32, Event)>>,
}
impl Bus {
/// Creates a new event bus.
pub fn new() -> Self {
let (bus, _) = tokio::sync::broadcast::channel(1000);
let bus = Arc::new(bus);
let counter = Arc::new(AtomicU32::default());
Self { bus, counter }
}
/// Creates a new endpoint on the event bus.
pub fn new_endpoint(&self) -> BusEndpoint {
let id = self.counter.fetch_add(1, Ordering::Relaxed);
let tx = (*self.bus).clone();
let rx = self.bus.subscribe();
let tx = BusSender { id, tx };
BusEndpoint { id, tx, rx }
}
}
impl Default for Bus {
fn default() -> Self {
Self::new()
}
}
pub struct BusEndpoint {
id: u32,
tx: BusSender,
rx: tokio::sync::broadcast::Receiver<(u32, Event)>,
}
impl BusEndpoint {
/// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself.
pub fn send(&self, event: Event) {
self.tx.send(event)
}
/// Returns the unique sequential ID of this endpoint.
pub fn id(&self) -> u32 {
self.id
}
/// Awaits the next `Event` on the bus to be read.
pub async fn recv(&mut self) -> Event {
loop {
match self.rx.recv().await {
Ok((id, event)) => {
if id == self.id {
// If the event was sent by this endpoint, it is skipped
continue;
} else {
trace!("#{} <- {:?}", self.id, event);
return event;
}
}
Err(_) => {
error!("Failed to read event bus from endpoint #{}", self.id);
return futures::future::pending().await;
}
}
}
}
/// Creates a new sender for this endpoint that can be cloned.
pub fn sender(&self) -> BusSender {
self.tx.clone()
}
}
#[derive(Clone)]
pub struct BusSender {
id: u32,
tx: tokio::sync::broadcast::Sender<(u32, Event)>,
}
impl BusSender {
/// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself.
pub fn send(&self, event: Event) {
trace!("#{} -> {:?}", self.id, event);
match self.tx.send((self.id, event)) {
Ok(_) => {}
Err(_) => error!("Failed to send event to bus from endpoint #{}", self.id),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_bus() {
let bus = Bus::new();
let mut endpoint_1 = bus.new_endpoint();
let mut endpoint_2 = bus.new_endpoint();
let mut endpoint_3 = bus.new_endpoint();
assert_eq!(endpoint_1.id(), 0);
assert_eq!(endpoint_2.id(), 1);
assert_eq!(endpoint_3.id(), 2);
endpoint_1.send(Event::Dumb);
let recv_2 = endpoint_2.recv().await;
let recv_3 = endpoint_3.recv().await;
assert!(matches!(recv_2, Event::Dumb));
assert!(matches!(recv_3, Event::Dumb));
endpoint_2.send(Event::Dumb);
let recv_1 = endpoint_1.recv().await;
let recv_3 = endpoint_3.recv().await;
assert!(matches!(recv_1, Event::Dumb));
assert!(matches!(recv_3, Event::Dumb));
}
}

View file

@ -1,35 +0,0 @@
use crate::virtual_device::VirtualIpDevice;
use crate::wg::WireGuardTunnel;
use smoltcp::iface::InterfaceBuilder;
use std::sync::Arc;
use tokio::time::Duration;
/// A repeating task that processes unroutable IP packets.
pub async fn run_ip_sink_interface(wg: Arc<WireGuardTunnel>) -> ! {
// Initialize interface
let device = VirtualIpDevice::new_sink(wg)
.await
.expect("Failed to initialize VirtualIpDevice for sink interface");
// No sockets on sink interface
let mut sockets: [_; 0] = Default::default();
let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..])
.ip_addrs([])
.finalize();
loop {
let loop_start = smoltcp::time::Instant::now();
match virtual_interface.poll(loop_start) {
Ok(processed) if processed => {
trace!("[SINK] Virtual interface polled some packets to be processed",);
tokio::time::sleep(Duration::from_millis(1)).await;
}
Err(e) => {
error!("[SINK] Virtual interface poll error: {:?}", e);
}
_ => {
tokio::time::sleep(Duration::from_millis(5)).await;
}
}
}
}

View file

@ -5,13 +5,18 @@ use std::sync::Arc;
use anyhow::Context;
use crate::config::Config;
use crate::config::{Config, PortProtocol};
use crate::events::Bus;
use crate::tunnel::tcp::TcpPortPool;
use crate::tunnel::udp::UdpPortPool;
use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::tcp::TcpVirtualInterface;
use crate::virtual_iface::udp::UdpVirtualInterface;
use crate::virtual_iface::VirtualInterfacePoll;
use crate::wg::WireGuardTunnel;
pub mod config;
pub mod ip_sink;
pub mod events;
pub mod tunnel;
pub mod virtual_device;
pub mod virtual_iface;
@ -30,7 +35,9 @@ async fn main() -> anyhow::Result<()> {
let tcp_port_pool = TcpPortPool::new();
let udp_port_pool = UdpPortPool::new();
let wg = WireGuardTunnel::new(&config)
let bus = Bus::default();
let wg = WireGuardTunnel::new(&config, bus.clone())
.await
.with_context(|| "Failed to initialize WireGuard tunnel")?;
let wg = Arc::new(wg);
@ -48,9 +55,41 @@ async fn main() -> anyhow::Result<()> {
}
{
// Start IP sink task for incoming IP packets
// Start production task for WireGuard
let wg = wg.clone();
tokio::spawn(async move { ip_sink::run_ip_sink_interface(wg).await });
tokio::spawn(async move { wg.produce_task().await });
}
if config
.port_forwards
.iter()
.any(|pf| pf.protocol == PortProtocol::Tcp)
{
// TCP device
let bus = bus.clone();
let device =
VirtualIpDevice::new(PortProtocol::Tcp, bus.clone(), config.max_transmission_unit);
// Start TCP Virtual Interface
let port_forwards = config.port_forwards.clone();
let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip);
tokio::spawn(async move { iface.poll_loop(device).await });
}
if config
.port_forwards
.iter()
.any(|pf| pf.protocol == PortProtocol::Udp)
{
// UDP device
let bus = bus.clone();
let device =
VirtualIpDevice::new(PortProtocol::Udp, bus.clone(), config.max_transmission_unit);
// Start UDP Virtual Interface
let port_forwards = config.port_forwards.clone();
let iface = UdpVirtualInterface::new(port_forwards, bus, config.source_peer_ip);
tokio::spawn(async move { iface.poll_loop(device).await });
}
{
@ -59,10 +98,18 @@ async fn main() -> anyhow::Result<()> {
port_forwards
.into_iter()
.map(|pf| (pf, wg.clone(), tcp_port_pool.clone(), udp_port_pool.clone()))
.for_each(move |(pf, wg, tcp_port_pool, udp_port_pool)| {
.map(|pf| {
(
pf,
wg.clone(),
tcp_port_pool.clone(),
udp_port_pool.clone(),
bus.clone(),
)
})
.for_each(move |(pf, wg, tcp_port_pool, udp_port_pool, bus)| {
tokio::spawn(async move {
tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg)
tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg, bus)
.await
.unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e))
});
@ -73,7 +120,7 @@ async fn main() -> anyhow::Result<()> {
}
fn init_logger(config: &Config) -> anyhow::Result<()> {
let mut builder = pretty_env_logger::formatted_builder();
let mut builder = pretty_env_logger::formatted_timed_builder();
builder.parse_filters(&config.log);
builder
.try_init()

View file

@ -2,6 +2,7 @@ use std::net::IpAddr;
use std::sync::Arc;
use crate::config::{PortForwardConfig, PortProtocol};
use crate::events::Bus;
use crate::tunnel::tcp::TcpPortPool;
use crate::tunnel::udp::UdpPortPool;
use crate::wg::WireGuardTunnel;
@ -16,6 +17,7 @@ pub async fn port_forward(
tcp_port_pool: TcpPortPool,
udp_port_pool: UdpPortPool,
wg: Arc<WireGuardTunnel>,
bus: Bus,
) -> anyhow::Result<()> {
info!(
"Tunneling {} [{}]->[{}] (via [{}] as peer {})",
@ -27,7 +29,7 @@ pub async fn port_forward(
);
match port_forward.protocol {
PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await,
PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, wg).await,
PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, bus).await,
PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, bus).await,
}
}

View file

@ -1,17 +1,17 @@
use crate::config::{PortForwardConfig, PortProtocol};
use crate::virtual_iface::tcp::TcpVirtualInterface;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel;
use crate::virtual_iface::VirtualPort;
use anyhow::Context;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use std::ops::Range;
use std::time::Duration;
use crate::events::{Bus, Event};
use rand::seq::SliceRandom;
use rand::thread_rng;
use tokio::io::AsyncWriteExt;
const MAX_PACKET: usize = 65536;
const MIN_PORT: u16 = 1000;
@ -22,14 +22,13 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
pub async fn tcp_proxy_server(
port_forward: PortForwardConfig,
port_pool: TcpPortPool,
wg: Arc<WireGuardTunnel>,
bus: Bus,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(port_forward.source)
.await
.with_context(|| "Failed to listen on TCP proxy server")?;
loop {
let wg = wg.clone();
let port_pool = port_pool.clone();
let (socket, peer_addr) = listener
.accept()
@ -52,10 +51,10 @@ pub async fn tcp_proxy_server(
info!("[{}] Incoming connection from {}", virtual_port, peer_addr);
let bus = bus.clone();
tokio::spawn(async move {
let port_pool = port_pool.clone();
let result =
handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await;
let result = handle_tcp_proxy_connection(socket, virtual_port, port_forward, bus).await;
if let Err(e) = result {
error!(
@ -66,8 +65,7 @@ pub async fn tcp_proxy_server(
info!("[{}] Connection closed by client", virtual_port);
}
// Release port when connection drops
wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp));
tokio::time::sleep(Duration::from_millis(100)).await; // Make sure the other tasks have time to process the event
port_pool.release(virtual_port).await;
});
}
@ -75,72 +73,26 @@ pub async fn tcp_proxy_server(
/// Handles a new TCP connection with its assigned virtual port.
async fn handle_tcp_proxy_connection(
socket: TcpStream,
virtual_port: u16,
mut socket: TcpStream,
virtual_port: VirtualPort,
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
bus: Bus,
) -> anyhow::Result<()> {
// Abort signal for stopping the Virtual Interface
let abort = Arc::new(AtomicBool::new(false));
// Signals that the Virtual Client is ready to send data
let (virtual_client_ready_tx, virtual_client_ready_rx) = tokio::sync::oneshot::channel::<()>();
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
// to the real client.
let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000);
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
// virtual interface (virtual server socket).
let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000);
// Spawn virtual interface
{
let abort = abort.clone();
let virtual_interface = TcpVirtualInterface::new(
virtual_port,
port_forward,
wg,
abort.clone(),
data_to_real_client_tx,
data_to_virtual_server_rx,
virtual_client_ready_tx,
);
tokio::spawn(async move {
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
error!("Virtual interface poll loop failed unexpectedly: {}", e);
abort.store(true, Ordering::Relaxed);
})
});
}
// Wait for virtual client to be ready.
virtual_client_ready_rx
.await
.with_context(|| "Virtual client dropped before being ready.")?;
trace!("[{}] Virtual client is ready to send data", virtual_port);
let mut endpoint = bus.new_endpoint();
endpoint.send(Event::ClientConnectionInitiated(port_forward, virtual_port));
let mut buffer = Vec::with_capacity(MAX_PACKET);
loop {
tokio::select! {
readable_result = socket.readable() => {
match readable_result {
Ok(_) => {
// Buffer for the individual TCP segment.
let mut buffer = Vec::with_capacity(MAX_PACKET);
match socket.try_read_buf(&mut buffer) {
Ok(size) if size > 0 => {
let data = &buffer[..size];
debug!(
"[{}] Read {} bytes of TCP data from real client",
virtual_port, size
);
if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await {
error!(
"[{}] Failed to dispatch data to virtual interface: {:?}",
virtual_port, e
);
}
let data = Vec::from(&buffer[..size]);
endpoint.send(Event::LocalData(port_forward, virtual_port, data));
// Reset buffer
buffer.clear();
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
@ -163,43 +115,32 @@ async fn handle_tcp_proxy_connection(
}
}
}
data_recv_result = data_to_real_client_rx.recv() => {
match data_recv_result {
Some(data) => match socket.try_write(&data) {
Ok(size) => {
debug!(
"[{}] Wrote {} bytes of TCP data to real client",
virtual_port, size
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if abort.load(Ordering::Relaxed) {
event = endpoint.recv() => {
match event {
Event::ClientConnectionDropped(e_vp) if e_vp == virtual_port => {
// This connection is supposed to be closed, stop the task.
break;
}
Event::RemoteData(e_vp, data) if e_vp == virtual_port => {
// Have remote data to send to the local client
let size = data.len();
match socket.write(&data).await {
Ok(size) => debug!("[{}] Sent {} bytes to local client", virtual_port, size),
Err(e) => {
error!("[{}] Failed to send {} bytes to local client: {:?}", virtual_port, size, e);
break;
} else {
continue;
}
}
Err(e) => {
error!(
"[{}] Failed to write to client TCP socket: {:?}",
virtual_port, e
);
}
},
None => {
if abort.load(Ordering::Relaxed) {
break;
} else {
continue;
}
},
}
_ => {}
}
}
}
}
trace!("[{}] TCP socket handler task terminated", virtual_port);
abort.store(true, Ordering::Relaxed);
// Notify other endpoints that this task has closed and no more data is to be sent to the local client
endpoint.send(Event::ClientConnectionDropped(virtual_port));
Ok(())
}
@ -230,19 +171,19 @@ impl TcpPortPool {
}
/// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity).
pub async fn next(&self) -> anyhow::Result<u16> {
pub async fn next(&self) -> anyhow::Result<VirtualPort> {
let mut inner = self.inner.write().await;
let port = inner
.queue
.pop_front()
.with_context(|| "TCP virtual port pool is exhausted")?;
Ok(port)
Ok(VirtualPort::new(port, PortProtocol::Tcp))
}
/// Releases a port back into the pool.
pub async fn release(&self, port: u16) {
pub async fn release(&self, port: VirtualPort) {
let mut inner = self.inner.write().await;
inner.queue.push_back(port);
inner.queue.push_back(port.num());
}
}

View file

@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use crate::events::{Bus, Event};
use anyhow::Context;
use priority_queue::double_priority_queue::DoublePriorityQueue;
use priority_queue::priority_queue::PriorityQueue;
@ -30,61 +31,24 @@ const UDP_TIMEOUT_SECONDS: u64 = 60;
/// TODO: Make this configurable by the CLI
const PORTS_PER_IP: usize = 100;
/// Starts the server that listens on UDP datagrams.
pub async fn udp_proxy_server(
port_forward: PortForwardConfig,
port_pool: UdpPortPool,
wg: Arc<WireGuardTunnel>,
bus: Bus,
) -> anyhow::Result<()> {
// Abort signal
let abort = Arc::new(AtomicBool::new(false));
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
// to the real client.
let (data_to_real_client_tx, mut data_to_real_client_rx) =
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
// virtual interface (virtual server socket).
let (data_to_virtual_server_tx, data_to_virtual_server_rx) =
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
{
// Spawn virtual interface
// Note: contrary to TCP, there is only one UDP virtual interface
let virtual_interface = UdpVirtualInterface::new(
port_forward,
wg,
data_to_real_client_tx,
data_to_virtual_server_rx,
);
let abort = abort.clone();
tokio::spawn(async move {
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
error!("Virtual interface poll loop failed unexpectedly: {}", e);
abort.store(true, Ordering::Relaxed);
});
});
}
let mut endpoint = bus.new_endpoint();
let socket = UdpSocket::bind(port_forward.source)
.await
.with_context(|| "Failed to bind on UDP proxy address")?;
let mut buffer = [0u8; MAX_PACKET];
loop {
if abort.load(Ordering::Relaxed) {
break;
}
tokio::select! {
to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => {
match to_send_result {
Ok(Some((port, data))) => {
data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| {
error!(
"Failed to dispatch data to UDP virtual interface: {:?}",
e
);
});
endpoint.send(Event::LocalData(port_forward, port, data));
}
Ok(None) => {
continue;
@ -98,18 +62,18 @@ pub async fn udp_proxy_server(
}
}
}
data_recv_result = data_to_real_client_rx.recv() => {
if let Some((port, data)) = data_recv_result {
if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await {
if let Err(e) = socket.send_to(&data, peer_addr).await {
event = endpoint.recv() => {
if let Event::RemoteData(port, data) = event {
if let Some(peer) = port_pool.get_peer_addr(port).await {
if let Err(e) = socket.send_to(&data, peer).await {
error!(
"[{}] Failed to send UDP datagram to real client ({}): {:?}",
port,
peer_addr,
peer,
e,
);
}
port_pool.update_last_transmit(port.0).await;
port_pool.update_last_transmit(port).await;
}
}
}
@ -141,14 +105,13 @@ async fn next_udp_datagram(
return Ok(None);
}
};
let port = VirtualPort(port, PortProtocol::Udp);
debug!(
"[{}] Received datagram of {} bytes from {}",
port, size, peer_addr
);
port_pool.update_last_transmit(port.0).await;
port_pool.update_last_transmit(port).await;
let data = buffer[..size].to_vec();
Ok(Some((port, data)))
@ -181,14 +144,14 @@ impl UdpPortPool {
}
/// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity).
pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result<u16> {
pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result<VirtualPort> {
// A port found to be reused. This is outside of the block because the read lock cannot be upgraded to a write lock.
let mut port_reuse: Option<u16> = None;
{
let inner = self.inner.read().await;
if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) {
return Ok(*port);
return Ok(VirtualPort::new(*port, PortProtocol::Udp));
}
// Count how many ports are being used by the peer IP
@ -240,26 +203,26 @@ impl UdpPortPool {
inner.port_by_peer_addr.insert(peer_addr, port);
inner.peer_addr_by_port.insert(port, peer_addr);
Ok(port)
Ok(VirtualPort::new(port, PortProtocol::Udp))
}
/// Notify that the given virtual port has received or transmitted a UDP datagram.
pub async fn update_last_transmit(&self, port: u16) {
pub async fn update_last_transmit(&self, port: VirtualPort) {
let mut inner = self.inner.write().await;
if let Some(peer) = inner.peer_addr_by_port.get(&port).copied() {
if let Some(peer) = inner.peer_addr_by_port.get(&port.num()).copied() {
let mut pq: &mut DoublePriorityQueue<u16, Instant> = inner
.peer_port_usage
.entry(peer.ip())
.or_insert_with(Default::default);
pq.push(port, Instant::now());
pq.push(port.num(), Instant::now());
}
let mut pq: &mut DoublePriorityQueue<u16, Instant> = &mut inner.port_usage;
pq.push(port, Instant::now());
pq.push(port.num(), Instant::now());
}
pub async fn get_peer_addr(&self, port: u16) -> Option<SocketAddr> {
pub async fn get_peer_addr(&self, port: VirtualPort) -> Option<SocketAddr> {
let inner = self.inner.read().await;
inner.peer_addr_by_port.get(&port).copied()
inner.peer_addr_by_port.get(&port.num()).copied()
}
}

View file

@ -1,45 +1,51 @@
use crate::virtual_iface::VirtualPort;
use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
use anyhow::Context;
use crate::config::PortProtocol;
use crate::events::{BusSender, Event};
use crate::Bus;
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant;
use std::sync::Arc;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint
/// are made available to this device using a channel receiver. IP packets sent from this device
/// are asynchronously sent out to the WireGuard tunnel.
/// A virtual device that processes IP packets through smoltcp and WireGuard.
pub struct VirtualIpDevice {
/// Tunnel to send IP packets to.
wg: Arc<WireGuardTunnel>,
/// Max transmission unit (bytes)
max_transmission_unit: usize,
/// Channel receiver for received IP packets.
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
bus_sender: BusSender,
/// Local queue for packets received from the bus that need to go through the smoltcp interface.
process_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
}
impl VirtualIpDevice {
/// Initializes a new virtual IP device.
pub fn new(
wg: Arc<WireGuardTunnel>,
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
) -> Self {
Self { wg, ip_dispatch_rx }
}
pub fn new(protocol: PortProtocol, bus: Bus, max_transmission_unit: usize) -> Self {
let mut bus_endpoint = bus.new_endpoint();
let bus_sender = bus_endpoint.sender();
let process_queue = Arc::new(Mutex::new(VecDeque::new()));
/// Registers a virtual IP device for a single virtual client.
pub fn new_direct(virtual_port: VirtualPort, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
{
let process_queue = process_queue.clone();
tokio::spawn(async move {
loop {
match bus_endpoint.recv().await {
Event::InboundInternetPacket(ip_proto, data) if ip_proto == protocol => {
let mut queue = process_queue
.lock()
.expect("Failed to acquire process queue lock");
queue.push_back(data);
bus_endpoint.send(Event::VirtualDeviceFed(ip_proto));
}
_ => {}
}
}
});
}
wg.register_virtual_interface(virtual_port, ip_dispatch_tx)
.with_context(|| "Failed to register IP dispatch for virtual interface")?;
Ok(Self { wg, ip_dispatch_rx })
}
pub async fn new_sink(wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let ip_dispatch_rx = wg
.register_sink_interface()
.await
.with_context(|| "Failed to register IP dispatch for sink virtual interface")?;
Ok(Self { wg, ip_dispatch_rx })
Self {
bus_sender,
process_queue,
max_transmission_unit,
}
}
}
@ -48,27 +54,34 @@ impl<'a> Device<'a> for VirtualIpDevice {
type TxToken = TxToken;
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
match self.ip_dispatch_rx.try_recv() {
Ok(buffer) => Some((
let next = {
let mut queue = self
.process_queue
.lock()
.expect("Failed to acquire process queue lock");
queue.pop_front()
};
match next {
Some(buffer) => Some((
Self::RxToken { buffer },
Self::TxToken {
wg: self.wg.clone(),
sender: self.bus_sender.clone(),
},
)),
Err(_) => None,
None => None,
}
}
fn transmit(&'a mut self) -> Option<Self::TxToken> {
Some(TxToken {
wg: self.wg.clone(),
sender: self.bus_sender.clone(),
})
}
fn capabilities(&self) -> DeviceCapabilities {
let mut cap = DeviceCapabilities::default();
cap.medium = Medium::Ip;
cap.max_transmission_unit = self.wg.max_transmission_unit;
cap.max_transmission_unit = self.max_transmission_unit;
cap
}
}
@ -89,7 +102,7 @@ impl smoltcp::phy::RxToken for RxToken {
#[doc(hidden)]
pub struct TxToken {
wg: Arc<WireGuardTunnel>,
sender: BusSender,
}
impl smoltcp::phy::TxToken for TxToken {
@ -100,14 +113,7 @@ impl smoltcp::phy::TxToken for TxToken {
let mut buffer = Vec::new();
buffer.resize(len, 0);
let result = f(&mut buffer);
tokio::spawn(async move {
match self.wg.send_ip_packet(&buffer).await {
Ok(_) => {}
Err(e) => {
error!("Failed to send IP packet to WireGuard endpoint: {:?}", e);
}
}
});
self.sender.send(Event::OutboundInternetPacket(buffer));
result
}
}

View file

@ -2,6 +2,7 @@ pub mod tcp;
pub mod udp;
use crate::config::PortProtocol;
use crate::VirtualIpDevice;
use async_trait::async_trait;
use std::fmt::{Display, Formatter};
@ -9,15 +10,56 @@ use std::fmt::{Display, Formatter};
pub trait VirtualInterfacePoll {
/// Initializes the virtual interface and processes incoming data to be dispatched
/// to the WireGuard tunnel and to the real client.
async fn poll_loop(mut self) -> anyhow::Result<()>;
async fn poll_loop(mut self, device: VirtualIpDevice) -> anyhow::Result<()>;
}
/// Virtual port.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct VirtualPort(pub u16, pub PortProtocol);
pub struct VirtualPort(u16, PortProtocol);
impl VirtualPort {
/// Create a new `VirtualPort` instance, with the given port number and associated protocol.
pub fn new(port: u16, proto: PortProtocol) -> Self {
VirtualPort(port, proto)
}
/// The port number
pub fn num(&self) -> u16 {
self.0
}
/// The protocol of this port.
pub fn proto(&self) -> PortProtocol {
self.1
}
}
impl From<VirtualPort> for u16 {
fn from(port: VirtualPort) -> Self {
port.num()
}
}
impl From<&VirtualPort> for u16 {
fn from(port: &VirtualPort) -> Self {
port.num()
}
}
impl From<VirtualPort> for PortProtocol {
fn from(port: VirtualPort) -> Self {
port.proto()
}
}
impl From<&VirtualPort> for PortProtocol {
fn from(port: &VirtualPort) -> Self {
port.proto()
}
}
impl Display for VirtualPort {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}:{}]", self.0, self.1)
write!(f, "[{}:{}]", self.num(), self.proto())
}
}

View file

@ -1,284 +1,250 @@
use crate::config::{PortForwardConfig, PortProtocol};
use crate::events::Event;
use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel;
use crate::Bus;
use anyhow::Context;
use async_trait::async_trait;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState};
use smoltcp::wire::{IpAddress, IpCidr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::IpAddr;
use std::time::Duration;
const MAX_PACKET: usize = 65536;
/// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa.
pub struct TcpVirtualInterface {
/// The virtual port assigned to the virtual client, used to
/// route Layer 4 segments/datagrams to and from the WireGuard tunnel.
virtual_port: u16,
/// The overall port-forward configuration: used for the destination address (on which
/// the virtual server listens) and the protocol in use.
port_forward: PortForwardConfig,
/// The WireGuard tunnel to send IP packets to.
wg: Arc<WireGuardTunnel>,
/// Abort signal to shutdown the virtual interface and its parent task.
abort: Arc<AtomicBool>,
/// Channel sender for pushing Layer 7 data back to the real client.
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
/// Channel receiver for processing Layer 7 data through the virtual interface.
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
/// One-shot sender to notify the parent task that the virtual client is ready to send Layer 7 data.
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
source_peer_ip: IpAddr,
port_forwards: Vec<PortForwardConfig>,
bus: Bus,
}
impl TcpVirtualInterface {
/// Initialize the parameters for a new virtual interface.
/// Use the `poll_loop()` future to start the virtual interface poll loop.
pub fn new(
virtual_port: u16,
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
abort: Arc<AtomicBool>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
) -> Self {
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
Self {
virtual_port,
port_forward,
wg,
abort,
data_to_real_client_tx,
data_to_virtual_server_rx,
virtual_client_ready_tx,
port_forwards: port_forwards
.into_iter()
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
.collect(),
source_peer_ip,
bus,
}
}
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<TcpSocket<'static>> {
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.listen((
IpAddress::from(port_forward.destination.ip()),
port_forward.destination.port(),
))
.with_context(|| "Virtual server socket failed to listen")?;
Ok(socket)
}
fn new_client_socket() -> anyhow::Result<TcpSocket<'static>> {
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data);
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data);
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
Ok(socket)
}
fn addresses(&self) -> Vec<IpCidr> {
let mut addresses = HashSet::new();
addresses.insert(IpAddress::from(self.source_peer_ip));
for config in self.port_forwards.iter() {
addresses.insert(IpAddress::from(config.destination.ip()));
}
addresses
.into_iter()
.map(|addr| IpCidr::new(addr, 32))
.collect()
}
}
#[async_trait]
impl VirtualInterfacePoll for TcpVirtualInterface {
async fn poll_loop(self) -> anyhow::Result<()> {
let mut virtual_client_ready_tx = Some(self.virtual_client_ready_tx);
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
let source_peer_ip = self.wg.source_peer_ip;
async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> {
// Create CIDR block for source peer IP + each port forward IP
let addresses = self.addresses();
// Create a device and interface to simulate IP packets
// In essence:
// * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client'
// * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel
// * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface'
// * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded)
// * The TCP data read by the 'virtual client' is sent to the 'real' TCP client
// Consumer for IP packets to send through the virtual interface
// Initialize the interface
let device =
VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg)
.with_context(|| "Failed to initialize TCP VirtualIpDevice")?;
// there are always 2 sockets: 1 virtual client and 1 virtual server.
let mut sockets: [_; 2] = Default::default();
let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..])
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(self.port_forward.destination.ip()), 32),
])
// Create virtual interface (contains smoltcp state machine)
let mut iface = InterfaceBuilder::new(device, vec![])
.ip_addrs(addresses)
.finalize();
// Server socket: this is a placeholder for the interface to route new connections to.
let server_socket: anyhow::Result<TcpSocket> = {
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
// Create virtual server for each port forward
for port_forward in self.port_forwards.iter() {
let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?;
iface.add_socket(server_socket);
}
socket
.listen((
IpAddress::from(self.port_forward.destination.ip()),
self.port_forward.destination.port(),
))
.with_context(|| "Virtual server socket failed to listen")?;
// The next time to poll the interface. Can be None for instant poll.
let mut next_poll: Option<tokio::time::Instant> = None;
Ok(socket)
};
// Bus endpoint to read events
let mut endpoint = self.bus.new_endpoint();
let client_socket: anyhow::Result<TcpSocket> = {
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data);
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data);
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
Ok(socket)
};
// Maps virtual port to its client socket handle
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
let _server_handle = virtual_interface.add_socket(server_socket?);
let client_handle = virtual_interface.add_socket(client_socket?);
// Any data that wasn't sent because it was over the sending buffer limit
let mut tx_extra = Vec::new();
// Counts the connection attempts by the virtual client
let mut connection_attempts = 0;
// Whether the client has successfully connected before. Prevents the case of connecting again.
let mut has_connected = false;
// Data packets to send from a virtual client
let mut send_queue: HashMap<VirtualPort, VecDeque<Vec<u8>>> = HashMap::new();
loop {
let loop_start = smoltcp::time::Instant::now();
tokio::select! {
_ = match (next_poll, port_client_handle_map.len()) {
(None, 0) => tokio::time::sleep(Duration::MAX),
(None, _) => tokio::time::sleep(Duration::ZERO),
(Some(until), _) => tokio::time::sleep_until(until),
} => {
let loop_start = smoltcp::time::Instant::now();
// Shutdown occurs when the real client closes the connection,
// or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore.
// One last poll-loop iteration is executed so that the RST segment can be dispatched.
let shutdown = self.abort.load(Ordering::Relaxed);
if shutdown {
// Shutdown: sends a RST packet.
trace!("[{}] Shutting down virtual interface", self.virtual_port);
let client_socket = virtual_interface.get_socket::<TcpSocket>(client_handle);
client_socket.abort();
}
match virtual_interface.poll(loop_start) {
Ok(processed) if processed => {
trace!(
"[{}] Virtual interface polled some packets to be processed",
self.virtual_port
);
}
Err(e) => {
error!(
"[{}] Virtual interface poll error: {:?}",
self.virtual_port, e
);
}
_ => {}
}
{
let (client_socket, context) =
virtual_interface.get_socket_and_context::<TcpSocket>(client_handle);
if !shutdown && client_socket.state() == TcpState::Closed && !has_connected {
// Not shutting down, but the client socket is closed, and the client never successfully connected.
if connection_attempts < 10 {
// Try to connect
client_socket
.connect(
context,
(
IpAddress::from(self.port_forward.destination.ip()),
self.port_forward.destination.port(),
),
(IpAddress::from(source_peer_ip), self.virtual_port),
)
.with_context(|| "Virtual server socket failed to listen")?;
if connection_attempts > 0 {
debug!(
"[{}] Virtual client retrying connection in 500ms",
self.virtual_port
);
// Not our first connection attempt, wait a little bit.
tokio::time::sleep(Duration::from_millis(500)).await;
match iface.poll(loop_start) {
Ok(processed) if processed => {
trace!("TCP virtual interface polled some packets to be processed");
}
} else {
// Too many connection attempts
self.abort.store(true, Ordering::Relaxed);
}
connection_attempts += 1;
continue;
}
if client_socket.state() == TcpState::Established {
// Prevent reconnection if the server later closes.
has_connected = true;
}
if client_socket.can_recv() {
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
Ok(data) => {
trace!(
"[{}] Virtual client received {} bytes of data",
self.virtual_port,
data.len()
);
// Send it to the real client
if let Err(e) = self.data_to_real_client_tx.send(data).await {
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", self.virtual_port, e);
}
}
Err(e) => {
error!(
"[{}] Failed to read from virtual client socket: {:?}",
self.virtual_port, e
);
}
}
}
if client_socket.can_send() {
if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() {
virtual_client_ready_tx
.send(())
.expect("Failed to notify real client that virtual client is ready");
Err(e) => error!("TCP virtual interface poll error: {:?}", e),
_ => {}
}
let mut to_transfer = None;
if tx_extra.is_empty() {
// The payload segment from the previous loop is complete,
// we can now read the next payload in the queue.
if let Ok(data) = data_to_virtual_server_rx.try_recv() {
to_transfer = Some(data);
} else if client_socket.state() == TcpState::CloseWait {
// No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN),
// the interface is shutdown.
trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", self.virtual_port);
self.abort.store(true, Ordering::Relaxed);
continue;
}
}
let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice();
if !to_transfer_slice.is_empty() {
let total = to_transfer_slice.len();
match client_socket.send_slice(to_transfer_slice) {
Ok(sent) => {
trace!(
"[{}] Sent {}/{} bytes via virtual client socket",
self.virtual_port,
sent,
total,
);
tx_extra = Vec::from(&to_transfer_slice[sent..total]);
}
Err(e) => {
error!(
"[{}] Failed to send slice via virtual client socket: {:?}",
self.virtual_port, e
);
// Find client socket send data to
for (virtual_port, client_handle) in port_client_handle_map.iter() {
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
if client_socket.can_send() {
if let Some(send_queue) = send_queue.get_mut(virtual_port) {
let to_transfer = send_queue.pop_front();
if let Some(to_transfer_slice) = to_transfer.as_deref() {
let total = to_transfer_slice.len();
match client_socket.send_slice(to_transfer_slice) {
Ok(sent) => {
if sent < total {
// Sometimes only a subset is sent, so the rest needs to be sent on the next poll
let tx_extra = Vec::from(&to_transfer_slice[sent..total]);
send_queue.push_front(tx_extra);
}
}
Err(e) => {
error!(
"Failed to send slice via virtual client socket: {:?}", e
);
}
}
break;
} else if client_socket.state() == TcpState::CloseWait {
client_socket.close();
break;
}
}
}
}
}
}
if shutdown {
break;
}
// Find client socket recv data from
for (virtual_port, client_handle) in port_client_handle_map.iter() {
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
if client_socket.can_recv() {
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
Ok(data) => {
if !data.is_empty() {
endpoint.send(Event::RemoteData(*virtual_port, data));
break;
}
}
Err(e) => {
error!(
"Failed to read from virtual client socket: {:?}", e
);
}
}
}
}
match virtual_interface.poll_delay(loop_start) {
Some(smoltcp::time::Duration::ZERO) => {
continue;
// Find closed sockets
port_client_handle_map.retain(|virtual_port, client_handle| {
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
if client_socket.state() == TcpState::Closed {
endpoint.send(Event::ClientConnectionDropped(*virtual_port));
send_queue.remove(virtual_port);
iface.remove_socket(*client_handle);
false
} else {
// Not closed, retain
true
}
});
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
next_poll = match iface.poll_delay(loop_start) {
Some(smoltcp::time::Duration::ZERO) => None,
Some(delay) => {
trace!("TCP Virtual interface delayed next poll by {}", delay);
Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis()))
},
None => None,
};
}
_ => {
tokio::time::sleep(Duration::from_millis(1)).await;
event = endpoint.recv() => {
match event {
Event::ClientConnectionInitiated(port_forward, virtual_port) => {
let client_socket = TcpVirtualInterface::new_client_socket()?;
let client_handle = iface.add_socket(client_socket);
// Add handle to map
port_client_handle_map.insert(virtual_port, client_handle);
send_queue.insert(virtual_port, VecDeque::new());
let (client_socket, context) = iface.get_socket_and_context::<TcpSocket>(client_handle);
client_socket
.connect(
context,
(
IpAddress::from(port_forward.destination.ip()),
port_forward.destination.port(),
),
(IpAddress::from(self.source_peer_ip), virtual_port.num()),
)
.with_context(|| "Virtual server socket failed to listen")?;
next_poll = None;
}
Event::ClientConnectionDropped(virtual_port) => {
if let Some(client_handle) = port_client_handle_map.get(&virtual_port) {
let client_handle = *client_handle;
port_client_handle_map.remove(&virtual_port);
send_queue.remove(&virtual_port);
let client_socket = iface.get_socket::<TcpSocket>(client_handle);
client_socket.close();
next_poll = None;
}
}
Event::LocalData(_, virtual_port, data) if send_queue.contains_key(&virtual_port) => {
if let Some(send_queue) = send_queue.get_mut(&virtual_port) {
send_queue.push_back(data);
next_poll = None;
}
}
Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Tcp => {
next_poll = None;
}
_ => {}
}
}
}
}
trace!("[{}] Virtual interface task terminated", self.virtual_port);
self.abort.store(true, Ordering::Relaxed);
Ok(())
}
}

View file

@ -1,112 +1,134 @@
use anyhow::Context;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#![allow(dead_code)]
use anyhow::Context;
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::IpAddr;
use crate::events::Event;
use crate::{Bus, PortProtocol};
use async_trait::async_trait;
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer};
use smoltcp::wire::{IpAddress, IpCidr};
use std::time::Duration;
use crate::config::PortForwardConfig;
use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
const MAX_PACKET: usize = 65536;
pub struct UdpVirtualInterface {
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
source_peer_ip: IpAddr,
port_forwards: Vec<PortForwardConfig>,
bus: Bus,
}
impl UdpVirtualInterface {
pub fn new(
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
) -> Self {
/// Initialize the parameters for a new virtual interface.
/// Use the `poll_loop()` future to start the virtual interface poll loop.
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
Self {
port_forward,
wg,
data_to_real_client_tx,
data_to_virtual_server_rx,
port_forwards: port_forwards
.into_iter()
.filter(|f| matches!(f.protocol, PortProtocol::Udp))
.collect(),
source_peer_ip,
bus,
}
}
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<UdpSocket<'static>> {
static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = [];
static mut UDP_SERVER_RX_DATA: [u8; 0] = [];
static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = [];
static mut UDP_SERVER_TX_DATA: [u8; 0] = [];
let udp_rx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe {
&mut UDP_SERVER_RX_DATA[..]
});
let udp_tx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe {
&mut UDP_SERVER_TX_DATA[..]
});
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
socket
.bind((
IpAddress::from(port_forward.destination.ip()),
port_forward.destination.port(),
))
.with_context(|| "UDP virtual server socket failed to bind")?;
Ok(socket)
}
fn new_client_socket(
source_peer_ip: IpAddr,
client_port: VirtualPort,
) -> anyhow::Result<UdpSocket<'static>> {
let rx_meta = vec![UdpPacketMetadata::EMPTY; 10];
let tx_meta = vec![UdpPacketMetadata::EMPTY; 10];
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data);
let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data);
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
socket
.bind((IpAddress::from(source_peer_ip), client_port.num()))
.with_context(|| "UDP virtual client failed to bind")?;
Ok(socket)
}
fn addresses(&self) -> Vec<IpCidr> {
let mut addresses = HashSet::new();
addresses.insert(IpAddress::from(self.source_peer_ip));
for config in self.port_forwards.iter() {
addresses.insert(IpAddress::from(config.destination.ip()));
}
addresses
.into_iter()
.map(|addr| IpCidr::new(addr, 32))
.collect()
}
}
#[async_trait]
impl VirtualInterfacePoll for UdpVirtualInterface {
async fn poll_loop(self) -> anyhow::Result<()> {
// Data receiver to dispatch using virtual client sockets
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> {
// Create CIDR block for source peer IP + each port forward IP
let addresses = self.addresses();
// The IP to bind client sockets to
let source_peer_ip = self.wg.source_peer_ip;
// The IP/port to bind the server socket to
let destination = self.port_forward.destination;
// Initialize a channel for IP packets.
// The "base transmitted" is cloned so that each virtual port can register a sender in the tunnel.
// The receiver is given to the device so that the Virtual Interface can process incoming IP packets from the tunnel.
let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx);
let mut virtual_interface = InterfaceBuilder::new(device, vec![])
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(source_peer_ip.into(), 32),
IpCidr::new(destination.ip().into(), 32),
])
// Create virtual interface (contains smoltcp state machine)
let mut iface = InterfaceBuilder::new(device, vec![])
.ip_addrs(addresses)
.finalize();
// Server socket: this is a placeholder for the interface.
let server_socket: anyhow::Result<UdpSocket> = {
static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = [];
static mut UDP_SERVER_RX_DATA: [u8; 0] = [];
static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = [];
static mut UDP_SERVER_TX_DATA: [u8; 0] = [];
let udp_rx_buffer =
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe {
&mut UDP_SERVER_RX_DATA[..]
});
let udp_tx_buffer =
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe {
&mut UDP_SERVER_TX_DATA[..]
});
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
// Create virtual server for each port forward
for port_forward in self.port_forwards.iter() {
let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?;
iface.add_socket(server_socket);
}
socket
.bind((IpAddress::from(destination.ip()), destination.port()))
.with_context(|| "UDP virtual server socket failed to listen")?;
Ok(socket)
};
let _server_handle = virtual_interface.add_socket(server_socket?);
// A map of virtual port to client socket.
let mut client_sockets: HashMap<VirtualPort, SocketHandle> = HashMap::new();
// The next instant required to poll the virtual interface
// None means "immediate poll required".
// The next time to poll the interface. Can be None for instant poll.
let mut next_poll: Option<tokio::time::Instant> = None;
// Bus endpoint to read events
let mut endpoint = self.bus.new_endpoint();
// Maps virtual port to its client socket handle
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
// Data packets to send from a virtual client
let mut send_queue: HashMap<VirtualPort, VecDeque<(PortForwardConfig, Vec<u8>)>> =
HashMap::new();
loop {
let wg = self.wg.clone();
tokio::select! {
// Wait the recommended amount of time by smoltcp, and poll again.
_ = match next_poll {
None => tokio::time::sleep(Duration::ZERO),
Some(until) => tokio::time::sleep_until(until)
_ = match (next_poll, port_client_handle_map.len()) {
(None, 0) => tokio::time::sleep(Duration::MAX),
(None, _) => tokio::time::sleep(Duration::ZERO),
(Some(until), _) => tokio::time::sleep_until(until),
} => {
let loop_start = smoltcp::time::Instant::now();
match virtual_interface.poll(loop_start) {
match iface.poll(loop_start) {
Ok(processed) if processed => {
trace!("UDP virtual interface polled some packets to be processed");
}
@ -114,84 +136,81 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
_ => {}
}
// Loop through each client socket and check if there is any data to send back
// to the real client.
for (virtual_port, client_socket_handle) in client_sockets.iter() {
let client_socket = virtual_interface.get_socket::<UdpSocket>(*client_socket_handle);
match client_socket.recv() {
Ok((data, _peer)) => {
// Send the data back to the real client using MPSC channel
self.data_to_real_client_tx
.send((*virtual_port, data.to_vec()))
.await
.unwrap_or_else(|e| {
error!(
"[{}] Failed to dispatch data from virtual client to real client: {:?}",
virtual_port, e
);
});
}
Err(smoltcp::Error::Exhausted) => {}
Err(e) => {
error!(
"[{}] Failed to read from virtual client socket: {:?}",
virtual_port, e
);
// Find client socket send data to
for (virtual_port, client_handle) in port_client_handle_map.iter() {
let client_socket = iface.get_socket::<UdpSocket>(*client_handle);
if client_socket.can_send() {
if let Some(send_queue) = send_queue.get_mut(virtual_port) {
let to_transfer = send_queue.pop_front();
if let Some((port_forward, data)) = to_transfer {
client_socket
.send_slice(
&data,
(IpAddress::from(port_forward.destination.ip()), port_forward.destination.port()).into(),
)
.unwrap_or_else(|e| {
error!(
"[{}] Failed to send data to virtual server: {:?}",
virtual_port, e
);
});
break;
}
}
}
}
next_poll = match virtual_interface.poll_delay(loop_start) {
Some(smoltcp::time::Duration::ZERO) => None,
Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())),
None => None,
}
}
// Wait for data to be received from the real client
data_recv_result = data_to_virtual_server_rx.recv() => {
if let Some((client_port, data)) = data_recv_result {
// Register the socket in WireGuard Tunnel (overrides any previous registration as well)
wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone())
.unwrap_or_else(|e| {
error!(
"[{}] Failed to register UDP socket in WireGuard tunnel: {:?}",
client_port, e
);
});
let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| {
let rx_meta = vec![UdpPacketMetadata::EMPTY; 10];
let tx_meta = vec![UdpPacketMetadata::EMPTY; 10];
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data);
let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data);
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
socket
.bind((IpAddress::from(wg.source_peer_ip), client_port.0))
.unwrap_or_else(|e| {
// Find client socket recv data from
for (virtual_port, client_handle) in port_client_handle_map.iter() {
let client_socket = iface.get_socket::<UdpSocket>(*client_handle);
if client_socket.can_recv() {
match client_socket.recv() {
Ok((data, _peer)) => {
if !data.is_empty() {
endpoint.send(Event::RemoteData(*virtual_port, data.to_vec()));
break;
}
}
Err(e) => {
error!(
"[{}] UDP virtual client socket failed to bind: {:?}",
client_port, e
"Failed to read from virtual client socket: {:?}", e
);
});
}
}
}
}
virtual_interface.add_socket(socket)
});
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
next_poll = match iface.poll_delay(loop_start) {
Some(smoltcp::time::Duration::ZERO) => None,
Some(delay) => {
trace!("UDP Virtual interface delayed next poll by {}", delay);
Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis()))
},
None => None,
};
}
event = endpoint.recv() => {
match event {
Event::LocalData(port_forward, virtual_port, data) => {
if let Some(send_queue) = send_queue.get_mut(&virtual_port) {
// Client socket already exists
send_queue.push_back((port_forward, data));
} else {
// Client socket does not exist
let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?;
let client_handle = iface.add_socket(client_socket);
let client_socket = virtual_interface.get_socket::<UdpSocket>(*client_socket_handle);
client_socket
.send_slice(
&data,
(IpAddress::from(destination.ip()), destination.port()).into(),
)
.unwrap_or_else(|e| {
error!(
"[{}] Failed to send data to virtual server: {:?}",
client_port, e
);
});
// Add handle to map
port_client_handle_map.insert(virtual_port, client_handle);
send_queue.insert(virtual_port, VecDeque::from(vec![(port_forward, data)]));
}
next_poll = None;
}
Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Udp => {
next_poll = None;
}
_ => {}
}
}
}

181
src/wg.rs
View file

@ -1,15 +1,15 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::time::Duration;
use crate::Bus;
use anyhow::Context;
use boringtun::noise::{Tunn, TunnResult};
use log::Level;
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket};
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet};
use tokio::net::UdpSocket;
use tokio::sync::RwLock;
use crate::config::{Config, PortProtocol};
use crate::virtual_iface::VirtualPort;
use crate::events::Event;
/// The capacity of the channel for received IP packets.
pub const DISPATCH_CAPACITY: usize = 1_000;
@ -26,17 +26,13 @@ pub struct WireGuardTunnel {
udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP).
pub(crate) endpoint: SocketAddr,
/// Maps virtual ports to the corresponding IP packet dispatcher.
virtual_port_ip_tx: dashmap::DashMap<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
/// The max transmission unit for WireGuard.
pub(crate) max_transmission_unit: usize,
/// Event bus
bus: Bus,
}
impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config) -> anyhow::Result<Self> {
pub async fn new(config: &Config, bus: Bus) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?;
let endpoint = config.endpoint_addr;
@ -46,16 +42,13 @@ impl WireGuardTunnel {
})
.await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let virtual_port_ip_tx = Default::default();
Ok(Self {
source_peer_ip,
peer,
udp,
endpoint,
virtual_port_ip_tx,
sink_ip_tx: RwLock::new(None),
max_transmission_unit: config.max_transmission_unit,
bus,
})
}
@ -90,31 +83,20 @@ impl WireGuardTunnel {
Ok(())
}
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub fn register_virtual_interface(
&self,
virtual_port: VirtualPort,
sender: tokio::sync::mpsc::Sender<Vec<u8>>,
) -> anyhow::Result<()> {
self.virtual_port_ip_tx.insert(virtual_port, sender);
Ok(())
}
pub async fn produce_task(&self) -> ! {
trace!("Starting WireGuard production task");
let mut endpoint = self.bus.new_endpoint();
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub async fn register_sink_interface(
&self,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
let mut sink_ip_tx = self.sink_ip_tx.write().await;
*sink_ip_tx = Some(sender);
Ok(receiver)
}
/// Releases the virtual interface from IP dispatch.
pub fn release_virtual_interface(&self, virtual_port: VirtualPort) {
self.virtual_port_ip_tx.remove(&virtual_port);
loop {
if let Event::OutboundInternetPacket(data) = endpoint.recv().await {
match self.send_ip_packet(&data).await {
Ok(_) => {}
Err(e) => {
error!("{:?}", e);
}
}
}
}
}
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
@ -160,6 +142,7 @@ impl WireGuardTunnel {
/// decapsulates them, and dispatches newly received IP packets.
pub async fn consume_task(&self) -> ! {
trace!("Starting WireGuard consumption task");
let endpoint = self.bus.new_endpoint();
loop {
let mut recv_buf = [0u8; MAX_PACKET];
@ -212,38 +195,8 @@ impl WireGuardTunnel {
// For debugging purposes: parse packet
trace_ip_packet("Received IP packet", packet);
match self.route_ip_packet(packet) {
RouteResult::Dispatch(port) => {
let sender = self.virtual_port_ip_tx.get(&port);
if let Some(sender_guard) = sender {
let sender = sender_guard.value();
match sender.send(packet.to_vec()).await {
Ok(_) => {
trace!(
"Dispatched received IP packet to virtual port {}",
port
);
}
Err(e) => {
error!(
"Failed to dispatch received IP packet to virtual port {}: {}",
port, e
);
}
}
} else {
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
}
}
RouteResult::Sink => {
trace!("Sending unroutable IP packet received from WireGuard endpoint to sink interface");
self.route_ip_sink(packet).await.unwrap_or_else(|e| {
error!("Failed to send unroutable IP packet to sink: {:?}", e)
});
}
RouteResult::Drop => {
trace!("Dropped unroutable IP packet received from WireGuard endpoint");
}
if let Some(proto) = self.route_protocol(packet) {
endpoint.send(Event::InboundInternetPacket(proto, packet.into()));
}
}
_ => {}
@ -264,89 +217,32 @@ impl WireGuardTunnel {
.with_context(|| "Failed to initialize boringtun Tunn")
}
/// Makes a decision on the handling of an incoming IP packet.
fn route_ip_packet(&self, packet: &[u8]) -> RouteResult {
/// Determine the inner protocol of the incoming IP packet (TCP/UDP).
fn route_protocol(&self, packet: &[u8]) -> Option<PortProtocol> {
match IpVersion::of_packet(packet) {
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
.ok()
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() {
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
_ => None,
})
.flatten()
.unwrap_or(RouteResult::Drop),
.flatten(),
Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet)
.ok()
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.next_header() {
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
_ => None,
})
.flatten()
.unwrap_or(RouteResult::Drop),
_ => RouteResult::Drop,
}
}
/// Makes a decision on the handling of an incoming TCP segment.
fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
TcpPacket::new_checked(segment)
.ok()
.map(|tcp| {
if self
.virtual_port_ip_tx
.get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
.is_some()
{
RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
} else if tcp.rst() {
RouteResult::Drop
} else {
RouteResult::Sink
}
})
.unwrap_or(RouteResult::Drop)
}
/// Makes a decision on the handling of an incoming UDP datagram.
fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult {
UdpPacket::new_checked(datagram)
.ok()
.map(|udp| {
if self
.virtual_port_ip_tx
.get(&VirtualPort(udp.dst_port(), PortProtocol::Udp))
.is_some()
{
RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp))
} else {
RouteResult::Drop
}
})
.unwrap_or(RouteResult::Drop)
}
/// Route a packet to the IP sink interface.
async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> {
let ip_sink_tx = self.sink_ip_tx.read().await;
if let Some(ip_sink_tx) = &*ip_sink_tx {
ip_sink_tx
.send(packet.to_vec())
.await
.with_context(|| "Failed to dispatch IP packet to sink interface")
} else {
warn!(
"Could not dispatch unroutable IP packet to sink because interface is not active."
);
Ok(())
.flatten(),
_ => None,
}
}
}
@ -370,12 +266,3 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
}
}
}
enum RouteResult {
/// Dispatch the packet to the virtual port.
Dispatch(VirtualPort),
/// The packet is not routable, and should be sent to the sink interface.
Sink,
/// The packet is not routable, and can be safely ignored.
Drop,
}