Remove broadcasting logic to fix simultaneous connection issues.

This commit is contained in:
Aram 🍐 2021-10-16 19:16:10 -04:00
parent b6739a5066
commit cfdbdc8f51
5 changed files with 86 additions and 123 deletions

View file

@ -130,7 +130,7 @@ forward it to the server's port, which handles the TCP segment. The server respo
the peer's local WireGuard interface, gets encrypted, forwarded to the WireGuard endpoint, and then finally back to onetun's UDP port. the peer's local WireGuard interface, gets encrypted, forwarded to the WireGuard endpoint, and then finally back to onetun's UDP port.
When onetun receives an encrypted packet from the WireGuard endpoint, it decrypts it using boringtun. When onetun receives an encrypted packet from the WireGuard endpoint, it decrypts it using boringtun.
The resulting IP packet is broadcasted to all virtual interfaces running inside onetun; once the corresponding The resulting IP packet is dispatched to the corresponding virtual interface running inside onetun; once the corresponding
interface is matched, the IP packet is read and unpacked, and the virtual client's TCP state is updated. interface is matched, the IP packet is read and unpacked, and the virtual client's TCP state is updated.
Whenever data is sent by the real client, it is simply "sent" by the virtual client, which kicks off the whole IP encapsulation Whenever data is sent by the real client, it is simply "sent" by the virtual client, which kicks off the whole IP encapsulation

View file

@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> {
init_logger(&config)?; init_logger(&config)?;
let port_pool = Arc::new(PortPool::new()); let port_pool = Arc::new(PortPool::new());
let wg = WireGuardTunnel::new(&config, port_pool.clone()) let wg = WireGuardTunnel::new(&config)
.await .await
.with_context(|| "Failed to initialize WireGuard tunnel")?; .with_context(|| "Failed to initialize WireGuard tunnel")?;
let wg = Arc::new(wg); let wg = Arc::new(wg);
@ -47,12 +47,6 @@ async fn main() -> anyhow::Result<()> {
tokio::spawn(async move { wg.consume_task().await }); tokio::spawn(async move { wg.consume_task().await });
} }
{
// Start IP broadcast drain task for WireGuard
let wg = wg.clone();
tokio::spawn(async move { wg.broadcast_drain_task().await });
}
info!( info!(
"Tunnelling [{}]->[{}] (via [{}] as peer {})", "Tunnelling [{}]->[{}] (via [{}] as peer {})",
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip &config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
@ -106,8 +100,13 @@ async fn tcp_proxy_server(
tokio::spawn(async move { tokio::spawn(async move {
let port_pool = Arc::clone(&port_pool); let port_pool = Arc::clone(&port_pool);
let result = let result = handle_tcp_proxy_connection(
handle_tcp_proxy_connection(socket, virtual_port, source_peer_ip, dest_addr, wg) socket,
virtual_port,
source_peer_ip,
dest_addr,
wg.clone(),
)
.await; .await;
if let Err(e) = result { if let Err(e) = result {
@ -120,6 +119,7 @@ async fn tcp_proxy_server(
} }
// Release port when connection drops // Release port when connection drops
wg.release_virtual_interface(virtual_port);
port_pool.release(virtual_port); port_pool.release(virtual_port);
}); });
} }
@ -270,7 +270,8 @@ async fn virtual_tcp_interface(
// Consumer for IP packets to send through the virtual interface // Consumer for IP packets to send through the virtual interface
// Initialize the interface // Initialize the interface
let device = VirtualIpDevice::new(wg); let device = VirtualIpDevice::new(virtual_port, wg)
.with_context(|| "Failed to initialize VirtualIpDevice")?;
let mut virtual_interface = InterfaceBuilder::new(device) let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([ .ip_addrs([
// Interface handles IP packets for the sender and recipient // Interface handles IP packets for the sender and recipient

View file

@ -13,7 +13,7 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
pub struct PortPool { pub struct PortPool {
/// Remaining ports /// Remaining ports
inner: lockfree::queue::Queue<u16>, inner: lockfree::queue::Queue<u16>,
/// Ports in use /// Ports in use, with their associated IP channel sender.
taken: lockfree::set::Set<u16>, taken: lockfree::set::Set<u16>,
} }

View file

@ -1,26 +1,26 @@
use crate::wg::WireGuardTunnel; use crate::wg::WireGuardTunnel;
use anyhow::Context;
use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant; use smoltcp::time::Instant;
use std::sync::Arc; use std::sync::Arc;
/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint /// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint
/// are made available to this device using a broadcast channel receiver. IP packets sent from this device /// are made available to this device using a channel receiver. IP packets sent from this device
/// are asynchronously sent out to the WireGuard tunnel. /// are asynchronously sent out to the WireGuard tunnel.
pub struct VirtualIpDevice { pub struct VirtualIpDevice {
/// Tunnel to send IP packets to. /// Tunnel to send IP packets to.
wg: Arc<WireGuardTunnel>, wg: Arc<WireGuardTunnel>,
/// Broadcast channel receiver for received IP packets. /// Channel receiver for received IP packets.
ip_broadcast_rx: tokio::sync::broadcast::Receiver<Vec<u8>>, ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
} }
impl VirtualIpDevice { impl VirtualIpDevice {
pub fn new(wg: Arc<WireGuardTunnel>) -> Self { pub fn new(virtual_port: u16, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let ip_broadcast_rx = wg.subscribe(); let ip_dispatch_rx = wg
.register_virtual_interface(virtual_port)
.with_context(|| "Failed to register IP dispatch for virtual interface")?;
Self { Ok(Self { wg, ip_dispatch_rx })
wg,
ip_broadcast_rx,
}
} }
} }
@ -29,7 +29,7 @@ impl<'a> Device<'a> for VirtualIpDevice {
type TxToken = TxToken; type TxToken = TxToken;
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
match self.ip_broadcast_rx.try_recv() { match self.ip_dispatch_rx.try_recv() {
Ok(buffer) => Some(( Ok(buffer) => Some((
Self::RxToken { buffer }, Self::RxToken { buffer },
Self::TxToken { Self::TxToken {

162
src/wg.rs
View file

@ -1,10 +1,8 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use boringtun::noise::{Tunn, TunnResult}; use boringtun::noise::{Tunn, TunnResult};
use futures::lock::Mutex;
use log::Level; use log::Level;
use smoltcp::phy::ChecksumCapabilities; use smoltcp::phy::ChecksumCapabilities;
use smoltcp::wire::{ use smoltcp::wire::{
@ -12,14 +10,12 @@ use smoltcp::wire::{
TcpPacket, TcpRepr, TcpSeqNumber, TcpPacket, TcpRepr, TcpSeqNumber,
}; };
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::broadcast::error::RecvError;
use crate::config::Config; use crate::config::Config;
use crate::port_pool::PortPool;
use crate::MAX_PACKET; use crate::MAX_PACKET;
/// The capacity of the broadcast channel for received IP packets. /// The capacity of the channel for received IP packets.
const BROADCAST_CAPACITY: usize = 1_000; const DISPATCH_CAPACITY: usize = 1_000;
/// A WireGuard tunnel. Encapsulates and decapsulates IP packets /// A WireGuard tunnel. Encapsulates and decapsulates IP packets
/// to be sent to and received from a remote UDP endpoint. /// to be sent to and received from a remote UDP endpoint.
@ -32,34 +28,27 @@ pub struct WireGuardTunnel {
udp: UdpSocket, udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP). /// The address of the public WireGuard endpoint (UDP).
endpoint: SocketAddr, endpoint: SocketAddr,
/// Broadcast sender for received IP packets. /// Maps virtual ports to the corresponding IP packet dispatcher.
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>, virtual_port_ip_tx: lockfree::map::Map<u16, tokio::sync::mpsc::Sender<Vec<u8>>>,
/// Sink so that the broadcaster doesn't close. A repeating task should drain this as much as possible.
ip_broadcast_rx_sink: Mutex<tokio::sync::broadcast::Receiver<Vec<u8>>>,
/// Port pool.
port_pool: Arc<PortPool>,
} }
impl WireGuardTunnel { impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel. /// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config, port_pool: Arc<PortPool>) -> anyhow::Result<Self> { pub async fn new(config: &Config) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip; let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?; let peer = Self::create_tunnel(config)?;
let udp = UdpSocket::bind("0.0.0.0:0") let udp = UdpSocket::bind("0.0.0.0:0")
.await .await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?; .with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let endpoint = config.endpoint_addr; let endpoint = config.endpoint_addr;
let (ip_broadcast_tx, ip_broadcast_rx_sink) = let virtual_port_ip_tx = lockfree::map::Map::new();
tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
Ok(Self { Ok(Self {
source_peer_ip, source_peer_ip,
peer, peer,
udp, udp,
endpoint, endpoint,
ip_broadcast_tx, virtual_port_ip_tx,
ip_broadcast_rx_sink: Mutex::new(ip_broadcast_rx_sink),
port_pool,
}) })
} }
@ -94,9 +83,24 @@ impl WireGuardTunnel {
Ok(()) Ok(())
} }
/// Create a new receiver for broadcasted IP packets, received from the WireGuard endpoint. /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Vec<u8>> { pub fn register_virtual_interface(
self.ip_broadcast_tx.subscribe() &self,
virtual_port: u16,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let existing = self.virtual_port_ip_tx.get(&virtual_port);
if existing.is_some() {
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
} else {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
self.virtual_port_ip_tx.insert(virtual_port, sender);
Ok(receiver)
}
}
/// Releases the virtual interface from IP dispatch.
pub fn release_virtual_interface(&self, virtual_port: u16) {
self.virtual_port_ip_tx.remove(&virtual_port);
} }
/// WireGuard Routine task. Handles Handshake, keep-alive, etc. /// WireGuard Routine task. Handles Handshake, keep-alive, etc.
@ -139,7 +143,7 @@ impl WireGuardTunnel {
} }
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint, /// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
/// decapsulates them, and broadcasts newly received IP packets. /// decapsulates them, and dispatches newly received IP packets.
pub async fn consume_task(&self) -> ! { pub async fn consume_task(&self) -> ! {
trace!("Starting WireGuard consumption task"); trace!("Starting WireGuard consumption task");
@ -195,30 +199,33 @@ impl WireGuardTunnel {
trace_ip_packet("Received IP packet", packet); trace_ip_packet("Received IP packet", packet);
match self.route_ip_packet(packet) { match self.route_ip_packet(packet) {
RouteResult::Broadcast => { RouteResult::Dispatch(port) => {
// Broadcast IP packet let sender = self.virtual_port_ip_tx.get(&port);
if self.ip_broadcast_tx.receiver_count() > 1 { if let Some(sender_guard) = sender {
match self.ip_broadcast_tx.send(packet.to_vec()) { let sender = sender_guard.val();
Ok(n) => { match sender.send(packet.to_vec()).await {
Ok(_) => {
trace!( trace!(
"Broadcasted received IP packet to {} virtual interfaces", "Dispatched received IP packet to virtual port {}",
n - 1 port
); );
} }
Err(e) => { Err(e) => {
error!( error!(
"Failed to broadcast received IP packet to recipients: {}", "Failed to dispatch received IP packet to virtual port {}: {}",
e port, e
); );
} }
} }
} else {
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
} }
} }
RouteResult::TcpReset(packet) => { RouteResult::TcpReset => {
trace!("Resetting dead TCP connection after packet from WireGuard endpoint"); trace!("Resetting dead TCP connection after packet from WireGuard endpoint");
self.send_ip_packet(&packet) self.route_tcp_sink(packet).await.unwrap_or_else(|e| {
.await error!("Failed to send TCP reset to sink: {:?}", e)
.unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e)); });
} }
RouteResult::Drop => { RouteResult::Drop => {
trace!("Dropped incoming IP packet from WireGuard endpoint"); trace!("Dropped incoming IP packet from WireGuard endpoint");
@ -230,33 +237,6 @@ impl WireGuardTunnel {
} }
} }
/// A repeating task that drains the default IP broadcast channel receiver.
/// It is necessary to keep this receiver alive to prevent the overall channel from closing,
/// so draining its backlog regularly is required to avoid memory leaks.
pub async fn broadcast_drain_task(&self) {
trace!("Starting IP broadcast sink drain task");
loop {
let mut sink = self.ip_broadcast_rx_sink.lock().await;
match sink.recv().await {
Ok(_) => {
trace!("Drained a packet from IP broadcast sink");
}
Err(e) => match e {
RecvError::Closed => {
trace!("IP broadcast sink finished draining: channel closed");
break;
}
RecvError::Lagged(_) => {
warn!("IP broadcast sink is falling behind");
}
},
}
}
trace!("Stopped IP broadcast sink drain");
}
fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> { fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> {
Tunn::new( Tunn::new(
config.private_key.clone(), config.private_key.clone(),
@ -278,14 +258,9 @@ impl WireGuardTunnel {
// Only care if the packet is destined for this tunnel // Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() { .map(|packet| match packet.protocol() {
IpProtocol::Tcp => Some(self.route_tcp_segment( IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpVersion::Ipv4, // Unrecognized protocol, so we cannot determine where to route
packet.src_addr().into(), _ => Some(RouteResult::Drop),
packet.dst_addr().into(),
packet.payload(),
)),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
}) })
.flatten() .flatten()
.unwrap_or(RouteResult::Drop), .unwrap_or(RouteResult::Drop),
@ -294,14 +269,9 @@ impl WireGuardTunnel {
// Only care if the packet is destined for this tunnel // Only care if the packet is destined for this tunnel
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.next_header() { .map(|packet| match packet.next_header() {
IpProtocol::Tcp => Some(self.route_tcp_segment( IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpVersion::Ipv6, // Unrecognized protocol, so we cannot determine where to route
packet.src_addr().into(), _ => Some(RouteResult::Drop),
packet.dst_addr().into(),
packet.payload(),
)),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
}) })
.flatten() .flatten()
.unwrap_or(RouteResult::Drop), .unwrap_or(RouteResult::Drop),
@ -310,40 +280,32 @@ impl WireGuardTunnel {
} }
/// Makes a decision on the handling of an incoming TCP segment. /// Makes a decision on the handling of an incoming TCP segment.
fn route_tcp_segment( fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
&self,
ip_version: IpVersion,
src_addr: IpAddress,
dst_addr: IpAddress,
segment: &[u8],
) -> RouteResult {
TcpPacket::new_checked(segment) TcpPacket::new_checked(segment)
.ok() .ok()
.map(|tcp| { .map(|tcp| {
if self.port_pool.is_in_use(tcp.dst_port()) { if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() {
RouteResult::Broadcast RouteResult::Dispatch(tcp.dst_port())
} else if tcp.rst() { } else if tcp.rst() {
RouteResult::Drop RouteResult::Drop
} else { } else {
// Port is not in use, but it's a TCP packet so we'll craft a RST. RouteResult::TcpReset
RouteResult::TcpReset(craft_tcp_rst_reply(
ip_version,
src_addr,
tcp.src_port(),
dst_addr,
tcp.dst_port(),
tcp.ack_number(),
))
} }
}) })
.unwrap_or(RouteResult::Drop) .unwrap_or(RouteResult::Drop)
} }
/// Route a packet to the TCP sink interface.
async fn route_tcp_sink(&self, _packet: &[u8]) -> anyhow::Result<()> {
// TODO
Ok(())
}
} }
/// Craft an IP packet containing a TCP RST segment, given an IP version, /// Craft an IP packet containing a TCP RST segment, given an IP version,
/// source address (the one to reply to), destination address (the one the reply comes from), /// source address (the one to reply to), destination address (the one the reply comes from),
/// and the ACK number received in the initiating TCP segment. /// and the ACK number received in the initiating TCP segment.
fn craft_tcp_rst_reply( fn _craft_tcp_rst_reply(
ip_version: IpVersion, ip_version: IpVersion,
source_addr: IpAddress, source_addr: IpAddress,
source_port: u16, source_port: u16,
@ -450,10 +412,10 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
} }
enum RouteResult { enum RouteResult {
/// The packet can be broadcasted to the virtual interfaces /// Dispatch the packet to the virtual port.
Broadcast, Dispatch(u16),
/// The packet is not routable so it may be reset. /// The packet is not routable so it may be reset.
TcpReset(Vec<u8>), TcpReset,
/// The packet can be safely ignored. /// The packet can be safely ignored.
Drop, Drop,
} }