Remove broadcasting logic to fix simultaneous connection issues.

This commit is contained in:
Aram 🍐 2021-10-16 19:16:10 -04:00
parent b6739a5066
commit cfdbdc8f51
5 changed files with 86 additions and 123 deletions

View file

@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> {
init_logger(&config)?;
let port_pool = Arc::new(PortPool::new());
let wg = WireGuardTunnel::new(&config, port_pool.clone())
let wg = WireGuardTunnel::new(&config)
.await
.with_context(|| "Failed to initialize WireGuard tunnel")?;
let wg = Arc::new(wg);
@ -47,12 +47,6 @@ async fn main() -> anyhow::Result<()> {
tokio::spawn(async move { wg.consume_task().await });
}
{
// Start IP broadcast drain task for WireGuard
let wg = wg.clone();
tokio::spawn(async move { wg.broadcast_drain_task().await });
}
info!(
"Tunnelling [{}]->[{}] (via [{}] as peer {})",
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
@ -106,9 +100,14 @@ async fn tcp_proxy_server(
tokio::spawn(async move {
let port_pool = Arc::clone(&port_pool);
let result =
handle_tcp_proxy_connection(socket, virtual_port, source_peer_ip, dest_addr, wg)
.await;
let result = handle_tcp_proxy_connection(
socket,
virtual_port,
source_peer_ip,
dest_addr,
wg.clone(),
)
.await;
if let Err(e) = result {
error!(
@ -120,6 +119,7 @@ async fn tcp_proxy_server(
}
// Release port when connection drops
wg.release_virtual_interface(virtual_port);
port_pool.release(virtual_port);
});
}
@ -270,7 +270,8 @@ async fn virtual_tcp_interface(
// Consumer for IP packets to send through the virtual interface
// Initialize the interface
let device = VirtualIpDevice::new(wg);
let device = VirtualIpDevice::new(virtual_port, wg)
.with_context(|| "Failed to initialize VirtualIpDevice")?;
let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient

View file

@ -13,7 +13,7 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
pub struct PortPool {
/// Remaining ports
inner: lockfree::queue::Queue<u16>,
/// Ports in use
/// Ports in use, with their associated IP channel sender.
taken: lockfree::set::Set<u16>,
}

View file

@ -1,26 +1,26 @@
use crate::wg::WireGuardTunnel;
use anyhow::Context;
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant;
use std::sync::Arc;
/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint
/// are made available to this device using a broadcast channel receiver. IP packets sent from this device
/// are made available to this device using a channel receiver. IP packets sent from this device
/// are asynchronously sent out to the WireGuard tunnel.
pub struct VirtualIpDevice {
/// Tunnel to send IP packets to.
wg: Arc<WireGuardTunnel>,
/// Broadcast channel receiver for received IP packets.
ip_broadcast_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
/// Channel receiver for received IP packets.
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
}
impl VirtualIpDevice {
pub fn new(wg: Arc<WireGuardTunnel>) -> Self {
let ip_broadcast_rx = wg.subscribe();
pub fn new(virtual_port: u16, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let ip_dispatch_rx = wg
.register_virtual_interface(virtual_port)
.with_context(|| "Failed to register IP dispatch for virtual interface")?;
Self {
wg,
ip_broadcast_rx,
}
Ok(Self { wg, ip_dispatch_rx })
}
}
@ -29,7 +29,7 @@ impl<'a> Device<'a> for VirtualIpDevice {
type TxToken = TxToken;
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
match self.ip_broadcast_rx.try_recv() {
match self.ip_dispatch_rx.try_recv() {
Ok(buffer) => Some((
Self::RxToken { buffer },
Self::TxToken {

162
src/wg.rs
View file

@ -1,10 +1,8 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use boringtun::noise::{Tunn, TunnResult};
use futures::lock::Mutex;
use log::Level;
use smoltcp::phy::ChecksumCapabilities;
use smoltcp::wire::{
@ -12,14 +10,12 @@ use smoltcp::wire::{
TcpPacket, TcpRepr, TcpSeqNumber,
};
use tokio::net::UdpSocket;
use tokio::sync::broadcast::error::RecvError;
use crate::config::Config;
use crate::port_pool::PortPool;
use crate::MAX_PACKET;
/// The capacity of the broadcast channel for received IP packets.
const BROADCAST_CAPACITY: usize = 1_000;
/// The capacity of the channel for received IP packets.
const DISPATCH_CAPACITY: usize = 1_000;
/// A WireGuard tunnel. Encapsulates and decapsulates IP packets
/// to be sent to and received from a remote UDP endpoint.
@ -32,34 +28,27 @@ pub struct WireGuardTunnel {
udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP).
endpoint: SocketAddr,
/// Broadcast sender for received IP packets.
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>,
/// Sink so that the broadcaster doesn't close. A repeating task should drain this as much as possible.
ip_broadcast_rx_sink: Mutex<tokio::sync::broadcast::Receiver<Vec<u8>>>,
/// Port pool.
port_pool: Arc<PortPool>,
/// Maps virtual ports to the corresponding IP packet dispatcher.
virtual_port_ip_tx: lockfree::map::Map<u16, tokio::sync::mpsc::Sender<Vec<u8>>>,
}
impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config, port_pool: Arc<PortPool>) -> anyhow::Result<Self> {
pub async fn new(config: &Config) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?;
let udp = UdpSocket::bind("0.0.0.0:0")
.await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let endpoint = config.endpoint_addr;
let (ip_broadcast_tx, ip_broadcast_rx_sink) =
tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
let virtual_port_ip_tx = lockfree::map::Map::new();
Ok(Self {
source_peer_ip,
peer,
udp,
endpoint,
ip_broadcast_tx,
ip_broadcast_rx_sink: Mutex::new(ip_broadcast_rx_sink),
port_pool,
virtual_port_ip_tx,
})
}
@ -94,9 +83,24 @@ impl WireGuardTunnel {
Ok(())
}
/// Create a new receiver for broadcasted IP packets, received from the WireGuard endpoint.
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Vec<u8>> {
self.ip_broadcast_tx.subscribe()
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub fn register_virtual_interface(
&self,
virtual_port: u16,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let existing = self.virtual_port_ip_tx.get(&virtual_port);
if existing.is_some() {
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
} else {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
self.virtual_port_ip_tx.insert(virtual_port, sender);
Ok(receiver)
}
}
/// Releases the virtual interface from IP dispatch.
pub fn release_virtual_interface(&self, virtual_port: u16) {
self.virtual_port_ip_tx.remove(&virtual_port);
}
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
@ -139,7 +143,7 @@ impl WireGuardTunnel {
}
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
/// decapsulates them, and broadcasts newly received IP packets.
/// decapsulates them, and dispatches newly received IP packets.
pub async fn consume_task(&self) -> ! {
trace!("Starting WireGuard consumption task");
@ -195,30 +199,33 @@ impl WireGuardTunnel {
trace_ip_packet("Received IP packet", packet);
match self.route_ip_packet(packet) {
RouteResult::Broadcast => {
// Broadcast IP packet
if self.ip_broadcast_tx.receiver_count() > 1 {
match self.ip_broadcast_tx.send(packet.to_vec()) {
Ok(n) => {
RouteResult::Dispatch(port) => {
let sender = self.virtual_port_ip_tx.get(&port);
if let Some(sender_guard) = sender {
let sender = sender_guard.val();
match sender.send(packet.to_vec()).await {
Ok(_) => {
trace!(
"Broadcasted received IP packet to {} virtual interfaces",
n - 1
"Dispatched received IP packet to virtual port {}",
port
);
}
Err(e) => {
error!(
"Failed to broadcast received IP packet to recipients: {}",
e
"Failed to dispatch received IP packet to virtual port {}: {}",
port, e
);
}
}
} else {
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
}
}
RouteResult::TcpReset(packet) => {
RouteResult::TcpReset => {
trace!("Resetting dead TCP connection after packet from WireGuard endpoint");
self.send_ip_packet(&packet)
.await
.unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e));
self.route_tcp_sink(packet).await.unwrap_or_else(|e| {
error!("Failed to send TCP reset to sink: {:?}", e)
});
}
RouteResult::Drop => {
trace!("Dropped incoming IP packet from WireGuard endpoint");
@ -230,33 +237,6 @@ impl WireGuardTunnel {
}
}
/// A repeating task that drains the default IP broadcast channel receiver.
/// It is necessary to keep this receiver alive to prevent the overall channel from closing,
/// so draining its backlog regularly is required to avoid memory leaks.
pub async fn broadcast_drain_task(&self) {
trace!("Starting IP broadcast sink drain task");
loop {
let mut sink = self.ip_broadcast_rx_sink.lock().await;
match sink.recv().await {
Ok(_) => {
trace!("Drained a packet from IP broadcast sink");
}
Err(e) => match e {
RecvError::Closed => {
trace!("IP broadcast sink finished draining: channel closed");
break;
}
RecvError::Lagged(_) => {
warn!("IP broadcast sink is falling behind");
}
},
}
}
trace!("Stopped IP broadcast sink drain");
}
fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> {
Tunn::new(
config.private_key.clone(),
@ -278,14 +258,9 @@ impl WireGuardTunnel {
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() {
IpProtocol::Tcp => Some(self.route_tcp_segment(
IpVersion::Ipv4,
packet.src_addr().into(),
packet.dst_addr().into(),
packet.payload(),
)),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
})
.flatten()
.unwrap_or(RouteResult::Drop),
@ -294,14 +269,9 @@ impl WireGuardTunnel {
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.next_header() {
IpProtocol::Tcp => Some(self.route_tcp_segment(
IpVersion::Ipv6,
packet.src_addr().into(),
packet.dst_addr().into(),
packet.payload(),
)),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
})
.flatten()
.unwrap_or(RouteResult::Drop),
@ -310,40 +280,32 @@ impl WireGuardTunnel {
}
/// Makes a decision on the handling of an incoming TCP segment.
fn route_tcp_segment(
&self,
ip_version: IpVersion,
src_addr: IpAddress,
dst_addr: IpAddress,
segment: &[u8],
) -> RouteResult {
fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
TcpPacket::new_checked(segment)
.ok()
.map(|tcp| {
if self.port_pool.is_in_use(tcp.dst_port()) {
RouteResult::Broadcast
if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() {
RouteResult::Dispatch(tcp.dst_port())
} else if tcp.rst() {
RouteResult::Drop
} else {
// Port is not in use, but it's a TCP packet so we'll craft a RST.
RouteResult::TcpReset(craft_tcp_rst_reply(
ip_version,
src_addr,
tcp.src_port(),
dst_addr,
tcp.dst_port(),
tcp.ack_number(),
))
RouteResult::TcpReset
}
})
.unwrap_or(RouteResult::Drop)
}
/// Route a packet to the TCP sink interface.
async fn route_tcp_sink(&self, _packet: &[u8]) -> anyhow::Result<()> {
// TODO
Ok(())
}
}
/// Craft an IP packet containing a TCP RST segment, given an IP version,
/// source address (the one to reply to), destination address (the one the reply comes from),
/// and the ACK number received in the initiating TCP segment.
fn craft_tcp_rst_reply(
fn _craft_tcp_rst_reply(
ip_version: IpVersion,
source_addr: IpAddress,
source_port: u16,
@ -450,10 +412,10 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
}
enum RouteResult {
/// The packet can be broadcasted to the virtual interfaces
Broadcast,
/// Dispatch the packet to the virtual port.
Dispatch(u16),
/// The packet is not routable so it may be reset.
TcpReset(Vec<u8>),
TcpReset,
/// The packet can be safely ignored.
Drop,
}