From f2e6ec1e3f18065c22a995708b0c4e7faf6a2d4a Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 13 Oct 2021 03:19:56 -0400 Subject: [PATCH] One interface/server for all. Doesn't quite work for repeats. --- Cargo.lock | 21 ++ Cargo.toml | 1 + src/client.rs | 9 + src/main.rs | 501 +++++++++++++++++++++++++----------------- src/virtual_device.rs | 6 +- 5 files changed, 335 insertions(+), 203 deletions(-) create mode 100644 src/client.rs diff --git a/Cargo.lock b/Cargo.lock index 3edfa80..d106561 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -206,6 +206,16 @@ dependencies = [ "libc", ] +[[package]] +name = "dashmap" +version = "4.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e77a43b28d0668df09411cb0bc9a8c2adc40f9a048afe863e05fd43251e8e39c" +dependencies = [ + "cfg-if", + "num_cpus", +] + [[package]] name = "dirs-next" version = "2.0.0" @@ -428,6 +438,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.26.2" @@ -451,6 +471,7 @@ dependencies = [ "boringtun", "clap", "crossbeam-channel", + "dashmap", "log", "pretty_env_logger", "smoltcp", diff --git a/Cargo.toml b/Cargo.toml index 603c457..65d1196 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,4 @@ pretty_env_logger = "0.3" anyhow = "1" crossbeam-channel = "0.5" smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" } +dashmap = "4.0.2" diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..8a5afe7 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,9 @@ +#[derive(Clone)] +pub struct ProxyClient { + /// Unique identifier for this client (used as a port number in the virtual interface). + pub virtual_port: u16, + /// For sending data to the client. + pub data_tx: crossbeam_channel::Sender>, + /// For receiving data from the client. + pub data_rx: crossbeam_channel::Receiver>, +} diff --git a/src/main.rs b/src/main.rs index 71a2918..0dfe75d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,23 +11,24 @@ use std::sync::{Arc, Barrier, Mutex, RwLock}; use std::thread; use std::time::Duration; +use crate::client::ProxyClient; use anyhow::Context; -use boringtun::crypto::{X25519PublicKey, X25519SecretKey}; use boringtun::device::peer::Peer; use boringtun::noise::{Tunn, TunnResult}; -use clap::{App, Arg}; use crossbeam_channel::{Receiver, RecvError, Sender}; +use dashmap::DashMap; use smoltcp::iface::InterfaceBuilder; use smoltcp::phy::ChecksumCapabilities; use smoltcp::socket::{SocketRef, SocketSet, TcpSocket, TcpSocketBuffer}; use smoltcp::time::Instant; use smoltcp::wire::{ - IpAddress, IpCidr, IpRepr, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, + IpAddress, IpCidr, IpRepr, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, PrettyPrinter, }; use crate::config::Config; use crate::virtual_device::VirtualIpDevice; +mod client; mod config; mod virtual_device; @@ -73,6 +74,258 @@ fn main() -> anyhow::Result<()> { let endpoint_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").with_context(|| "Failed to bind endpoint socket")?); + let (new_client_tx, new_client_rx) = crossbeam_channel::unbounded::(); + let (dead_client_tx, dead_client_rx) = crossbeam_channel::unbounded::(); + + // tx/rx for IP packets the interface exchanged that should be filtered/routed + let (send_to_ip_filter_tx, send_to_ip_filter_rx) = crossbeam_channel::unbounded::>(); + + // Virtual interface thread + { + thread::spawn(move || { + // Virtual device: generated IP packets will be send to ip_tx, and IP packets that should be polled should be sent to given ip_rx. + let virtual_device = + VirtualIpDevice::new(send_to_ip_filter_tx, send_to_virtual_interface_rx.clone()); + + // Create a virtual interface that will generate the IP packets for us + let mut virtual_interface = InterfaceBuilder::new(virtual_device) + .ip_addrs([ + // Interface handles IP packets for the sender and recipient + IpCidr::new(IpAddress::from(source_peer_ip), 32), + IpCidr::new(IpAddress::from(dest_addr_ip), 32), + ]) + .any_ip(true) + .finalize(); + + // Server socket: this is a placeholder for the interface to route new connections to. + // TODO: Determine if we even need buffers here. + let server_socket = { + static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .listen((IpAddress::from(dest_addr_ip), dest_addr_port)) + .expect("Virtual server socket failed to listen"); + + socket + }; + + // Socket set: there is always 1 TCP socket for the server, and the rest are client sockets created over time. + let socket_set_entries = Vec::new(); + let mut socket_set = SocketSet::new(socket_set_entries); + socket_set.add(server_socket); + + // Gate socket_set behind RwLock so we can add clients in the background + let socket_set = Arc::new(RwLock::new(socket_set)); + let socket_set_1 = socket_set.clone(); + let socket_set_2 = socket_set.clone(); + + let client_port_to_handle = Arc::new(DashMap::new()); + let client_port_to_handle_1 = client_port_to_handle.clone(); + + let client_handle_to_client = Arc::new(DashMap::new()); + let client_handle_to_client_1 = client_handle_to_client.clone(); + let client_handle_to_client_2 = client_handle_to_client.clone(); + + // Checks if there are new clients to initialize, and adds them to the socket_set + thread::spawn(move || { + let socket_set = socket_set_1; + let client_handle_to_client = client_handle_to_client_1; + + loop { + let client = new_client_rx.recv().expect("failed to read new_client_rx"); + + // Create a virtual client socket for the client + let client_socket = { + static mut TCP_CLIENT_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + static mut TCP_CLIENT_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; + let tcp_rx_buffer = + TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); + let tcp_tx_buffer = + TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .connect( + (IpAddress::from(dest_addr_ip), dest_addr_port), + (IpAddress::from(source_peer_ip), client.virtual_port), + ) + .expect("failed to connect virtual client"); + + socket + }; + + // Add to socket set: this makes the ip/port combination routable in the interface, so that IP packets + // received from WG actually go somewhere. + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to add new client"); + let client_handle = socket_set.add(client_socket); + + // Map the client handle by port so we can look it up later + client_port_to_handle.insert(client.virtual_port, client_handle); + client_handle_to_client.insert(client_handle, client); + } + }); + + // Checks if there are clients that disconnected, and removes them from the socket_set + thread::spawn(move || { + let dead_client_rx = dead_client_rx.clone(); + let socket_set = socket_set_2; + let client_handle_to_client = client_handle_to_client_2; + let client_port_to_handle = client_port_to_handle_1; + + loop { + let client_port = dead_client_rx + .recv() + .expect("failed to read dead_client_rx"); + + // Get handle, if any + let handle = client_port_to_handle.remove(&client_port); + + // Remove handle from socket set and from map (handle -> client def) + if let Some((_, handle)) = handle { + // Remove socket set + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to add new client"); + socket_set.remove(handle); + client_handle_to_client.remove(&handle); + debug!("Removed client from socket set: vport={}", client_port); + } + } + }); + + loop { + let loop_start = Instant::now(); + + // Poll virtual interface + // Note: minimize lock time on socket set so new clients can fit in + { + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to poll interface"); + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + debug!("Virtual interface polled and processed some packets"); + } + Err(e) => { + error!("Virtual interface poll error: {}", e); + } + _ => {} + } + } + + { + let mut socket_set = socket_set + .write() + .expect("failed to acquire lock on socket_set to get client socket"); + // Process packets for each client + for x in client_handle_to_client.iter() { + let client_handle = x.key(); + let client_def = x.value(); + + let mut client_socket: SocketRef = + socket_set.get(*client_handle); + + // Send data received from the real client as the virtual client + if client_socket.can_send() { + while !client_def.data_rx.is_empty() { + let to_send = client_def + .data_rx + .recv() + .expect("failed to read from client data_rx channel"); + client_socket.send_slice(&to_send).expect("virtual client failed to send data as received from data_rx channel"); + } + } + + // Send data received by the virtual client to the real client + if client_socket.can_recv() { + let data = client_socket + .recv(|b| (b.len(), b.to_vec())) + .expect("virtual client failed to recv"); + client_def + .data_tx + .send(data) + .expect("failed to send data to client data_tx channel"); + } + } + } + + // Use poll_delay to know when is the next time to poll. + { + let socket_set = socket_set + .read() + .expect("failed to acquire read lock on socket_set to poll_delay"); + + match virtual_interface.poll_delay(&socket_set, loop_start) { + Some(smoltcp::time::Duration::ZERO) => {} + Some(delay) => { + thread::sleep(std::time::Duration::from_millis(delay.millis())) + } + _ => thread::sleep(std::time::Duration::from_millis(1)), + } + } + } + }); + } + + // Packet routing thread + // Filters packets sent by the virtual interface, so that only the ones that should be sent + // to the real server are. + thread::spawn(move || { + loop { + let recv = send_to_ip_filter_rx + .recv() + .expect("failed to read send_to_ip_filter_rx channel"); + let src_addr: IpAddr = match IpVersion::of_packet(&recv) { + Ok(v) => match v { + IpVersion::Ipv4 => { + match Ipv4Repr::parse( + &Ipv4Packet::new_unchecked(&recv), + &ChecksumCapabilities::ignored(), + ) { + Ok(packet) => Ipv4Addr::from(packet.src_addr).into(), + Err(e) => { + error!("Unable to determine source IPv4 from packet: {}", e); + continue; + } + } + } + IpVersion::Ipv6 => match Ipv6Repr::parse(&Ipv6Packet::new_unchecked(&recv)) { + Ok(packet) => Ipv6Addr::from(packet.src_addr).into(), + Err(e) => { + error!("Unable to determine source IPv6 from packet: {}", e); + continue; + } + }, + _ => { + error!("Unable to determine IP version from packet: unspecified",); + continue; + } + }, + Err(e) => { + error!("Unable to determine IP version from packet: {}", e); + continue; + } + }; + if src_addr == source_peer_ip { + debug!( + "IP packet: {} bytes from {} to send to WG", + recv.len(), + src_addr + ); + // Add to queue to be encapsulated and sent by other thread + send_to_real_server_tx + .send(recv) + .expect("failed to write to send_to_real_server_tx channel"); + } + } + }); + // Thread that encapsulates and sends WG packets { let peer = peer.clone(); @@ -88,6 +341,7 @@ fn main() -> anyhow::Result<()> { endpoint_socket .send_to(packet, endpoint_addr) .expect("failed to send packet to wg endpoint"); + debug!("Sent {} bytes through WG (encryped)", packet.len()); } TunnResult::Err(e) => { error!("Failed to encapsulate: {:?}", e); @@ -109,6 +363,8 @@ fn main() -> anyhow::Result<()> { } }); } + + // Thread that decapulates WG IP packets and feeds them to the interface { let peer = peer.clone(); let endpoint_socket = endpoint_socket.clone(); @@ -147,13 +403,35 @@ fn main() -> anyhow::Result<()> { } } TunnResult::WriteToTunnelV4(packet, _) => { - debug!("Got {} bytes to send back to client", packet.len()); + debug!( + "Got {} bytes to send back to virtual interface", + packet.len() + ); + + // For debugging purposes: parse packet + { + match IpVersion::of_packet(&packet) { + Ok(IpVersion::Ipv4) => trace!( + "IPv4 packet received: {}", + PrettyPrinter::>::new("", &packet) + ), + Ok(IpVersion::Ipv6) => trace!( + "IPv6 packet received: {}", + PrettyPrinter::>::new("", &packet) + ), + _ => {} + } + } + send_to_virtual_interface_tx .send(packet.to_vec()) .expect("failed to queue received wg packet"); } TunnResult::WriteToTunnelV6(packet, _) => { - debug!("Got {} bytes to send back to client", packet.len()); + debug!( + "Got {} bytes to send back to virtual interface", + packet.len() + ); send_to_virtual_interface_tx .send(packet.to_vec()) .expect("failed to queue received wg packet"); @@ -188,115 +466,62 @@ fn main() -> anyhow::Result<()> { for client_stream in proxy_listener.incoming() { client_stream .map(|client_stream| { - let send_to_real_server_tx = send_to_real_server_tx.clone(); - let send_to_virtual_interface_rx = send_to_virtual_interface_rx.clone(); + let dead_client_tx = dead_client_tx.clone(); // Pick a port // TODO: Pool let port = 60000; + let (data_to_read_tx, data_to_read_rx) = crossbeam_channel::unbounded::>(); + let (data_to_send_tx, data_to_send_rx) = crossbeam_channel::unbounded::>(); let client_addr = client_stream .peer_addr() .expect("client has no peer address"); info!("[{}] Incoming connection from {}", port, client_addr); - // tx/rx for data received from the client - // this data is received - let (send_to_virtual_client_tx, send_to_virtual_client_rx) = crossbeam_channel::unbounded::>(); + let client = ProxyClient { + virtual_port: port, + data_tx: data_to_send_tx, + data_rx: data_to_read_rx, + }; - // tx/rx for packets received from the destination - // this data is received from the WG endpoint; the IP packets are routed using the port number - let (send_to_real_client_tx, send_to_real_client_rx) = crossbeam_channel::unbounded::>(); - - // tx/rx for IP packets the interface exchanged that should be filtered/routed - let (send_to_ip_filter_tx, send_to_ip_filter_rx) = crossbeam_channel::unbounded::>(); - - let stopped = Arc::new(AtomicBool::new(false)); - let stopped_1 = Arc::clone(&stopped); - let stopped_2 = Arc::clone(&stopped); + // Register the new client with the virtual interface + new_client_tx.send(client.clone()).expect("failed to notify virtual interface of new client"); // Reads data from the client thread::spawn(move || { - let stopped = stopped_1.clone(); - let mut client_stream = client_stream; + + // todo: change this to tokio? client_stream .set_nonblocking(true) .expect("failed to set nonblocking"); - loop { - if stopped.load(Ordering::Relaxed) { - break; - } + loop { let mut buffer = [0; MAX_PACKET]; let read = client_stream.read(&mut buffer); match read { Ok(size) if size == 0 => { info!("[{}] Connection closed by client: {}", port, client_addr); - stopped.store(true, Ordering::Relaxed); break; } Ok(size) => { debug!("[{}] Data received from client: {} bytes", port, size); let data = &buffer[..size]; - send_to_virtual_client_tx - .send(data.to_vec()) - .unwrap_or_else(|e| error!("[{}] failed to send data to client_received_tx channel as received from client: {}", port, e)); + data_to_read_tx.send(data.to_vec()) + .unwrap_or_else(|e| error!("[{}] failed to send data to data_to_read_tx channel as received from client: {}", port, e)); } Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { // Ignore and continue } Err(e) => { warn!("[{}] Connection error: {}", port, e); - stopped.store(true, Ordering::Relaxed); break; } } - while !send_to_ip_filter_rx.is_empty() { - let recv = send_to_ip_filter_rx.recv().expect("failed to read send_to_ip_filter_rx"); - let src_addr: IpAddr = match IpVersion::of_packet(&recv) { - Ok(v) => { - match v { - IpVersion::Ipv4 => { - match Ipv4Repr::parse(&Ipv4Packet::new_unchecked(&recv), &ChecksumCapabilities::ignored()) { - Ok(packet) => Ipv4Addr::from(packet.src_addr).into(), - Err(e) => { - error!("[{}] Unable to determine source IPv4 from packet: {}", port, e); - continue; - } - } - } - IpVersion::Ipv6 => { - match Ipv6Repr::parse(&Ipv6Packet::new_unchecked(&recv)) { - Ok(packet) => Ipv6Addr::from(packet.src_addr).into(), - Err(e) => { - error!("[{}] Unable to determine source IPv6 from packet: {}", port, e); - continue; - } - } - } - _ => { - error!("[{}] Unable to determine IP version from packet: unspecified", port); - continue; - } - } - } - Err(e) => { - error!("[{}] Unable to determine IP version from packet: {}", port, e); - continue; - } - }; - - if src_addr == source_peer_ip { - debug!("[{}] IP packet: {} bytes from {} to send to WG", port, recv.len(), src_addr); - // Add to queue to be encapsulated and sent by other thread - send_to_real_server_tx.send(recv).expect("failed to write to send_to_real_server_tx channel"); - } - } - - while !send_to_real_client_rx.is_empty() { - let recv = send_to_real_client_rx.recv().expect("failed to read destination_sent_rx"); + while !data_to_send_rx.is_empty() { + let recv = data_to_send_rx.recv().expect("failed to read data_to_send_rx"); client_stream .write(recv.as_slice()) .unwrap_or_else(|e| { @@ -305,131 +530,9 @@ fn main() -> anyhow::Result<()> { }); } } + + dead_client_tx.send(port).expect("failed to send to dead_client_tx channel"); }); - - // This thread simulates the IP-layer communication between the client and server. - // * When we get data from the 'real' client, we send it via the virtual client - // * When the virtual client sends data, it generates IP packets, which are captures via ip_rx/ip_tx - thread::spawn(move || { - let stopped = Arc::clone(&stopped_2); - - let server_socket = { - static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer) - }; - - let client_socket = { - static mut TCP_CLIENT_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - static mut TCP_CLIENT_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); - TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer) - }; - - let mut socket_set_entries: [_; 2] = Default::default(); - let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); - let server_handle = socket_set.add(server_socket); - let client_handle = socket_set.add(client_socket); - - // Virtual device - let device = VirtualIpDevice::new(send_to_ip_filter_tx, send_to_virtual_interface_rx.clone()); - - // Create a virtual interface to simulate TCP connection - let mut iface = InterfaceBuilder::new(device) - .ip_addrs([ - // Interface handles IP packets for the sender and recipient - IpCidr::new(IpAddress::from(source_peer_ip), 32), - IpCidr::new(IpAddress::from(dest_addr_ip), 32), - ]) - .any_ip(true) - .finalize(); - - // keeps track of whether the virtual clients needs to be initialized - let mut started = false; - - loop { - let loop_start = Instant::now(); - if stopped.load(Ordering::Relaxed) { - debug!("[{}] Killing virtual thread", port); - break; - } - - match iface.poll(&mut socket_set, loop_start) { - Ok(processed) => { - if processed { - debug!("[{}] virtual iface polled and processed some packets", port); - } - } - Err(e) => { - error!("[{}] virtual iface poll error: {:?}", port, e); - break; - } - } - - // Spawn a server socket so the virtual interface allows routing - // Note: the server socket is never read, since the IP packets are intercepted - // at the interface level. - { - let mut server_socket: SocketRef = socket_set.get(server_handle); - if !started { - // Open the virtual server socket - match server_socket.listen((IpAddress::from(dest_addr_ip), dest_addr_port)) { - Ok(_) => { - debug!("[{}] Virtual server listening: {}", port, server_socket.local_endpoint()); - } - Err(e) => { - error!("[{}] Virtual server failed to listen: {}", port, e); - break; - } - } - } - } - - // Virtual client - { - let mut client_socket: SocketRef = socket_set.get(client_handle); - if !started { - client_socket.connect( - (IpAddress::from(dest_addr_ip), dest_addr_port), - (IpAddress::from(source_peer_ip), port), - ) - .expect("failed to connect virtual client"); - debug!("[{}] Virtual client connected", port); - } - if client_socket.can_send() { - while !send_to_virtual_client_rx.is_empty() { - let to_send = send_to_virtual_client_rx.recv().expect("failed to read from client_received_rx channel"); - client_socket.send_slice(to_send.as_slice()).expect("virtual client failed to send data from channel"); - } - } - if client_socket.can_recv() { - let data = client_socket.recv(|b| (b.len(), b.to_vec())).expect("failed to recv"); - send_to_real_client_tx.send(data).expect("failed to send to channel send_to_real_client_tx"); - } - if !client_socket.is_open() { - warn!("[{}] Client socket is no longer open", port); - break; - } - } - - // After the first loop, the client and server have started - started = true; - - match iface.poll_delay(&socket_set, loop_start) { - Some(smoltcp::time::Duration::ZERO) => {} - Some(delay) => std::thread::sleep(std::time::Duration::from_millis(delay.millis())), - _ => {} - } - } - - // if this thread ends, end the other ones too - debug!("[{}] Virtual thread stopped", port); - stopped.store(true, Ordering::Relaxed); - }); - // * When the real destination sends IP packets (via WG endpoint), we send it via the device/interface }) .unwrap_or_else(|e| error!("{:?}", e)); } diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 392ac5e..24ddfe4 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,9 +1,7 @@ -use std::collections::VecDeque; - -use smoltcp::phy::{ChecksumCapabilities, Device, DeviceCapabilities, Medium}; +use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; -use smoltcp::wire::{Ipv4Packet, Ipv4Repr}; +#[derive(Clone)] pub struct VirtualIpDevice { /// Channel for packets sent by the interface. ip_tx: crossbeam_channel::Sender>,