mirror of
https://github.com/arampoire/onetun.git
synced 2025-12-01 00:20:24 -05:00
Replace lockfree with tokio::sync
This commit is contained in:
parent
5cec6d4943
commit
11c5ec99fd
7 changed files with 52 additions and 53 deletions
27
Cargo.lock
generated
27
Cargo.lock
generated
|
|
@ -203,6 +203,16 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dashmap"
|
||||||
|
version = "4.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e77a43b28d0668df09411cb0bc9a8c2adc40f9a048afe863e05fd43251e8e39c"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"num_cpus",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs-next"
|
name = "dirs-next"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
|
|
@ -469,15 +479,6 @@ dependencies = [
|
||||||
"scopeguard",
|
"scopeguard",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "lockfree"
|
|
||||||
version = "0.5.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "74ee94b5ad113c7cb98c5a040f783d0952ee4fe100993881d1673c2cb002dd23"
|
|
||||||
dependencies = [
|
|
||||||
"owned-alloc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.14"
|
version = "0.4.14"
|
||||||
|
|
@ -609,8 +610,8 @@ dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"boringtun",
|
"boringtun",
|
||||||
"clap",
|
"clap",
|
||||||
|
"dashmap",
|
||||||
"futures",
|
"futures",
|
||||||
"lockfree",
|
|
||||||
"log",
|
"log",
|
||||||
"nom",
|
"nom",
|
||||||
"pretty_env_logger",
|
"pretty_env_logger",
|
||||||
|
|
@ -619,12 +620,6 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "owned-alloc"
|
|
||||||
version = "0.2.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "30fceb411f9a12ff9222c5f824026be368ff15dc2f13468d850c7d3f502205d6"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot"
|
name = "parking_lot"
|
||||||
version = "0.11.2"
|
version = "0.11.2"
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@ pretty_env_logger = "0.3"
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" }
|
smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" }
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
lockfree = "0.5.1"
|
|
||||||
futures = "0.3.17"
|
futures = "0.3.17"
|
||||||
rand = "0.8.4"
|
rand = "0.8.4"
|
||||||
nom = "7"
|
nom = "7"
|
||||||
async-trait = "0.1.51"
|
async-trait = "0.1.51"
|
||||||
|
dashmap = "4.0.2"
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
init_logger(&config)?;
|
init_logger(&config)?;
|
||||||
|
|
||||||
// Initialize the port pool for each protocol
|
// Initialize the port pool for each protocol
|
||||||
let tcp_port_pool = Arc::new(TcpPortPool::new());
|
let tcp_port_pool = TcpPortPool::new();
|
||||||
// TODO: udp_port_pool
|
// TODO: udp_port_pool
|
||||||
|
|
||||||
let wg = WireGuardTunnel::new(&config)
|
let wg = WireGuardTunnel::new(&config)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ pub mod udp;
|
||||||
pub async fn port_forward(
|
pub async fn port_forward(
|
||||||
port_forward: PortForwardConfig,
|
port_forward: PortForwardConfig,
|
||||||
source_peer_ip: IpAddr,
|
source_peer_ip: IpAddr,
|
||||||
tcp_port_pool: Arc<TcpPortPool>,
|
tcp_port_pool: TcpPortPool,
|
||||||
wg: Arc<WireGuardTunnel>,
|
wg: Arc<WireGuardTunnel>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
info!(
|
info!(
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ use crate::virtual_iface::tcp::TcpVirtualInterface;
|
||||||
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
|
||||||
use crate::wg::WireGuardTunnel;
|
use crate::wg::WireGuardTunnel;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
use std::collections::{HashSet, VecDeque};
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
@ -20,7 +21,7 @@ const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
|
||||||
/// Starts the server that listens on TCP connections.
|
/// Starts the server that listens on TCP connections.
|
||||||
pub async fn tcp_proxy_server(
|
pub async fn tcp_proxy_server(
|
||||||
port_forward: PortForwardConfig,
|
port_forward: PortForwardConfig,
|
||||||
port_pool: Arc<TcpPortPool>,
|
port_pool: TcpPortPool,
|
||||||
wg: Arc<WireGuardTunnel>,
|
wg: Arc<WireGuardTunnel>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let listener = TcpListener::bind(port_forward.source)
|
let listener = TcpListener::bind(port_forward.source)
|
||||||
|
|
@ -38,7 +39,7 @@ pub async fn tcp_proxy_server(
|
||||||
// Assign a 'virtual port': this is a unique port number used to route IP packets
|
// Assign a 'virtual port': this is a unique port number used to route IP packets
|
||||||
// received from the WireGuard tunnel. It is the port number that the virtual client will
|
// received from the WireGuard tunnel. It is the port number that the virtual client will
|
||||||
// listen on.
|
// listen on.
|
||||||
let virtual_port = match port_pool.next() {
|
let virtual_port = match port_pool.next().await {
|
||||||
Ok(port) => port,
|
Ok(port) => port,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(
|
error!(
|
||||||
|
|
@ -52,7 +53,7 @@ pub async fn tcp_proxy_server(
|
||||||
info!("[{}] Incoming connection from {}", virtual_port, peer_addr);
|
info!("[{}] Incoming connection from {}", virtual_port, peer_addr);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let port_pool = Arc::clone(&port_pool);
|
let port_pool = port_pool.clone();
|
||||||
let result =
|
let result =
|
||||||
handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await;
|
handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await;
|
||||||
|
|
||||||
|
|
@ -67,7 +68,7 @@ pub async fn tcp_proxy_server(
|
||||||
|
|
||||||
// Release port when connection drops
|
// Release port when connection drops
|
||||||
wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp));
|
wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp));
|
||||||
port_pool.release(virtual_port);
|
port_pool.release(virtual_port).await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -203,12 +204,9 @@ async fn handle_tcp_proxy_connection(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A pool of virtual ports available for TCP connections.
|
/// A pool of virtual ports available for TCP connections.
|
||||||
/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`.
|
#[derive(Clone)]
|
||||||
pub struct TcpPortPool {
|
pub struct TcpPortPool {
|
||||||
/// Remaining ports
|
inner: Arc<tokio::sync::RwLock<TcpPortPoolInner>>,
|
||||||
inner: lockfree::queue::Queue<u16>,
|
|
||||||
/// Ports in use, with their associated IP channel sender.
|
|
||||||
taken: lockfree::set::Set<u16>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TcpPortPool {
|
impl Default for TcpPortPool {
|
||||||
|
|
@ -220,37 +218,41 @@ impl Default for TcpPortPool {
|
||||||
impl TcpPortPool {
|
impl TcpPortPool {
|
||||||
/// Initializes a new pool of virtual ports.
|
/// Initializes a new pool of virtual ports.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let inner = lockfree::queue::Queue::default();
|
let mut inner = TcpPortPoolInner::default();
|
||||||
let mut ports: Vec<u16> = PORT_RANGE.collect();
|
let mut ports: Vec<u16> = PORT_RANGE.collect();
|
||||||
ports.shuffle(&mut thread_rng());
|
ports.shuffle(&mut thread_rng());
|
||||||
ports.into_iter().for_each(|p| inner.push(p) as ());
|
ports
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|p| inner.queue.push_back(p) as ());
|
||||||
Self {
|
Self {
|
||||||
inner,
|
inner: Arc::new(tokio::sync::RwLock::new(inner)),
|
||||||
taken: lockfree::set::Set::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity).
|
/// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity).
|
||||||
pub fn next(&self) -> anyhow::Result<u16> {
|
pub async fn next(&self) -> anyhow::Result<u16> {
|
||||||
let port = self
|
let mut inner = self.inner.write().await;
|
||||||
.inner
|
let port = inner
|
||||||
.pop()
|
.queue
|
||||||
|
.pop_front()
|
||||||
.with_context(|| "Virtual port pool is exhausted")?;
|
.with_context(|| "Virtual port pool is exhausted")?;
|
||||||
self.taken
|
inner.taken.insert(port);
|
||||||
.insert(port)
|
|
||||||
.ok()
|
|
||||||
.with_context(|| "Failed to insert taken")?;
|
|
||||||
Ok(port)
|
Ok(port)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Releases a port back into the pool.
|
/// Releases a port back into the pool.
|
||||||
pub fn release(&self, port: u16) {
|
pub async fn release(&self, port: u16) {
|
||||||
self.inner.push(port);
|
let mut inner = self.inner.write().await;
|
||||||
self.taken.remove(&port);
|
inner.queue.push_back(port);
|
||||||
}
|
inner.taken.remove(&port);
|
||||||
|
|
||||||
/// Whether the given port is in use by a virtual interface.
|
|
||||||
pub fn is_in_use(&self, port: u16) -> bool {
|
|
||||||
self.taken.contains(&port)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Non thread-safe inner logic for TCP port pool.
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
struct TcpPortPoolInner {
|
||||||
|
/// Remaining ports in the pool.
|
||||||
|
queue: VecDeque<u16>,
|
||||||
|
/// Ports taken out of the pool.
|
||||||
|
taken: HashSet<u16>,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,11 @@ use crate::wg::WireGuardTunnel;
|
||||||
const MAX_PACKET: usize = 65536;
|
const MAX_PACKET: usize = 65536;
|
||||||
|
|
||||||
/// How long to keep the UDP peer address assigned to its virtual specified port, in seconds.
|
/// How long to keep the UDP peer address assigned to its virtual specified port, in seconds.
|
||||||
|
/// TODO: Make this configurable by the CLI
|
||||||
const UDP_TIMEOUT_SECONDS: u64 = 60;
|
const UDP_TIMEOUT_SECONDS: u64 = 60;
|
||||||
|
|
||||||
/// To prevent port-flooding, we set a limit on the amount of open ports per IP address.
|
/// To prevent port-flooding, we set a limit on the amount of open ports per IP address.
|
||||||
|
/// TODO: Make this configurable by the CLI
|
||||||
const PORTS_PER_IP: usize = 100;
|
const PORTS_PER_IP: usize = 100;
|
||||||
|
|
||||||
pub async fn udp_proxy_server(
|
pub async fn udp_proxy_server(
|
||||||
|
|
|
||||||
10
src/wg.rs
10
src/wg.rs
|
|
@ -27,7 +27,7 @@ pub struct WireGuardTunnel {
|
||||||
/// The address of the public WireGuard endpoint (UDP).
|
/// The address of the public WireGuard endpoint (UDP).
|
||||||
pub(crate) endpoint: SocketAddr,
|
pub(crate) endpoint: SocketAddr,
|
||||||
/// Maps virtual ports to the corresponding IP packet dispatcher.
|
/// Maps virtual ports to the corresponding IP packet dispatcher.
|
||||||
virtual_port_ip_tx: lockfree::map::Map<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
|
virtual_port_ip_tx: dashmap::DashMap<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||||
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
|
/// IP packet dispatcher for unroutable packets. `None` if not initialized.
|
||||||
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
|
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
|
||||||
}
|
}
|
||||||
|
|
@ -41,7 +41,7 @@ impl WireGuardTunnel {
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
|
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
|
||||||
let endpoint = config.endpoint_addr;
|
let endpoint = config.endpoint_addr;
|
||||||
let virtual_port_ip_tx = lockfree::map::Map::new();
|
let virtual_port_ip_tx = Default::default();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
source_peer_ip,
|
source_peer_ip,
|
||||||
|
|
@ -89,8 +89,8 @@ impl WireGuardTunnel {
|
||||||
&self,
|
&self,
|
||||||
virtual_port: VirtualPort,
|
virtual_port: VirtualPort,
|
||||||
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
|
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
|
||||||
let existing = self.virtual_port_ip_tx.get(&virtual_port);
|
let existing = self.virtual_port_ip_tx.contains_key(&virtual_port);
|
||||||
if existing.is_some() {
|
if existing {
|
||||||
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
|
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
|
||||||
} else {
|
} else {
|
||||||
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
|
||||||
|
|
@ -215,7 +215,7 @@ impl WireGuardTunnel {
|
||||||
RouteResult::Dispatch(port) => {
|
RouteResult::Dispatch(port) => {
|
||||||
let sender = self.virtual_port_ip_tx.get(&port);
|
let sender = self.virtual_port_ip_tx.get(&port);
|
||||||
if let Some(sender_guard) = sender {
|
if let Some(sender_guard) = sender {
|
||||||
let sender = sender_guard.val();
|
let sender = sender_guard.value();
|
||||||
match sender.send(packet.to_vec()).await {
|
match sender.send(packet.to_vec()).await {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
trace!(
|
trace!(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue