mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 06:38:32 -04:00
Remove broadcasting logic to fix simultaneous connection issues.
This commit is contained in:
parent
b6739a5066
commit
cfdbdc8f51
5 changed files with 86 additions and 123 deletions
|
@ -130,7 +130,7 @@ forward it to the server's port, which handles the TCP segment. The server respo
|
|||
the peer's local WireGuard interface, gets encrypted, forwarded to the WireGuard endpoint, and then finally back to onetun's UDP port.
|
||||
|
||||
When onetun receives an encrypted packet from the WireGuard endpoint, it decrypts it using boringtun.
|
||||
The resulting IP packet is broadcasted to all virtual interfaces running inside onetun; once the corresponding
|
||||
The resulting IP packet is dispatched to the corresponding virtual interface running inside onetun; once the corresponding
|
||||
interface is matched, the IP packet is read and unpacked, and the virtual client's TCP state is updated.
|
||||
|
||||
Whenever data is sent by the real client, it is simply "sent" by the virtual client, which kicks off the whole IP encapsulation
|
||||
|
|
23
src/main.rs
23
src/main.rs
|
@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
init_logger(&config)?;
|
||||
let port_pool = Arc::new(PortPool::new());
|
||||
|
||||
let wg = WireGuardTunnel::new(&config, port_pool.clone())
|
||||
let wg = WireGuardTunnel::new(&config)
|
||||
.await
|
||||
.with_context(|| "Failed to initialize WireGuard tunnel")?;
|
||||
let wg = Arc::new(wg);
|
||||
|
@ -47,12 +47,6 @@ async fn main() -> anyhow::Result<()> {
|
|||
tokio::spawn(async move { wg.consume_task().await });
|
||||
}
|
||||
|
||||
{
|
||||
// Start IP broadcast drain task for WireGuard
|
||||
let wg = wg.clone();
|
||||
tokio::spawn(async move { wg.broadcast_drain_task().await });
|
||||
}
|
||||
|
||||
info!(
|
||||
"Tunnelling [{}]->[{}] (via [{}] as peer {})",
|
||||
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
|
||||
|
@ -106,9 +100,14 @@ async fn tcp_proxy_server(
|
|||
|
||||
tokio::spawn(async move {
|
||||
let port_pool = Arc::clone(&port_pool);
|
||||
let result =
|
||||
handle_tcp_proxy_connection(socket, virtual_port, source_peer_ip, dest_addr, wg)
|
||||
.await;
|
||||
let result = handle_tcp_proxy_connection(
|
||||
socket,
|
||||
virtual_port,
|
||||
source_peer_ip,
|
||||
dest_addr,
|
||||
wg.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
error!(
|
||||
|
@ -120,6 +119,7 @@ async fn tcp_proxy_server(
|
|||
}
|
||||
|
||||
// Release port when connection drops
|
||||
wg.release_virtual_interface(virtual_port);
|
||||
port_pool.release(virtual_port);
|
||||
});
|
||||
}
|
||||
|
@ -270,7 +270,8 @@ async fn virtual_tcp_interface(
|
|||
|
||||
// Consumer for IP packets to send through the virtual interface
|
||||
// Initialize the interface
|
||||
let device = VirtualIpDevice::new(wg);
|
||||
let device = VirtualIpDevice::new(virtual_port, wg)
|
||||
.with_context(|| "Failed to initialize VirtualIpDevice")?;
|
||||
let mut virtual_interface = InterfaceBuilder::new(device)
|
||||
.ip_addrs([
|
||||
// Interface handles IP packets for the sender and recipient
|
||||
|
|
|
@ -13,7 +13,7 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
|
|||
pub struct PortPool {
|
||||
/// Remaining ports
|
||||
inner: lockfree::queue::Queue<u16>,
|
||||
/// Ports in use
|
||||
/// Ports in use, with their associated IP channel sender.
|
||||
taken: lockfree::set::Set<u16>,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,26 +1,26 @@
|
|||
use crate::wg::WireGuardTunnel;
|
||||
use anyhow::Context;
|
||||
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
|
||||
use smoltcp::time::Instant;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint
|
||||
/// are made available to this device using a broadcast channel receiver. IP packets sent from this device
|
||||
/// are made available to this device using a channel receiver. IP packets sent from this device
|
||||
/// are asynchronously sent out to the WireGuard tunnel.
|
||||
pub struct VirtualIpDevice {
|
||||
/// Tunnel to send IP packets to.
|
||||
wg: Arc<WireGuardTunnel>,
|
||||
/// Broadcast channel receiver for received IP packets.
|
||||
ip_broadcast_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
/// Channel receiver for received IP packets.
|
||||
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl VirtualIpDevice {
|
||||
pub fn new(wg: Arc<WireGuardTunnel>) -> Self {
|
||||
let ip_broadcast_rx = wg.subscribe();
|
||||
pub fn new(virtual_port: u16, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
|
||||
let ip_dispatch_rx = wg
|
||||
.register_virtual_interface(virtual_port)
|
||||
.with_context(|| "Failed to register IP dispatch for virtual interface")?;
|
||||
|
||||
Self {
|
||||
wg,
|
||||
ip_broadcast_rx,
|
||||
}
|
||||
Ok(Self { wg, ip_dispatch_rx })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ impl<'a> Device<'a> for VirtualIpDevice {
|
|||
type TxToken = TxToken;
|
||||
|
||||
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
|
||||
match self.ip_broadcast_rx.try_recv() {
|
||||
match self.ip_dispatch_rx.try_recv() {
|
||||
Ok(buffer) => Some((
|
||||
Self::RxToken { buffer },
|
||||
Self::TxToken {
|
||||
|
|
162
src/wg.rs
162
src/wg.rs
|
@ -1,10 +1,8 @@
|
|||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use futures::lock::Mutex;
|
||||
use log::Level;
|
||||
use smoltcp::phy::ChecksumCapabilities;
|
||||
use smoltcp::wire::{
|
||||
|
@ -12,14 +10,12 @@ use smoltcp::wire::{
|
|||
TcpPacket, TcpRepr, TcpSeqNumber,
|
||||
};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::broadcast::error::RecvError;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::port_pool::PortPool;
|
||||
use crate::MAX_PACKET;
|
||||
|
||||
/// The capacity of the broadcast channel for received IP packets.
|
||||
const BROADCAST_CAPACITY: usize = 1_000;
|
||||
/// The capacity of the channel for received IP packets.
|
||||
const DISPATCH_CAPACITY: usize = 1_000;
|
||||
|
||||
/// A WireGuard tunnel. Encapsulates and decapsulates IP packets
|
||||
/// to be sent to and received from a remote UDP endpoint.
|
||||
|
@ -32,34 +28,27 @@ pub struct WireGuardTunnel {
|
|||
udp: UdpSocket,
|
||||
/// The address of the public WireGuard endpoint (UDP).
|
||||
endpoint: SocketAddr,
|
||||
/// Broadcast sender for received IP packets.
|
||||
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>,
|
||||
/// Sink so that the broadcaster doesn't close. A repeating task should drain this as much as possible.
|
||||
ip_broadcast_rx_sink: Mutex<tokio::sync::broadcast::Receiver<Vec<u8>>>,
|
||||
/// Port pool.
|
||||
port_pool: Arc<PortPool>,
|
||||
/// Maps virtual ports to the corresponding IP packet dispatcher.
|
||||
virtual_port_ip_tx: lockfree::map::Map<u16, tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl WireGuardTunnel {
|
||||
/// Initialize a new WireGuard tunnel.
|
||||
pub async fn new(config: &Config, port_pool: Arc<PortPool>) -> anyhow::Result<Self> {
|
||||
pub async fn new(config: &Config) -> anyhow::Result<Self> {
|
||||
let source_peer_ip = config.source_peer_ip;
|
||||
let peer = Self::create_tunnel(config)?;
|
||||
let udp = UdpSocket::bind("0.0.0.0:0")
|
||||
.await
|
||||
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
|
||||
let endpoint = config.endpoint_addr;
|
||||
let (ip_broadcast_tx, ip_broadcast_rx_sink) =
|
||||
tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
|
||||
let virtual_port_ip_tx = lockfree::map::Map::new();
|
||||
|
||||
Ok(Self {
|
||||
source_peer_ip,
|
||||
peer,
|
||||
udp,
|
||||
endpoint,
|
||||
ip_broadcast_tx,
|
||||
ip_broadcast_rx_sink: Mutex::new(ip_broadcast_rx_sink),
|
||||
port_pool,
|
||||
virtual_port_ip_tx,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -94,9 +83,24 @@ impl WireGuardTunnel {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new receiver for broadcasted IP packets, received from the WireGuard endpoint.
|
||||
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Vec<u8>> {
|
||||
self.ip_broadcast_tx.subscribe()
|
||||
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
|
||||
pub fn register_virtual_interface(
|
||||
&self,
|
||||
virtual_port: u16,
|
||||
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
|
||||
let existing = self.virtual_port_ip_tx.get(&virtual_port);
|
||||
if existing.is_some() {
|
||||
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
|
||||
} else {
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
||||
self.virtual_port_ip_tx.insert(virtual_port, sender);
|
||||
Ok(receiver)
|
||||
}
|
||||
}
|
||||
|
||||
/// Releases the virtual interface from IP dispatch.
|
||||
pub fn release_virtual_interface(&self, virtual_port: u16) {
|
||||
self.virtual_port_ip_tx.remove(&virtual_port);
|
||||
}
|
||||
|
||||
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||
|
@ -139,7 +143,7 @@ impl WireGuardTunnel {
|
|||
}
|
||||
|
||||
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
|
||||
/// decapsulates them, and broadcasts newly received IP packets.
|
||||
/// decapsulates them, and dispatches newly received IP packets.
|
||||
pub async fn consume_task(&self) -> ! {
|
||||
trace!("Starting WireGuard consumption task");
|
||||
|
||||
|
@ -195,30 +199,33 @@ impl WireGuardTunnel {
|
|||
trace_ip_packet("Received IP packet", packet);
|
||||
|
||||
match self.route_ip_packet(packet) {
|
||||
RouteResult::Broadcast => {
|
||||
// Broadcast IP packet
|
||||
if self.ip_broadcast_tx.receiver_count() > 1 {
|
||||
match self.ip_broadcast_tx.send(packet.to_vec()) {
|
||||
Ok(n) => {
|
||||
RouteResult::Dispatch(port) => {
|
||||
let sender = self.virtual_port_ip_tx.get(&port);
|
||||
if let Some(sender_guard) = sender {
|
||||
let sender = sender_guard.val();
|
||||
match sender.send(packet.to_vec()).await {
|
||||
Ok(_) => {
|
||||
trace!(
|
||||
"Broadcasted received IP packet to {} virtual interfaces",
|
||||
n - 1
|
||||
"Dispatched received IP packet to virtual port {}",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to broadcast received IP packet to recipients: {}",
|
||||
e
|
||||
"Failed to dispatch received IP packet to virtual port {}: {}",
|
||||
port, e
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
|
||||
}
|
||||
}
|
||||
RouteResult::TcpReset(packet) => {
|
||||
RouteResult::TcpReset => {
|
||||
trace!("Resetting dead TCP connection after packet from WireGuard endpoint");
|
||||
self.send_ip_packet(&packet)
|
||||
.await
|
||||
.unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e));
|
||||
self.route_tcp_sink(packet).await.unwrap_or_else(|e| {
|
||||
error!("Failed to send TCP reset to sink: {:?}", e)
|
||||
});
|
||||
}
|
||||
RouteResult::Drop => {
|
||||
trace!("Dropped incoming IP packet from WireGuard endpoint");
|
||||
|
@ -230,33 +237,6 @@ impl WireGuardTunnel {
|
|||
}
|
||||
}
|
||||
|
||||
/// A repeating task that drains the default IP broadcast channel receiver.
|
||||
/// It is necessary to keep this receiver alive to prevent the overall channel from closing,
|
||||
/// so draining its backlog regularly is required to avoid memory leaks.
|
||||
pub async fn broadcast_drain_task(&self) {
|
||||
trace!("Starting IP broadcast sink drain task");
|
||||
|
||||
loop {
|
||||
let mut sink = self.ip_broadcast_rx_sink.lock().await;
|
||||
match sink.recv().await {
|
||||
Ok(_) => {
|
||||
trace!("Drained a packet from IP broadcast sink");
|
||||
}
|
||||
Err(e) => match e {
|
||||
RecvError::Closed => {
|
||||
trace!("IP broadcast sink finished draining: channel closed");
|
||||
break;
|
||||
}
|
||||
RecvError::Lagged(_) => {
|
||||
warn!("IP broadcast sink is falling behind");
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
trace!("Stopped IP broadcast sink drain");
|
||||
}
|
||||
|
||||
fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> {
|
||||
Tunn::new(
|
||||
config.private_key.clone(),
|
||||
|
@ -278,14 +258,9 @@ impl WireGuardTunnel {
|
|||
// Only care if the packet is destined for this tunnel
|
||||
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
|
||||
.map(|packet| match packet.protocol() {
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(
|
||||
IpVersion::Ipv4,
|
||||
packet.src_addr().into(),
|
||||
packet.dst_addr().into(),
|
||||
packet.payload(),
|
||||
)),
|
||||
// Unrecognized protocol, so we'll allow it.
|
||||
_ => Some(RouteResult::Broadcast),
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
|
||||
// Unrecognized protocol, so we cannot determine where to route
|
||||
_ => Some(RouteResult::Drop),
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or(RouteResult::Drop),
|
||||
|
@ -294,14 +269,9 @@ impl WireGuardTunnel {
|
|||
// Only care if the packet is destined for this tunnel
|
||||
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
|
||||
.map(|packet| match packet.next_header() {
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(
|
||||
IpVersion::Ipv6,
|
||||
packet.src_addr().into(),
|
||||
packet.dst_addr().into(),
|
||||
packet.payload(),
|
||||
)),
|
||||
// Unrecognized protocol, so we'll allow it.
|
||||
_ => Some(RouteResult::Broadcast),
|
||||
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
|
||||
// Unrecognized protocol, so we cannot determine where to route
|
||||
_ => Some(RouteResult::Drop),
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or(RouteResult::Drop),
|
||||
|
@ -310,40 +280,32 @@ impl WireGuardTunnel {
|
|||
}
|
||||
|
||||
/// Makes a decision on the handling of an incoming TCP segment.
|
||||
fn route_tcp_segment(
|
||||
&self,
|
||||
ip_version: IpVersion,
|
||||
src_addr: IpAddress,
|
||||
dst_addr: IpAddress,
|
||||
segment: &[u8],
|
||||
) -> RouteResult {
|
||||
fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
|
||||
TcpPacket::new_checked(segment)
|
||||
.ok()
|
||||
.map(|tcp| {
|
||||
if self.port_pool.is_in_use(tcp.dst_port()) {
|
||||
RouteResult::Broadcast
|
||||
if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() {
|
||||
RouteResult::Dispatch(tcp.dst_port())
|
||||
} else if tcp.rst() {
|
||||
RouteResult::Drop
|
||||
} else {
|
||||
// Port is not in use, but it's a TCP packet so we'll craft a RST.
|
||||
RouteResult::TcpReset(craft_tcp_rst_reply(
|
||||
ip_version,
|
||||
src_addr,
|
||||
tcp.src_port(),
|
||||
dst_addr,
|
||||
tcp.dst_port(),
|
||||
tcp.ack_number(),
|
||||
))
|
||||
RouteResult::TcpReset
|
||||
}
|
||||
})
|
||||
.unwrap_or(RouteResult::Drop)
|
||||
}
|
||||
|
||||
/// Route a packet to the TCP sink interface.
|
||||
async fn route_tcp_sink(&self, _packet: &[u8]) -> anyhow::Result<()> {
|
||||
// TODO
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Craft an IP packet containing a TCP RST segment, given an IP version,
|
||||
/// source address (the one to reply to), destination address (the one the reply comes from),
|
||||
/// and the ACK number received in the initiating TCP segment.
|
||||
fn craft_tcp_rst_reply(
|
||||
fn _craft_tcp_rst_reply(
|
||||
ip_version: IpVersion,
|
||||
source_addr: IpAddress,
|
||||
source_port: u16,
|
||||
|
@ -450,10 +412,10 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
|
|||
}
|
||||
|
||||
enum RouteResult {
|
||||
/// The packet can be broadcasted to the virtual interfaces
|
||||
Broadcast,
|
||||
/// Dispatch the packet to the virtual port.
|
||||
Dispatch(u16),
|
||||
/// The packet is not routable so it may be reset.
|
||||
TcpReset(Vec<u8>),
|
||||
TcpReset,
|
||||
/// The packet can be safely ignored.
|
||||
Drop,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue