mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 17:38:32 -04:00
Improve reliability using event-based synchronization
This commit is contained in:
parent
62b2641627
commit
51788c9557
12 changed files with 628 additions and 805 deletions
181
src/wg.rs
181
src/wg.rs
|
@ -1,15 +1,15 @@
|
|||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::Bus;
|
||||
use anyhow::Context;
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use log::Level;
|
||||
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket};
|
||||
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::{Config, PortProtocol};
|
||||
use crate::virtual_iface::VirtualPort;
|
||||
use crate::events::Event;
|
||||
|
||||
/// The capacity of the channel for received IP packets.
|
||||
pub const DISPATCH_CAPACITY: usize = 1_000;
|
||||
|
@ -26,17 +26,13 @@ pub struct WireGuardTunnel {
|
|||
udp: UdpSocket,
|
||||
/// The address of the public WireGuard endpoint (UDP).
|
||||
pub(crate) endpoint: SocketAddr,
|
||||
/// Maps virtual ports to the corresponding IP packet dispatcher.
|
||||
virtual_port_ip_tx: dashmap::DashMap<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
|
||||
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
|
||||
/// The max transmission unit for WireGuard.
|
||||
pub(crate) max_transmission_unit: usize,
|
||||
/// Event bus
|
||||
bus: Bus,
|
||||
}
|
||||
|
||||
impl WireGuardTunnel {
|
||||
/// Initialize a new WireGuard tunnel.
|
||||
pub async fn new(config: &Config) -> anyhow::Result<Self> {
|
||||
pub async fn new(config: &Config, bus: Bus) -> anyhow::Result<Self> {
|
||||
let source_peer_ip = config.source_peer_ip;
|
||||
let peer = Self::create_tunnel(config)?;
|
||||
let endpoint = config.endpoint_addr;
|
||||
|
@ -46,16 +42,13 @@ impl WireGuardTunnel {
|
|||
})
|
||||
.await
|
||||
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
|
||||
let virtual_port_ip_tx = Default::default();
|
||||
|
||||
Ok(Self {
|
||||
source_peer_ip,
|
||||
peer,
|
||||
udp,
|
||||
endpoint,
|
||||
virtual_port_ip_tx,
|
||||
sink_ip_tx: RwLock::new(None),
|
||||
max_transmission_unit: config.max_transmission_unit,
|
||||
bus,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -90,31 +83,20 @@ impl WireGuardTunnel {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
|
||||
pub fn register_virtual_interface(
|
||||
&self,
|
||||
virtual_port: VirtualPort,
|
||||
sender: tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.virtual_port_ip_tx.insert(virtual_port, sender);
|
||||
Ok(())
|
||||
}
|
||||
pub async fn produce_task(&self) -> ! {
|
||||
trace!("Starting WireGuard production task");
|
||||
let mut endpoint = self.bus.new_endpoint();
|
||||
|
||||
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
|
||||
pub async fn register_sink_interface(
|
||||
&self,
|
||||
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
||||
|
||||
let mut sink_ip_tx = self.sink_ip_tx.write().await;
|
||||
*sink_ip_tx = Some(sender);
|
||||
|
||||
Ok(receiver)
|
||||
}
|
||||
|
||||
/// Releases the virtual interface from IP dispatch.
|
||||
pub fn release_virtual_interface(&self, virtual_port: VirtualPort) {
|
||||
self.virtual_port_ip_tx.remove(&virtual_port);
|
||||
loop {
|
||||
if let Event::OutboundInternetPacket(data) = endpoint.recv().await {
|
||||
match self.send_ip_packet(&data).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
error!("{:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||
|
@ -160,6 +142,7 @@ impl WireGuardTunnel {
|
|||
/// decapsulates them, and dispatches newly received IP packets.
|
||||
pub async fn consume_task(&self) -> ! {
|
||||
trace!("Starting WireGuard consumption task");
|
||||
let endpoint = self.bus.new_endpoint();
|
||||
|
||||
loop {
|
||||
let mut recv_buf = [0u8; MAX_PACKET];
|
||||
|
@ -212,38 +195,8 @@ impl WireGuardTunnel {
|
|||
// For debugging purposes: parse packet
|
||||
trace_ip_packet("Received IP packet", packet);
|
||||
|
||||
match self.route_ip_packet(packet) {
|
||||
RouteResult::Dispatch(port) => {
|
||||
let sender = self.virtual_port_ip_tx.get(&port);
|
||||
if let Some(sender_guard) = sender {
|
||||
let sender = sender_guard.value();
|
||||
match sender.send(packet.to_vec()).await {
|
||||
Ok(_) => {
|
||||
trace!(
|
||||
"Dispatched received IP packet to virtual port {}",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to dispatch received IP packet to virtual port {}: {}",
|
||||
port, e
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
|
||||
}
|
||||
}
|
||||
RouteResult::Sink => {
|
||||
trace!("Sending unroutable IP packet received from WireGuard endpoint to sink interface");
|
||||
self.route_ip_sink(packet).await.unwrap_or_else(|e| {
|
||||
error!("Failed to send unroutable IP packet to sink: {:?}", e)
|
||||
});
|
||||
}
|
||||
RouteResult::Drop => {
|
||||
trace!("Dropped unroutable IP packet received from WireGuard endpoint");
|
||||
}
|
||||
if let Some(proto) = self.route_protocol(packet) {
|
||||
endpoint.send(Event::InboundInternetPacket(proto, packet.into()));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
|
@ -264,89 +217,32 @@ impl WireGuardTunnel {
|
|||
.with_context(|| "Failed to initialize boringtun Tunn")
|
||||
}
|
||||
|
||||
/// Makes a decision on the handling of an incoming IP packet.
|
||||
fn route_ip_packet(&self, packet: &[u8]) -> RouteResult {
|
||||
/// Determine the inner protocol of the incoming IP packet (TCP/UDP).
|
||||
fn route_protocol(&self, packet: &[u8]) -> Option<PortProtocol> {
|
||||
match IpVersion::of_packet(packet) {
|
||||
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
|
||||
.ok()
|
||||
// Only care if the packet is destined for this tunnel
|
||||
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
|
||||
.map(|packet| match packet.protocol() {
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
|
||||
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
|
||||
IpProtocol::Tcp => Some(PortProtocol::Tcp),
|
||||
IpProtocol::Udp => Some(PortProtocol::Udp),
|
||||
// Unrecognized protocol, so we cannot determine where to route
|
||||
_ => Some(RouteResult::Drop),
|
||||
_ => None,
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or(RouteResult::Drop),
|
||||
.flatten(),
|
||||
Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet)
|
||||
.ok()
|
||||
// Only care if the packet is destined for this tunnel
|
||||
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
|
||||
.map(|packet| match packet.next_header() {
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
|
||||
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
|
||||
IpProtocol::Tcp => Some(PortProtocol::Tcp),
|
||||
IpProtocol::Udp => Some(PortProtocol::Udp),
|
||||
// Unrecognized protocol, so we cannot determine where to route
|
||||
_ => Some(RouteResult::Drop),
|
||||
_ => None,
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or(RouteResult::Drop),
|
||||
_ => RouteResult::Drop,
|
||||
}
|
||||
}
|
||||
|
||||
/// Makes a decision on the handling of an incoming TCP segment.
|
||||
fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
|
||||
TcpPacket::new_checked(segment)
|
||||
.ok()
|
||||
.map(|tcp| {
|
||||
if self
|
||||
.virtual_port_ip_tx
|
||||
.get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
|
||||
.is_some()
|
||||
{
|
||||
RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
|
||||
} else if tcp.rst() {
|
||||
RouteResult::Drop
|
||||
} else {
|
||||
RouteResult::Sink
|
||||
}
|
||||
})
|
||||
.unwrap_or(RouteResult::Drop)
|
||||
}
|
||||
|
||||
/// Makes a decision on the handling of an incoming UDP datagram.
|
||||
fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult {
|
||||
UdpPacket::new_checked(datagram)
|
||||
.ok()
|
||||
.map(|udp| {
|
||||
if self
|
||||
.virtual_port_ip_tx
|
||||
.get(&VirtualPort(udp.dst_port(), PortProtocol::Udp))
|
||||
.is_some()
|
||||
{
|
||||
RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp))
|
||||
} else {
|
||||
RouteResult::Drop
|
||||
}
|
||||
})
|
||||
.unwrap_or(RouteResult::Drop)
|
||||
}
|
||||
|
||||
/// Route a packet to the IP sink interface.
|
||||
async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> {
|
||||
let ip_sink_tx = self.sink_ip_tx.read().await;
|
||||
|
||||
if let Some(ip_sink_tx) = &*ip_sink_tx {
|
||||
ip_sink_tx
|
||||
.send(packet.to_vec())
|
||||
.await
|
||||
.with_context(|| "Failed to dispatch IP packet to sink interface")
|
||||
} else {
|
||||
warn!(
|
||||
"Could not dispatch unroutable IP packet to sink because interface is not active."
|
||||
);
|
||||
Ok(())
|
||||
.flatten(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -370,12 +266,3 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum RouteResult {
|
||||
/// Dispatch the packet to the virtual port.
|
||||
Dispatch(VirtualPort),
|
||||
/// The packet is not routable, and should be sent to the sink interface.
|
||||
Sink,
|
||||
/// The packet is not routable, and can be safely ignored.
|
||||
Drop,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue