diff --git a/src/main.rs b/src/main.rs index 649dafe..71a2918 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,10 +49,12 @@ fn main() -> anyhow::Result<()> { let endpoint_addr = config.endpoint_addr; // tx/rx for unencrypted IP packets to send through wireguard tunnel - let (wg_send_tx, wg_send_rx) = crossbeam_channel::unbounded::>(); + let (send_to_real_server_tx, send_to_real_server_rx) = + crossbeam_channel::unbounded::>(); - // tx/rx for unencrypted IP packets that were received through wireguard tunnel - let (wg_recv_tx, wg_recv_rx) = crossbeam_channel::unbounded::>(); + // tx/rx for decrypted IP packets that were received through wireguard tunnel + let (send_to_virtual_interface_tx, send_to_virtual_interface_rx) = + crossbeam_channel::unbounded::>(); // Initialize peer based on config let peer = Tunn::new( @@ -80,7 +82,7 @@ fn main() -> anyhow::Result<()> { let peer = peer.clone(); loop { let mut send_buf = [0u8; MAX_PACKET]; - match wg_send_rx.recv() { + match send_to_real_server_rx.recv() { Ok(next) => match peer.encapsulate(next.as_slice(), &mut send_buf) { TunnResult::WriteToNetwork(packet) => { endpoint_socket @@ -98,7 +100,10 @@ fn main() -> anyhow::Result<()> { } }, Err(e) => { - error!("Failed to consume from wg_send_rx channel: {}", e); + error!( + "Failed to consume from send_to_real_server_rx channel: {}", + e + ); } } } @@ -143,13 +148,13 @@ fn main() -> anyhow::Result<()> { } TunnResult::WriteToTunnelV4(packet, _) => { debug!("Got {} bytes to send back to client", packet.len()); - wg_recv_tx + send_to_virtual_interface_tx .send(packet.to_vec()) .expect("failed to queue received wg packet"); } TunnResult::WriteToTunnelV6(packet, _) => { debug!("Got {} bytes to send back to client", packet.len()); - wg_recv_tx + send_to_virtual_interface_tx .send(packet.to_vec()) .expect("failed to queue received wg packet"); } @@ -183,7 +188,8 @@ fn main() -> anyhow::Result<()> { for client_stream in proxy_listener.incoming() { client_stream .map(|client_stream| { - let wg_send_tx = wg_send_tx.clone(); + let send_to_real_server_tx = send_to_real_server_tx.clone(); + let send_to_virtual_interface_rx = send_to_virtual_interface_rx.clone(); // Pick a port // TODO: Pool @@ -196,14 +202,14 @@ fn main() -> anyhow::Result<()> { // tx/rx for data received from the client // this data is received - let (client_received_tx, client_received_rx) = crossbeam_channel::unbounded::>(); + let (send_to_virtual_client_tx, send_to_virtual_client_rx) = crossbeam_channel::unbounded::>(); // tx/rx for packets received from the destination // this data is received from the WG endpoint; the IP packets are routed using the port number - let (destination_sent_tx, destination_sent_rx) = crossbeam_channel::unbounded::>(); + let (send_to_real_client_tx, send_to_real_client_rx) = crossbeam_channel::unbounded::>(); - // tx/rx for packets the virtual client sent and that should be sent to the wg tunnel - let (ip_tx, ip_rx) = crossbeam_channel::unbounded::>(); + // tx/rx for IP packets the interface exchanged that should be filtered/routed + let (send_to_ip_filter_tx, send_to_ip_filter_rx) = crossbeam_channel::unbounded::>(); let stopped = Arc::new(AtomicBool::new(false)); let stopped_1 = Arc::clone(&stopped); @@ -233,7 +239,7 @@ fn main() -> anyhow::Result<()> { Ok(size) => { debug!("[{}] Data received from client: {} bytes", port, size); let data = &buffer[..size]; - client_received_tx + send_to_virtual_client_tx .send(data.to_vec()) .unwrap_or_else(|e| error!("[{}] failed to send data to client_received_tx channel as received from client: {}", port, e)); } @@ -247,8 +253,8 @@ fn main() -> anyhow::Result<()> { } } - while !ip_rx.is_empty() { - let recv = ip_rx.recv().expect("failed to read ip_rx"); + while !send_to_ip_filter_rx.is_empty() { + let recv = send_to_ip_filter_rx.recv().expect("failed to read send_to_ip_filter_rx"); let src_addr: IpAddr = match IpVersion::of_packet(&recv) { Ok(v) => { match v { @@ -285,12 +291,12 @@ fn main() -> anyhow::Result<()> { if src_addr == source_peer_ip { debug!("[{}] IP packet: {} bytes from {} to send to WG", port, recv.len(), src_addr); // Add to queue to be encapsulated and sent by other thread - wg_send_tx.send(recv).expect("failed to write to wg_send_tx channel"); + send_to_real_server_tx.send(recv).expect("failed to write to send_to_real_server_tx channel"); } } - while !destination_sent_rx.is_empty() { - let recv = destination_sent_rx.recv().expect("failed to read destination_sent_rx"); + while !send_to_real_client_rx.is_empty() { + let recv = send_to_real_client_rx.recv().expect("failed to read destination_sent_rx"); client_stream .write(recv.as_slice()) .unwrap_or_else(|e| { @@ -329,7 +335,7 @@ fn main() -> anyhow::Result<()> { let client_handle = socket_set.add(client_socket); // Virtual device - let device = VirtualIpDevice::new(ip_tx); + let device = VirtualIpDevice::new(send_to_ip_filter_tx, send_to_virtual_interface_rx.clone()); // Create a virtual interface to simulate TCP connection let mut iface = InterfaceBuilder::new(device) @@ -363,7 +369,9 @@ fn main() -> anyhow::Result<()> { } } - // Server socket polling + // Spawn a server socket so the virtual interface allows routing + // Note: the server socket is never read, since the IP packets are intercepted + // at the interface level. { let mut server_socket: SocketRef = socket_set.get(server_handle); if !started { @@ -378,22 +386,6 @@ fn main() -> anyhow::Result<()> { } } } - if server_socket.can_recv() { - let buffer = server_socket - .recv(|buffer| { (buffer.len(), buffer.to_vec()) }); - match buffer { - Ok(buffer) => { - debug!("[{}] Virtual server socket read: {} bytes", port, buffer.len()); - } - Err(e) => { - error!("[{}] Virtual server failed to read: {}", port, e); - break; - } - } - } - if server_socket.can_send() { - // TODO: See if this is actually useful - } } // Virtual client @@ -408,14 +400,14 @@ fn main() -> anyhow::Result<()> { debug!("[{}] Virtual client connected", port); } if client_socket.can_send() { - while !client_received_rx.is_empty() { - let to_send = client_received_rx.recv().expect("failed to read from client_received_rx channel"); + while !send_to_virtual_client_rx.is_empty() { + let to_send = send_to_virtual_client_rx.recv().expect("failed to read from client_received_rx channel"); client_socket.send_slice(to_send.as_slice()).expect("virtual client failed to send data from channel"); } } if client_socket.can_recv() { - // TODO: See if this is actually useful? - client_socket.recv(|b| (b.len(), 0)).expect("failed to recv"); + let data = client_socket.recv(|b| (b.len(), b.to_vec())).expect("failed to recv"); + send_to_real_client_tx.send(data).expect("failed to send to channel send_to_real_client_tx"); } if !client_socket.is_open() { warn!("[{}] Client socket is no longer open", port); diff --git a/src/virtual_device.rs b/src/virtual_device.rs index e251f4c..392ac5e 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,42 +1,45 @@ +use std::collections::VecDeque; + use smoltcp::phy::{ChecksumCapabilities, Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; use smoltcp::wire::{Ipv4Packet, Ipv4Repr}; -use std::collections::VecDeque; pub struct VirtualIpDevice { - queue: VecDeque>, - /// Sends IP packets + /// Channel for packets sent by the interface. ip_tx: crossbeam_channel::Sender>, + ip_rx: crossbeam_channel::Receiver>, } impl VirtualIpDevice { - pub fn new(ip_tx: crossbeam_channel::Sender>) -> Self { - Self { - queue: VecDeque::new(), - ip_tx, - } + pub fn new( + ip_tx: crossbeam_channel::Sender>, + ip_rx: crossbeam_channel::Receiver>, + ) -> Self { + Self { ip_tx, ip_rx } } } impl<'a> Device<'a> for VirtualIpDevice { type RxToken = RxToken; - type TxToken = TxToken<'a>; + type TxToken = TxToken; fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - self.queue.pop_front().map(move |buffer| { - let rx = RxToken { buffer }; - let tx = TxToken { - queue: &mut self.queue, - tx: Some(self.ip_tx.clone()), - }; - (rx, tx) - }) + if !self.ip_rx.is_empty() { + let buffer = self.ip_rx.recv().expect("failed to read ip_rx"); + Some(( + RxToken { buffer }, + TxToken { + ip_tx: self.ip_tx.clone(), + }, + )) + } else { + None + } } fn transmit(&'a mut self) -> Option { Some(TxToken { - queue: &mut self.queue, - tx: Some(self.ip_tx.clone()), + ip_tx: self.ip_tx.clone(), }) } @@ -63,12 +66,11 @@ impl smoltcp::phy::RxToken for RxToken { } #[doc(hidden)] -pub struct TxToken<'a> { - queue: &'a mut VecDeque>, - tx: Option>>, +pub struct TxToken { + ip_tx: crossbeam_channel::Sender>, } -impl<'a> smoltcp::phy::TxToken for TxToken<'a> { +impl<'a> smoltcp::phy::TxToken for TxToken { fn consume(self, _timestamp: Instant, len: usize, f: F) -> smoltcp::Result where F: FnOnce(&mut [u8]) -> smoltcp::Result, @@ -76,8 +78,9 @@ impl<'a> smoltcp::phy::TxToken for TxToken<'a> { let mut buffer = Vec::new(); buffer.resize(len, 0); let result = f(&mut buffer); - self.tx.map(|tx| tx.send(buffer.clone())); - self.queue.push_back(buffer); + self.ip_tx + .send(buffer.clone()) + .expect("failed to send to ip_tx"); result } }