Implement locking of Tunn in WireGuardTunnel

This commit is contained in:
Aram 🍐 2023-12-24 14:42:34 -05:00
parent e23cfc3e7e
commit 1d703facc0

View file

@ -9,6 +9,7 @@ use boringtun::noise::{Tunn, TunnResult};
use log::Level; use log::Level;
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet}; use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use crate::config::{Config, PortProtocol}; use crate::config::{Config, PortProtocol};
use crate::events::Event; use crate::events::Event;
@ -23,7 +24,7 @@ const MAX_PACKET: usize = 65536;
pub struct WireGuardTunnel { pub struct WireGuardTunnel {
pub(crate) source_peer_ip: IpAddr, pub(crate) source_peer_ip: IpAddr,
/// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol.
peer: Box<Tunn>, peer: Mutex<Box<Tunn>>,
/// The UDP socket for the public WireGuard endpoint to connect to. /// The UDP socket for the public WireGuard endpoint to connect to.
udp: UdpSocket, udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP). /// The address of the public WireGuard endpoint (UDP).
@ -36,7 +37,7 @@ impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel. /// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config, bus: Bus) -> anyhow::Result<Self> { pub async fn new(config: &Config, bus: Bus) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip; let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?; let peer = Mutex::new(Box::new(Self::create_tunnel(config)?));
let endpoint = config.endpoint_addr; let endpoint = config.endpoint_addr;
let udp = UdpSocket::bind(config.endpoint_bind_addr) let udp = UdpSocket::bind(config.endpoint_bind_addr)
.await .await
@ -55,7 +56,11 @@ impl WireGuardTunnel {
pub async fn send_ip_packet(&self, packet: &[u8]) -> anyhow::Result<()> { pub async fn send_ip_packet(&self, packet: &[u8]) -> anyhow::Result<()> {
trace_ip_packet("Sending IP packet", packet); trace_ip_packet("Sending IP packet", packet);
let mut send_buf = [0u8; MAX_PACKET]; let mut send_buf = [0u8; MAX_PACKET];
match self.peer.encapsulate(packet, &mut send_buf) { let encapsulate_result = {
let mut peer = self.peer.lock().await;
peer.encapsulate(packet, &mut send_buf)
};
match encapsulate_result {
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
self.udp self.udp
.send_to(packet, self.endpoint) .send_to(packet, self.endpoint)
@ -104,7 +109,7 @@ impl WireGuardTunnel {
loop { loop {
let mut send_buf = [0u8; MAX_PACKET]; let mut send_buf = [0u8; MAX_PACKET];
let tun_result = self.peer.update_timers(&mut send_buf); let tun_result = { self.peer.lock().await.update_timers(&mut send_buf) };
self.handle_routine_tun_result(tun_result).await; self.handle_routine_tun_result(tun_result).await;
} }
} }
@ -131,7 +136,11 @@ impl WireGuardTunnel {
warn!("Wireguard handshake has expired!"); warn!("Wireguard handshake has expired!");
let mut buf = vec![0u8; MAX_PACKET]; let mut buf = vec![0u8; MAX_PACKET];
let result = self.peer.format_handshake_initiation(&mut buf[..], false); let result = self
.peer
.lock()
.await
.format_handshake_initiation(&mut buf[..], false);
self.handle_routine_tun_result(result).await self.handle_routine_tun_result(result).await
} }
@ -172,7 +181,11 @@ impl WireGuardTunnel {
}; };
let data = &recv_buf[..size]; let data = &recv_buf[..size];
match self.peer.decapsulate(None, data, &mut send_buf) { let decapsulate_result = {
let mut peer = self.peer.lock().await;
peer.decapsulate(None, data, &mut send_buf)
};
match decapsulate_result {
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
match self.udp.send_to(packet, self.endpoint).await { match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {} Ok(_) => {}
@ -181,9 +194,10 @@ impl WireGuardTunnel {
continue; continue;
} }
}; };
let mut peer = self.peer.lock().await;
loop { loop {
let mut send_buf = [0u8; MAX_PACKET]; let mut send_buf = [0u8; MAX_PACKET];
match self.peer.decapsulate(None, &[], &mut send_buf) { match peer.decapsulate(None, &[], &mut send_buf) {
TunnResult::WriteToNetwork(packet) => { TunnResult::WriteToNetwork(packet) => {
match self.udp.send_to(packet, self.endpoint).await { match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {} Ok(_) => {}
@ -217,10 +231,13 @@ impl WireGuardTunnel {
} }
} }
fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> { fn create_tunnel(config: &Config) -> anyhow::Result<Tunn> {
let private = config.private_key.as_ref().clone();
let public = *config.endpoint_public_key.as_ref();
Tunn::new( Tunn::new(
config.private_key.clone(), private,
config.endpoint_public_key.clone(), public,
config.preshared_key, config.preshared_key,
config.keepalive_seconds, config.keepalive_seconds,
0, 0,