mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 23:18:31 -04:00
Improve reliability using event-based synchronization
This commit is contained in:
parent
62b2641627
commit
51788c9557
12 changed files with 628 additions and 805 deletions
|
@ -14,10 +14,51 @@ pub trait VirtualInterfacePoll {
|
|||
|
||||
/// Virtual port.
|
||||
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct VirtualPort(pub u16, pub PortProtocol);
|
||||
pub struct VirtualPort(u16, PortProtocol);
|
||||
|
||||
impl VirtualPort {
|
||||
/// Create a new `VirtualPort` instance, with the given port number and associated protocol.
|
||||
pub fn new(port: u16, proto: PortProtocol) -> Self {
|
||||
VirtualPort(port, proto)
|
||||
}
|
||||
|
||||
/// The port number
|
||||
pub fn num(&self) -> u16 {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// The protocol of this port.
|
||||
pub fn proto(&self) -> PortProtocol {
|
||||
self.1
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VirtualPort> for u16 {
|
||||
fn from(port: VirtualPort) -> Self {
|
||||
port.num()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&VirtualPort> for u16 {
|
||||
fn from(port: &VirtualPort) -> Self {
|
||||
port.num()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VirtualPort> for PortProtocol {
|
||||
fn from(port: VirtualPort) -> Self {
|
||||
port.proto()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&VirtualPort> for PortProtocol {
|
||||
fn from(port: &VirtualPort) -> Self {
|
||||
port.proto()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for VirtualPort {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[{}:{}]", self.0, self.1)
|
||||
write!(f, "[{}:{}]", self.num(), self.proto())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,284 +1,253 @@
|
|||
use crate::config::{PortForwardConfig, PortProtocol};
|
||||
use crate::events::Event;
|
||||
use crate::virtual_device::VirtualIpDevice;
|
||||
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
||||
use crate::wg::WireGuardTunnel;
|
||||
use crate::Bus;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use smoltcp::iface::InterfaceBuilder;
|
||||
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
|
||||
use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState};
|
||||
use smoltcp::wire::{IpAddress, IpCidr};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
const MAX_PACKET: usize = 65536;
|
||||
|
||||
/// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa.
|
||||
pub struct TcpVirtualInterface {
|
||||
/// The virtual port assigned to the virtual client, used to
|
||||
/// route Layer 4 segments/datagrams to and from the WireGuard tunnel.
|
||||
virtual_port: u16,
|
||||
/// The overall port-forward configuration: used for the destination address (on which
|
||||
/// the virtual server listens) and the protocol in use.
|
||||
port_forward: PortForwardConfig,
|
||||
/// The WireGuard tunnel to send IP packets to.
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
/// Abort signal to shutdown the virtual interface and its parent task.
|
||||
abort: Arc<AtomicBool>,
|
||||
/// Channel sender for pushing Layer 7 data back to the real client.
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
/// Channel receiver for processing Layer 7 data through the virtual interface.
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
/// One-shot sender to notify the parent task that the virtual client is ready to send Layer 7 data.
|
||||
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
|
||||
source_peer_ip: IpAddr,
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
device: VirtualIpDevice,
|
||||
bus: Bus,
|
||||
}
|
||||
|
||||
impl TcpVirtualInterface {
|
||||
/// Initialize the parameters for a new virtual interface.
|
||||
/// Use the `poll_loop()` future to start the virtual interface poll loop.
|
||||
pub fn new(
|
||||
virtual_port: u16,
|
||||
port_forward: PortForwardConfig,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
abort: Arc<AtomicBool>,
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
bus: Bus,
|
||||
device: VirtualIpDevice,
|
||||
source_peer_ip: IpAddr,
|
||||
) -> Self {
|
||||
Self {
|
||||
virtual_port,
|
||||
port_forward,
|
||||
wg,
|
||||
abort,
|
||||
data_to_real_client_tx,
|
||||
data_to_virtual_server_rx,
|
||||
virtual_client_ready_tx,
|
||||
port_forwards: port_forwards
|
||||
.into_iter()
|
||||
.filter(|f| matches!(f.protocol, PortProtocol::Tcp))
|
||||
.collect(),
|
||||
device,
|
||||
source_peer_ip,
|
||||
bus,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<TcpSocket<'static>> {
|
||||
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
|
||||
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
|
||||
|
||||
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
|
||||
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
|
||||
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
|
||||
|
||||
socket
|
||||
.listen((
|
||||
IpAddress::from(port_forward.destination.ip()),
|
||||
port_forward.destination.port(),
|
||||
))
|
||||
.with_context(|| "Virtual server socket failed to listen")?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
fn new_client_socket() -> anyhow::Result<TcpSocket<'static>> {
|
||||
let rx_data = vec![0u8; MAX_PACKET];
|
||||
let tx_data = vec![0u8; MAX_PACKET];
|
||||
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data);
|
||||
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data);
|
||||
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
|
||||
Ok(socket)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VirtualInterfacePoll for TcpVirtualInterface {
|
||||
async fn poll_loop(self) -> anyhow::Result<()> {
|
||||
let mut virtual_client_ready_tx = Some(self.virtual_client_ready_tx);
|
||||
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
|
||||
let source_peer_ip = self.wg.source_peer_ip;
|
||||
// Create CIDR block for source peer IP + each port forward IP
|
||||
let addresses: Vec<IpCidr> = {
|
||||
let mut addresses = HashSet::new();
|
||||
addresses.insert(IpAddress::from(self.source_peer_ip));
|
||||
for config in self.port_forwards.iter() {
|
||||
addresses.insert(IpAddress::from(config.destination.ip()));
|
||||
}
|
||||
addresses
|
||||
.into_iter()
|
||||
.map(|addr| IpCidr::new(addr, 32))
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Create a device and interface to simulate IP packets
|
||||
// In essence:
|
||||
// * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client'
|
||||
// * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel
|
||||
// * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface'
|
||||
// * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded)
|
||||
// * The TCP data read by the 'virtual client' is sent to the 'real' TCP client
|
||||
|
||||
// Consumer for IP packets to send through the virtual interface
|
||||
// Initialize the interface
|
||||
let device =
|
||||
VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg)
|
||||
.with_context(|| "Failed to initialize TCP VirtualIpDevice")?;
|
||||
|
||||
// there are always 2 sockets: 1 virtual client and 1 virtual server.
|
||||
let mut sockets: [_; 2] = Default::default();
|
||||
let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..])
|
||||
.ip_addrs([
|
||||
// Interface handles IP packets for the sender and recipient
|
||||
IpCidr::new(IpAddress::from(source_peer_ip), 32),
|
||||
IpCidr::new(IpAddress::from(self.port_forward.destination.ip()), 32),
|
||||
])
|
||||
let mut iface = InterfaceBuilder::new(self.device, vec![])
|
||||
.ip_addrs(addresses)
|
||||
.finalize();
|
||||
|
||||
// Server socket: this is a placeholder for the interface to route new connections to.
|
||||
let server_socket: anyhow::Result<TcpSocket> = {
|
||||
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
|
||||
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
|
||||
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
|
||||
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
|
||||
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
|
||||
// Create virtual server for each port forward
|
||||
for port_forward in self.port_forwards.iter() {
|
||||
let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?;
|
||||
iface.add_socket(server_socket);
|
||||
}
|
||||
|
||||
socket
|
||||
.listen((
|
||||
IpAddress::from(self.port_forward.destination.ip()),
|
||||
self.port_forward.destination.port(),
|
||||
))
|
||||
.with_context(|| "Virtual server socket failed to listen")?;
|
||||
// The next time to poll the interface. Can be None for instant poll.
|
||||
let mut next_poll: Option<tokio::time::Instant> = None;
|
||||
|
||||
Ok(socket)
|
||||
};
|
||||
// Bus endpoint to read events
|
||||
let mut endpoint = self.bus.new_endpoint();
|
||||
|
||||
let client_socket: anyhow::Result<TcpSocket> = {
|
||||
let rx_data = vec![0u8; MAX_PACKET];
|
||||
let tx_data = vec![0u8; MAX_PACKET];
|
||||
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data);
|
||||
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data);
|
||||
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
|
||||
Ok(socket)
|
||||
};
|
||||
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
|
||||
|
||||
let _server_handle = virtual_interface.add_socket(server_socket?);
|
||||
let client_handle = virtual_interface.add_socket(client_socket?);
|
||||
|
||||
// Any data that wasn't sent because it was over the sending buffer limit
|
||||
let mut tx_extra = Vec::new();
|
||||
|
||||
// Counts the connection attempts by the virtual client
|
||||
let mut connection_attempts = 0;
|
||||
// Whether the client has successfully connected before. Prevents the case of connecting again.
|
||||
let mut has_connected = false;
|
||||
// Data packets to send from a virtual client
|
||||
let mut send_queue: HashMap<VirtualPort, VecDeque<Vec<u8>>> = HashMap::new();
|
||||
|
||||
loop {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
tokio::select! {
|
||||
_ = match (next_poll, port_client_handle_map.len()) {
|
||||
(None, 0) => tokio::time::sleep(Duration::MAX),
|
||||
(None, _) => tokio::time::sleep(Duration::ZERO),
|
||||
(Some(until), _) => tokio::time::sleep_until(until),
|
||||
} => {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
|
||||
// Shutdown occurs when the real client closes the connection,
|
||||
// or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore.
|
||||
// One last poll-loop iteration is executed so that the RST segment can be dispatched.
|
||||
let shutdown = self.abort.load(Ordering::Relaxed);
|
||||
|
||||
if shutdown {
|
||||
// Shutdown: sends a RST packet.
|
||||
trace!("[{}] Shutting down virtual interface", self.virtual_port);
|
||||
let client_socket = virtual_interface.get_socket::<TcpSocket>(client_handle);
|
||||
client_socket.abort();
|
||||
}
|
||||
|
||||
match virtual_interface.poll(loop_start) {
|
||||
Ok(processed) if processed => {
|
||||
trace!(
|
||||
"[{}] Virtual interface polled some packets to be processed",
|
||||
self.virtual_port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Virtual interface poll error: {:?}",
|
||||
self.virtual_port, e
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
{
|
||||
let (client_socket, context) =
|
||||
virtual_interface.get_socket_and_context::<TcpSocket>(client_handle);
|
||||
|
||||
if !shutdown && client_socket.state() == TcpState::Closed && !has_connected {
|
||||
// Not shutting down, but the client socket is closed, and the client never successfully connected.
|
||||
if connection_attempts < 10 {
|
||||
// Try to connect
|
||||
client_socket
|
||||
.connect(
|
||||
context,
|
||||
(
|
||||
IpAddress::from(self.port_forward.destination.ip()),
|
||||
self.port_forward.destination.port(),
|
||||
),
|
||||
(IpAddress::from(source_peer_ip), self.virtual_port),
|
||||
)
|
||||
.with_context(|| "Virtual server socket failed to listen")?;
|
||||
if connection_attempts > 0 {
|
||||
debug!(
|
||||
"[{}] Virtual client retrying connection in 500ms",
|
||||
self.virtual_port
|
||||
);
|
||||
// Not our first connection attempt, wait a little bit.
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
match iface.poll(loop_start) {
|
||||
Ok(processed) if processed => {
|
||||
trace!("TCP virtual interface polled some packets to be processed");
|
||||
}
|
||||
} else {
|
||||
// Too many connection attempts
|
||||
self.abort.store(true, Ordering::Relaxed);
|
||||
}
|
||||
connection_attempts += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if client_socket.state() == TcpState::Established {
|
||||
// Prevent reconnection if the server later closes.
|
||||
has_connected = true;
|
||||
}
|
||||
|
||||
if client_socket.can_recv() {
|
||||
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
|
||||
Ok(data) => {
|
||||
trace!(
|
||||
"[{}] Virtual client received {} bytes of data",
|
||||
self.virtual_port,
|
||||
data.len()
|
||||
);
|
||||
// Send it to the real client
|
||||
if let Err(e) = self.data_to_real_client_tx.send(data).await {
|
||||
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", self.virtual_port, e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Failed to read from virtual client socket: {:?}",
|
||||
self.virtual_port, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
if client_socket.can_send() {
|
||||
if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() {
|
||||
virtual_client_ready_tx
|
||||
.send(())
|
||||
.expect("Failed to notify real client that virtual client is ready");
|
||||
Err(e) => error!("TCP virtual interface poll error: {:?}", e),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let mut to_transfer = None;
|
||||
|
||||
if tx_extra.is_empty() {
|
||||
// The payload segment from the previous loop is complete,
|
||||
// we can now read the next payload in the queue.
|
||||
if let Ok(data) = data_to_virtual_server_rx.try_recv() {
|
||||
to_transfer = Some(data);
|
||||
} else if client_socket.state() == TcpState::CloseWait {
|
||||
// No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN),
|
||||
// the interface is shutdown.
|
||||
trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", self.virtual_port);
|
||||
self.abort.store(true, Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice();
|
||||
if !to_transfer_slice.is_empty() {
|
||||
let total = to_transfer_slice.len();
|
||||
match client_socket.send_slice(to_transfer_slice) {
|
||||
Ok(sent) => {
|
||||
trace!(
|
||||
"[{}] Sent {}/{} bytes via virtual client socket",
|
||||
self.virtual_port,
|
||||
sent,
|
||||
total,
|
||||
);
|
||||
tx_extra = Vec::from(&to_transfer_slice[sent..total]);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Failed to send slice via virtual client socket: {:?}",
|
||||
self.virtual_port, e
|
||||
);
|
||||
// Find client socket send data to
|
||||
for (virtual_port, client_handle) in port_client_handle_map.iter() {
|
||||
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
|
||||
if client_socket.can_send() {
|
||||
if let Some(send_queue) = send_queue.get_mut(virtual_port) {
|
||||
let to_transfer = send_queue.pop_front();
|
||||
if let Some(to_transfer_slice) = to_transfer.as_deref() {
|
||||
let total = to_transfer_slice.len();
|
||||
match client_socket.send_slice(to_transfer_slice) {
|
||||
Ok(sent) => {
|
||||
if sent < total {
|
||||
// Sometimes only a subset is sent, so the rest needs to be sent on the next poll
|
||||
let tx_extra = Vec::from(&to_transfer_slice[sent..total]);
|
||||
send_queue.push_front(tx_extra);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to send slice via virtual client socket: {:?}", e
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if client_socket.state() == TcpState::CloseWait {
|
||||
client_socket.close();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shutdown {
|
||||
break;
|
||||
}
|
||||
// Find client socket recv data from
|
||||
for (virtual_port, client_handle) in port_client_handle_map.iter() {
|
||||
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
|
||||
if client_socket.can_recv() {
|
||||
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
|
||||
Ok(data) => {
|
||||
if !data.is_empty() {
|
||||
endpoint.send(Event::RemoteData(*virtual_port, data));
|
||||
break;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to read from virtual client socket: {:?}", e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match virtual_interface.poll_delay(loop_start) {
|
||||
Some(smoltcp::time::Duration::ZERO) => {
|
||||
continue;
|
||||
// Find closed sockets
|
||||
port_client_handle_map.retain(|virtual_port, client_handle| {
|
||||
let client_socket = iface.get_socket::<TcpSocket>(*client_handle);
|
||||
if client_socket.state() == TcpState::Closed {
|
||||
endpoint.send(Event::ClientConnectionDropped(*virtual_port));
|
||||
send_queue.remove(virtual_port);
|
||||
false
|
||||
} else {
|
||||
// Not closed, retain
|
||||
true
|
||||
}
|
||||
});
|
||||
|
||||
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
|
||||
next_poll = match iface.poll_delay(loop_start) {
|
||||
Some(smoltcp::time::Duration::ZERO) => None,
|
||||
Some(delay) => {
|
||||
trace!("TCP Virtual interface delayed next poll by {}", delay);
|
||||
Some(tokio::time::Instant::now() + Duration::from_millis(delay.total_millis()))
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
}
|
||||
_ => {
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
event = endpoint.recv() => {
|
||||
match event {
|
||||
Event::ClientConnectionInitiated(port_forward, virtual_port) => {
|
||||
let client_socket = TcpVirtualInterface::new_client_socket()?;
|
||||
let client_handle = iface.add_socket(client_socket);
|
||||
|
||||
// Add handle to map
|
||||
port_client_handle_map.insert(virtual_port, client_handle);
|
||||
send_queue.insert(virtual_port, VecDeque::new());
|
||||
|
||||
let (client_socket, context) = iface.get_socket_and_context::<TcpSocket>(client_handle);
|
||||
|
||||
client_socket
|
||||
.connect(
|
||||
context,
|
||||
(
|
||||
IpAddress::from(port_forward.destination.ip()),
|
||||
port_forward.destination.port(),
|
||||
),
|
||||
(IpAddress::from(self.source_peer_ip), virtual_port.num()),
|
||||
)
|
||||
.with_context(|| "Virtual server socket failed to listen")?;
|
||||
|
||||
next_poll = None;
|
||||
}
|
||||
Event::ClientConnectionDropped(virtual_port) => {
|
||||
if let Some(client_handle) = port_client_handle_map.get(&virtual_port) {
|
||||
let client_handle = *client_handle;
|
||||
port_client_handle_map.remove(&virtual_port);
|
||||
send_queue.remove(&virtual_port);
|
||||
|
||||
let client_socket = iface.get_socket::<TcpSocket>(client_handle);
|
||||
client_socket.close();
|
||||
next_poll = None;
|
||||
}
|
||||
}
|
||||
Event::LocalData(virtual_port, data) if send_queue.contains_key(&virtual_port) => {
|
||||
if let Some(send_queue) = send_queue.get_mut(&virtual_port) {
|
||||
send_queue.push_back(data);
|
||||
next_poll = None;
|
||||
}
|
||||
}
|
||||
Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Tcp => {
|
||||
next_poll = None;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!("[{}] Virtual interface task terminated", self.virtual_port);
|
||||
self.abort.store(true, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,39 +1,39 @@
|
|||
use anyhow::Context;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
#![allow(dead_code)]
|
||||
use std::net::IpAddr;
|
||||
|
||||
use crate::{Bus, PortProtocol};
|
||||
use async_trait::async_trait;
|
||||
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
|
||||
use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer};
|
||||
use smoltcp::wire::{IpAddress, IpCidr};
|
||||
|
||||
use crate::config::PortForwardConfig;
|
||||
use crate::virtual_device::VirtualIpDevice;
|
||||
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
||||
use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
|
||||
use crate::virtual_iface::VirtualInterfacePoll;
|
||||
|
||||
const MAX_PACKET: usize = 65536;
|
||||
|
||||
pub struct UdpVirtualInterface {
|
||||
port_forward: PortForwardConfig,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
|
||||
source_peer_ip: IpAddr,
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
device: VirtualIpDevice,
|
||||
bus: Bus,
|
||||
}
|
||||
|
||||
impl UdpVirtualInterface {
|
||||
/// Initialize the parameters for a new virtual interface.
|
||||
/// Use the `poll_loop()` future to start the virtual interface poll loop.
|
||||
pub fn new(
|
||||
port_forward: PortForwardConfig,
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
|
||||
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
|
||||
port_forwards: Vec<PortForwardConfig>,
|
||||
bus: Bus,
|
||||
device: VirtualIpDevice,
|
||||
source_peer_ip: IpAddr,
|
||||
) -> Self {
|
||||
Self {
|
||||
port_forward,
|
||||
wg,
|
||||
data_to_real_client_tx,
|
||||
data_to_virtual_server_rx,
|
||||
port_forwards: port_forwards
|
||||
.into_iter()
|
||||
.filter(|f| matches!(f.protocol, PortProtocol::Udp))
|
||||
.collect(),
|
||||
device,
|
||||
source_peer_ip,
|
||||
bus,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -41,160 +41,9 @@ impl UdpVirtualInterface {
|
|||
#[async_trait]
|
||||
impl VirtualInterfacePoll for UdpVirtualInterface {
|
||||
async fn poll_loop(self) -> anyhow::Result<()> {
|
||||
// Data receiver to dispatch using virtual client sockets
|
||||
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
|
||||
|
||||
// The IP to bind client sockets to
|
||||
let source_peer_ip = self.wg.source_peer_ip;
|
||||
|
||||
// The IP/port to bind the server socket to
|
||||
let destination = self.port_forward.destination;
|
||||
|
||||
// Initialize a channel for IP packets.
|
||||
// The "base transmitted" is cloned so that each virtual port can register a sender in the tunnel.
|
||||
// The receiver is given to the device so that the Virtual Interface can process incoming IP packets from the tunnel.
|
||||
let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
||||
|
||||
let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx);
|
||||
let mut virtual_interface = InterfaceBuilder::new(device, vec![])
|
||||
.ip_addrs([
|
||||
// Interface handles IP packets for the sender and recipient
|
||||
IpCidr::new(source_peer_ip.into(), 32),
|
||||
IpCidr::new(destination.ip().into(), 32),
|
||||
])
|
||||
.finalize();
|
||||
|
||||
// Server socket: this is a placeholder for the interface.
|
||||
let server_socket: anyhow::Result<UdpSocket> = {
|
||||
static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = [];
|
||||
static mut UDP_SERVER_RX_DATA: [u8; 0] = [];
|
||||
static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = [];
|
||||
static mut UDP_SERVER_TX_DATA: [u8; 0] = [];
|
||||
let udp_rx_buffer =
|
||||
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe {
|
||||
&mut UDP_SERVER_RX_DATA[..]
|
||||
});
|
||||
let udp_tx_buffer =
|
||||
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe {
|
||||
&mut UDP_SERVER_TX_DATA[..]
|
||||
});
|
||||
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
|
||||
|
||||
socket
|
||||
.bind((IpAddress::from(destination.ip()), destination.port()))
|
||||
.with_context(|| "UDP virtual server socket failed to listen")?;
|
||||
|
||||
Ok(socket)
|
||||
};
|
||||
|
||||
let _server_handle = virtual_interface.add_socket(server_socket?);
|
||||
|
||||
// A map of virtual port to client socket.
|
||||
let mut client_sockets: HashMap<VirtualPort, SocketHandle> = HashMap::new();
|
||||
|
||||
// The next instant required to poll the virtual interface
|
||||
// None means "immediate poll required".
|
||||
let mut next_poll: Option<tokio::time::Instant> = None;
|
||||
|
||||
loop {
|
||||
let wg = self.wg.clone();
|
||||
tokio::select! {
|
||||
// Wait the recommended amount of time by smoltcp, and poll again.
|
||||
_ = match next_poll {
|
||||
None => tokio::time::sleep(Duration::ZERO),
|
||||
Some(until) => tokio::time::sleep_until(until)
|
||||
} => {
|
||||
let loop_start = smoltcp::time::Instant::now();
|
||||
|
||||
match virtual_interface.poll(loop_start) {
|
||||
Ok(processed) if processed => {
|
||||
trace!("UDP virtual interface polled some packets to be processed");
|
||||
}
|
||||
Err(e) => error!("UDP virtual interface poll error: {:?}", e),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Loop through each client socket and check if there is any data to send back
|
||||
// to the real client.
|
||||
for (virtual_port, client_socket_handle) in client_sockets.iter() {
|
||||
let client_socket = virtual_interface.get_socket::<UdpSocket>(*client_socket_handle);
|
||||
match client_socket.recv() {
|
||||
Ok((data, _peer)) => {
|
||||
// Send the data back to the real client using MPSC channel
|
||||
self.data_to_real_client_tx
|
||||
.send((*virtual_port, data.to_vec()))
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] Failed to dispatch data from virtual client to real client: {:?}",
|
||||
virtual_port, e
|
||||
);
|
||||
});
|
||||
}
|
||||
Err(smoltcp::Error::Exhausted) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"[{}] Failed to read from virtual client socket: {:?}",
|
||||
virtual_port, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next_poll = match virtual_interface.poll_delay(loop_start) {
|
||||
Some(smoltcp::time::Duration::ZERO) => None,
|
||||
Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
// Wait for data to be received from the real client
|
||||
data_recv_result = data_to_virtual_server_rx.recv() => {
|
||||
if let Some((client_port, data)) = data_recv_result {
|
||||
// Register the socket in WireGuard Tunnel (overrides any previous registration as well)
|
||||
wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone())
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] Failed to register UDP socket in WireGuard tunnel: {:?}",
|
||||
client_port, e
|
||||
);
|
||||
});
|
||||
|
||||
let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| {
|
||||
let rx_meta = vec![UdpPacketMetadata::EMPTY; 10];
|
||||
let tx_meta = vec![UdpPacketMetadata::EMPTY; 10];
|
||||
let rx_data = vec![0u8; MAX_PACKET];
|
||||
let tx_data = vec![0u8; MAX_PACKET];
|
||||
let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data);
|
||||
let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data);
|
||||
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
|
||||
|
||||
socket
|
||||
.bind((IpAddress::from(wg.source_peer_ip), client_port.0))
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] UDP virtual client socket failed to bind: {:?}",
|
||||
client_port, e
|
||||
);
|
||||
});
|
||||
|
||||
virtual_interface.add_socket(socket)
|
||||
});
|
||||
|
||||
let client_socket = virtual_interface.get_socket::<UdpSocket>(*client_socket_handle);
|
||||
client_socket
|
||||
.send_slice(
|
||||
&data,
|
||||
(IpAddress::from(destination.ip()), destination.port()).into(),
|
||||
)
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"[{}] Failed to send data to virtual server: {:?}",
|
||||
client_port, e
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: Create smoltcp virtual device and interface
|
||||
// TODO: Create smoltcp virtual servers for `port_forwards`
|
||||
// TODO: listen on events
|
||||
futures::future::pending().await
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue