From 1d703facc0ab323270a842b2a4021a7eeb06d16f Mon Sep 17 00:00:00 2001 From: Aram Peres <6775216+aramperes@users.noreply.github.com> Date: Sun, 24 Dec 2023 14:42:34 -0500 Subject: [PATCH] Implement locking of Tunn in WireGuardTunnel --- src/wg.rs | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/wg.rs b/src/wg.rs index e1edf01..5f5b735 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -9,6 +9,7 @@ use boringtun::noise::{Tunn, TunnResult}; use log::Level; use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet}; use tokio::net::UdpSocket; +use tokio::sync::Mutex; use crate::config::{Config, PortProtocol}; use crate::events::Event; @@ -23,7 +24,7 @@ const MAX_PACKET: usize = 65536; pub struct WireGuardTunnel { pub(crate) source_peer_ip: IpAddr, /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. - peer: Box, + peer: Mutex>, /// The UDP socket for the public WireGuard endpoint to connect to. udp: UdpSocket, /// The address of the public WireGuard endpoint (UDP). @@ -36,7 +37,7 @@ impl WireGuardTunnel { /// Initialize a new WireGuard tunnel. pub async fn new(config: &Config, bus: Bus) -> anyhow::Result { let source_peer_ip = config.source_peer_ip; - let peer = Self::create_tunnel(config)?; + let peer = Mutex::new(Box::new(Self::create_tunnel(config)?)); let endpoint = config.endpoint_addr; let udp = UdpSocket::bind(config.endpoint_bind_addr) .await @@ -55,7 +56,11 @@ impl WireGuardTunnel { pub async fn send_ip_packet(&self, packet: &[u8]) -> anyhow::Result<()> { trace_ip_packet("Sending IP packet", packet); let mut send_buf = [0u8; MAX_PACKET]; - match self.peer.encapsulate(packet, &mut send_buf) { + let encapsulate_result = { + let mut peer = self.peer.lock().await; + peer.encapsulate(packet, &mut send_buf) + }; + match encapsulate_result { TunnResult::WriteToNetwork(packet) => { self.udp .send_to(packet, self.endpoint) @@ -104,7 +109,7 @@ impl WireGuardTunnel { loop { let mut send_buf = [0u8; MAX_PACKET]; - let tun_result = self.peer.update_timers(&mut send_buf); + let tun_result = { self.peer.lock().await.update_timers(&mut send_buf) }; self.handle_routine_tun_result(tun_result).await; } } @@ -131,7 +136,11 @@ impl WireGuardTunnel { warn!("Wireguard handshake has expired!"); let mut buf = vec![0u8; MAX_PACKET]; - let result = self.peer.format_handshake_initiation(&mut buf[..], false); + let result = self + .peer + .lock() + .await + .format_handshake_initiation(&mut buf[..], false); self.handle_routine_tun_result(result).await } @@ -172,7 +181,11 @@ impl WireGuardTunnel { }; let data = &recv_buf[..size]; - match self.peer.decapsulate(None, data, &mut send_buf) { + let decapsulate_result = { + let mut peer = self.peer.lock().await; + peer.decapsulate(None, data, &mut send_buf) + }; + match decapsulate_result { TunnResult::WriteToNetwork(packet) => { match self.udp.send_to(packet, self.endpoint).await { Ok(_) => {} @@ -181,9 +194,10 @@ impl WireGuardTunnel { continue; } }; + let mut peer = self.peer.lock().await; loop { let mut send_buf = [0u8; MAX_PACKET]; - match self.peer.decapsulate(None, &[], &mut send_buf) { + match peer.decapsulate(None, &[], &mut send_buf) { TunnResult::WriteToNetwork(packet) => { match self.udp.send_to(packet, self.endpoint).await { Ok(_) => {} @@ -217,10 +231,13 @@ impl WireGuardTunnel { } } - fn create_tunnel(config: &Config) -> anyhow::Result> { + fn create_tunnel(config: &Config) -> anyhow::Result { + let private = config.private_key.as_ref().clone(); + let public = *config.endpoint_public_key.as_ref(); + Tunn::new( - config.private_key.clone(), - config.endpoint_public_key.clone(), + private, + public, config.preshared_key, config.keepalive_seconds, 0,