Replace lockfree with tokio::sync

This commit is contained in:
Aram 🍐 2021-10-20 16:05:04 -04:00
parent 5cec6d4943
commit 11c5ec99fd
7 changed files with 52 additions and 53 deletions

View file

@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> {
init_logger(&config)?;
// Initialize the port pool for each protocol
let tcp_port_pool = Arc::new(TcpPortPool::new());
let tcp_port_pool = TcpPortPool::new();
// TODO: udp_port_pool
let wg = WireGuardTunnel::new(&config)

View file

@ -12,7 +12,7 @@ pub mod udp;
pub async fn port_forward(
port_forward: PortForwardConfig,
source_peer_ip: IpAddr,
tcp_port_pool: Arc<TcpPortPool>,
tcp_port_pool: TcpPortPool,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
info!(

View file

@ -3,6 +3,7 @@ use crate::virtual_iface::tcp::TcpVirtualInterface;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel;
use anyhow::Context;
use std::collections::{HashSet, VecDeque};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
@ -20,7 +21,7 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
/// Starts the server that listens on TCP connections.
pub async fn tcp_proxy_server(
port_forward: PortForwardConfig,
port_pool: Arc<TcpPortPool>,
port_pool: TcpPortPool,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(port_forward.source)
@ -38,7 +39,7 @@ pub async fn tcp_proxy_server(
// Assign a 'virtual port': this is a unique port number used to route IP packets
// received from the WireGuard tunnel. It is the port number that the virtual client will
// listen on.
let virtual_port = match port_pool.next() {
let virtual_port = match port_pool.next().await {
Ok(port) => port,
Err(e) => {
error!(
@ -52,7 +53,7 @@ pub async fn tcp_proxy_server(
info!("[{}] Incoming connection from {}", virtual_port, peer_addr);
tokio::spawn(async move {
let port_pool = Arc::clone(&port_pool);
let port_pool = port_pool.clone();
let result =
handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await;
@ -67,7 +68,7 @@ pub async fn tcp_proxy_server(
// Release port when connection drops
wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp));
port_pool.release(virtual_port);
port_pool.release(virtual_port).await;
});
}
}
@ -203,12 +204,9 @@ async fn handle_tcp_proxy_connection(
}
/// A pool of virtual ports available for TCP connections.
/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`.
#[derive(Clone)]
pub struct TcpPortPool {
/// Remaining ports
inner: lockfree::queue::Queue<u16>,
/// Ports in use, with their associated IP channel sender.
taken: lockfree::set::Set<u16>,
inner: Arc<tokio::sync::RwLock<TcpPortPoolInner>>,
}
impl Default for TcpPortPool {
@ -220,37 +218,41 @@ impl Default for TcpPortPool {
impl TcpPortPool {
/// Initializes a new pool of virtual ports.
pub fn new() -> Self {
let inner = lockfree::queue::Queue::default();
let mut inner = TcpPortPoolInner::default();
let mut ports: Vec<u16> = PORT_RANGE.collect();
ports.shuffle(&mut thread_rng());
ports.into_iter().for_each(|p| inner.push(p) as ());
ports
.into_iter()
.for_each(|p| inner.queue.push_back(p) as ());
Self {
inner,
taken: lockfree::set::Set::new(),
inner: Arc::new(tokio::sync::RwLock::new(inner)),
}
}
/// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity).
pub fn next(&self) -> anyhow::Result<u16> {
let port = self
.inner
.pop()
pub async fn next(&self) -> anyhow::Result<u16> {
let mut inner = self.inner.write().await;
let port = inner
.queue
.pop_front()
.with_context(|| "Virtual port pool is exhausted")?;
self.taken
.insert(port)
.ok()
.with_context(|| "Failed to insert taken")?;
inner.taken.insert(port);
Ok(port)
}
/// Releases a port back into the pool.
pub fn release(&self, port: u16) {
self.inner.push(port);
self.taken.remove(&port);
}
/// Whether the given port is in use by a virtual interface.
pub fn is_in_use(&self, port: u16) -> bool {
self.taken.contains(&port)
pub async fn release(&self, port: u16) {
let mut inner = self.inner.write().await;
inner.queue.push_back(port);
inner.taken.remove(&port);
}
}
/// Non thread-safe inner logic for TCP port pool.
#[derive(Debug, Clone, Default)]
struct TcpPortPoolInner {
/// Remaining ports in the pool.
queue: VecDeque<u16>,
/// Ports taken out of the pool.
taken: HashSet<u16>,
}

View file

@ -9,9 +9,11 @@ use crate::wg::WireGuardTunnel;
const MAX_PACKET: usize = 65536;
/// How long to keep the UDP peer address assigned to its virtual specified port, in seconds.
/// TODO: Make this configurable by the CLI
const UDP_TIMEOUT_SECONDS: u64 = 60;
/// To prevent port-flooding, we set a limit on the amount of open ports per IP address.
/// TODO: Make this configurable by the CLI
const PORTS_PER_IP: usize = 100;
pub async fn udp_proxy_server(

View file

@ -27,7 +27,7 @@ pub struct WireGuardTunnel {
/// The address of the public WireGuard endpoint (UDP).
pub(crate) endpoint: SocketAddr,
/// Maps virtual ports to the corresponding IP packet dispatcher.
virtual_port_ip_tx: lockfree::map::Map<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
virtual_port_ip_tx: dashmap::DashMap<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
}
@ -41,7 +41,7 @@ impl WireGuardTunnel {
.await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let endpoint = config.endpoint_addr;
let virtual_port_ip_tx = lockfree::map::Map::new();
let virtual_port_ip_tx = Default::default();
Ok(Self {
source_peer_ip,
@ -89,8 +89,8 @@ impl WireGuardTunnel {
&self,
virtual_port: VirtualPort,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let existing = self.virtual_port_ip_tx.get(&virtual_port);
if existing.is_some() {
let existing = self.virtual_port_ip_tx.contains_key(&virtual_port);
if existing {
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
} else {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
@ -215,7 +215,7 @@ impl WireGuardTunnel {
RouteResult::Dispatch(port) => {
let sender = self.virtual_port_ip_tx.get(&port);
if let Some(sender_guard) = sender {
let sender = sender_guard.val();
let sender = sender_guard.value();
match sender.send(packet.to_vec()).await {
Ok(_) => {
trace!(