Improve reliability using event-based synchronization

This commit is contained in:
Aram 🍐 2022-01-08 01:05:51 -05:00
parent 62b2641627
commit 51788c9557
12 changed files with 628 additions and 805 deletions

181
src/wg.rs
View file

@ -1,15 +1,15 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::time::Duration;
use crate::Bus;
use anyhow::Context;
use boringtun::noise::{Tunn, TunnResult};
use log::Level;
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket};
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet};
use tokio::net::UdpSocket;
use tokio::sync::RwLock;
use crate::config::{Config, PortProtocol};
use crate::virtual_iface::VirtualPort;
use crate::events::Event;
/// The capacity of the channel for received IP packets.
pub const DISPATCH_CAPACITY: usize = 1_000;
@ -26,17 +26,13 @@ pub struct WireGuardTunnel {
udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP).
pub(crate) endpoint: SocketAddr,
/// Maps virtual ports to the corresponding IP packet dispatcher.
virtual_port_ip_tx: dashmap::DashMap<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
/// The max transmission unit for WireGuard.
pub(crate) max_transmission_unit: usize,
/// Event bus
bus: Bus,
}
impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config) -> anyhow::Result<Self> {
pub async fn new(config: &Config, bus: Bus) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?;
let endpoint = config.endpoint_addr;
@ -46,16 +42,13 @@ impl WireGuardTunnel {
})
.await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let virtual_port_ip_tx = Default::default();
Ok(Self {
source_peer_ip,
peer,
udp,
endpoint,
virtual_port_ip_tx,
sink_ip_tx: RwLock::new(None),
max_transmission_unit: config.max_transmission_unit,
bus,
})
}
@ -90,31 +83,20 @@ impl WireGuardTunnel {
Ok(())
}
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub fn register_virtual_interface(
&self,
virtual_port: VirtualPort,
sender: tokio::sync::mpsc::Sender<Vec<u8>>,
) -> anyhow::Result<()> {
self.virtual_port_ip_tx.insert(virtual_port, sender);
Ok(())
}
pub async fn produce_task(&self) -> ! {
trace!("Starting WireGuard production task");
let mut endpoint = self.bus.new_endpoint();
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub async fn register_sink_interface(
&self,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
let mut sink_ip_tx = self.sink_ip_tx.write().await;
*sink_ip_tx = Some(sender);
Ok(receiver)
}
/// Releases the virtual interface from IP dispatch.
pub fn release_virtual_interface(&self, virtual_port: VirtualPort) {
self.virtual_port_ip_tx.remove(&virtual_port);
loop {
if let Event::OutboundInternetPacket(data) = endpoint.recv().await {
match self.send_ip_packet(&data).await {
Ok(_) => {}
Err(e) => {
error!("{:?}", e);
}
}
}
}
}
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
@ -160,6 +142,7 @@ impl WireGuardTunnel {
/// decapsulates them, and dispatches newly received IP packets.
pub async fn consume_task(&self) -> ! {
trace!("Starting WireGuard consumption task");
let endpoint = self.bus.new_endpoint();
loop {
let mut recv_buf = [0u8; MAX_PACKET];
@ -212,38 +195,8 @@ impl WireGuardTunnel {
// For debugging purposes: parse packet
trace_ip_packet("Received IP packet", packet);
match self.route_ip_packet(packet) {
RouteResult::Dispatch(port) => {
let sender = self.virtual_port_ip_tx.get(&port);
if let Some(sender_guard) = sender {
let sender = sender_guard.value();
match sender.send(packet.to_vec()).await {
Ok(_) => {
trace!(
"Dispatched received IP packet to virtual port {}",
port
);
}
Err(e) => {
error!(
"Failed to dispatch received IP packet to virtual port {}: {}",
port, e
);
}
}
} else {
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
}
}
RouteResult::Sink => {
trace!("Sending unroutable IP packet received from WireGuard endpoint to sink interface");
self.route_ip_sink(packet).await.unwrap_or_else(|e| {
error!("Failed to send unroutable IP packet to sink: {:?}", e)
});
}
RouteResult::Drop => {
trace!("Dropped unroutable IP packet received from WireGuard endpoint");
}
if let Some(proto) = self.route_protocol(packet) {
endpoint.send(Event::InboundInternetPacket(proto, packet.into()));
}
}
_ => {}
@ -264,89 +217,32 @@ impl WireGuardTunnel {
.with_context(|| "Failed to initialize boringtun Tunn")
}
/// Makes a decision on the handling of an incoming IP packet.
fn route_ip_packet(&self, packet: &[u8]) -> RouteResult {
/// Determine the inner protocol of the incoming IP packet (TCP/UDP).
fn route_protocol(&self, packet: &[u8]) -> Option<PortProtocol> {
match IpVersion::of_packet(packet) {
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
.ok()
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() {
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
_ => None,
})
.flatten()
.unwrap_or(RouteResult::Drop),
.flatten(),
Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet)
.ok()
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.next_header() {
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
_ => None,
})
.flatten()
.unwrap_or(RouteResult::Drop),
_ => RouteResult::Drop,
}
}
/// Makes a decision on the handling of an incoming TCP segment.
fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
TcpPacket::new_checked(segment)
.ok()
.map(|tcp| {
if self
.virtual_port_ip_tx
.get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
.is_some()
{
RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
} else if tcp.rst() {
RouteResult::Drop
} else {
RouteResult::Sink
}
})
.unwrap_or(RouteResult::Drop)
}
/// Makes a decision on the handling of an incoming UDP datagram.
fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult {
UdpPacket::new_checked(datagram)
.ok()
.map(|udp| {
if self
.virtual_port_ip_tx
.get(&VirtualPort(udp.dst_port(), PortProtocol::Udp))
.is_some()
{
RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp))
} else {
RouteResult::Drop
}
})
.unwrap_or(RouteResult::Drop)
}
/// Route a packet to the IP sink interface.
async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> {
let ip_sink_tx = self.sink_ip_tx.read().await;
if let Some(ip_sink_tx) = &*ip_sink_tx {
ip_sink_tx
.send(packet.to_vec())
.await
.with_context(|| "Failed to dispatch IP packet to sink interface")
} else {
warn!(
"Could not dispatch unroutable IP packet to sink because interface is not active."
);
Ok(())
.flatten(),
_ => None,
}
}
}
@ -370,12 +266,3 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
}
}
}
enum RouteResult {
/// Dispatch the packet to the virtual port.
Dispatch(VirtualPort),
/// The packet is not routable, and should be sent to the sink interface.
Sink,
/// The packet is not routable, and can be safely ignored.
Drop,
}