Merge pull request #13 from aramperes/12-simultaneous

This commit is contained in:
Aram 🍐 2021-10-16 19:47:05 -04:00 committed by GitHub
commit ce3b23e562
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 161 additions and 215 deletions

View file

@ -130,8 +130,8 @@ forward it to the server's port, which handles the TCP segment. The server respo
the peer's local WireGuard interface, gets encrypted, forwarded to the WireGuard endpoint, and then finally back to onetun's UDP port.
When onetun receives an encrypted packet from the WireGuard endpoint, it decrypts it using boringtun.
The resulting IP packet is broadcasted to all virtual interfaces running inside onetun; once the corresponding
interface is matched, the IP packet is read and unpacked, and the virtual client's TCP state is updated.
The resulting IP packet is dispatched to the corresponding virtual interface running inside onetun;
the IP packet is then read and processed by the virtual interface, and the virtual client's TCP state is updated.
Whenever data is sent by the real client, it is simply "sent" by the virtual client, which kicks off the whole IP encapsulation
and WireGuard encryption again. When data is sent by the real server, it ends up routed in the virtual interface, which allows

35
src/ip_sink.rs Normal file
View file

@ -0,0 +1,35 @@
use crate::virtual_device::VirtualIpDevice;
use crate::wg::WireGuardTunnel;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::socket::SocketSet;
use std::sync::Arc;
use tokio::time::Duration;
/// A repeating task that processes unroutable IP packets.
pub async fn run_ip_sink_interface(wg: Arc<WireGuardTunnel>) -> ! {
// Initialize interface
let device = VirtualIpDevice::new_sink(wg)
.await
.expect("Failed to initialize VirtualIpDevice for sink interface");
// No sockets on sink interface
let mut socket_set_entries: [_; 0] = Default::default();
let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
let mut virtual_interface = InterfaceBuilder::new(device).ip_addrs([]).finalize();
loop {
let loop_start = smoltcp::time::Instant::now();
match virtual_interface.poll(&mut socket_set, loop_start) {
Ok(processed) if processed => {
trace!("[SINK] Virtual interface polled some packets to be processed",);
tokio::time::sleep(Duration::from_millis(1)).await;
}
Err(e) => {
error!("[SINK] Virtual interface poll error: {:?}", e);
}
_ => {
tokio::time::sleep(Duration::from_millis(5)).await;
}
}
}
}

View file

@ -18,6 +18,7 @@ use crate::virtual_device::VirtualIpDevice;
use crate::wg::WireGuardTunnel;
pub mod config;
pub mod ip_sink;
pub mod port_pool;
pub mod virtual_device;
pub mod wg;
@ -30,7 +31,7 @@ async fn main() -> anyhow::Result<()> {
init_logger(&config)?;
let port_pool = Arc::new(PortPool::new());
let wg = WireGuardTunnel::new(&config, port_pool.clone())
let wg = WireGuardTunnel::new(&config)
.await
.with_context(|| "Failed to initialize WireGuard tunnel")?;
let wg = Arc::new(wg);
@ -48,9 +49,9 @@ async fn main() -> anyhow::Result<()> {
}
{
// Start IP broadcast drain task for WireGuard
// Start IP sink task for incoming IP packets
let wg = wg.clone();
tokio::spawn(async move { wg.broadcast_drain_task().await });
tokio::spawn(async move { ip_sink::run_ip_sink_interface(wg).await });
}
info!(
@ -106,8 +107,13 @@ async fn tcp_proxy_server(
tokio::spawn(async move {
let port_pool = Arc::clone(&port_pool);
let result =
handle_tcp_proxy_connection(socket, virtual_port, source_peer_ip, dest_addr, wg)
let result = handle_tcp_proxy_connection(
socket,
virtual_port,
source_peer_ip,
dest_addr,
wg.clone(),
)
.await;
if let Err(e) = result {
@ -120,6 +126,7 @@ async fn tcp_proxy_server(
}
// Release port when connection drops
wg.release_virtual_interface(virtual_port);
port_pool.release(virtual_port);
});
}
@ -270,14 +277,14 @@ async fn virtual_tcp_interface(
// Consumer for IP packets to send through the virtual interface
// Initialize the interface
let device = VirtualIpDevice::new(wg);
let device = VirtualIpDevice::new(virtual_port, wg)
.with_context(|| "Failed to initialize VirtualIpDevice")?;
let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(dest_addr.ip()), 32),
])
.any_ip(true)
.finalize();
// Server socket: this is a placeholder for the interface to route new connections to.

View file

@ -13,7 +13,7 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
pub struct PortPool {
/// Remaining ports
inner: lockfree::queue::Queue<u16>,
/// Ports in use
/// Ports in use, with their associated IP channel sender.
taken: lockfree::set::Set<u16>,
}

View file

@ -1,26 +1,34 @@
use crate::wg::WireGuardTunnel;
use anyhow::Context;
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant;
use std::sync::Arc;
/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint
/// are made available to this device using a broadcast channel receiver. IP packets sent from this device
/// are made available to this device using a channel receiver. IP packets sent from this device
/// are asynchronously sent out to the WireGuard tunnel.
pub struct VirtualIpDevice {
/// Tunnel to send IP packets to.
wg: Arc<WireGuardTunnel>,
/// Broadcast channel receiver for received IP packets.
ip_broadcast_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
/// Channel receiver for received IP packets.
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
}
impl VirtualIpDevice {
pub fn new(wg: Arc<WireGuardTunnel>) -> Self {
let ip_broadcast_rx = wg.subscribe();
pub fn new(virtual_port: u16, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let ip_dispatch_rx = wg
.register_virtual_interface(virtual_port)
.with_context(|| "Failed to register IP dispatch for virtual interface")?;
Self {
wg,
ip_broadcast_rx,
Ok(Self { wg, ip_dispatch_rx })
}
pub async fn new_sink(wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let ip_dispatch_rx = wg
.register_sink_interface()
.await
.with_context(|| "Failed to register IP dispatch for sink virtual interface")?;
Ok(Self { wg, ip_dispatch_rx })
}
}
@ -29,7 +37,7 @@ impl<'a> Device<'a> for VirtualIpDevice {
type TxToken = TxToken;
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
match self.ip_broadcast_rx.try_recv() {
match self.ip_dispatch_rx.try_recv() {
Ok(buffer) => Some((
Self::RxToken { buffer },
Self::TxToken {

280
src/wg.rs
View file

@ -1,25 +1,18 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use boringtun::noise::{Tunn, TunnResult};
use futures::lock::Mutex;
use log::Level;
use smoltcp::phy::ChecksumCapabilities;
use smoltcp::wire::{
IpAddress, IpProtocol, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, TcpControl,
TcpPacket, TcpRepr, TcpSeqNumber,
};
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket};
use tokio::net::UdpSocket;
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::RwLock;
use crate::config::Config;
use crate::port_pool::PortPool;
use crate::MAX_PACKET;
/// The capacity of the broadcast channel for received IP packets.
const BROADCAST_CAPACITY: usize = 1_000;
/// The capacity of the channel for received IP packets.
const DISPATCH_CAPACITY: usize = 1_000;
/// A WireGuard tunnel. Encapsulates and decapsulates IP packets
/// to be sent to and received from a remote UDP endpoint.
@ -32,34 +25,30 @@ pub struct WireGuardTunnel {
udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP).
endpoint: SocketAddr,
/// Broadcast sender for received IP packets.
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>,
/// Sink so that the broadcaster doesn't close. A repeating task should drain this as much as possible.
ip_broadcast_rx_sink: Mutex<tokio::sync::broadcast::Receiver<Vec<u8>>>,
/// Port pool.
port_pool: Arc<PortPool>,
/// Maps virtual ports to the corresponding IP packet dispatcher.
virtual_port_ip_tx: lockfree::map::Map<u16, tokio::sync::mpsc::Sender<Vec<u8>>>,
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
}
impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config, port_pool: Arc<PortPool>) -> anyhow::Result<Self> {
pub async fn new(config: &Config) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?;
let udp = UdpSocket::bind("0.0.0.0:0")
.await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let endpoint = config.endpoint_addr;
let (ip_broadcast_tx, ip_broadcast_rx_sink) =
tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
let virtual_port_ip_tx = lockfree::map::Map::new();
Ok(Self {
source_peer_ip,
peer,
udp,
endpoint,
ip_broadcast_tx,
ip_broadcast_rx_sink: Mutex::new(ip_broadcast_rx_sink),
port_pool,
virtual_port_ip_tx,
sink_ip_tx: RwLock::new(None),
})
}
@ -94,9 +83,36 @@ impl WireGuardTunnel {
Ok(())
}
/// Create a new receiver for broadcasted IP packets, received from the WireGuard endpoint.
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Vec<u8>> {
self.ip_broadcast_tx.subscribe()
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub fn register_virtual_interface(
&self,
virtual_port: u16,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let existing = self.virtual_port_ip_tx.get(&virtual_port);
if existing.is_some() {
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
} else {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
self.virtual_port_ip_tx.insert(virtual_port, sender);
Ok(receiver)
}
}
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub async fn register_sink_interface(
&self,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
let mut sink_ip_tx = self.sink_ip_tx.write().await;
*sink_ip_tx = Some(sender);
Ok(receiver)
}
/// Releases the virtual interface from IP dispatch.
pub fn release_virtual_interface(&self, virtual_port: u16) {
self.virtual_port_ip_tx.remove(&virtual_port);
}
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
@ -139,7 +155,7 @@ impl WireGuardTunnel {
}
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
/// decapsulates them, and broadcasts newly received IP packets.
/// decapsulates them, and dispatches newly received IP packets.
pub async fn consume_task(&self) -> ! {
trace!("Starting WireGuard consumption task");
@ -195,33 +211,36 @@ impl WireGuardTunnel {
trace_ip_packet("Received IP packet", packet);
match self.route_ip_packet(packet) {
RouteResult::Broadcast => {
// Broadcast IP packet
if self.ip_broadcast_tx.receiver_count() > 1 {
match self.ip_broadcast_tx.send(packet.to_vec()) {
Ok(n) => {
RouteResult::Dispatch(port) => {
let sender = self.virtual_port_ip_tx.get(&port);
if let Some(sender_guard) = sender {
let sender = sender_guard.val();
match sender.send(packet.to_vec()).await {
Ok(_) => {
trace!(
"Broadcasted received IP packet to {} virtual interfaces",
n - 1
"Dispatched received IP packet to virtual port {}",
port
);
}
Err(e) => {
error!(
"Failed to broadcast received IP packet to recipients: {}",
e
"Failed to dispatch received IP packet to virtual port {}: {}",
port, e
);
}
}
} else {
warn!("[{}] Race condition: failed to get virtual port sender after it was dispatched", port);
}
}
RouteResult::TcpReset(packet) => {
trace!("Resetting dead TCP connection after packet from WireGuard endpoint");
self.send_ip_packet(&packet)
.await
.unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e));
RouteResult::Sink => {
trace!("Sending unroutable IP packet received from WireGuard endpoint to sink interface");
self.route_ip_sink(packet).await.unwrap_or_else(|e| {
error!("Failed to send unroutable IP packet to sink: {:?}", e)
});
}
RouteResult::Drop => {
trace!("Dropped incoming IP packet from WireGuard endpoint");
trace!("Dropped unroutable IP packet received from WireGuard endpoint");
}
}
}
@ -230,33 +249,6 @@ impl WireGuardTunnel {
}
}
/// A repeating task that drains the default IP broadcast channel receiver.
/// It is necessary to keep this receiver alive to prevent the overall channel from closing,
/// so draining its backlog regularly is required to avoid memory leaks.
pub async fn broadcast_drain_task(&self) {
trace!("Starting IP broadcast sink drain task");
loop {
let mut sink = self.ip_broadcast_rx_sink.lock().await;
match sink.recv().await {
Ok(_) => {
trace!("Drained a packet from IP broadcast sink");
}
Err(e) => match e {
RecvError::Closed => {
trace!("IP broadcast sink finished draining: channel closed");
break;
}
RecvError::Lagged(_) => {
warn!("IP broadcast sink is falling behind");
}
},
}
}
trace!("Stopped IP broadcast sink drain");
}
fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> {
Tunn::new(
config.private_key.clone(),
@ -278,14 +270,9 @@ impl WireGuardTunnel {
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() {
IpProtocol::Tcp => Some(self.route_tcp_segment(
IpVersion::Ipv4,
packet.src_addr().into(),
packet.dst_addr().into(),
packet.payload(),
)),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
})
.flatten()
.unwrap_or(RouteResult::Drop),
@ -294,14 +281,9 @@ impl WireGuardTunnel {
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.next_header() {
IpProtocol::Tcp => Some(self.route_tcp_segment(
IpVersion::Ipv6,
packet.src_addr().into(),
packet.dst_addr().into(),
packet.payload(),
)),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
// Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop),
})
.flatten()
.unwrap_or(RouteResult::Drop),
@ -310,123 +292,37 @@ impl WireGuardTunnel {
}
/// Makes a decision on the handling of an incoming TCP segment.
fn route_tcp_segment(
&self,
ip_version: IpVersion,
src_addr: IpAddress,
dst_addr: IpAddress,
segment: &[u8],
) -> RouteResult {
fn route_tcp_segment(&self, segment: &[u8]) -> RouteResult {
TcpPacket::new_checked(segment)
.ok()
.map(|tcp| {
if self.port_pool.is_in_use(tcp.dst_port()) {
RouteResult::Broadcast
if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() {
RouteResult::Dispatch(tcp.dst_port())
} else if tcp.rst() {
RouteResult::Drop
} else {
// Port is not in use, but it's a TCP packet so we'll craft a RST.
RouteResult::TcpReset(craft_tcp_rst_reply(
ip_version,
src_addr,
tcp.src_port(),
dst_addr,
tcp.dst_port(),
tcp.ack_number(),
))
RouteResult::Sink
}
})
.unwrap_or(RouteResult::Drop)
}
}
/// Craft an IP packet containing a TCP RST segment, given an IP version,
/// source address (the one to reply to), destination address (the one the reply comes from),
/// and the ACK number received in the initiating TCP segment.
fn craft_tcp_rst_reply(
ip_version: IpVersion,
source_addr: IpAddress,
source_port: u16,
dest_addr: IpAddress,
dest_port: u16,
ack_number: TcpSeqNumber,
) -> Vec<u8> {
let tcp_repr = TcpRepr {
src_port: dest_port,
dst_port: source_port,
control: TcpControl::Rst,
seq_number: ack_number,
ack_number: None,
window_len: 0,
window_scale: None,
max_seg_size: None,
sack_permitted: false,
sack_ranges: [None, None, None],
payload: &[],
};
/// Route a packet to the IP sink interface.
async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> {
let ip_sink_tx = self.sink_ip_tx.read().await;
let mut tcp_buffer = vec![0u8; 20];
let mut tcp_packet = &mut TcpPacket::new_unchecked(&mut tcp_buffer);
tcp_repr.emit(
&mut tcp_packet,
&dest_addr,
&source_addr,
&ChecksumCapabilities::default(),
if let Some(ip_sink_tx) = &*ip_sink_tx {
ip_sink_tx
.send(packet.to_vec())
.await
.with_context(|| "Failed to dispatch IP packet to sink interface")
} else {
warn!(
"Could not dispatch unroutable IP packet to sink because interface is not active."
);
let mut ip_buffer = vec![0u8; MAX_PACKET];
let (header_len, total_len) = match ip_version {
IpVersion::Ipv4 => {
let dest_addr = match dest_addr {
IpAddress::Ipv4(dest_addr) => dest_addr,
_ => panic!(),
};
let source_addr = match source_addr {
IpAddress::Ipv4(source_addr) => source_addr,
_ => panic!(),
};
let mut ip_packet = &mut Ipv4Packet::new_unchecked(&mut ip_buffer);
let ip_repr = Ipv4Repr {
src_addr: dest_addr,
dst_addr: source_addr,
protocol: IpProtocol::Tcp,
payload_len: tcp_buffer.len(),
hop_limit: 64,
};
ip_repr.emit(&mut ip_packet, &ChecksumCapabilities::default());
(
ip_packet.header_len() as usize,
ip_packet.total_len() as usize,
)
Ok(())
}
IpVersion::Ipv6 => {
let dest_addr = match dest_addr {
IpAddress::Ipv6(dest_addr) => dest_addr,
_ => panic!(),
};
let source_addr = match source_addr {
IpAddress::Ipv6(source_addr) => source_addr,
_ => panic!(),
};
let mut ip_packet = &mut Ipv6Packet::new_unchecked(&mut ip_buffer);
let ip_repr = Ipv6Repr {
src_addr: dest_addr,
dst_addr: source_addr,
next_header: IpProtocol::Tcp,
payload_len: tcp_buffer.len(),
hop_limit: 64,
};
ip_repr.emit(&mut ip_packet);
(ip_packet.header_len(), ip_packet.total_len())
}
_ => panic!(),
};
ip_buffer[header_len..total_len].copy_from_slice(&tcp_buffer);
let packet: &[u8] = &ip_buffer[..total_len];
packet.to_vec()
}
fn trace_ip_packet(message: &str, packet: &[u8]) {
@ -450,10 +346,10 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
}
enum RouteResult {
/// The packet can be broadcasted to the virtual interfaces
Broadcast,
/// The packet is not routable so it may be reset.
TcpReset(Vec<u8>),
/// The packet can be safely ignored.
/// Dispatch the packet to the virtual port.
Dispatch(u16),
/// The packet is not routable, and should be sent to the sink interface.
Sink,
/// The packet is not routable, and can be safely ignored.
Drop,
}