mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 06:58:31 -04:00
Merge pull request #28 from aramperes/bus-based
This commit is contained in:
commit
5b388f2ea3
12 changed files with 771 additions and 780 deletions
|
@ -235,7 +235,7 @@ fn is_file_insecurely_readable(path: &str) -> Option<(bool, bool)> {
|
|||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn is_file_insecurely_readable(path: &str) -> Option<(bool, bool)> {
|
||||
fn is_file_insecurely_readable(_path: &str) -> Option<(bool, bool)> {
|
||||
// No good way to determine permissions on non-Unix target
|
||||
None
|
||||
}
|
||||
|
@ -399,9 +399,12 @@ impl Display for PortForwardConfig {
|
|||
}
|
||||
}
|
||||
|
||||
/// Layer 7 protocols for ports.
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)]
|
||||
pub enum PortProtocol {
|
||||
/// TCP
|
||||
Tcp,
|
||||
/// UDP
|
||||
Udp,
|
||||
}
|
||||
|
||||
|
|
150
src/events.rs
Normal file
150
src/events.rs
Normal file
|
@ -0,0 +1,150 @@
|
|||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PortForwardConfig;
|
||||
use crate::virtual_iface::VirtualPort;
|
||||
use crate::PortProtocol;
|
||||
|
||||
/// Events that go on the bus between the local server, smoltcp, and WireGuard.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
/// Dumb event with no data.
|
||||
Dumb,
|
||||
/// A new connection with the local server was initiated, and the given virtual port was assigned.
|
||||
ClientConnectionInitiated(PortForwardConfig, VirtualPort),
|
||||
/// A connection was dropped from the pool and should be closed in all interfaces.
|
||||
ClientConnectionDropped(VirtualPort),
|
||||
/// Data received by the local server that should be sent to the virtual server.
|
||||
LocalData(PortForwardConfig, VirtualPort, Vec<u8>),
|
||||
/// Data received by the remote server that should be sent to the local client.
|
||||
RemoteData(VirtualPort, Vec<u8>),
|
||||
/// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device.
|
||||
InboundInternetPacket(PortProtocol, Vec<u8>),
|
||||
/// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device.
|
||||
OutboundInternetPacket(Vec<u8>),
|
||||
/// Notifies that a virtual device read an IP packet.
|
||||
VirtualDeviceFed(PortProtocol),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Bus {
|
||||
counter: Arc<AtomicU32>,
|
||||
bus: Arc<tokio::sync::broadcast::Sender<(u32, Event)>>,
|
||||
}
|
||||
|
||||
impl Bus {
|
||||
/// Creates a new event bus.
|
||||
pub fn new() -> Self {
|
||||
let (bus, _) = tokio::sync::broadcast::channel(1000);
|
||||
let bus = Arc::new(bus);
|
||||
let counter = Arc::new(AtomicU32::default());
|
||||
Self { bus, counter }
|
||||
}
|
||||
|
||||
/// Creates a new endpoint on the event bus.
|
||||
pub fn new_endpoint(&self) -> BusEndpoint {
|
||||
let id = self.counter.fetch_add(1, Ordering::Relaxed);
|
||||
let tx = (*self.bus).clone();
|
||||
let rx = self.bus.subscribe();
|
||||
|
||||
let tx = BusSender { id, tx };
|
||||
BusEndpoint { id, tx, rx }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Bus {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BusEndpoint {
|
||||
id: u32,
|
||||
tx: BusSender,
|
||||
rx: tokio::sync::broadcast::Receiver<(u32, Event)>,
|
||||
}
|
||||
|
||||
impl BusEndpoint {
|
||||
/// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself.
|
||||
pub fn send(&self, event: Event) {
|
||||
self.tx.send(event)
|
||||
}
|
||||
|
||||
/// Returns the unique sequential ID of this endpoint.
|
||||
pub fn id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Awaits the next `Event` on the bus to be read.
|
||||
pub async fn recv(&mut self) -> Event {
|
||||
loop {
|
||||
match self.rx.recv().await {
|
||||
Ok((id, event)) => {
|
||||
if id == self.id {
|
||||
// If the event was sent by this endpoint, it is skipped
|
||||
continue;
|
||||
} else {
|
||||
trace!("#{} <- {:?}", self.id, event);
|
||||
return event;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Failed to read event bus from endpoint #{}", self.id);
|
||||
return futures::future::pending().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new sender for this endpoint that can be cloned.
|
||||
pub fn sender(&self) -> BusSender {
|
||||
self.tx.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BusSender {
|
||||
id: u32,
|
||||
tx: tokio::sync::broadcast::Sender<(u32, Event)>,
|
||||
}
|
||||
|
||||
impl BusSender {
|
||||
/// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself.
|
||||
pub fn send(&self, event: Event) {
|
||||
trace!("#{} -> {:?}", self.id, event);
|
||||
match self.tx.send((self.id, event)) {
|
||||
Ok(_) => {}
|
||||
Err(_) => error!("Failed to send event to bus from endpoint #{}", self.id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bus() {
|
||||
let bus = Bus::new();
|
||||
|
||||
let mut endpoint_1 = bus.new_endpoint();
|
||||
let mut endpoint_2 = bus.new_endpoint();
|
||||
let mut endpoint_3 = bus.new_endpoint();
|
||||
|
||||
assert_eq!(endpoint_1.id(), 0);
|
||||
assert_eq!(endpoint_2.id(), 1);
|
||||
assert_eq!(endpoint_3.id(), 2);
|
||||
|
||||
endpoint_1.send(Event::Dumb);
|
||||
let recv_2 = endpoint_2.recv().await;
|
||||
let recv_3 = endpoint_3.recv().await;
|
||||
assert!(matches!(recv_2, Event::Dumb));
|
||||
assert!(matches!(recv_3, Event::Dumb));
|
||||
|
||||
endpoint_2.send(Event::Dumb);
|
||||
let recv_1 = endpoint_1.recv().await;
|
||||
let recv_3 = endpoint_3.recv().await;
|
||||
assert!(matches!(recv_1, Event::Dumb));
|
||||
assert!(matches!(recv_3, Event::Dumb));
|
||||
}
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
use crate::virtual_device::VirtualIpDevice;
|
||||
use crate::wg::WireGuardTunnel;
|
||||
use smoltcp::iface::InterfaceBuilder;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Duration;
|
||||
|
||||
/// A repeating task that processes unroutable IP packets.
|
||||
pub async fn run_ip_sink_interface(wg: Arc<WireGuardTunnel>) -> ! {
|
||||
// Initialize interface
|
||||
let device = VirtualIpDevice::new_sink(wg)
|
||||
.await
|
||||
.expect("Failed to initialize VirtualIpDevice for sink interface");
|
||||
|
||||
// No sockets on sink interface
|
||||
let mut sockets: [_; 0] = Default::default();
|
||||
let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..])
|
||||
.ip_addrs([])
|
||||
.finalize();
|
||||
|
||||
loop {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
match virtual_interface.poll(loop_start) {
|
||||
Ok(processed) if processed => {
|
||||
trace!("[SINK] Virtual interface polled some packets to be processed",);
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("[SINK] Virtual interface poll error: {:?}", e);
|
||||
}
|
||||
_ => {
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
65
src/main.rs
65
src/main.rs
|
@ -5,13 +5,18 @@ use std::sync::Arc;
|
|||
|
||||
use anyhow::Context;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config::{Config, PortProtocol};
|
||||
use crate::events::Bus;
|
||||
use crate::tunnel::tcp::TcpPortPool;
|
||||
use crate::tunnel::udp::UdpPortPool;
|
||||
use crate::virtual_device::VirtualIpDevice;
|
||||
use crate::virtual_iface::tcp::TcpVirtualInterface;
|
||||
use crate::virtual_iface::udp::UdpVirtualInterface;
|
||||
use crate::virtual_iface::VirtualInterfacePoll;
|
||||
use crate::wg::WireGuardTunnel;
|
||||
|
||||
pub mod config;
|
||||
pub mod ip_sink;
|
||||
pub mod events;
|
||||
pub mod tunnel;
|
||||
pub mod virtual_device;
|
||||
pub mod virtual_iface;
|
||||
|
@ -30,7 +35,9 @@ async fn main() -> anyhow::Result<()> {
|
|||
let tcp_port_pool = TcpPortPool::new();
|
||||
let udp_port_pool = UdpPortPool::new();
|
||||
|
||||
let wg = WireGuardTunnel::new(&config)
|
||||
let bus = Bus::default();
|
||||
|
||||
let wg = WireGuardTunnel::new(&config, bus.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to initialize WireGuard tunnel")?;
|
||||
let wg = Arc::new(wg);
|
||||
|
@ -48,9 +55,41 @@ async fn main() -> anyhow::Result<()> {
|
|||
}
|
||||
|
||||
{
|
||||
// Start IP sink task for incoming IP packets
|
||||
// Start production task for WireGuard
|
||||
let wg = wg.clone();
|
||||
tokio::spawn(async move { ip_sink::run_ip_sink_interface(wg).await });
|
||||
tokio::spawn(async move { wg.produce_task().await });
|
||||
}
|
||||
|
||||
if config
|
||||
.port_forwards
|
||||
.iter()
|
||||
.any(|pf| pf.protocol == PortProtocol::Tcp)
|
||||
{
|
||||
// TCP device
|
||||
let bus = bus.clone();
|
||||
let device =
|
||||
VirtualIpDevice::new(PortProtocol::Tcp, bus.clone(), config.max_transmission_unit);
|
||||
|
||||
// Start TCP Virtual Interface
|
||||
let port_forwards = config.port_forwards.clone();
|
||||
let iface = TcpVirtualInterface::new(port_forwards, bus, config.source_peer_ip);
|
||||
tokio::spawn(async move { iface.poll_loop(device).await });
|
||||
}
|
||||
|
||||
if config
|
||||
.port_forwards
|
||||
.iter()
|
||||
.any(|pf| pf.protocol == PortProtocol::Udp)
|
||||
{
|
||||
// UDP device
|
||||
let bus = bus.clone();
|
||||
let device =
|
||||
VirtualIpDevice::new(PortProtocol::Udp, bus.clone(), config.max_transmission_unit);
|
||||
|
||||
// Start UDP Virtual Interface
|
||||
let port_forwards = config.port_forwards.clone();
|
||||
let iface = UdpVirtualInterface::new(port_forwards, bus, config.source_peer_ip);
|
||||
tokio::spawn(async move { iface.poll_loop(device).await });
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -59,10 +98,18 @@ async fn main() -> anyhow::Result<()> {
|
|||
|
||||
port_forwards
|
||||
.into_iter()
|
||||
.map(|pf| (pf, wg.clone(), tcp_port_pool.clone(), udp_port_pool.clone()))
|
||||
.for_each(move |(pf, wg, tcp_port_pool, udp_port_pool)| {
|
||||
.map(|pf| {
|
||||
(
|
||||
pf,
|
||||
wg.clone(),
|
||||
tcp_port_pool.clone(),
|
||||
udp_port_pool.clone(),
|
||||
bus.clone(),
|
||||
)
|
||||
})
|
||||
.for_each(move |(pf, wg, tcp_port_pool, udp_port_pool, bus)| {
|
||||
tokio::spawn(async move {
|
||||
tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg)
|
||||
tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg, bus)
|
||||
.await
|
||||
.unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e))
|
||||
});
|
||||
|
@ -73,7 +120,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
}
|
||||
|
||||
fn init_logger(config: &Config) -> anyhow::Result<()> {
|
||||
let mut builder = pretty_env_logger::formatted_builder();
|
||||
let mut builder = pretty_env_logger::formatted_timed_builder();
|
||||
builder.parse_filters(&config.log);
|
||||
builder
|
||||
.try_init()
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::net::IpAddr;
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::config::{PortForwardConfig, PortProtocol};
|
||||
use crate::events::Bus;
|
||||
use crate::tunnel::tcp::TcpPortPool;
|
||||
use crate::tunnel::udp::UdpPortPool;
|
||||
use crate::wg::WireGuardTunnel;
|
||||
|
@ -16,6 +17,7 @@ pub async fn port_forward(
|
|||
tcp_port_pool: TcpPortPool,
|
||||
udp_port_pool: UdpPortPool,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
bus: Bus,
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"Tunneling {} [{}]->[{}] (via [{}] as peer {})",
|
||||
|
@ -27,7 +29,7 @@ pub async fn port_forward(
|
|||
);
|
||||
|
||||
match port_forward.protocol {
|
||||
PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await,
|
||||
PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, wg).await,
|
||||
PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, bus).await,
|
||||
PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, bus).await,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
use crate::config::{PortForwardConfig, PortProtocol};
|
||||
use crate::virtual_iface::tcp::TcpVirtualInterface;
|
||||
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
||||
use crate::wg::WireGuardTunnel;
|
||||
use crate::virtual_iface::VirtualPort;
|
||||
use anyhow::Context;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use std::ops::Range;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::events::{Bus, Event};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
const MAX_PACKET: usize = 65536;
|
||||
const MIN_PORT: u16 = 1000;
|
||||
|
@ -22,14 +22,13 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
|
|||
pub async fn tcp_proxy_server(
|
||||
port_forward: PortForwardConfig,
|
||||
port_pool: TcpPortPool,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
bus: Bus,
|
||||
) -> anyhow::Result<()> {
|
||||
let listener = TcpListener::bind(port_forward.source)
|
||||
.await
|
||||
.with_context(|| "Failed to listen on TCP proxy server")?;
|
||||
|
||||
loop {
|
||||
let wg = wg.clone();
|
||||
let port_pool = port_pool.clone();
|
||||
let (socket, peer_addr) = listener
|
||||
.accept()
|
||||
|
@ -52,10 +51,10 @@ pub async fn tcp_proxy_server(
|
|||
|
||||
info!("[{}] Incoming connection from {}", virtual_port, peer_addr);
|
||||
|
||||
let bus = bus.clone();
|
||||
tokio::spawn(async move {
|
||||
let port_pool = port_pool.clone();
|
||||
let result =
|
||||
handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await;
|
||||
let result = handle_tcp_proxy_connection(socket, virtual_port, port_forward, bus).await;
|
||||
|
||||
if let Err(e) = result {
|
||||
error!(
|
||||
|
@ -66,8 +65,7 @@ pub async fn tcp_proxy_server(
|
|||
info!("[{}] Connection closed by client", virtual_port);
|
||||
}
|
||||
|
||||
// Release port when connection drops
|
||||
wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp));
|
||||
tokio::time::sleep(Duration::from_millis(100)).await; // Make sure the other tasks have time to process the event
|
||||
port_pool.release(virtual_port).await;
|
||||
});
|
||||
}
|
||||
|
@ -75,72 +73,26 @@ pub async fn tcp_proxy_server(
|
|||
|
||||
/// Handles a new TCP connection with its assigned virtual port.
|
||||
async fn handle_tcp_proxy_connection(
|
||||
socket: TcpStream,
|
||||
virtual_port: u16,
|
||||
mut socket: TcpStream,
|
||||
virtual_port: VirtualPort,
|
||||
port_forward: PortForwardConfig,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
bus: Bus,
|
||||
) -> anyhow::Result<()> {
|
||||
// Abort signal for stopping the Virtual Interface
|
||||
let abort = Arc::new(AtomicBool::new(false));
|
||||
|
||||
// Signals that the Virtual Client is ready to send data
|
||||
let (virtual_client_ready_tx, virtual_client_ready_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
|
||||
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
|
||||
// to the real client.
|
||||
let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000);
|
||||
|
||||
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
|
||||
// virtual interface (virtual server socket).
|
||||
let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000);
|
||||
|
||||
// Spawn virtual interface
|
||||
{
|
||||
let abort = abort.clone();
|
||||
let virtual_interface = TcpVirtualInterface::new(
|
||||
virtual_port,
|
||||
port_forward,
|
||||
wg,
|
||||
abort.clone(),
|
||||
data_to_real_client_tx,
|
||||
data_to_virtual_server_rx,
|
||||
virtual_client_ready_tx,
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
|
||||
error!("Virtual interface poll loop failed unexpectedly: {}", e);
|
||||
abort.store(true, Ordering::Relaxed);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for virtual client to be ready.
|
||||
virtual_client_ready_rx
|
||||
.await
|
||||
.with_context(|| "Virtual client dropped before being ready.")?;
|
||||
trace!("[{}] Virtual client is ready to send data", virtual_port);
|
||||
let mut endpoint = bus.new_endpoint();
|
||||
endpoint.send(Event::ClientConnectionInitiated(port_forward, virtual_port));
|
||||
|
||||
let mut buffer = Vec::with_capacity(MAX_PACKET);
|
||||
loop {
|
||||
tokio::select! {
|
||||
readable_result = socket.readable() => {
|
||||
match readable_result {
|
||||
Ok(_) => {
|
||||
// Buffer for the individual TCP segment.
|
||||
let mut buffer = Vec::with_capacity(MAX_PACKET);
|
||||
match socket.try_read_buf(&mut buffer) {
|
||||
Ok(size) if size > 0 => {
|
||||
let data = &buffer[..size];
|
||||
debug!(
|
||||
"[{}] Read {} bytes of TCP data from real client",
|
||||
virtual_port, size
|
||||
);
|
||||
if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await {
|
||||
error!(
|
||||
"[{}] Failed to dispatch data to virtual interface: {:?}",
|
||||
virtual_port, e
|
||||
);
|
||||
}
|
||||
let data = Vec::from(&buffer[..size]);
|
||||
endpoint.send(Event::LocalData(port_forward, virtual_port, data));
|
||||
// Reset buffer
|
||||
buffer.clear();
|
||||
}
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||
continue;
|
||||
|
@ -163,43 +115,32 @@ async fn handle_tcp_proxy_connection(
|
|||
}
|
||||
}
|
||||
}
|
||||
data_recv_result = data_to_real_client_rx.recv() => {
|
||||
match data_recv_result {
|
||||
Some(data) => match socket.try_write(&data) {
|
||||
Ok(size) => {
|
||||
debug!(
|
||||
"[{}] Wrote {} bytes of TCP data to real client",
|
||||
virtual_port, size
|
||||
);
|
||||
}
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||
if abort.load(Ordering::Relaxed) {
|
||||
event = endpoint.recv() => {
|
||||
match event {
|
||||
Event::ClientConnectionDropped(e_vp) if e_vp == virtual_port => {
|
||||
// This connection is supposed to be closed, stop the task.
|
||||
break;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Event::RemoteData(e_vp, data) if e_vp == virtual_port => {
|
||||
// Have remote data to send to the local client
|
||||
let size = data.len();
|
||||
match socket.write(&data).await {
|
||||
Ok(size) => debug!("[{}] Sent {} bytes to local client", virtual_port, size),
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Failed to write to client TCP socket: {:?}",
|
||||
virtual_port, e
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
if abort.load(Ordering::Relaxed) {
|
||||
error!("[{}] Failed to send {} bytes to local client: {:?}", virtual_port, size, e);
|
||||
break;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("[{}] TCP socket handler task terminated", virtual_port);
|
||||
abort.store(true, Ordering::Relaxed);
|
||||
// Notify other endpoints that this task has closed and no more data is to be sent to the local client
|
||||
endpoint.send(Event::ClientConnectionDropped(virtual_port));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -230,19 +171,19 @@ impl TcpPortPool {
|
|||
}
|
||||
|
||||
/// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity).
|
||||
pub async fn next(&self) -> anyhow::Result<u16> {
|
||||
pub async fn next(&self) -> anyhow::Result<VirtualPort> {
|
||||
let mut inner = self.inner.write().await;
|
||||
let port = inner
|
||||
.queue
|
||||
.pop_front()
|
||||
.with_context(|| "TCP virtual port pool is exhausted")?;
|
||||
Ok(port)
|
||||
Ok(VirtualPort::new(port, PortProtocol::Tcp))
|
||||
}
|
||||
|
||||
/// Releases a port back into the pool.
|
||||
pub async fn release(&self, port: u16) {
|
||||
pub async fn release(&self, port: VirtualPort) {
|
||||
let mut inner = self.inner.write().await;
|
||||
inner.queue.push_back(port);
|
||||
inner.queue.push_back(port.num());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
|
|||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::events::{Bus, Event};
|
||||
use anyhow::Context;
|
||||
use priority_queue::double_priority_queue::DoublePriorityQueue;
|
||||
use priority_queue::priority_queue::PriorityQueue;
|
||||
|
@ -30,61 +31,24 @@ const UDP_TIMEOUT_SECONDS: u64 = 60;
|
|||
/// TODO: Make this configurable by the CLI
|
||||
const PORTS_PER_IP: usize = 100;
|
||||
|
||||
/// Starts the server that listens on UDP datagrams.
|
||||
pub async fn udp_proxy_server(
|
||||
port_forward: PortForwardConfig,
|
||||
port_pool: UdpPortPool,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
bus: Bus,
|
||||
) -> anyhow::Result<()> {
|
||||
// Abort signal
|
||||
let abort = Arc::new(AtomicBool::new(false));
|
||||
|
||||
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
|
||||
// to the real client.
|
||||
let (data_to_real_client_tx, mut data_to_real_client_rx) =
|
||||
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
|
||||
|
||||
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
|
||||
// virtual interface (virtual server socket).
|
||||
let (data_to_virtual_server_tx, data_to_virtual_server_rx) =
|
||||
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
|
||||
|
||||
{
|
||||
// Spawn virtual interface
|
||||
// Note: contrary to TCP, there is only one UDP virtual interface
|
||||
let virtual_interface = UdpVirtualInterface::new(
|
||||
port_forward,
|
||||
wg,
|
||||
data_to_real_client_tx,
|
||||
data_to_virtual_server_rx,
|
||||
);
|
||||
let abort = abort.clone();
|
||||
tokio::spawn(async move {
|
||||
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
|
||||
error!("Virtual interface poll loop failed unexpectedly: {}", e);
|
||||
abort.store(true, Ordering::Relaxed);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
let mut endpoint = bus.new_endpoint();
|
||||
let socket = UdpSocket::bind(port_forward.source)
|
||||
.await
|
||||
.with_context(|| "Failed to bind on UDP proxy address")?;
|
||||
|
||||
let mut buffer = [0u8; MAX_PACKET];
|
||||
loop {
|
||||
if abort.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
tokio::select! {
|
||||
to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => {
|
||||
match to_send_result {
|
||||
Ok(Some((port, data))) => {
|
||||
data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"Failed to dispatch data to UDP virtual interface: {:?}",
|
||||
e
|
||||
);
|
||||
});
|
||||
endpoint.send(Event::LocalData(port_forward, port, data));
|
||||
}
|
||||
Ok(None) => {
|
||||
continue;
|
||||
|
@ -98,18 +62,18 @@ pub async fn udp_proxy_server(
|
|||
}
|
||||
}
|
||||
}
|
||||
data_recv_result = data_to_real_client_rx.recv() => {
|
||||
if let Some((port, data)) = data_recv_result {
|
||||
if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await {
|
||||
if let Err(e) = socket.send_to(&data, peer_addr).await {
|
||||
event = endpoint.recv() => {
|
||||
if let Event::RemoteData(port, data) = event {
|
||||
if let Some(peer) = port_pool.get_peer_addr(port).await {
|
||||
if let Err(e) = socket.send_to(&data, peer).await {
|
||||
error!(
|
||||
"[{}] Failed to send UDP datagram to real client ({}): {:?}",
|
||||
port,
|
||||
peer_addr,
|
||||
peer,
|
||||
e,
|
||||
);
|
||||
}
|
||||
port_pool.update_last_transmit(port.0).await;
|
||||
port_pool.update_last_transmit(port).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -141,14 +105,13 @@ async fn next_udp_datagram(
|
|||
return Ok(None);
|
||||
}
|
||||
};
|
||||
let port = VirtualPort(port, PortProtocol::Udp);
|
||||
|
||||
debug!(
|
||||
"[{}] Received datagram of {} bytes from {}",
|
||||
port, size, peer_addr
|
||||
);
|
||||
|
||||
port_pool.update_last_transmit(port.0).await;
|
||||
port_pool.update_last_transmit(port).await;
|
||||
|
||||
let data = buffer[..size].to_vec();
|
||||
Ok(Some((port, data)))
|
||||
|
@ -181,14 +144,14 @@ impl UdpPortPool {
|
|||
}
|
||||
|
||||
/// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity).
|
||||
pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result<u16> {
|
||||
pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result<VirtualPort> {
|
||||
// A port found to be reused. This is outside of the block because the read lock cannot be upgraded to a write lock.
|
||||
let mut port_reuse: Option<u16> = None;
|
||||
|
||||
{
|
||||
let inner = self.inner.read().await;
|
||||
if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) {
|
||||
return Ok(*port);
|
||||
return Ok(VirtualPort::new(*port, PortProtocol::Udp));
|
||||
}
|
||||
|
||||
// Count how many ports are being used by the peer IP
|
||||
|
@ -240,26 +203,26 @@ impl UdpPortPool {
|
|||
|
||||
inner.port_by_peer_addr.insert(peer_addr, port);
|
||||
inner.peer_addr_by_port.insert(port, peer_addr);
|
||||
Ok(port)
|
||||
Ok(VirtualPort::new(port, PortProtocol::Udp))
|
||||
}
|
||||
|
||||
/// Notify that the given virtual port has received or transmitted a UDP datagram.
|
||||
pub async fn update_last_transmit(&self, port: u16) {
|
||||
pub async fn update_last_transmit(&self, port: VirtualPort) {
|
||||
let mut inner = self.inner.write().await;
|
||||
if let Some(peer) = inner.peer_addr_by_port.get(&port).copied() {
|
||||
if let Some(peer) = inner.peer_addr_by_port.get(&port.num()).copied() {
|
||||
let mut pq: &mut DoublePriorityQueue<u16, Instant> = inner
|
||||
.peer_port_usage
|
||||
.entry(peer.ip())
|
||||
.or_insert_with(Default::default);
|
||||
pq.push(port, Instant::now());
|
||||
pq.push(port.num(), Instant::now());
|
||||
}
|
||||
let mut pq: &mut DoublePriorityQueue<u16, Instant> = &mut inner.port_usage;
|
||||
pq.push(port, Instant::now());
|
||||
pq.push(port.num(), Instant::now());
|
||||
}
|
||||
|
||||
pub async fn get_peer_addr(&self, port: u16) -> Option<SocketAddr> {
|
||||
pub async fn get_peer_addr(&self, port: VirtualPort) -> Option<SocketAddr> {
|
||||
let inner = self.inner.read().await;
|
||||
inner.peer_addr_by_port.get(&port).copied()
|
||||
inner.peer_addr_by_port.get(&port.num()).copied()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,45 +1,51 @@
|
|||
use crate::virtual_iface::VirtualPort;
|
||||
use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
|
||||
use anyhow::Context;
|
||||
use crate::config::PortProtocol;
|
||||
use crate::events::{BusSender, Event};
|
||||
use crate::Bus;
|
||||
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
|
||||
use smoltcp::time::Instant;
|
||||
use std::sync::Arc;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint
|
||||
/// are made available to this device using a channel receiver. IP packets sent from this device
|
||||
/// are asynchronously sent out to the WireGuard tunnel.
|
||||
/// A virtual device that processes IP packets through smoltcp and WireGuard.
|
||||
pub struct VirtualIpDevice {
|
||||
/// Tunnel to send IP packets to.
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
/// Max transmission unit (bytes)
|
||||
max_transmission_unit: usize,
|
||||
/// Channel receiver for received IP packets.
|
||||
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
bus_sender: BusSender,
|
||||
/// Local queue for packets received from the bus that need to go through the smoltcp interface.
|
||||
process_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl VirtualIpDevice {
|
||||
/// Initializes a new virtual IP device.
|
||||
pub fn new(
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
) -> Self {
|
||||
Self { wg, ip_dispatch_rx }
|
||||
pub fn new(protocol: PortProtocol, bus: Bus, max_transmission_unit: usize) -> Self {
|
||||
let mut bus_endpoint = bus.new_endpoint();
|
||||
let bus_sender = bus_endpoint.sender();
|
||||
let process_queue = Arc::new(Mutex::new(VecDeque::new()));
|
||||
|
||||
{
|
||||
let process_queue = process_queue.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match bus_endpoint.recv().await {
|
||||
Event::InboundInternetPacket(ip_proto, data) if ip_proto == protocol => {
|
||||
let mut queue = process_queue
|
||||
.lock()
|
||||
.expect("Failed to acquire process queue lock");
|
||||
queue.push_back(data);
|
||||
bus_endpoint.send(Event::VirtualDeviceFed(ip_proto));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Registers a virtual IP device for a single virtual client.
|
||||
pub fn new_direct(virtual_port: VirtualPort, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
|
||||
let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
||||
|
||||
wg.register_virtual_interface(virtual_port, ip_dispatch_tx)
|
||||
.with_context(|| "Failed to register IP dispatch for virtual interface")?;
|
||||
|
||||
Ok(Self { wg, ip_dispatch_rx })
|
||||
Self {
|
||||
bus_sender,
|
||||
process_queue,
|
||||
max_transmission_unit,
|
||||
}
|
||||
|
||||
pub async fn new_sink(wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
|
||||
let ip_dispatch_rx = wg
|
||||
.register_sink_interface()
|
||||
.await
|
||||
.with_context(|| "Failed to register IP dispatch for sink virtual interface")?;
|
||||
Ok(Self { wg, ip_dispatch_rx })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,27 +54,34 @@ impl<'a> Device<'a> for VirtualIpDevice {
|
|||
type TxToken = TxToken;
|
||||
|
||||
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
|
||||
match self.ip_dispatch_rx.try_recv() {
|
||||
Ok(buffer) => Some((
|
||||
let next = {
|
||||
let mut queue = self
|
||||
.process_queue
|
||||
.lock()
|
||||
.expect("Failed to acquire process queue lock");
|
||||
queue.pop_front()
|
||||
};
|
||||
match next {
|
||||
Some(buffer) => Some((
|
||||
Self::RxToken { buffer },
|
||||
Self::TxToken {
|
||||
wg: self.wg.clone(),
|
||||
sender: self.bus_sender.clone(),
|
||||
},
|
||||
)),
|
||||
Err(_) => None,
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn transmit(&'a mut self) -> Option<Self::TxToken> {
|
||||
Some(TxToken {
|
||||
wg: self.wg.clone(),
|
||||
sender: self.bus_sender.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> DeviceCapabilities {
|
||||
let mut cap = DeviceCapabilities::default();
|
||||
cap.medium = Medium::Ip;
|
||||
cap.max_transmission_unit = self.wg.max_transmission_unit;
|
||||
cap.max_transmission_unit = self.max_transmission_unit;
|
||||
cap
|
||||
}
|
||||
}
|
||||
|
@ -89,7 +102,7 @@ impl smoltcp::phy::RxToken for RxToken {
|
|||
|
||||
#[doc(hidden)]
|
||||
pub struct TxToken {
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
sender: BusSender,
|
||||
}
|
||||
|
||||
impl smoltcp::phy::TxToken for TxToken {
|
||||
|
@ -100,14 +113,7 @@ impl smoltcp::phy::TxToken for TxToken {
|
|||
let mut buffer = Vec::new();
|
||||
buffer.resize(len, 0);
|
||||
let result = f(&mut buffer);
|
||||
tokio::spawn(async move {
|
||||
match self.wg.send_ip_packet(&buffer).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
error!("Failed to send IP packet to WireGuard endpoint: {:?}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
self.sender.send(Event::OutboundInternetPacket(buffer));
|
||||
result
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ pub mod tcp;
|
|||
pub mod udp;
|
||||
|
||||
use crate::config::PortProtocol;
|
||||
use crate::VirtualIpDevice;
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
|
@ -9,15 +10,56 @@ use std::fmt::{Display, Formatter};
|
|||
pub trait VirtualInterfacePoll {
|
||||
/// Initializes the virtual interface and processes incoming data to be dispatched
|
||||
/// to the WireGuard tunnel and to the real client.
|
||||
async fn poll_loop(mut self) -> anyhow::Result<()>;
|
||||
async fn poll_loop(mut self, device: VirtualIpDevice) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
/// Virtual port.
|
||||
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct VirtualPort(pub u16, pub PortProtocol);
|
||||
pub struct VirtualPort(u16, PortProtocol);
|
||||
|
||||
impl VirtualPort {
|
||||
/// Create a new `VirtualPort` instance, with the given port number and associated protocol.
|
||||
pub fn new(port: u16, proto: PortProtocol) -> Self {
|
||||
VirtualPort(port, proto)
|
||||
}
|
||||
|
||||
/// The port number
|
||||
pub fn num(&self) -> u16 {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// The protocol of this port.
|
||||
pub fn proto(&self) -> PortProtocol {
|
||||
self.1
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VirtualPort> for u16 {
|
||||
fn from(port: VirtualPort) -> Self {
|
||||
port.num()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&VirtualPort> for u16 {
|
||||
fn from(port: &VirtualPort) -> Self {
|
||||
port.num()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VirtualPort> for PortProtocol {
|
||||
fn from(port: VirtualPort) -> Self {
|
||||
port.proto()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&VirtualPort> for PortProtocol {
|
||||
fn from(port: &VirtualPort) -> Self {
|
||||
port.proto()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for VirtualPort {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[{}:{}]", self.0, self.1)
|
||||
write!(f, "[{}:{}]", self.num(), self.proto())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,284 +1,250 @@
|
|||
use crate::config::{PortForwardConfig, PortProtocol};
|
||||
use crate::events::Event;
|
||||
use crate::virtual_device::VirtualIpDevice;
|
||||
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
||||
use crate::wg::WireGuardTunnel;
|
||||
use crate::Bus;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use smoltcp::iface::InterfaceBuilder;
|
||||
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
|
||||
use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState};
|
||||
use smoltcp::wire::{IpAddress, IpCidr};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
const MAX_PACKET: usize = 65536;
|
||||
|
||||
/// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa.
|
||||
pub struct TcpVirtualInterface {
|
||||
/// The virtual port assigned to the virtual client, used to
|
||||
/// route Layer 4 segments/datagrams to and from the WireGuard tunnel.
|
||||
virtual_port: u16,
|
||||
/// The overall port-forward configuration: used for the destination address (on which
|
||||
/// the virtual server listens) and the protocol in use.
|
||||
port_forward: PortForwardConfig,
|
||||
/// The WireGuard tunnel to send IP packets to.
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
/// Abort signal to shutdown the virtual interface and its parent task.
|
||||
abort: Arc<AtomicBool>,
|
||||
/// Channel sender for pushing Layer 7 data back to the real client.
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
/// Channel receiver for processing Layer 7 data through the virtual interface.
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
/// One-shot sender to notify the parent task that the virtual client is ready to send Layer 7 data.
|
||||
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
|
||||
source_peer_ip: IpAddr,
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
bus: Bus,
|
||||
}
|
||||
|
||||
impl TcpVirtualInterface {
|
||||
/// Initialize the parameters for a new virtual interface.
|
||||
/// Use the `poll_loop()` future to start the virtual interface poll loop.
|
||||
pub fn new(
|
||||
virtual_port: u16,
|
||||
port_forward: PortForwardConfig,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
abort: Arc<AtomicBool>,
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
|
||||
) -> Self {
|
||||
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
|
||||
Self {
|
||||
virtual_port,
|
||||
port_forward,
|
||||
wg,
|
||||
abort,
|
||||
data_to_real_client_tx,
|
||||
data_to_virtual_server_rx,
|
||||
virtual_client_ready_tx,
|
||||
}
|
||||
port_forwards: port_forwards
|
||||
.into_iter()
|
||||
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
|
||||
.collect(),
|
||||
source_peer_ip,
|
||||
bus,
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VirtualInterfacePoll for TcpVirtualInterface {
|
||||
async fn poll_loop(self) -> anyhow::Result<()> {
|
||||
let mut virtual_client_ready_tx = Some(self.virtual_client_ready_tx);
|
||||
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
|
||||
let source_peer_ip = self.wg.source_peer_ip;
|
||||
|
||||
// Create a device and interface to simulate IP packets
|
||||
// In essence:
|
||||
// * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client'
|
||||
// * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel
|
||||
// * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface'
|
||||
// * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded)
|
||||
// * The TCP data read by the 'virtual client' is sent to the 'real' TCP client
|
||||
|
||||
// Consumer for IP packets to send through the virtual interface
|
||||
// Initialize the interface
|
||||
let device =
|
||||
VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg)
|
||||
.with_context(|| "Failed to initialize TCP VirtualIpDevice")?;
|
||||
|
||||
// there are always 2 sockets: 1 virtual client and 1 virtual server.
|
||||
let mut sockets: [_; 2] = Default::default();
|
||||
let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..])
|
||||
.ip_addrs([
|
||||
// Interface handles IP packets for the sender and recipient
|
||||
IpCidr::new(IpAddress::from(source_peer_ip), 32),
|
||||
IpCidr::new(IpAddress::from(self.port_forward.destination.ip()), 32),
|
||||
])
|
||||
.finalize();
|
||||
|
||||
// Server socket: this is a placeholder for the interface to route new connections to.
|
||||
let server_socket: anyhow::Result<TcpSocket> = {
|
||||
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<TcpSocket<'static>> {
|
||||
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
|
||||
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
|
||||
|
||||
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
|
||||
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
|
||||
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
|
||||
|
||||
socket
|
||||
.listen((
|
||||
IpAddress::from(self.port_forward.destination.ip()),
|
||||
self.port_forward.destination.port(),
|
||||
IpAddress::from(port_forward.destination.ip()),
|
||||
port_forward.destination.port(),
|
||||
))
|
||||
.with_context(|| "Virtual server socket failed to listen")?;
|
||||
|
||||
Ok(socket)
|
||||
};
|
||||
}
|
||||
|
||||
let client_socket: anyhow::Result<TcpSocket> = {
|
||||
fn new_client_socket() -> anyhow::Result<TcpSocket<'static>> {
|
||||
let rx_data = vec![0u8; MAX_PACKET];
|
||||
let tx_data = vec![0u8; MAX_PACKET];
|
||||
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data);
|
||||
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data);
|
||||
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
|
||||
Ok(socket)
|
||||
};
|
||||
}
|
||||
|
||||
let _server_handle = virtual_interface.add_socket(server_socket?);
|
||||
let client_handle = virtual_interface.add_socket(client_socket?);
|
||||
fn addresses(&self) -> Vec<IpCidr> {
|
||||
let mut addresses = HashSet::new();
|
||||
addresses.insert(IpAddress::from(self.source_peer_ip));
|
||||
for config in self.port_forwards.iter() {
|
||||
addresses.insert(IpAddress::from(config.destination.ip()));
|
||||
}
|
||||
addresses
|
||||
.into_iter()
|
||||
.map(|addr| IpCidr::new(addr, 32))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// Any data that wasn't sent because it was over the sending buffer limit
|
||||
let mut tx_extra = Vec::new();
|
||||
#[async_trait]
|
||||
impl VirtualInterfacePoll for TcpVirtualInterface {
|
||||
async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> {
|
||||
// Create CIDR block for source peer IP + each port forward IP
|
||||
let addresses = self.addresses();
|
||||
|
||||
// Counts the connection attempts by the virtual client
|
||||
let mut connection_attempts = 0;
|
||||
// Whether the client has successfully connected before. Prevents the case of connecting again.
|
||||
let mut has_connected = false;
|
||||
// Create virtual interface (contains smoltcp state machine)
|
||||
let mut iface = InterfaceBuilder::new(device, vec![])
|
||||
.ip_addrs(addresses)
|
||||
.finalize();
|
||||
|
||||
// Create virtual server for each port forward
|
||||
for port_forward in self.port_forwards.iter() {
|
||||
let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?;
|
||||
iface.add_socket(server_socket);
|
||||
}
|
||||
|
||||
// The next time to poll the interface. Can be None for instant poll.
|
||||
let mut next_poll: Option<tokio::time::Instant> = None;
|
||||
|
||||
// Bus endpoint to read events
|
||||
let mut endpoint = self.bus.new_endpoint();
|
||||
|
||||
// Maps virtual port to its client socket handle
|
||||
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
|
||||
|
||||
// Data packets to send from a virtual client
|
||||
let mut send_queue: HashMap<VirtualPort, VecDeque<Vec<u8>>> = HashMap::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = match (next_poll, port_client_handle_map.len()) {
|
||||
(None, 0) => tokio::time::sleep(Duration::MAX),
|
||||
(None, _) => tokio::time::sleep(Duration::ZERO),
|
||||
(Some(until), _) => tokio::time::sleep_until(until),
|
||||
} => {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
|
||||
// Shutdown occurs when the real client closes the connection,
|
||||
// or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore.
|
||||
// One last poll-loop iteration is executed so that the RST segment can be dispatched.
|
||||
let shutdown = self.abort.load(Ordering::Relaxed);
|
||||
|
||||
if shutdown {
|
||||
// Shutdown: sends a RST packet.
|
||||
trace!("[{}] Shutting down virtual interface", self.virtual_port);
|
||||
let client_socket = virtual_interface.get_socket::<TcpSocket>(client_handle);
|
||||
client_socket.abort();
|
||||
}
|
||||
|
||||
match virtual_interface.poll(loop_start) {
|
||||
match iface.poll(loop_start) {
|
||||
Ok(processed) if processed => {
|
||||
trace!(
|
||||
"[{}] Virtual interface polled some packets to be processed",
|
||||
self.virtual_port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Virtual interface poll error: {:?}",
|
||||
self.virtual_port, e
|
||||
);
|
||||
trace!("TCP virtual interface polled some packets to be processed");
|
||||
}
|
||||
Err(e) => error!("TCP virtual interface poll error: {:?}", e),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
{
|
||||
let (client_socket, context) =
|
||||
virtual_interface.get_socket_and_context::<TcpSocket>(client_handle);
|
||||
// Find client socket send data to
|
||||
for (virtual_port, client_handle) in port_client_handle_map.iter() {
|
||||
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
|
||||
if client_socket.can_send() {
|
||||
if let Some(send_queue) = send_queue.get_mut(virtual_port) {
|
||||
let to_transfer = send_queue.pop_front();
|
||||
if let Some(to_transfer_slice) = to_transfer.as_deref() {
|
||||
let total = to_transfer_slice.len();
|
||||
match client_socket.send_slice(to_transfer_slice) {
|
||||
Ok(sent) => {
|
||||
if sent < total {
|
||||
// Sometimes only a subset is sent, so the rest needs to be sent on the next poll
|
||||
let tx_extra = Vec::from(&to_transfer_slice[sent..total]);
|
||||
send_queue.push_front(tx_extra);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to send slice via virtual client socket: {:?}", e
|
||||
);
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else if client_socket.state() == TcpState::CloseWait {
|
||||
client_socket.close();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find client socket recv data from
|
||||
for (virtual_port, client_handle) in port_client_handle_map.iter() {
|
||||
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
|
||||
if client_socket.can_recv() {
|
||||
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
|
||||
Ok(data) => {
|
||||
if !data.is_empty() {
|
||||
endpoint.send(Event::RemoteData(*virtual_port, data));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to read from virtual client socket: {:?}", e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find closed sockets
|
||||
port_client_handle_map.retain(|virtual_port, client_handle| {
|
||||
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
|
||||
if client_socket.state() == TcpState::Closed {
|
||||
endpoint.send(Event::ClientConnectionDropped(*virtual_port));
|
||||
send_queue.remove(virtual_port);
|
||||
iface.remove_socket(*client_handle);
|
||||
false
|
||||
} else {
|
||||
// Not closed, retain
|
||||
true
|
||||
}
|
||||
});
|
||||
|
||||
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
|
||||
next_poll = match iface.poll_delay(loop_start) {
|
||||
Some(smoltcp::time::Duration::ZERO) => None,
|
||||
Some(delay) => {
|
||||
trace!("TCP Virtual interface delayed next poll by {}", delay);
|
||||
Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis()))
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
}
|
||||
event = endpoint.recv() => {
|
||||
match event {
|
||||
Event::ClientConnectionInitiated(port_forward, virtual_port) => {
|
||||
let client_socket = TcpVirtualInterface::new_client_socket()?;
|
||||
let client_handle = iface.add_socket(client_socket);
|
||||
|
||||
// Add handle to map
|
||||
port_client_handle_map.insert(virtual_port, client_handle);
|
||||
send_queue.insert(virtual_port, VecDeque::new());
|
||||
|
||||
let (client_socket, context) = iface.get_socket_and_context::<TcpSocket>(client_handle);
|
||||
|
||||
if !shutdown && client_socket.state() == TcpState::Closed && !has_connected {
|
||||
// Not shutting down, but the client socket is closed, and the client never successfully connected.
|
||||
if connection_attempts < 10 {
|
||||
// Try to connect
|
||||
client_socket
|
||||
.connect(
|
||||
context,
|
||||
(
|
||||
IpAddress::from(self.port_forward.destination.ip()),
|
||||
self.port_forward.destination.port(),
|
||||
IpAddress::from(port_forward.destination.ip()),
|
||||
port_forward.destination.port(),
|
||||
),
|
||||
(IpAddress::from(source_peer_ip), self.virtual_port),
|
||||
(IpAddress::from(self.source_peer_ip), virtual_port.num()),
|
||||
)
|
||||
.with_context(|| "Virtual server socket failed to listen")?;
|
||||
if connection_attempts > 0 {
|
||||
debug!(
|
||||
"[{}] Virtual client retrying connection in 500ms",
|
||||
self.virtual_port
|
||||
);
|
||||
// Not our first connection attempt, wait a little bit.
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
} else {
|
||||
// Too many connection attempts
|
||||
self.abort.store(true, Ordering::Relaxed);
|
||||
}
|
||||
connection_attempts += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if client_socket.state() == TcpState::Established {
|
||||
// Prevent reconnection if the server later closes.
|
||||
has_connected = true;
|
||||
next_poll = None;
|
||||
}
|
||||
Event::ClientConnectionDropped(virtual_port) => {
|
||||
if let Some(client_handle) = port_client_handle_map.get(&virtual_port) {
|
||||
let client_handle = *client_handle;
|
||||
port_client_handle_map.remove(&virtual_port);
|
||||
send_queue.remove(&virtual_port);
|
||||
|
||||
if client_socket.can_recv() {
|
||||
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
|
||||
Ok(data) => {
|
||||
trace!(
|
||||
"[{}] Virtual client received {} bytes of data",
|
||||
self.virtual_port,
|
||||
data.len()
|
||||
);
|
||||
// Send it to the real client
|
||||
if let Err(e) = self.data_to_real_client_tx.send(data).await {
|
||||
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", self.virtual_port, e);
|
||||
let client_socket = iface.get_socket::<TcpSocket>(client_handle);
|
||||
client_socket.close();
|
||||
next_poll = None;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Failed to read from virtual client socket: {:?}",
|
||||
self.virtual_port, e
|
||||
);
|
||||
Event::LocalData(_, virtual_port, data) if send_queue.contains_key(&virtual_port) => {
|
||||
if let Some(send_queue) = send_queue.get_mut(&virtual_port) {
|
||||
send_queue.push_back(data);
|
||||
next_poll = None;
|
||||
}
|
||||
}
|
||||
Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Tcp => {
|
||||
next_poll = None;
|
||||
}
|
||||
if client_socket.can_send() {
|
||||
if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() {
|
||||
virtual_client_ready_tx
|
||||
.send(())
|
||||
.expect("Failed to notify real client that virtual client is ready");
|
||||
}
|
||||
|
||||
let mut to_transfer = None;
|
||||
|
||||
if tx_extra.is_empty() {
|
||||
// The payload segment from the previous loop is complete,
|
||||
// we can now read the next payload in the queue.
|
||||
if let Ok(data) = data_to_virtual_server_rx.try_recv() {
|
||||
to_transfer = Some(data);
|
||||
} else if client_socket.state() == TcpState::CloseWait {
|
||||
// No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN),
|
||||
// the interface is shutdown.
|
||||
trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", self.virtual_port);
|
||||
self.abort.store(true, Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice();
|
||||
if !to_transfer_slice.is_empty() {
|
||||
let total = to_transfer_slice.len();
|
||||
match client_socket.send_slice(to_transfer_slice) {
|
||||
Ok(sent) => {
|
||||
trace!(
|
||||
"[{}] Sent {}/{} bytes via virtual client socket",
|
||||
self.virtual_port,
|
||||
sent,
|
||||
total,
|
||||
);
|
||||
tx_extra = Vec::from(&to_transfer_slice[sent..total]);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Failed to send slice via virtual client socket: {:?}",
|
||||
self.virtual_port, e
|
||||
);
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shutdown {
|
||||
break;
|
||||
}
|
||||
|
||||
match virtual_interface.poll_delay(loop_start) {
|
||||
Some(smoltcp::time::Duration::ZERO) => {
|
||||
continue;
|
||||
}
|
||||
_ => {
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!("[{}] Virtual interface task terminated", self.virtual_port);
|
||||
self.abort.store(true, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,165 +1,68 @@
|
|||
use anyhow::Context;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Context;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::net::IpAddr;
|
||||
|
||||
use crate::events::Event;
|
||||
use crate::{Bus, PortProtocol};
|
||||
use async_trait::async_trait;
|
||||
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
|
||||
use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer};
|
||||
use smoltcp::wire::{IpAddress, IpCidr};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config::PortForwardConfig;
|
||||
use crate::virtual_device::VirtualIpDevice;
|
||||
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
||||
use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
|
||||
|
||||
const MAX_PACKET: usize = 65536;
|
||||
|
||||
pub struct UdpVirtualInterface {
|
||||
port_forward: PortForwardConfig,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
|
||||
source_peer_ip: IpAddr,
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
bus: Bus,
|
||||
}
|
||||
|
||||
impl UdpVirtualInterface {
|
||||
pub fn new(
|
||||
port_forward: PortForwardConfig,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
|
||||
) -> Self {
|
||||
/// Initialize the parameters for a new virtual interface.
|
||||
/// Use the `poll_loop()` future to start the virtual interface poll loop.
|
||||
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
|
||||
Self {
|
||||
port_forward,
|
||||
wg,
|
||||
data_to_real_client_tx,
|
||||
data_to_virtual_server_rx,
|
||||
}
|
||||
port_forwards: port_forwards
|
||||
.into_iter()
|
||||
.filter(|f| matches!(f.protocol, PortProtocol::Udp))
|
||||
.collect(),
|
||||
source_peer_ip,
|
||||
bus,
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VirtualInterfacePoll for UdpVirtualInterface {
|
||||
async fn poll_loop(self) -> anyhow::Result<()> {
|
||||
// Data receiver to dispatch using virtual client sockets
|
||||
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
|
||||
|
||||
// The IP to bind client sockets to
|
||||
let source_peer_ip = self.wg.source_peer_ip;
|
||||
|
||||
// The IP/port to bind the server socket to
|
||||
let destination = self.port_forward.destination;
|
||||
|
||||
// Initialize a channel for IP packets.
|
||||
// The "base transmitted" is cloned so that each virtual port can register a sender in the tunnel.
|
||||
// The receiver is given to the device so that the Virtual Interface can process incoming IP packets from the tunnel.
|
||||
let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
||||
|
||||
let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx);
|
||||
let mut virtual_interface = InterfaceBuilder::new(device, vec![])
|
||||
.ip_addrs([
|
||||
// Interface handles IP packets for the sender and recipient
|
||||
IpCidr::new(source_peer_ip.into(), 32),
|
||||
IpCidr::new(destination.ip().into(), 32),
|
||||
])
|
||||
.finalize();
|
||||
|
||||
// Server socket: this is a placeholder for the interface.
|
||||
let server_socket: anyhow::Result<UdpSocket> = {
|
||||
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<UdpSocket<'static>> {
|
||||
static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = [];
|
||||
static mut UDP_SERVER_RX_DATA: [u8; 0] = [];
|
||||
static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = [];
|
||||
static mut UDP_SERVER_TX_DATA: [u8; 0] = [];
|
||||
let udp_rx_buffer =
|
||||
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe {
|
||||
let udp_rx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe {
|
||||
&mut UDP_SERVER_RX_DATA[..]
|
||||
});
|
||||
let udp_tx_buffer =
|
||||
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe {
|
||||
let udp_tx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe {
|
||||
&mut UDP_SERVER_TX_DATA[..]
|
||||
});
|
||||
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
|
||||
|
||||
socket
|
||||
.bind((IpAddress::from(destination.ip()), destination.port()))
|
||||
.with_context(|| "UDP virtual server socket failed to listen")?;
|
||||
|
||||
.bind((
|
||||
IpAddress::from(port_forward.destination.ip()),
|
||||
port_forward.destination.port(),
|
||||
))
|
||||
.with_context(|| "UDP virtual server socket failed to bind")?;
|
||||
Ok(socket)
|
||||
};
|
||||
|
||||
let _server_handle = virtual_interface.add_socket(server_socket?);
|
||||
|
||||
// A map of virtual port to client socket.
|
||||
let mut client_sockets: HashMap<VirtualPort, SocketHandle> = HashMap::new();
|
||||
|
||||
// The next instant required to poll the virtual interface
|
||||
// None means "immediate poll required".
|
||||
let mut next_poll: Option<tokio::time::Instant> = None;
|
||||
|
||||
loop {
|
||||
let wg = self.wg.clone();
|
||||
tokio::select! {
|
||||
// Wait the recommended amount of time by smoltcp, and poll again.
|
||||
_ = match next_poll {
|
||||
None => tokio::time::sleep(Duration::ZERO),
|
||||
Some(until) => tokio::time::sleep_until(until)
|
||||
} => {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
|
||||
match virtual_interface.poll(loop_start) {
|
||||
Ok(processed) if processed => {
|
||||
trace!("UDP virtual interface polled some packets to be processed");
|
||||
}
|
||||
Err(e) => error!("UDP virtual interface poll error: {:?}", e),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Loop through each client socket and check if there is any data to send back
|
||||
// to the real client.
|
||||
for (virtual_port, client_socket_handle) in client_sockets.iter() {
|
||||
let client_socket = virtual_interface.get_socket::<UdpSocket>(*client_socket_handle);
|
||||
match client_socket.recv() {
|
||||
Ok((data, _peer)) => {
|
||||
// Send the data back to the real client using MPSC channel
|
||||
self.data_to_real_client_tx
|
||||
.send((*virtual_port, data.to_vec()))
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] Failed to dispatch data from virtual client to real client: {:?}",
|
||||
virtual_port, e
|
||||
);
|
||||
});
|
||||
}
|
||||
Err(smoltcp::Error::Exhausted) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Failed to read from virtual client socket: {:?}",
|
||||
virtual_port, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next_poll = match virtual_interface.poll_delay(loop_start) {
|
||||
Some(smoltcp::time::Duration::ZERO) => None,
|
||||
Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
// Wait for data to be received from the real client
|
||||
data_recv_result = data_to_virtual_server_rx.recv() => {
|
||||
if let Some((client_port, data)) = data_recv_result {
|
||||
// Register the socket in WireGuard Tunnel (overrides any previous registration as well)
|
||||
wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone())
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] Failed to register UDP socket in WireGuard tunnel: {:?}",
|
||||
client_port, e
|
||||
);
|
||||
});
|
||||
|
||||
let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| {
|
||||
fn new_client_socket(
|
||||
source_peer_ip: IpAddr,
|
||||
client_port: VirtualPort,
|
||||
) -> anyhow::Result<UdpSocket<'static>> {
|
||||
let rx_meta = vec![UdpPacketMetadata::EMPTY; 10];
|
||||
let tx_meta = vec![UdpPacketMetadata::EMPTY; 10];
|
||||
let rx_data = vec![0u8; MAX_PACKET];
|
||||
|
@ -167,31 +70,147 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
|
|||
let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data);
|
||||
let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data);
|
||||
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
|
||||
|
||||
socket
|
||||
.bind((IpAddress::from(wg.source_peer_ip), client_port.0))
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] UDP virtual client socket failed to bind: {:?}",
|
||||
client_port, e
|
||||
);
|
||||
});
|
||||
.bind((IpAddress::from(source_peer_ip), client_port.num()))
|
||||
.with_context(|| "UDP virtual client failed to bind")?;
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
virtual_interface.add_socket(socket)
|
||||
});
|
||||
fn addresses(&self) -> Vec<IpCidr> {
|
||||
let mut addresses = HashSet::new();
|
||||
addresses.insert(IpAddress::from(self.source_peer_ip));
|
||||
for config in self.port_forwards.iter() {
|
||||
addresses.insert(IpAddress::from(config.destination.ip()));
|
||||
}
|
||||
addresses
|
||||
.into_iter()
|
||||
.map(|addr| IpCidr::new(addr, 32))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
let client_socket = virtual_interface.get_socket::<UdpSocket>(*client_socket_handle);
|
||||
#[async_trait]
|
||||
impl VirtualInterfacePoll for UdpVirtualInterface {
|
||||
async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> {
|
||||
// Create CIDR block for source peer IP + each port forward IP
|
||||
let addresses = self.addresses();
|
||||
|
||||
// Create virtual interface (contains smoltcp state machine)
|
||||
let mut iface = InterfaceBuilder::new(device, vec![])
|
||||
.ip_addrs(addresses)
|
||||
.finalize();
|
||||
|
||||
// Create virtual server for each port forward
|
||||
for port_forward in self.port_forwards.iter() {
|
||||
let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?;
|
||||
iface.add_socket(server_socket);
|
||||
}
|
||||
|
||||
// The next time to poll the interface. Can be None for instant poll.
|
||||
let mut next_poll: Option<tokio::time::Instant> = None;
|
||||
|
||||
// Bus endpoint to read events
|
||||
let mut endpoint = self.bus.new_endpoint();
|
||||
|
||||
// Maps virtual port to its client socket handle
|
||||
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
|
||||
|
||||
// Data packets to send from a virtual client
|
||||
let mut send_queue: HashMap<VirtualPort, VecDeque<(PortForwardConfig, Vec<u8>)>> =
|
||||
HashMap::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = match (next_poll, port_client_handle_map.len()) {
|
||||
(None, 0) => tokio::time::sleep(Duration::MAX),
|
||||
(None, _) => tokio::time::sleep(Duration::ZERO),
|
||||
(Some(until), _) => tokio::time::sleep_until(until),
|
||||
} => {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
|
||||
match iface.poll(loop_start) {
|
||||
Ok(processed) if processed => {
|
||||
trace!("UDP virtual interface polled some packets to be processed");
|
||||
}
|
||||
Err(e) => error!("UDP virtual interface poll error: {:?}", e),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Find client socket send data to
|
||||
for (virtual_port, client_handle) in port_client_handle_map.iter() {
|
||||
let client_socket = iface.get_socket::<UdpSocket>(*client_handle);
|
||||
if client_socket.can_send() {
|
||||
if let Some(send_queue) = send_queue.get_mut(virtual_port) {
|
||||
let to_transfer = send_queue.pop_front();
|
||||
if let Some((port_forward, data)) = to_transfer {
|
||||
client_socket
|
||||
.send_slice(
|
||||
&data,
|
||||
(IpAddress::from(destination.ip()), destination.port()).into(),
|
||||
(IpAddress::from(port_forward.destination.ip()), port_forward.destination.port()).into(),
|
||||
)
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] Failed to send data to virtual server: {:?}",
|
||||
client_port, e
|
||||
virtual_port, e
|
||||
);
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find client socket recv data from
|
||||
for (virtual_port, client_handle) in port_client_handle_map.iter() {
|
||||
let client_socket = iface.get_socket::<UdpSocket>(*client_handle);
|
||||
if client_socket.can_recv() {
|
||||
match client_socket.recv() {
|
||||
Ok((data, _peer)) => {
|
||||
if !data.is_empty() {
|
||||
endpoint.send(Event::RemoteData(*virtual_port, data.to_vec()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to read from virtual client socket: {:?}", e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
|
||||
next_poll = match iface.poll_delay(loop_start) {
|
||||
Some(smoltcp::time::Duration::ZERO) => None,
|
||||
Some(delay) => {
|
||||
trace!("UDP Virtual interface delayed next poll by {}", delay);
|
||||
Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis()))
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
}
|
||||
event = endpoint.recv() => {
|
||||
match event {
|
||||
Event::LocalData(port_forward, virtual_port, data) => {
|
||||
if let Some(send_queue) = send_queue.get_mut(&virtual_port) {
|
||||
// Client socket already exists
|
||||
send_queue.push_back((port_forward, data));
|
||||
} else {
|
||||
// Client socket does not exist
|
||||
let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?;
|
||||
let client_handle = iface.add_socket(client_socket);
|
||||
|
||||
// Add handle to map
|
||||
port_client_handle_map.insert(virtual_port, client_handle);
|
||||
send_queue.insert(virtual_port, VecDeque::from(vec![(port_forward, data)]));
|
||||
}
|
||||
next_poll = None;
|
||||
}
|
||||
Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Udp => {
|
||||
next_poll = None;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
179
src/wg.rs
179
src/wg.rs
|
@ -1,15 +1,15 @@
|
|||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::Bus;
|
||||
use anyhow::Context;
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use log::Level;
|
||||
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket};
|
||||
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::{Config, PortProtocol};
|
||||
use crate::virtual_iface::VirtualPort;
|
||||
use crate::events::Event;
|
||||
|
||||
/// The capacity of the channel for received IP packets.
|
||||
pub const DISPATCH_CAPACITY: usize = 1_000;
|
||||
|
@ -26,17 +26,13 @@ pub struct WireGuardTunnel {
|
|||
udp: UdpSocket,
|
||||
/// The address of the public WireGuard endpoint (UDP).
|
||||
pub(crate) endpoint: SocketAddr,
|
||||
/// Maps virtual ports to the corresponding IP packet dispatcher.
|
||||
virtual_port_ip_tx: dashmap::DashMap<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
|
||||
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
|
||||
/// The max transmission unit for WireGuard.
|
||||
pub(crate) max_transmission_unit: usize,
|
||||
/// Event bus
|
||||
bus: Bus,
|
||||
}
|
||||
|
||||
impl WireGuardTunnel {
|
||||
/// Initialize a new WireGuard tunnel.
|
||||
pub async fn new(config: &Config) -> anyhow::Result<Self> {
|
||||
pub async fn new(config: &Config, bus: Bus) -> anyhow::Result<Self> {
|
||||
let source_peer_ip = config.source_peer_ip;
|
||||
let peer = Self::create_tunnel(config)?;
|
||||
let endpoint = config.endpoint_addr;
|
||||
|
@ -46,16 +42,13 @@ impl WireGuardTunnel {
|
|||
})
|
||||
.await
|
||||
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
|
||||
let virtual_port_ip_tx = Default::default();
|
||||
|
||||
Ok(Self {
|
||||
source_peer_ip,
|
||||
peer,
|
||||
udp,
|
||||
endpoint,
|
||||
virtual_port_ip_tx,
|
||||
sink_ip_tx: RwLock::new(None),
|
||||
max_transmission_unit: config.max_transmission_unit,
|
||||
bus,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -90,31 +83,20 @@ impl WireGuardTunnel {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
|
||||
pub fn register_virtual_interface(
|
||||
&self,
|
||||
virtual_port: VirtualPort,
|
||||
sender: tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.virtual_port_ip_tx.insert(virtual_port, sender);
|
||||
Ok(())
|
||||
pub async fn produce_task(&self) -> ! {
|
||||
trace!("Starting WireGuard production task");
|
||||
let mut endpoint = self.bus.new_endpoint();
|
||||
|
||||
loop {
|
||||
if let Event::OutboundInternetPacket(data) = endpoint.recv().await {
|
||||
match self.send_ip_packet(&data).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
error!("{:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
|
||||
pub async fn register_sink_interface(
|
||||
&self,
|
||||
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
||||
|
||||
let mut sink_ip_tx = self.sink_ip_tx.write().await;
|
||||
*sink_ip_tx = Some(sender);
|
||||
|
||||
Ok(receiver)
|
||||
}
|
||||
|
||||
/// Releases the virtual interface from IP dispatch.
|
||||
pub fn release_virtual_interface(&self, virtual_port: VirtualPort) {
|
||||
self.virtual_port_ip_tx.remove(&virtual_port);
|
||||
}
|
||||
|
||||
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||
|
@ -160,6 +142,7 @@ impl WireGuardTunnel {
|
|||
/// decapsulates them, and dispatches newly received IP packets.
|
||||
pub async fn consume_task(&self) -> ! {
|
||||
trace!("Starting WireGuard consumption task");
|
||||
let endpoint = self.bus.new_endpoint();
|
||||
|
||||
loop {
|
||||
let mut recv_buf = [0u8; MAX_PACKET];
|
||||
|
@ -212,38 +195,8 @@ impl WireGuardTunnel {
|
|||
// For debugging purposes: parse packet
|
||||
trace_ip_packet("Received IP packet", packet);
|
||||
|
||||
match self.route_ip_packet(packet) {
|
||||
RouteResult::Dispatch(port) => {
|
||||
let sender = self.virtual_port_ip_tx.get(&port);
|
||||
if let Some(sender_guard) = sender {
|
||||
let sender = sender_guard.value();
|
||||
match sender.send(packet.to_vec()).await {
|
||||
Ok(_) => {
|
||||
trace!(
|
||||
"Dispatched received IP packet to virtual port {}",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to dispatch received IP packet to virtual port {}: {}",
|
||||
port, e
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
|
||||
}
|
||||
}
|
||||
RouteResult::Sink => {
|
||||
trace!("Sending unroutable IP packet received from WireGuard endpoint to sink interface");
|
||||
self.route_ip_sink(packet).await.unwrap_or_else(|e| {
|
||||
error!("Failed to send unroutable IP packet to sink: {:?}", e)
|
||||
});
|
||||
}
|
||||
RouteResult::Drop => {
|
||||
trace!("Dropped unroutable IP packet received from WireGuard endpoint");
|
||||
}
|
||||
if let Some(proto) = self.route_protocol(packet) {
|
||||
endpoint.send(Event::InboundInternetPacket(proto, packet.into()));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
|
@ -264,89 +217,32 @@ impl WireGuardTunnel {
|
|||
.with_context(|| "Failed to initialize boringtun Tunn")
|
||||
}
|
||||
|
||||
/// Makes a decision on the handling of an incoming IP packet.
|
||||
fn route_ip_packet(&self, packet: &[u8]) -> RouteResult {
|
||||
/// Determine the inner protocol of the incoming IP packet (TCP/UDP).
|
||||
fn route_protocol(&self, packet: &[u8]) -> Option<PortProtocol> {
|
||||
match IpVersion::of_packet(packet) {
|
||||
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
|
||||
.ok()
|
||||
// Only care if the packet is destined for this tunnel
|
||||
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
|
||||
.map(|packet| match packet.protocol() {
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
|
||||
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
|
||||
IpProtocol::Tcp => Some(PortProtocol::Tcp),
|
||||
IpProtocol::Udp => Some(PortProtocol::Udp),
|
||||
// Unrecognized protocol, so we cannot determine where to route
|
||||
_ => Some(RouteResult::Drop),
|
||||
_ => None,
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or(RouteResult::Drop),
|
||||
.flatten(),
|
||||
Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet)
|
||||
.ok()
|
||||
// Only care if the packet is destined for this tunnel
|
||||
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
|
||||
.map(|packet| match packet.next_header() {
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
|
||||
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
|
||||
IpProtocol::Tcp => Some(PortProtocol::Tcp),
|
||||
IpProtocol::Udp => Some(PortProtocol::Udp),
|
||||
// Unrecognized protocol, so we cannot determine where to route
|
||||
_ => Some(RouteResult::Drop),
|
||||
_ => None,
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or(RouteResult::Drop),
|
||||
_ => RouteResult::Drop,
|
||||
}
|
||||
}
|
||||
|
||||
/// Makes a decision on the handling of an incoming TCP segment.
|
||||
fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
|
||||
TcpPacket::new_checked(segment)
|
||||
.ok()
|
||||
.map(|tcp| {
|
||||
if self
|
||||
.virtual_port_ip_tx
|
||||
.get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
|
||||
.is_some()
|
||||
{
|
||||
RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
|
||||
} else if tcp.rst() {
|
||||
RouteResult::Drop
|
||||
} else {
|
||||
RouteResult::Sink
|
||||
}
|
||||
})
|
||||
.unwrap_or(RouteResult::Drop)
|
||||
}
|
||||
|
||||
/// Makes a decision on the handling of an incoming UDP datagram.
|
||||
fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult {
|
||||
UdpPacket::new_checked(datagram)
|
||||
.ok()
|
||||
.map(|udp| {
|
||||
if self
|
||||
.virtual_port_ip_tx
|
||||
.get(&VirtualPort(udp.dst_port(), PortProtocol::Udp))
|
||||
.is_some()
|
||||
{
|
||||
RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp))
|
||||
} else {
|
||||
RouteResult::Drop
|
||||
}
|
||||
})
|
||||
.unwrap_or(RouteResult::Drop)
|
||||
}
|
||||
|
||||
/// Route a packet to the IP sink interface.
|
||||
async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> {
|
||||
let ip_sink_tx = self.sink_ip_tx.read().await;
|
||||
|
||||
if let Some(ip_sink_tx) = &*ip_sink_tx {
|
||||
ip_sink_tx
|
||||
.send(packet.to_vec())
|
||||
.await
|
||||
.with_context(|| "Failed to dispatch IP packet to sink interface")
|
||||
} else {
|
||||
warn!(
|
||||
"Could not dispatch unroutable IP packet to sink because interface is not active."
|
||||
);
|
||||
Ok(())
|
||||
.flatten(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -370,12 +266,3 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum RouteResult {
|
||||
/// Dispatch the packet to the virtual port.
|
||||
Dispatch(VirtualPort),
|
||||
/// The packet is not routable, and should be sent to the sink interface.
|
||||
Sink,
|
||||
/// The packet is not routable, and can be safely ignored.
|
||||
Drop,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue