diff --git a/Cargo.lock b/Cargo.lock index 542b73d..2b8ed7f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -779,7 +779,8 @@ checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" [[package]] name = "smoltcp" version = "0.8.0" -source = "git+https://github.com/smoltcp-rs/smoltcp?rev=25c539bb7c96789270f032ede2a967cf0fe5cf57#25c539bb7c96789270f032ede2a967cf0fe5cf57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2308a1657c8db1f5b4993bab4e620bdbe5623bd81f254cf60326767bb243237" dependencies = [ "bitflags", "byteorder", diff --git a/Cargo.toml b/Cargo.toml index 5efad48..50dd3c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ clap = { version = "2.33", default-features = false, features = ["suggestions"] log = "0.4" pretty_env_logger = "0.4" anyhow = "1" -smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", rev = "25c539bb7c96789270f032ede2a967cf0fe5cf57" } +smoltcp = "0.8.0" tokio = { version = "1", features = ["full"] } futures = "0.3.17" rand = "0.8.4" diff --git a/src/ip_sink.rs b/src/ip_sink.rs index f17eb21..a209da9 100644 --- a/src/ip_sink.rs +++ b/src/ip_sink.rs @@ -1,7 +1,6 @@ use crate::virtual_device::VirtualIpDevice; use crate::wg::WireGuardTunnel; use smoltcp::iface::InterfaceBuilder; -use smoltcp::socket::SocketSet; use std::sync::Arc; use tokio::time::Duration; @@ -13,13 +12,12 @@ pub async fn run_ip_sink_interface(wg: Arc) -> ! { .expect("Failed to initialize VirtualIpDevice for sink interface"); // No sockets on sink interface - let mut socket_set_entries: [_; 0] = Default::default(); - let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); - let mut virtual_interface = InterfaceBuilder::new(device).ip_addrs([]).finalize(); + let mut sockets: [_; 0] = Default::default(); + let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..]).ip_addrs([]).finalize(); loop { let loop_start = smoltcp::time::Instant::now(); - match virtual_interface.poll(&mut socket_set, loop_start) { + match virtual_interface.poll(loop_start) { Ok(processed) if processed => { trace!("[SINK] Virtual interface polled some packets to be processed",); tokio::time::sleep(Duration::from_millis(1)).await; diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index a632792..d34357a 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -5,7 +5,7 @@ use crate::wg::WireGuardTunnel; use anyhow::Context; use async_trait::async_trait; use smoltcp::iface::InterfaceBuilder; -use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState}; +use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState}; use smoltcp::wire::{IpAddress, IpCidr}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -77,7 +77,10 @@ impl VirtualInterfacePoll for TcpVirtualInterface { let device = VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; - let mut virtual_interface = InterfaceBuilder::new(device) + + // there are always 2 sockets: 1 virtual client and 1 virtual server. + let mut sockets: [_; 2] = Default::default(); + let mut virtual_interface = InterfaceBuilder::new(device, &mut sockets[..]) .ip_addrs([ // Interface handles IP packets for the sender and recipient IpCidr::new(IpAddress::from(source_peer_ip), 32), @@ -112,11 +115,8 @@ impl VirtualInterfacePoll for TcpVirtualInterface { Ok(socket) }; - // Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server. - let mut socket_set_entries: [_; 2] = Default::default(); - let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); - let _server_handle = socket_set.add(server_socket?); - let client_handle = socket_set.add(client_socket?); + let _server_handle = virtual_interface.add_socket(server_socket?); + let client_handle = virtual_interface.add_socket(client_socket?); // Any data that wasn't sent because it was over the sending buffer limit let mut tx_extra = Vec::new(); @@ -137,11 +137,11 @@ impl VirtualInterfacePoll for TcpVirtualInterface { if shutdown { // Shutdown: sends a RST packet. trace!("[{}] Shutting down virtual interface", self.virtual_port); - let mut client_socket = socket_set.get::(client_handle); + let client_socket = virtual_interface.get_socket::(client_handle); client_socket.abort(); } - match virtual_interface.poll(&mut socket_set, loop_start) { + match virtual_interface.poll(loop_start) { Ok(processed) if processed => { trace!( "[{}] Virtual interface polled some packets to be processed", @@ -158,7 +158,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } { - let mut client_socket = socket_set.get::(client_handle); + let (client_socket, context) = virtual_interface.get_socket_and_context::(client_handle); if !shutdown && client_socket.state() == TcpState::Closed && !has_connected { // Not shutting down, but the client socket is closed, and the client never successfully connected. @@ -166,6 +166,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Try to connect client_socket .connect( + context, ( IpAddress::from(self.port_forward.destination.ip()), self.port_forward.destination.port(), @@ -266,7 +267,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { break; } - match virtual_interface.poll_delay(&socket_set, loop_start) { + match virtual_interface.poll_delay(loop_start) { Some(smoltcp::time::Duration::ZERO) => { continue; } diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 212cb8d..1fdbb63 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use smoltcp::iface::InterfaceBuilder; -use smoltcp::socket::{SocketHandle, SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; +use smoltcp::iface::{InterfaceBuilder, SocketHandle}; +use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; use smoltcp::wire::{IpAddress, IpCidr}; use crate::config::PortForwardConfig; @@ -56,7 +56,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx); - let mut virtual_interface = InterfaceBuilder::new(device) + let mut virtual_interface = InterfaceBuilder::new(device, vec![]) .ip_addrs([ // Interface handles IP packets for the sender and recipient IpCidr::new(source_peer_ip.into(), 32), @@ -87,8 +87,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { Ok(socket) }; - let mut socket_set = SocketSet::new(vec![]); - let _server_handle = socket_set.add(server_socket?); + let _server_handle = virtual_interface.add_socket(server_socket?); // A map of virtual port to client socket. let mut client_sockets: HashMap = HashMap::new(); @@ -107,7 +106,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } => { let loop_start = smoltcp::time::Instant::now(); - match virtual_interface.poll(&mut socket_set, loop_start) { + match virtual_interface.poll(loop_start) { Ok(processed) if processed => { trace!("UDP virtual interface polled some packets to be processed"); } @@ -118,7 +117,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { // Loop through each client socket and check if there is any data to send back // to the real client. for (virtual_port, client_socket_handle) in client_sockets.iter() { - let mut client_socket = socket_set.get::(*client_socket_handle); + let client_socket = virtual_interface.get_socket::(*client_socket_handle); match client_socket.recv() { Ok((data, _peer)) => { // Send the data back to the real client using MPSC channel @@ -142,7 +141,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } } - next_poll = match virtual_interface.poll_delay(&socket_set, loop_start) { + next_poll = match virtual_interface.poll_delay(loop_start) { Some(smoltcp::time::Duration::ZERO) => None, Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())), None => None, @@ -178,10 +177,10 @@ impl VirtualInterfacePoll for UdpVirtualInterface { ); }); - socket_set.add(socket) + virtual_interface.add_socket(socket) }); - let mut client_socket = socket_set.get::(*client_socket_handle); + let client_socket = virtual_interface.get_socket::(*client_socket_handle); client_socket .send_slice( &data,