From 11c5ec99fd5aef913f210558b03490a5fb6b88da Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 20 Oct 2021 16:05:04 -0400 Subject: [PATCH] Replace lockfree with tokio::sync --- Cargo.lock | 27 +++++++++------------ Cargo.toml | 2 +- src/main.rs | 2 +- src/tunnel/mod.rs | 2 +- src/tunnel/tcp.rs | 60 ++++++++++++++++++++++++----------------------- src/tunnel/udp.rs | 2 ++ src/wg.rs | 10 ++++---- 7 files changed, 52 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f3ba00d..885b05b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -203,6 +203,16 @@ dependencies = [ "libc", ] +[[package]] +name = "dashmap" +version = "4.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e77a43b28d0668df09411cb0bc9a8c2adc40f9a048afe863e05fd43251e8e39c" +dependencies = [ + "cfg-if", + "num_cpus", +] + [[package]] name = "dirs-next" version = "2.0.0" @@ -469,15 +479,6 @@ dependencies = [ "scopeguard", ] -[[package]] -name = "lockfree" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74ee94b5ad113c7cb98c5a040f783d0952ee4fe100993881d1673c2cb002dd23" -dependencies = [ - "owned-alloc", -] - [[package]] name = "log" version = "0.4.14" @@ -609,8 +610,8 @@ dependencies = [ "async-trait", "boringtun", "clap", + "dashmap", "futures", - "lockfree", "log", "nom", "pretty_env_logger", @@ -619,12 +620,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "owned-alloc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30fceb411f9a12ff9222c5f824026be368ff15dc2f13468d850c7d3f502205d6" - [[package]] name = "parking_lot" version = "0.11.2" diff --git a/Cargo.toml b/Cargo.toml index d552c2b..905ac32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,8 @@ pretty_env_logger = "0.3" anyhow = "1" smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" } tokio = { version = "1", features = ["full"] } -lockfree = "0.5.1" futures = "0.3.17" rand = "0.8.4" nom = "7" async-trait = "0.1.51" +dashmap = "4.0.2" diff --git a/src/main.rs b/src/main.rs index 0d301d0..b2f1f75 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> { init_logger(&config)?; // Initialize the port pool for each protocol - let tcp_port_pool = Arc::new(TcpPortPool::new()); + let tcp_port_pool = TcpPortPool::new(); // TODO: udp_port_pool let wg = WireGuardTunnel::new(&config) diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 00fe4e8..a4042c9 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -12,7 +12,7 @@ pub mod udp; pub async fn port_forward( port_forward: PortForwardConfig, source_peer_ip: IpAddr, - tcp_port_pool: Arc, + tcp_port_pool: TcpPortPool, wg: Arc, ) -> anyhow::Result<()> { info!( diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index 1f01642..fbbdbc2 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -3,6 +3,7 @@ use crate::virtual_iface::tcp::TcpVirtualInterface; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; +use std::collections::{HashSet, VecDeque}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; @@ -20,7 +21,7 @@ const PORT_RANGE: Range = MIN_PORT..MAX_PORT; /// Starts the server that listens on TCP connections. pub async fn tcp_proxy_server( port_forward: PortForwardConfig, - port_pool: Arc, + port_pool: TcpPortPool, wg: Arc, ) -> anyhow::Result<()> { let listener = TcpListener::bind(port_forward.source) @@ -38,7 +39,7 @@ pub async fn tcp_proxy_server( // Assign a 'virtual port': this is a unique port number used to route IP packets // received from the WireGuard tunnel. It is the port number that the virtual client will // listen on. - let virtual_port = match port_pool.next() { + let virtual_port = match port_pool.next().await { Ok(port) => port, Err(e) => { error!( @@ -52,7 +53,7 @@ pub async fn tcp_proxy_server( info!("[{}] Incoming connection from {}", virtual_port, peer_addr); tokio::spawn(async move { - let port_pool = Arc::clone(&port_pool); + let port_pool = port_pool.clone(); let result = handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await; @@ -67,7 +68,7 @@ pub async fn tcp_proxy_server( // Release port when connection drops wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp)); - port_pool.release(virtual_port); + port_pool.release(virtual_port).await; }); } } @@ -203,12 +204,9 @@ async fn handle_tcp_proxy_connection( } /// A pool of virtual ports available for TCP connections. -/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`. +#[derive(Clone)] pub struct TcpPortPool { - /// Remaining ports - inner: lockfree::queue::Queue, - /// Ports in use, with their associated IP channel sender. - taken: lockfree::set::Set, + inner: Arc>, } impl Default for TcpPortPool { @@ -220,37 +218,41 @@ impl Default for TcpPortPool { impl TcpPortPool { /// Initializes a new pool of virtual ports. pub fn new() -> Self { - let inner = lockfree::queue::Queue::default(); + let mut inner = TcpPortPoolInner::default(); let mut ports: Vec = PORT_RANGE.collect(); ports.shuffle(&mut thread_rng()); - ports.into_iter().for_each(|p| inner.push(p) as ()); + ports + .into_iter() + .for_each(|p| inner.queue.push_back(p) as ()); Self { - inner, - taken: lockfree::set::Set::new(), + inner: Arc::new(tokio::sync::RwLock::new(inner)), } } /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). - pub fn next(&self) -> anyhow::Result { - let port = self - .inner - .pop() + pub async fn next(&self) -> anyhow::Result { + let mut inner = self.inner.write().await; + let port = inner + .queue + .pop_front() .with_context(|| "Virtual port pool is exhausted")?; - self.taken - .insert(port) - .ok() - .with_context(|| "Failed to insert taken")?; + inner.taken.insert(port); Ok(port) } /// Releases a port back into the pool. - pub fn release(&self, port: u16) { - self.inner.push(port); - self.taken.remove(&port); - } - - /// Whether the given port is in use by a virtual interface. - pub fn is_in_use(&self, port: u16) -> bool { - self.taken.contains(&port) + pub async fn release(&self, port: u16) { + let mut inner = self.inner.write().await; + inner.queue.push_back(port); + inner.taken.remove(&port); } } + +/// Non thread-safe inner logic for TCP port pool. +#[derive(Debug, Clone, Default)] +struct TcpPortPoolInner { + /// Remaining ports in the pool. + queue: VecDeque, + /// Ports taken out of the pool. + taken: HashSet, +} diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 7f9fda4..2326351 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -9,9 +9,11 @@ use crate::wg::WireGuardTunnel; const MAX_PACKET: usize = 65536; /// How long to keep the UDP peer address assigned to its virtual specified port, in seconds. +/// TODO: Make this configurable by the CLI const UDP_TIMEOUT_SECONDS: u64 = 60; /// To prevent port-flooding, we set a limit on the amount of open ports per IP address. +/// TODO: Make this configurable by the CLI const PORTS_PER_IP: usize = 100; pub async fn udp_proxy_server( diff --git a/src/wg.rs b/src/wg.rs index d10ad50..653568e 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -27,7 +27,7 @@ pub struct WireGuardTunnel { /// The address of the public WireGuard endpoint (UDP). pub(crate) endpoint: SocketAddr, /// Maps virtual ports to the corresponding IP packet dispatcher. - virtual_port_ip_tx: lockfree::map::Map>>, + virtual_port_ip_tx: dashmap::DashMap>>, /// IP packet dispatcher for unroutable packets. `None` if not initialized. sink_ip_tx: RwLock>>>, } @@ -41,7 +41,7 @@ impl WireGuardTunnel { .await .with_context(|| "Failed to create UDP socket for WireGuard connection")?; let endpoint = config.endpoint_addr; - let virtual_port_ip_tx = lockfree::map::Map::new(); + let virtual_port_ip_tx = Default::default(); Ok(Self { source_peer_ip, @@ -89,8 +89,8 @@ impl WireGuardTunnel { &self, virtual_port: VirtualPort, ) -> anyhow::Result>> { - let existing = self.virtual_port_ip_tx.get(&virtual_port); - if existing.is_some() { + let existing = self.virtual_port_ip_tx.contains_key(&virtual_port); + if existing { Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) } else { let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); @@ -215,7 +215,7 @@ impl WireGuardTunnel { RouteResult::Dispatch(port) => { let sender = self.virtual_port_ip_tx.get(&port); if let Some(sender_guard) = sender { - let sender = sender_guard.val(); + let sender = sender_guard.value(); match sender.send(packet.to_vec()).await { Ok(_) => { trace!(