From 492875c39297e7721fc32ca83ff2a9c02f3441c8 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 13 Oct 2021 20:23:12 -0400 Subject: [PATCH] Consumption task, and keep-alive --- src/config.rs | 27 +++++++++- src/main.rs | 32 +++++++++--- src/wg.rs | 135 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 183 insertions(+), 11 deletions(-) diff --git a/src/config.rs b/src/config.rs index b559e82..1a5ebeb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,7 @@ -use anyhow::Context; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; +use anyhow::Context; use boringtun::crypto::{X25519PublicKey, X25519SecretKey}; use clap::{App, Arg}; @@ -13,6 +13,7 @@ pub struct Config { pub(crate) endpoint_public_key: Arc, pub(crate) endpoint_addr: SocketAddr, pub(crate) source_peer_ip: IpAddr, + pub(crate) keepalive_seconds: Option, } impl Config { @@ -54,7 +55,13 @@ impl Config { .takes_value(true) .long("source-peer-ip") .env("ONETUN_SOURCE_PEER_IP") - .help("The source IP to identify this peer as (local). Example: 192.168.4.3") + .help("The source IP to identify this peer as (local). Example: 192.168.4.3"), + Arg::with_name("keep-alive") + .required(false) + .takes_value(true) + .long("keep-alive") + .env("ONETUN_KEEP_ALIVE") + .help("Configures a persistent keep-alive for the WireGuard tunnel, in seconds.") ]).get_matches(); Ok(Self { @@ -74,6 +81,8 @@ impl Config { .with_context(|| "Invalid endpoint address")?, source_peer_ip: parse_ip(matches.value_of("source-peer-ip")) .with_context(|| "Invalid source peer IP")?, + keepalive_seconds: parse_keep_alive(matches.value_of("keep-alive")) + .with_context(|| "Invalid keep-alive value")?, }) } } @@ -103,3 +112,17 @@ fn parse_public_key(s: Option<&str>) -> anyhow::Result { .map_err(|e| anyhow::anyhow!("{}", e)) .with_context(|| "Invalid public key") } + +fn parse_keep_alive(s: Option<&str>) -> anyhow::Result> { + if let Some(s) = s { + let parsed: u16 = s.parse().with_context(|| { + format!( + "Keep-alive must be a number between 0 and {} seconds", + u16::MAX + ) + })?; + Ok(Some(parsed)) + } else { + Ok(None) + } +} diff --git a/src/main.rs b/src/main.rs index c63e1c0..d9ecaaf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use anyhow::Context; use tokio::io::Interest; -use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::net::{TcpListener, TcpStream}; use crate::config::Config; use crate::port_pool::PortPool; @@ -31,24 +31,38 @@ async fn main() -> anyhow::Result<()> { .with_context(|| "Failed to initialize WireGuard tunnel")?; let wg = Arc::new(wg); - // Start routine task for WireGuard - tokio::spawn(async move { Arc::clone(&wg).routine_task().await }); + { + // Start routine task for WireGuard + let wg = wg.clone(); + tokio::spawn(async move { wg.routine_task().await }); + } + + { + // Start consumption task for WireGuard + let wg = wg.clone(); + tokio::spawn(async move { wg.consume_task().await }); + } info!( "Tunnelling [{}]->[{}] (via [{}] as peer {})", &config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip ); - tcp_proxy_server(config.source_addr.clone(), port_pool.clone()).await + tcp_proxy_server(config.source_addr.clone(), port_pool.clone(), wg).await } /// Starts the server that listens on TCP connections. -async fn tcp_proxy_server(listen_addr: SocketAddr, port_pool: Arc) -> anyhow::Result<()> { +async fn tcp_proxy_server( + listen_addr: SocketAddr, + port_pool: Arc, + wg: Arc, +) -> anyhow::Result<()> { let listener = TcpListener::bind(listen_addr) .await .with_context(|| "Failed to listen on TCP proxy server")?; loop { + let wg = wg.clone(); let port_pool = port_pool.clone(); let (socket, peer_addr) = listener .accept() @@ -73,7 +87,7 @@ async fn tcp_proxy_server(listen_addr: SocketAddr, port_pool: Arc) -> tokio::spawn(async move { let port_pool = Arc::clone(&port_pool); - let result = handle_tcp_proxy_connection(socket, virtual_port).await; + let result = handle_tcp_proxy_connection(socket, virtual_port, wg).await; if let Err(e) = result { error!( @@ -91,7 +105,11 @@ async fn tcp_proxy_server(listen_addr: SocketAddr, port_pool: Arc) -> } /// Handles a new TCP connection with its assigned virtual port. -async fn handle_tcp_proxy_connection(socket: TcpStream, virtual_port: u16) -> anyhow::Result<()> { +async fn handle_tcp_proxy_connection( + socket: TcpStream, + virtual_port: u16, + wg: Arc, +) -> anyhow::Result<()> { loop { let ready = socket .ready(Interest::READABLE | Interest::WRITABLE) diff --git a/src/wg.rs b/src/wg.rs index 7493dd3..b00a854 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -1,5 +1,4 @@ use std::net::SocketAddr; -use std::sync::Arc; use std::time::Duration; use anyhow::Context; @@ -9,6 +8,9 @@ use tokio::net::UdpSocket; use crate::config::Config; use crate::MAX_PACKET; +/// The capacity of the broadcast channel for received IP packets. +const BROADCAST_CAPACITY: usize = 1_000_000; + pub struct WireGuardTunnel { /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. peer: Box, @@ -16,6 +18,8 @@ pub struct WireGuardTunnel { udp: UdpSocket, /// The address of the public WireGuard endpoint (UDP). endpoint: SocketAddr, + /// Broadcast sender for received IP packets. + ip_broadcast_tx: tokio::sync::broadcast::Sender>, } impl WireGuardTunnel { @@ -26,14 +30,51 @@ impl WireGuardTunnel { .await .with_context(|| "Failed to create UDP socket for WireGuard connection")?; let endpoint = config.endpoint_addr; + let (ip_broadcast_tx, _) = tokio::sync::broadcast::channel(BROADCAST_CAPACITY); Ok(Self { peer, udp, endpoint, + ip_broadcast_tx, }) } + /// Encapsulates and sends an IP packet through to the WireGuard endpoint. + pub async fn send_ip_packet(&self, packet: &[u8]) -> anyhow::Result<()> { + let mut send_buf = [0u8; MAX_PACKET]; + match self.peer.encapsulate(packet, &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + self.udp + .send_to(packet, self.endpoint) + .await + .with_context(|| "Failed to send encrypted IP packet to WireGuard endpoint.")?; + debug!( + "Sent {} bytes to WireGuard endpoint (encrypted IP packet)", + packet.len() + ); + } + TunnResult::Err(e) => { + error!("Failed to encapsulate IP packet: {:?}", e); + } + TunnResult::Done => { + // Ignored + } + other => { + error!( + "Unexpected WireGuard state during encapsulation: {:?}", + other + ); + } + }; + Ok(()) + } + + /// Create a new receiver for broadcasted IP packets, received from the WireGuard endpoint. + pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver> { + self.ip_broadcast_tx.subscribe() + } + /// WireGuard Routine task. Handles Handshake, keep-alive, etc. pub async fn routine_task(&self) -> ! { trace!("Starting WireGuard routine task"); @@ -73,12 +114,86 @@ impl WireGuardTunnel { } } + /// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint, + /// decapsulates them, and broadcasts newly received IP packets. + pub async fn consume_task(&self) -> ! { + trace!("Starting WireGuard consumption task"); + + loop { + let mut recv_buf = [0u8; MAX_PACKET]; + let mut send_buf = [0u8; MAX_PACKET]; + + let size = match self.udp.recv(&mut recv_buf).await { + Ok(size) => size, + Err(e) => { + error!("Failed to read from WireGuard endpoint: {:?}", e); + // Sleep a little bit and try again + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } + }; + + let data = &recv_buf[..size]; + match self.peer.decapsulate(None, data, &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); + continue; + } + }; + loop { + let mut send_buf = [0u8; MAX_PACKET]; + match self.peer.decapsulate(None, &[], &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); + break; + } + }; + } + _ => { + break; + } + } + } + } + TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => { + debug!( + "WireGuard endpoint sent IP packet of {} bytes", + packet.len() + ); + + // For debugging purposes: parse packet + trace_ip_packet(packet); + + // Broadcast IP packet + match self.ip_broadcast_tx.send(packet.to_vec()) { + Ok(n) => { + trace!("Broadcasted received IP packet to {} recipients", n); + } + Err(e) => { + error!( + "Failed to broadcast received IP packet to recipients: {:?}", + e + ); + } + } + } + _ => {} + } + } + } + fn create_tunnel(config: &Config) -> anyhow::Result> { Tunn::new( config.private_key.clone(), config.endpoint_public_key.clone(), None, - None, + config.keepalive_seconds, 0, None, ) @@ -86,3 +201,19 @@ impl WireGuardTunnel { .with_context(|| "Failed to initialize boringtun Tunn") } } + +fn trace_ip_packet(packet: &[u8]) { + use smoltcp::wire::*; + + match IpVersion::of_packet(&packet) { + Ok(IpVersion::Ipv4) => trace!( + "IPv4 packet received: {}", + PrettyPrinter::>::new("", &packet) + ), + Ok(IpVersion::Ipv6) => trace!( + "IPv6 packet received: {}", + PrettyPrinter::>::new("", &packet) + ), + _ => {} + } +}