diff --git a/Cargo.lock b/Cargo.lock index d106561..495b8db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,6 +120,12 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +[[package]] +name = "bytes" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" + [[package]] name = "cc" version = "1.0.71" @@ -388,6 +394,15 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "lockfree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74ee94b5ad113c7cb98c5a040f783d0952ee4fe100993881d1673c2cb002dd23" +dependencies = [ + "owned-alloc", +] + [[package]] name = "log" version = "0.4.14" @@ -419,6 +434,37 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mio" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16" +dependencies = [ + "libc", + "log", + "miow", + "ntapi", + "winapi", +] + +[[package]] +name = "miow" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" +dependencies = [ + "winapi", +] + +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -472,11 +518,19 @@ dependencies = [ "clap", "crossbeam-channel", "dashmap", + "lockfree", "log", "pretty_env_logger", "smoltcp", + "tokio", ] +[[package]] +name = "owned-alloc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30fceb411f9a12ff9222c5f824026be368ff15dc2f13468d850c7d3f502205d6" + [[package]] name = "parking_lot" version = "0.11.2" @@ -502,6 +556,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "pin-project-lite" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" + [[package]] name = "pretty_env_logger" version = "0.3.1" @@ -615,6 +675,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + [[package]] name = "slog" version = "2.7.0" @@ -724,6 +793,37 @@ dependencies = [ "winapi", ] +[[package]] +name = "tokio" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc" +dependencies = [ + "autocfg", + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "once_cell", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "tokio-macros", + "winapi", +] + +[[package]] +name = "tokio-macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2dd85aeaba7b68df939bd357c6afb36c87951be9e80bf9c859f2fc3e9fca0fd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-width" version = "0.1.9" diff --git a/Cargo.toml b/Cargo.toml index 65d1196..6edccb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,5 @@ anyhow = "1" crossbeam-channel = "0.5" smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" } dashmap = "4.0.2" +tokio = { version = "1", features = ["full"] } +lockfree = "0.5.1" diff --git a/src/_main.rs b/src/_main.rs new file mode 100644 index 0000000..0dfe75d --- /dev/null +++ b/src/_main.rs @@ -0,0 +1,540 @@ +#[macro_use] +extern crate log; + +use std::collections::HashMap; +use std::io::{Read, Write}; +use std::net::{ + IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket, +}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Barrier, Mutex, RwLock}; +use std::thread; +use std::time::Duration; + +use crate::client::ProxyClient; +use anyhow::Context; +use boringtun::device::peer::Peer; +use boringtun::noise::{Tunn, TunnResult}; +use crossbeam_channel::{Receiver, RecvError, Sender}; +use dashmap::DashMap; +use smoltcp::iface::InterfaceBuilder; +use smoltcp::phy::ChecksumCapabilities; +use smoltcp::socket::{SocketRef, SocketSet, TcpSocket, TcpSocketBuffer}; +use smoltcp::time::Instant; +use smoltcp::wire::{ + IpAddress, IpCidr, IpRepr, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, PrettyPrinter, +}; + +use crate::config::Config; +use crate::virtual_device::VirtualIpDevice; + +mod client; +mod config; +mod virtual_device; + +const MAX_PACKET: usize = 65536; + +fn main() -> anyhow::Result<()> { + pretty_env_logger::init_custom_env("ONETUN_LOG"); + let config = Config::from_args().with_context(|| "Failed to read config")?; + debug!("Parsed arguments: {:?}", config); + + info!( + "Tunnelling [{}]->[{}] (via [{}] as peer {})", + &config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip + ); + + let source_peer_ip = config.source_peer_ip; + let dest_addr_ip = config.dest_addr.ip(); + let dest_addr_port = config.dest_addr.port(); + let endpoint_addr = config.endpoint_addr; + + // tx/rx for unencrypted IP packets to send through wireguard tunnel + let (send_to_real_server_tx, send_to_real_server_rx) = + crossbeam_channel::unbounded::>(); + + // tx/rx for decrypted IP packets that were received through wireguard tunnel + let (send_to_virtual_interface_tx, send_to_virtual_interface_rx) = + crossbeam_channel::unbounded::>(); + + // Initialize peer based on config + let peer = Tunn::new( + config.private_key.clone(), + config.endpoint_public_key.clone(), + None, + None, + 0, + None, + ) + .map_err(|s| anyhow::anyhow!("{}", s)) + .with_context(|| "Failed to initialize peer")?; + + let peer = Arc::new(peer); + + let endpoint_socket = + Arc::new(UdpSocket::bind("0.0.0.0:0").with_context(|| "Failed to bind endpoint socket")?); + + let (new_client_tx, new_client_rx) = crossbeam_channel::unbounded::(); + let (dead_client_tx, dead_client_rx) = crossbeam_channel::unbounded::(); + + // tx/rx for IP packets the interface exchanged that should be filtered/routed + let (send_to_ip_filter_tx, send_to_ip_filter_rx) = crossbeam_channel::unbounded::>(); + + // Virtual interface thread + { + thread::spawn(move || { + // Virtual device: generated IP packets will be send to ip_tx, and IP packets that should be polled should be sent to given ip_rx. + let virtual_device = + VirtualIpDevice::new(send_to_ip_filter_tx, send_to_virtual_interface_rx.clone()); + + // Create a virtual interface that will generate the IP packets for us + let mut virtual_interface = InterfaceBuilder::new(virtual_device) + .ip_addrs([ + // Interface handles IP packets for the sender and recipient + IpCidr::new(IpAddress::from(source_peer_ip), 32), + IpCidr::new(IpAddress::from(dest_addr_ip), 32), + ]) + .any_ip(true) + .finalize(); + + // Server socket: this is a placeholder for the interface to route new connections to. + // TODO: Determine if we even need buffers here. + let server_socket = { + static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .listen((IpAddress::from(dest_addr_ip), dest_addr_port)) + .expect("Virtual server socket failed to listen"); + + socket + }; + + // Socket set: there is always 1 TCP socket for the server, and the rest are client sockets created over time. + let socket_set_entries = Vec::new(); + let mut socket_set = SocketSet::new(socket_set_entries); + socket_set.add(server_socket); + + // Gate socket_set behind RwLock so we can add clients in the background + let socket_set = Arc::new(RwLock::new(socket_set)); + let socket_set_1 = socket_set.clone(); + let socket_set_2 = socket_set.clone(); + + let client_port_to_handle = Arc::new(DashMap::new()); + let client_port_to_handle_1 = client_port_to_handle.clone(); + + let client_handle_to_client = Arc::new(DashMap::new()); + let client_handle_to_client_1 = client_handle_to_client.clone(); + let client_handle_to_client_2 = client_handle_to_client.clone(); + + // Checks if there are new clients to initialize, and adds them to the socket_set + thread::spawn(move || { + let socket_set = socket_set_1; + let client_handle_to_client = client_handle_to_client_1; + + loop { + let client = new_client_rx.recv().expect("failed to read new_client_rx"); + + // Create a virtual client socket for the client + let client_socket = { + static mut TCP_CLIENT_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + static mut TCP_CLIENT_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + let tcp_rx_buffer = + TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); + let tcp_tx_buffer = + TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .connect( + (IpAddress::from(dest_addr_ip), dest_addr_port), + (IpAddress::from(source_peer_ip), client.virtual_port), + ) + .expect("failed to connect virtual client"); + + socket + }; + + // Add to socket set: this makes the ip/port combination routable in the interface, so that IP packets + // received from WG actually go somewhere. + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to add new client"); + let client_handle = socket_set.add(client_socket); + + // Map the client handle by port so we can look it up later + client_port_to_handle.insert(client.virtual_port, client_handle); + client_handle_to_client.insert(client_handle, client); + } + }); + + // Checks if there are clients that disconnected, and removes them from the socket_set + thread::spawn(move || { + let dead_client_rx = dead_client_rx.clone(); + let socket_set = socket_set_2; + let client_handle_to_client = client_handle_to_client_2; + let client_port_to_handle = client_port_to_handle_1; + + loop { + let client_port = dead_client_rx + .recv() + .expect("failed to read dead_client_rx"); + + // Get handle, if any + let handle = client_port_to_handle.remove(&client_port); + + // Remove handle from socket set and from map (handle -> client def) + if let Some((_, handle)) = handle { + // Remove socket set + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to add new client"); + socket_set.remove(handle); + client_handle_to_client.remove(&handle); + debug!("Removed client from socket set: vport={}", client_port); + } + } + }); + + loop { + let loop_start = Instant::now(); + + // Poll virtual interface + // Note: minimize lock time on socket set so new clients can fit in + { + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to poll interface"); + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + debug!("Virtual interface polled and processed some packets"); + } + Err(e) => { + error!("Virtual interface poll error: {}", e); + } + _ => {} + } + } + + { + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to get client socket"); + // Process packets for each client + for x in client_handle_to_client.iter() { + let client_handle = x.key(); + let client_def = x.value(); + + let mut client_socket: SocketRef = + socket_set.get(*client_handle); + + // Send data received from the real client as the virtual client + if client_socket.can_send() { + while !client_def.data_rx.is_empty() { + let to_send = client_def + .data_rx + .recv() + .expect("failed to read from client data_rx channel"); + client_socket.send_slice(&to_send).expect("virtual client failed to send data as received from data_rx channel"); + } + } + + // Send data received by the virtual client to the real client + if client_socket.can_recv() { + let data = client_socket + .recv(|b| (b.len(), b.to_vec())) + .expect("virtual client failed to recv"); + client_def + .data_tx + .send(data) + .expect("failed to send data to client data_tx channel"); + } + } + } + + // Use poll_delay to know when is the next time to poll. + { + let socket_set = socket_set + .read() + .expect("failed to acquire read lock on socket_set to poll_delay"); + + match virtual_interface.poll_delay(&socket_set, loop_start) { + Some(smoltcp::time::Duration::ZERO) => {} + Some(delay) => { + thread::sleep(std::time::Duration::from_millis(delay.millis())) + } + _ => thread::sleep(std::time::Duration::from_millis(1)), + } + } + } + }); + } + + // Packet routing thread + // Filters packets sent by the virtual interface, so that only the ones that should be sent + // to the real server are. + thread::spawn(move || { + loop { + let recv = send_to_ip_filter_rx + .recv() + .expect("failed to read send_to_ip_filter_rx channel"); + let src_addr: IpAddr = match IpVersion::of_packet(&recv) { + Ok(v) => match v { + IpVersion::Ipv4 => { + match Ipv4Repr::parse( + &Ipv4Packet::new_unchecked(&recv), + &ChecksumCapabilities::ignored(), + ) { + Ok(packet) => Ipv4Addr::from(packet.src_addr).into(), + Err(e) => { + error!("Unable to determine source IPv4 from packet: {}", e); + continue; + } + } + } + IpVersion::Ipv6 => match Ipv6Repr::parse(&Ipv6Packet::new_unchecked(&recv)) { + Ok(packet) => Ipv6Addr::from(packet.src_addr).into(), + Err(e) => { + error!("Unable to determine source IPv6 from packet: {}", e); + continue; + } + }, + _ => { + error!("Unable to determine IP version from packet: unspecified",); + continue; + } + }, + Err(e) => { + error!("Unable to determine IP version from packet: {}", e); + continue; + } + }; + if src_addr == source_peer_ip { + debug!( + "IP packet: {} bytes from {} to send to WG", + recv.len(), + src_addr + ); + // Add to queue to be encapsulated and sent by other thread + send_to_real_server_tx + .send(recv) + .expect("failed to write to send_to_real_server_tx channel"); + } + } + }); + + // Thread that encapsulates and sends WG packets + { + let peer = peer.clone(); + let endpoint_socket = endpoint_socket.clone(); + + thread::spawn(move || { + let peer = peer.clone(); + loop { + let mut send_buf = [0u8; MAX_PACKET]; + match send_to_real_server_rx.recv() { + Ok(next) => match peer.encapsulate(next.as_slice(), &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + endpoint_socket + .send_to(packet, endpoint_addr) + .expect("failed to send packet to wg endpoint"); + debug!("Sent {} bytes through WG (encryped)", packet.len()); + } + TunnResult::Err(e) => { + error!("Failed to encapsulate: {:?}", e); + } + TunnResult::Done => { + // Ignored + } + other => { + error!("Unexpected TunnResult during encapsulation: {:?}", other); + } + }, + Err(e) => { + error!( + "Failed to consume from send_to_real_server_rx channel: {}", + e + ); + } + } + } + }); + } + + // Thread that decapulates WG IP packets and feeds them to the interface + { + let peer = peer.clone(); + let endpoint_socket = endpoint_socket.clone(); + + thread::spawn(move || loop { + // Listen on the network + let mut recv_buf = [0u8; MAX_PACKET]; + let mut send_buf = [0u8; MAX_PACKET]; + + let n = match endpoint_socket.recv(&mut recv_buf) { + Ok(n) => n, + Err(e) => { + error!("Failed to read from endpoint socket: {}", e); + break; + } + }; + + let data = &recv_buf[..n]; + match peer.decapsulate(None, data, &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + endpoint_socket + .send_to(packet, endpoint_addr) + .expect("failed to send packet to wg endpoint"); + loop { + let mut send_buf = [0u8; MAX_PACKET]; + match peer.decapsulate(None, &[], &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + endpoint_socket + .send_to(packet, endpoint_addr) + .expect("failed to send packet to wg endpoint"); + } + _ => { + break; + } + } + } + } + TunnResult::WriteToTunnelV4(packet, _) => { + debug!( + "Got {} bytes to send back to virtual interface", + packet.len() + ); + + // For debugging purposes: parse packet + { + match IpVersion::of_packet(&packet) { + Ok(IpVersion::Ipv4) => trace!( + "IPv4 packet received: {}", + PrettyPrinter::>::new("", &packet) + ), + Ok(IpVersion::Ipv6) => trace!( + "IPv6 packet received: {}", + PrettyPrinter::>::new("", &packet) + ), + _ => {} + } + } + + send_to_virtual_interface_tx + .send(packet.to_vec()) + .expect("failed to queue received wg packet"); + } + TunnResult::WriteToTunnelV6(packet, _) => { + debug!( + "Got {} bytes to send back to virtual interface", + packet.len() + ); + send_to_virtual_interface_tx + .send(packet.to_vec()) + .expect("failed to queue received wg packet"); + } + _ => {} + } + }); + } + + // Maintenance thread + { + let peer = peer.clone(); + let endpoint_socket = endpoint_socket.clone(); + + thread::spawn(move || loop { + let mut send_buf = [0u8; MAX_PACKET]; + match peer.update_timers(&mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + debug!("Sending maintenance message: {} bytes", packet.len()); + endpoint_socket + .send_to(packet, endpoint_addr) + .expect("failed to send maintenance packet to endpoint address"); + } + _ => {} + } + + thread::sleep(Duration::from_millis(200)); + }); + } + + let proxy_listener = TcpListener::bind(config.source_addr).unwrap(); + for client_stream in proxy_listener.incoming() { + client_stream + .map(|client_stream| { + let dead_client_tx = dead_client_tx.clone(); + + // Pick a port + // TODO: Pool + let port = 60000; + let (data_to_read_tx, data_to_read_rx) = crossbeam_channel::unbounded::>(); + let (data_to_send_tx, data_to_send_rx) = crossbeam_channel::unbounded::>(); + + let client_addr = client_stream + .peer_addr() + .expect("client has no peer address"); + info!("[{}] Incoming connection from {}", port, client_addr); + + let client = ProxyClient { + virtual_port: port, + data_tx: data_to_send_tx, + data_rx: data_to_read_rx, + }; + + // Register the new client with the virtual interface + new_client_tx.send(client.clone()).expect("failed to notify virtual interface of new client"); + + // Reads data from the client + thread::spawn(move || { + let mut client_stream = client_stream; + + // todo: change this to tokio? + client_stream + .set_nonblocking(true) + .expect("failed to set nonblocking"); + + loop { + let mut buffer = [0; MAX_PACKET]; + let read = client_stream.read(&mut buffer); + match read { + Ok(size) if size == 0 => { + info!("[{}] Connection closed by client: {}", port, client_addr); + break; + } + Ok(size) => { + debug!("[{}] Data received from client: {} bytes", port, size); + let data = &buffer[..size]; + data_to_read_tx.send(data.to_vec()) + .unwrap_or_else(|e| error!("[{}] failed to send data to data_to_read_tx channel as received from client: {}", port, e)); + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // Ignore and continue + } + Err(e) => { + warn!("[{}] Connection error: {}", port, e); + break; + } + } + + while !data_to_send_rx.is_empty() { + let recv = data_to_send_rx.recv().expect("failed to read data_to_send_rx"); + client_stream + .write(recv.as_slice()) + .unwrap_or_else(|e| { + error!("[{}] failed to send write to client stream: {}", port, e); + 0 + }); + } + } + + dead_client_tx.send(port).expect("failed to send to dead_client_tx channel"); + }); + }) + .unwrap_or_else(|e| error!("{:?}", e)); + } + Ok(()) +} diff --git a/src/config.rs b/src/config.rs index 3c67a0d..b559e82 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,7 +6,7 @@ use boringtun::crypto::{X25519PublicKey, X25519SecretKey}; use clap::{App, Arg}; #[derive(Clone, Debug)] -pub(crate) struct Config { +pub struct Config { pub(crate) source_addr: SocketAddr, pub(crate) dest_addr: SocketAddr, pub(crate) private_key: Arc, diff --git a/src/main.rs b/src/main.rs index 0dfe75d..9a41e18 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,540 +1,135 @@ #[macro_use] extern crate log; -use std::collections::HashMap; -use std::io::{Read, Write}; -use std::net::{ - IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket, -}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Barrier, Mutex, RwLock}; -use std::thread; -use std::time::Duration; +use std::net::SocketAddr; +use std::sync::Arc; -use crate::client::ProxyClient; use anyhow::Context; -use boringtun::device::peer::Peer; -use boringtun::noise::{Tunn, TunnResult}; -use crossbeam_channel::{Receiver, RecvError, Sender}; -use dashmap::DashMap; -use smoltcp::iface::InterfaceBuilder; -use smoltcp::phy::ChecksumCapabilities; -use smoltcp::socket::{SocketRef, SocketSet, TcpSocket, TcpSocketBuffer}; -use smoltcp::time::Instant; -use smoltcp::wire::{ - IpAddress, IpCidr, IpRepr, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, PrettyPrinter, -}; +use tokio::io::Interest; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; use crate::config::Config; -use crate::virtual_device::VirtualIpDevice; +use crate::port_pool::PortPool; -mod client; -mod config; -mod virtual_device; +pub mod client; +pub mod config; +pub mod port_pool; +pub mod virtual_device; +pub mod wg; -const MAX_PACKET: usize = 65536; +pub const MAX_PACKET: usize = 65536; -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { pretty_env_logger::init_custom_env("ONETUN_LOG"); let config = Config::from_args().with_context(|| "Failed to read config")?; - debug!("Parsed arguments: {:?}", config); + let peer = Arc::new(wg::create_tunnel(&config)?); + let port_pool = Arc::new(PortPool::new()); + + // endpoint_addr: The address of the public WireGuard endpoint; UDP. + let endpoint_addr = config.endpoint_addr; + + // wireguard_udp: The UDP socket used to communicate with the public WireGuard endpoint. + let wireguard_udp = UdpSocket::bind("0.0.0.0:0") + .await + .with_context(|| "Failed to create UDP socket for WireGuard connection")?; + let wireguard_udp = Arc::new(wireguard_udp); + + // Start routine task for WireGuard + tokio::spawn( + async move { wg::routine(peer.clone(), wireguard_udp.clone(), endpoint_addr).await }, + ); info!( "Tunnelling [{}]->[{}] (via [{}] as peer {})", &config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip ); - let source_peer_ip = config.source_peer_ip; - let dest_addr_ip = config.dest_addr.ip(); - let dest_addr_port = config.dest_addr.port(); - let endpoint_addr = config.endpoint_addr; + tcp_proxy_server(config.source_addr.clone(), port_pool.clone()).await +} - // tx/rx for unencrypted IP packets to send through wireguard tunnel - let (send_to_real_server_tx, send_to_real_server_rx) = - crossbeam_channel::unbounded::>(); +/// Starts the server that listens on TCP connections. +async fn tcp_proxy_server(listen_addr: SocketAddr, port_pool: Arc) -> anyhow::Result<()> { + let listener = TcpListener::bind(listen_addr) + .await + .with_context(|| "Failed to listen on TCP proxy server")?; - // tx/rx for decrypted IP packets that were received through wireguard tunnel - let (send_to_virtual_interface_tx, send_to_virtual_interface_rx) = - crossbeam_channel::unbounded::>(); + loop { + let port_pool = port_pool.clone(); + let (socket, peer_addr) = listener + .accept() + .await + .with_context(|| "Failed to accept connection on TCP proxy server")?; - // Initialize peer based on config - let peer = Tunn::new( - config.private_key.clone(), - config.endpoint_public_key.clone(), - None, - None, - 0, - None, - ) - .map_err(|s| anyhow::anyhow!("{}", s)) - .with_context(|| "Failed to initialize peer")?; - - let peer = Arc::new(peer); - - let endpoint_socket = - Arc::new(UdpSocket::bind("0.0.0.0:0").with_context(|| "Failed to bind endpoint socket")?); - - let (new_client_tx, new_client_rx) = crossbeam_channel::unbounded::(); - let (dead_client_tx, dead_client_rx) = crossbeam_channel::unbounded::(); - - // tx/rx for IP packets the interface exchanged that should be filtered/routed - let (send_to_ip_filter_tx, send_to_ip_filter_rx) = crossbeam_channel::unbounded::>(); - - // Virtual interface thread - { - thread::spawn(move || { - // Virtual device: generated IP packets will be send to ip_tx, and IP packets that should be polled should be sent to given ip_rx. - let virtual_device = - VirtualIpDevice::new(send_to_ip_filter_tx, send_to_virtual_interface_rx.clone()); - - // Create a virtual interface that will generate the IP packets for us - let mut virtual_interface = InterfaceBuilder::new(virtual_device) - .ip_addrs([ - // Interface handles IP packets for the sender and recipient - IpCidr::new(IpAddress::from(source_peer_ip), 32), - IpCidr::new(IpAddress::from(dest_addr_ip), 32), - ]) - .any_ip(true) - .finalize(); - - // Server socket: this is a placeholder for the interface to route new connections to. - // TODO: Determine if we even need buffers here. - let server_socket = { - static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - - socket - .listen((IpAddress::from(dest_addr_ip), dest_addr_port)) - .expect("Virtual server socket failed to listen"); - - socket - }; - - // Socket set: there is always 1 TCP socket for the server, and the rest are client sockets created over time. - let socket_set_entries = Vec::new(); - let mut socket_set = SocketSet::new(socket_set_entries); - socket_set.add(server_socket); - - // Gate socket_set behind RwLock so we can add clients in the background - let socket_set = Arc::new(RwLock::new(socket_set)); - let socket_set_1 = socket_set.clone(); - let socket_set_2 = socket_set.clone(); - - let client_port_to_handle = Arc::new(DashMap::new()); - let client_port_to_handle_1 = client_port_to_handle.clone(); - - let client_handle_to_client = Arc::new(DashMap::new()); - let client_handle_to_client_1 = client_handle_to_client.clone(); - let client_handle_to_client_2 = client_handle_to_client.clone(); - - // Checks if there are new clients to initialize, and adds them to the socket_set - thread::spawn(move || { - let socket_set = socket_set_1; - let client_handle_to_client = client_handle_to_client_1; - - loop { - let client = new_client_rx.recv().expect("failed to read new_client_rx"); - - // Create a virtual client socket for the client - let client_socket = { - static mut TCP_CLIENT_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - static mut TCP_CLIENT_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - let tcp_rx_buffer = - TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); - let tcp_tx_buffer = - TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - - socket - .connect( - (IpAddress::from(dest_addr_ip), dest_addr_port), - (IpAddress::from(source_peer_ip), client.virtual_port), - ) - .expect("failed to connect virtual client"); - - socket - }; - - // Add to socket set: this makes the ip/port combination routable in the interface, so that IP packets - // received from WG actually go somewhere. - let mut socket_set = socket_set - .write() - .expect("failed to acquire lock on socket_set to add new client"); - let client_handle = socket_set.add(client_socket); - - // Map the client handle by port so we can look it up later - client_port_to_handle.insert(client.virtual_port, client_handle); - client_handle_to_client.insert(client_handle, client); - } - }); - - // Checks if there are clients that disconnected, and removes them from the socket_set - thread::spawn(move || { - let dead_client_rx = dead_client_rx.clone(); - let socket_set = socket_set_2; - let client_handle_to_client = client_handle_to_client_2; - let client_port_to_handle = client_port_to_handle_1; - - loop { - let client_port = dead_client_rx - .recv() - .expect("failed to read dead_client_rx"); - - // Get handle, if any - let handle = client_port_to_handle.remove(&client_port); - - // Remove handle from socket set and from map (handle -> client def) - if let Some((_, handle)) = handle { - // Remove socket set - let mut socket_set = socket_set - .write() - .expect("failed to acquire lock on socket_set to add new client"); - socket_set.remove(handle); - client_handle_to_client.remove(&handle); - debug!("Removed client from socket set: vport={}", client_port); - } - } - }); - - loop { - let loop_start = Instant::now(); - - // Poll virtual interface - // Note: minimize lock time on socket set so new clients can fit in - { - let mut socket_set = socket_set - .write() - .expect("failed to acquire lock on socket_set to poll interface"); - match virtual_interface.poll(&mut socket_set, loop_start) { - Ok(processed) if processed => { - debug!("Virtual interface polled and processed some packets"); - } - Err(e) => { - error!("Virtual interface poll error: {}", e); - } - _ => {} - } - } - - { - let mut socket_set = socket_set - .write() - .expect("failed to acquire lock on socket_set to get client socket"); - // Process packets for each client - for x in client_handle_to_client.iter() { - let client_handle = x.key(); - let client_def = x.value(); - - let mut client_socket: SocketRef = - socket_set.get(*client_handle); - - // Send data received from the real client as the virtual client - if client_socket.can_send() { - while !client_def.data_rx.is_empty() { - let to_send = client_def - .data_rx - .recv() - .expect("failed to read from client data_rx channel"); - client_socket.send_slice(&to_send).expect("virtual client failed to send data as received from data_rx channel"); - } - } - - // Send data received by the virtual client to the real client - if client_socket.can_recv() { - let data = client_socket - .recv(|b| (b.len(), b.to_vec())) - .expect("virtual client failed to recv"); - client_def - .data_tx - .send(data) - .expect("failed to send data to client data_tx channel"); - } - } - } - - // Use poll_delay to know when is the next time to poll. - { - let socket_set = socket_set - .read() - .expect("failed to acquire read lock on socket_set to poll_delay"); - - match virtual_interface.poll_delay(&socket_set, loop_start) { - Some(smoltcp::time::Duration::ZERO) => {} - Some(delay) => { - thread::sleep(std::time::Duration::from_millis(delay.millis())) - } - _ => thread::sleep(std::time::Duration::from_millis(1)), - } - } + // Assign a 'virtual port': this is a unique port number used to route IP packets + // received from the WireGuard tunnel. It is the port number that the virtual client will + // listen on. + let virtual_port = match port_pool.next() { + Ok(port) => port, + Err(e) => { + error!( + "Failed to assign virtual port number for connection [{}]: {:?}", + peer_addr, e + ); + continue; } + }; + + info!("[{}] Incoming connection from {}", virtual_port, peer_addr); + + tokio::spawn(async move { + let port_pool = Arc::clone(&port_pool); + let result = handle_tcp_proxy_connection(socket, virtual_port).await; + + if let Err(e) = result { + error!( + "[{}] Connection dropped un-gracefully: {:?}", + virtual_port, e + ); + } else { + info!("[{}] Connection closed by client", virtual_port); + } + + // Release port when connection drops + port_pool.release(virtual_port); }); } +} - // Packet routing thread - // Filters packets sent by the virtual interface, so that only the ones that should be sent - // to the real server are. - thread::spawn(move || { - loop { - let recv = send_to_ip_filter_rx - .recv() - .expect("failed to read send_to_ip_filter_rx channel"); - let src_addr: IpAddr = match IpVersion::of_packet(&recv) { - Ok(v) => match v { - IpVersion::Ipv4 => { - match Ipv4Repr::parse( - &Ipv4Packet::new_unchecked(&recv), - &ChecksumCapabilities::ignored(), - ) { - Ok(packet) => Ipv4Addr::from(packet.src_addr).into(), - Err(e) => { - error!("Unable to determine source IPv4 from packet: {}", e); - continue; - } - } - } - IpVersion::Ipv6 => match Ipv6Repr::parse(&Ipv6Packet::new_unchecked(&recv)) { - Ok(packet) => Ipv6Addr::from(packet.src_addr).into(), - Err(e) => { - error!("Unable to determine source IPv6 from packet: {}", e); - continue; - } - }, - _ => { - error!("Unable to determine IP version from packet: unspecified",); - continue; - } - }, - Err(e) => { - error!("Unable to determine IP version from packet: {}", e); +/// Handles a new TCP connection with its assigned virtual port. +async fn handle_tcp_proxy_connection(socket: TcpStream, virtual_port: u16) -> anyhow::Result<()> { + loop { + let ready = socket + .ready(Interest::READABLE | Interest::WRITABLE) + .await + .with_context(|| "Failed to wait for TCP proxy socket readiness")?; + + if ready.is_readable() { + let mut buffer = [0u8; MAX_PACKET]; + + match socket.try_read(&mut buffer) { + Ok(size) if size > 0 => { + let data = &buffer[..size]; + debug!( + "[{}] Read {} bytes of TCP data from real client", + virtual_port, size + ); + trace!("[{}] {:?}", virtual_port, data); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { continue; } - }; - if src_addr == source_peer_ip { - debug!( - "IP packet: {} bytes from {} to send to WG", - recv.len(), - src_addr - ); - // Add to queue to be encapsulated and sent by other thread - send_to_real_server_tx - .send(recv) - .expect("failed to write to send_to_real_server_tx channel"); + Err(e) => { + return Err(e).with_context(|| "Failed to read from real client TCP socket"); + } + _ => {} } } - }); - // Thread that encapsulates and sends WG packets - { - let peer = peer.clone(); - let endpoint_socket = endpoint_socket.clone(); + if ready.is_writable() {} - thread::spawn(move || { - let peer = peer.clone(); - loop { - let mut send_buf = [0u8; MAX_PACKET]; - match send_to_real_server_rx.recv() { - Ok(next) => match peer.encapsulate(next.as_slice(), &mut send_buf) { - TunnResult::WriteToNetwork(packet) => { - endpoint_socket - .send_to(packet, endpoint_addr) - .expect("failed to send packet to wg endpoint"); - debug!("Sent {} bytes through WG (encryped)", packet.len()); - } - TunnResult::Err(e) => { - error!("Failed to encapsulate: {:?}", e); - } - TunnResult::Done => { - // Ignored - } - other => { - error!("Unexpected TunnResult during encapsulation: {:?}", other); - } - }, - Err(e) => { - error!( - "Failed to consume from send_to_real_server_rx channel: {}", - e - ); - } - } - } - }); + if ready.is_read_closed() || ready.is_write_closed() { + return Ok(()); + } } - - // Thread that decapulates WG IP packets and feeds them to the interface - { - let peer = peer.clone(); - let endpoint_socket = endpoint_socket.clone(); - - thread::spawn(move || loop { - // Listen on the network - let mut recv_buf = [0u8; MAX_PACKET]; - let mut send_buf = [0u8; MAX_PACKET]; - - let n = match endpoint_socket.recv(&mut recv_buf) { - Ok(n) => n, - Err(e) => { - error!("Failed to read from endpoint socket: {}", e); - break; - } - }; - - let data = &recv_buf[..n]; - match peer.decapsulate(None, data, &mut send_buf) { - TunnResult::WriteToNetwork(packet) => { - endpoint_socket - .send_to(packet, endpoint_addr) - .expect("failed to send packet to wg endpoint"); - loop { - let mut send_buf = [0u8; MAX_PACKET]; - match peer.decapsulate(None, &[], &mut send_buf) { - TunnResult::WriteToNetwork(packet) => { - endpoint_socket - .send_to(packet, endpoint_addr) - .expect("failed to send packet to wg endpoint"); - } - _ => { - break; - } - } - } - } - TunnResult::WriteToTunnelV4(packet, _) => { - debug!( - "Got {} bytes to send back to virtual interface", - packet.len() - ); - - // For debugging purposes: parse packet - { - match IpVersion::of_packet(&packet) { - Ok(IpVersion::Ipv4) => trace!( - "IPv4 packet received: {}", - PrettyPrinter::>::new("", &packet) - ), - Ok(IpVersion::Ipv6) => trace!( - "IPv6 packet received: {}", - PrettyPrinter::>::new("", &packet) - ), - _ => {} - } - } - - send_to_virtual_interface_tx - .send(packet.to_vec()) - .expect("failed to queue received wg packet"); - } - TunnResult::WriteToTunnelV6(packet, _) => { - debug!( - "Got {} bytes to send back to virtual interface", - packet.len() - ); - send_to_virtual_interface_tx - .send(packet.to_vec()) - .expect("failed to queue received wg packet"); - } - _ => {} - } - }); - } - - // Maintenance thread - { - let peer = peer.clone(); - let endpoint_socket = endpoint_socket.clone(); - - thread::spawn(move || loop { - let mut send_buf = [0u8; MAX_PACKET]; - match peer.update_timers(&mut send_buf) { - TunnResult::WriteToNetwork(packet) => { - debug!("Sending maintenance message: {} bytes", packet.len()); - endpoint_socket - .send_to(packet, endpoint_addr) - .expect("failed to send maintenance packet to endpoint address"); - } - _ => {} - } - - thread::sleep(Duration::from_millis(200)); - }); - } - - let proxy_listener = TcpListener::bind(config.source_addr).unwrap(); - for client_stream in proxy_listener.incoming() { - client_stream - .map(|client_stream| { - let dead_client_tx = dead_client_tx.clone(); - - // Pick a port - // TODO: Pool - let port = 60000; - let (data_to_read_tx, data_to_read_rx) = crossbeam_channel::unbounded::>(); - let (data_to_send_tx, data_to_send_rx) = crossbeam_channel::unbounded::>(); - - let client_addr = client_stream - .peer_addr() - .expect("client has no peer address"); - info!("[{}] Incoming connection from {}", port, client_addr); - - let client = ProxyClient { - virtual_port: port, - data_tx: data_to_send_tx, - data_rx: data_to_read_rx, - }; - - // Register the new client with the virtual interface - new_client_tx.send(client.clone()).expect("failed to notify virtual interface of new client"); - - // Reads data from the client - thread::spawn(move || { - let mut client_stream = client_stream; - - // todo: change this to tokio? - client_stream - .set_nonblocking(true) - .expect("failed to set nonblocking"); - - loop { - let mut buffer = [0; MAX_PACKET]; - let read = client_stream.read(&mut buffer); - match read { - Ok(size) if size == 0 => { - info!("[{}] Connection closed by client: {}", port, client_addr); - break; - } - Ok(size) => { - debug!("[{}] Data received from client: {} bytes", port, size); - let data = &buffer[..size]; - data_to_read_tx.send(data.to_vec()) - .unwrap_or_else(|e| error!("[{}] failed to send data to data_to_read_tx channel as received from client: {}", port, e)); - } - Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { - // Ignore and continue - } - Err(e) => { - warn!("[{}] Connection error: {}", port, e); - break; - } - } - - while !data_to_send_rx.is_empty() { - let recv = data_to_send_rx.recv().expect("failed to read data_to_send_rx"); - client_stream - .write(recv.as_slice()) - .unwrap_or_else(|e| { - error!("[{}] failed to send write to client stream: {}", port, e); - 0 - }); - } - } - - dead_client_tx.send(port).expect("failed to send to dead_client_tx channel"); - }); - }) - .unwrap_or_else(|e| error!("{:?}", e)); - } - Ok(()) } diff --git a/src/port_pool.rs b/src/port_pool.rs new file mode 100644 index 0000000..b52dc84 --- /dev/null +++ b/src/port_pool.rs @@ -0,0 +1,31 @@ +use std::ops::Range; + +use anyhow::Context; + +const MIN_PORT: u16 = 32768; +const MAX_PORT: u16 = 60999; +const PORT_RANGE: Range = MIN_PORT..MAX_PORT; + +pub struct PortPool { + inner: lockfree::queue::Queue, +} + +impl PortPool { + pub fn new() -> Self { + let inner = lockfree::queue::Queue::default(); + PORT_RANGE.for_each(|p| inner.push(p) as ()); + Self { + inner, + } + } + + pub fn next(&self) -> anyhow::Result { + self.inner + .pop() + .with_context(|| "Virtual port pool is exhausted") + } + + pub fn release(&self, port: u16) { + self.inner.push(port); + } +} diff --git a/src/wg.rs b/src/wg.rs new file mode 100644 index 0000000..f1cafb1 --- /dev/null +++ b/src/wg.rs @@ -0,0 +1,64 @@ +use anyhow::Context; +use boringtun::noise::{Tunn, TunnResult}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::UdpSocket; + +use crate::config::Config; +use crate::MAX_PACKET; + +pub fn create_tunnel(config: &Config) -> anyhow::Result> { + Tunn::new( + config.private_key.clone(), + config.endpoint_public_key.clone(), + None, + None, + 0, + None, + ) + .map_err(|s| anyhow::anyhow!("{}", s)) + .with_context(|| "Failed to initialize peer") +} + +/// WireGuard Routine task. Handles Handshake, keep-alive, etc. +pub async fn routine( + peer: Arc>, + wireguard_udp: Arc, + endpoint_addr: SocketAddr, +) { + debug!("Started WireGuard routine thread"); + loop { + let mut send_buf = [0u8; MAX_PACKET]; + match peer.update_timers(&mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + debug!( + "Sending routine packet of {} bytes to WireGuard endpoint", + packet.len() + ); + match wireguard_udp.send_to(packet, endpoint_addr).await { + Ok(_) => {} + Err(e) => { + error!( + "Failed to send routine packet to WireGuard endpoint: {:?}", + e + ); + } + }; + } + TunnResult::Err(e) => { + error!( + "Failed to prepare routine packet for WireGuard endpoint: {:?}", + e + ); + } + TunnResult::Done => { + // Sleep for a bit + tokio::time::sleep(Duration::from_millis(100)).await; + } + other => { + warn!("Unexpected WireGuard routine task state: {:?}", other); + } + } + } +}