From 38fc217a29d36b77e6a5baf944d07b6e1b6d3b1f Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sat, 21 Oct 2023 11:12:18 +0800 Subject: [PATCH] smoltcp version 0.10 applied --- Cargo.lock | 164 ++++++++++++++++++++++++++++++++++++--- Cargo.toml | 12 ++- src/virtual_device.rs | 34 ++++---- src/virtual_iface/tcp.rs | 100 +++++++++++++----------- src/virtual_iface/udp.rs | 107 +++++++++++++------------ src/wg.rs | 2 +- 6 files changed, 296 insertions(+), 123 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7c8df37..ed1457b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,7 +46,7 @@ checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.41", ] [[package]] @@ -57,7 +57,16 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.41", +] + +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", ] [[package]] @@ -204,6 +213,44 @@ dependencies = [ "memchr", ] +[[package]] +name = "critical-section" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216" + +[[package]] +name = "defmt" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8a2d011b2fee29fb7d659b83c43fce9a2cb4df453e16d441a51448e448f3f98" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54f0216f6c5acb5ae1a47050a6645024e6edafc2ee32d421955eccfef12ef92e" +dependencies = [ + "defmt-parser", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.41", +] + +[[package]] +name = "defmt-parser" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "269924c02afd7f94bc4cecbfa5c379f6ffcf9766b3408fe63d22c728654eccd0" +dependencies = [ + "thiserror", +] + [[package]] name = "env_logger" version = "0.7.1" @@ -283,7 +330,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.41", ] [[package]] @@ -333,12 +380,34 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32", + "rustc_version", + "spin 0.9.8", + "stable_deref_trait", +] + [[package]] name = "hermit-abi" version = "0.1.19" @@ -614,6 +683,30 @@ dependencies = [ "indexmap", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.70" @@ -715,7 +808,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted 0.7.1", "web-sys", "winapi", @@ -727,6 +820,15 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.28" @@ -755,6 +857,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" + [[package]] name = "slab" version = "0.4.9" @@ -772,12 +880,15 @@ checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "smoltcp" -version = "0.8.2" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee34c1e1bfc7e9206cc0fb8030a90129b4e319ab53856249bb27642cab914fb3" +checksum = "8d2e3a36ac8fea7b94e666dfa3871063d6e0a5c9d5d4fec9a1a6b7b6760f0229" dependencies = [ "bitflags 1.3.2", "byteorder", + "cfg-if", + "defmt", + "heapless", "log", "managed", ] @@ -798,12 +909,37 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "strsim" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.41" @@ -851,7 +987,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.41", ] [[package]] @@ -879,7 +1015,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.41", ] [[package]] @@ -902,7 +1038,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.41", ] [[package]] @@ -932,6 +1068,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "walkdir" version = "2.4.0" @@ -969,7 +1111,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.41", "wasm-bindgen-shared", ] @@ -991,7 +1133,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.41", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 8a27b36..da35c45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,16 @@ futures = "0.3" rand = "0.8" nom = "7" async-trait = "0.1" -priority-queue = "1.3.0" -smoltcp = { version = "0.8.2", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-udp", "socket-tcp"] } +priority-queue = "1.3" +smoltcp = { version = "0.10", default-features = false, features = [ + "std", + "log", + "medium-ip", + "proto-ipv4", + "proto-ipv6", + "socket-udp", + "socket-tcp", +] } bytes = "1" base64 = "0.13" diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 57ad3a1..7af4d27 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,13 +1,15 @@ -use std::collections::VecDeque; -use std::sync::{Arc, Mutex}; - -use bytes::{BufMut, Bytes, BytesMut}; -use smoltcp::phy::{Device, DeviceCapabilities, Medium}; -use smoltcp::time::Instant; - use crate::config::PortProtocol; use crate::events::{BusSender, Event}; use crate::Bus; +use bytes::{BufMut, Bytes, BytesMut}; +use smoltcp::{ + phy::{DeviceCapabilities, Medium}, + time::Instant, +}; +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; /// A virtual device that processes IP packets through smoltcp and WireGuard. pub struct VirtualIpDevice { @@ -52,11 +54,11 @@ impl VirtualIpDevice { } } -impl<'a> Device<'a> for VirtualIpDevice { - type RxToken = RxToken; - type TxToken = TxToken; +impl smoltcp::phy::Device for VirtualIpDevice { + type RxToken<'a> = RxToken where Self: 'a; + type TxToken<'a> = TxToken where Self: 'a; - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { let next = { let mut queue = self .process_queue @@ -81,7 +83,7 @@ impl<'a> Device<'a> for VirtualIpDevice { } } - fn transmit(&'a mut self) -> Option { + fn transmit(&mut self, _timestamp: Instant) -> Option> { Some(TxToken { sender: self.bus_sender.clone(), }) @@ -101,9 +103,9 @@ pub struct RxToken { } impl smoltcp::phy::RxToken for RxToken { - fn consume(mut self, _timestamp: Instant, f: F) -> smoltcp::Result + fn consume(mut self, f: F) -> R where - F: FnOnce(&mut [u8]) -> smoltcp::Result, + F: FnOnce(&mut [u8]) -> R, { f(&mut self.buffer) } @@ -115,9 +117,9 @@ pub struct TxToken { } impl smoltcp::phy::TxToken for TxToken { - fn consume(self, _timestamp: Instant, len: usize, f: F) -> smoltcp::Result + fn consume(self, len: usize, f: F) -> R where - F: FnOnce(&mut [u8]) -> smoltcp::Result, + F: FnOnce(&mut [u8]) -> R, { let mut buffer = vec![0; len]; let result = f(&mut buffer); diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 8a7a509..24957bd 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -1,30 +1,34 @@ -use std::collections::{HashMap, HashSet, VecDeque}; -use std::net::IpAddr; -use std::time::Duration; - -use anyhow::Context; -use async_trait::async_trait; -use bytes::Bytes; -use smoltcp::iface::{InterfaceBuilder, SocketHandle}; -use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState}; -use smoltcp::wire::{IpAddress, IpCidr}; - use crate::config::{PortForwardConfig, PortProtocol}; use crate::events::Event; use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::Bus; +use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; +use smoltcp::{ + iface::{Config, Interface, SocketHandle, SocketSet}, + socket::tcp, + time::Instant, + wire::{HardwareAddress, IpAddress, IpCidr}, +}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + net::IpAddr, + time::Duration, +}; const MAX_PACKET: usize = 65536; /// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa. -pub struct TcpVirtualInterface { +pub struct TcpVirtualInterface<'a> { source_peer_ip: IpAddr, port_forwards: Vec, bus: Bus, + sockets: SocketSet<'a>, } -impl TcpVirtualInterface { +impl<'a> TcpVirtualInterface<'a> { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { @@ -35,16 +39,17 @@ impl TcpVirtualInterface { .collect(), source_peer_ip, bus, + sockets: SocketSet::new([]), } } - fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { + fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { static mut TCP_SERVER_RX_DATA: [u8; 0] = []; static mut TCP_SERVER_TX_DATA: [u8; 0] = []; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); socket .listen(( @@ -56,12 +61,12 @@ impl TcpVirtualInterface { Ok(socket) } - fn new_client_socket() -> anyhow::Result> { + fn new_client_socket() -> anyhow::Result> { let rx_data = vec![0u8; MAX_PACKET]; let tx_data = vec![0u8; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); - let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); - let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + let tcp_rx_buffer = tcp::SocketBuffer::new(rx_data); + let tcp_tx_buffer = tcp::SocketBuffer::new(tx_data); + let socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); Ok(socket) } @@ -79,20 +84,31 @@ impl TcpVirtualInterface { } #[async_trait] -impl VirtualInterfacePoll for TcpVirtualInterface { - async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { +impl VirtualInterfacePoll for TcpVirtualInterface<'_> { + async fn poll_loop(mut self, mut device: VirtualIpDevice) -> anyhow::Result<()> { // Create CIDR block for source peer IP + each port forward IP let addresses = self.addresses(); + let config = Config::new(HardwareAddress::Ip); // Create virtual interface (contains smoltcp state machine) - let mut iface = InterfaceBuilder::new(device, vec![]) - .ip_addrs(addresses) - .finalize(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + addresses.iter().for_each(|addr| { + ip_addrs.push(*addr).unwrap(); + }); + }); + iface.set_any_ip(true); + + // Maps virtual port to its client socket handle + let mut port_client_handle_map: HashMap = HashMap::new(); // Create virtual server for each port forward for port_forward in self.port_forwards.iter() { let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?; - iface.add_socket(server_socket); + let handle = self.sockets.add(server_socket); + let virtual_port = + VirtualPort::new(port_forward.destination.port(), port_forward.protocol); + port_client_handle_map.insert(virtual_port, handle); } // The next time to poll the interface. Can be None for instant poll. @@ -101,9 +117,6 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Bus endpoint to read events let mut endpoint = self.bus.new_endpoint(); - // Maps virtual port to its client socket handle - let mut port_client_handle_map: HashMap = HashMap::new(); - // Data packets to send from a virtual client let mut send_queue: HashMap> = HashMap::new(); @@ -118,11 +131,11 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Find closed sockets port_client_handle_map.retain(|virtual_port, client_handle| { - let client_socket = iface.get_socket::(*client_handle); - if client_socket.state() == TcpState::Closed { + let client_socket = self.sockets.get_mut::(*client_handle); + if client_socket.state() == tcp::State::Closed { endpoint.send(Event::ClientConnectionDropped(*virtual_port)); send_queue.remove(virtual_port); - iface.remove_socket(*client_handle); + self.sockets.remove(*client_handle); false } else { // Not closed, retain @@ -130,16 +143,12 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } }); - match iface.poll(loop_start) { - Ok(processed) if processed => { - trace!("TCP virtual interface polled some packets to be processed"); - } - Err(e) => error!("TCP virtual interface poll error: {:?}", e), - _ => {} + if iface.poll(loop_start, &mut device, &mut self.sockets) { + log::trace!("TCP virtual interface polled some packets to be processed"); } for (virtual_port, client_handle) in port_client_handle_map.iter() { - let client_socket = iface.get_socket::(*client_handle); + let client_socket = self.sockets.get_mut::(*client_handle); if client_socket.can_send() { if let Some(send_queue) = send_queue.get_mut(virtual_port) { let to_transfer = send_queue.pop_front(); @@ -159,7 +168,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { ); } } - } else if client_socket.state() == TcpState::CloseWait { + } else if client_socket.state() == tcp::State::CloseWait { client_socket.close(); } } @@ -182,7 +191,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) - next_poll = match iface.poll_delay(loop_start) { + next_poll = match iface.poll_delay(loop_start, &self.sockets) { Some(smoltcp::time::Duration::ZERO) => None, Some(delay) => { trace!("TCP Virtual interface delayed next poll by {}", delay); @@ -195,13 +204,14 @@ impl VirtualInterfacePoll for TcpVirtualInterface { match event { Event::ClientConnectionInitiated(port_forward, virtual_port) => { let client_socket = TcpVirtualInterface::new_client_socket()?; - let client_handle = iface.add_socket(client_socket); + let client_handle = self.sockets.add(client_socket); // Add handle to map port_client_handle_map.insert(virtual_port, client_handle); send_queue.insert(virtual_port, VecDeque::new()); - let (client_socket, context) = iface.get_socket_and_context::(client_handle); + let client_socket = self.sockets.get_mut::(client_handle); + let context = iface.context(); client_socket .connect( @@ -218,7 +228,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } Event::ClientConnectionDropped(virtual_port) => { if let Some(client_handle) = port_client_handle_map.get(&virtual_port) { - let client_socket = iface.get_socket::(*client_handle); + let client_socket = self.sockets.get_mut::(*client_handle); client_socket.close(); next_poll = None; } diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index a3d1652..8426e4a 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,29 +1,33 @@ -use std::collections::{HashMap, HashSet, VecDeque}; -use std::net::IpAddr; -use std::time::Duration; - -use anyhow::Context; -use async_trait::async_trait; -use bytes::Bytes; -use smoltcp::iface::{InterfaceBuilder, SocketHandle}; -use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; -use smoltcp::wire::{IpAddress, IpCidr}; - use crate::config::PortForwardConfig; use crate::events::Event; use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::{Bus, PortProtocol}; +use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; +use smoltcp::{ + iface::{Config, Interface, SocketHandle, SocketSet}, + socket::udp::{self, UdpMetadata}, + time::Instant, + wire::{HardwareAddress, IpAddress, IpCidr}, +}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + net::IpAddr, + time::Duration, +}; const MAX_PACKET: usize = 65536; -pub struct UdpVirtualInterface { +pub struct UdpVirtualInterface<'a> { source_peer_ip: IpAddr, port_forwards: Vec, bus: Bus, + sockets: SocketSet<'a>, } -impl UdpVirtualInterface { +impl<'a> UdpVirtualInterface<'a> { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { @@ -34,21 +38,24 @@ impl UdpVirtualInterface { .collect(), source_peer_ip, bus, + sockets: SocketSet::new([]), } } - fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { - static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; + fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { + static mut UDP_SERVER_RX_META: [udp::PacketMetadata; 0] = []; static mut UDP_SERVER_RX_DATA: [u8; 0] = []; - static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_TX_META: [udp::PacketMetadata; 0] = []; static mut UDP_SERVER_TX_DATA: [u8; 0] = []; - let udp_rx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { - &mut UDP_SERVER_RX_DATA[..] - }); - let udp_tx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { - &mut UDP_SERVER_TX_DATA[..] - }); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + let udp_rx_buffer = + udp::PacketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { + &mut UDP_SERVER_RX_DATA[..] + }); + let udp_tx_buffer = + udp::PacketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { + &mut UDP_SERVER_TX_DATA[..] + }); + let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); socket .bind(( IpAddress::from(port_forward.destination.ip()), @@ -61,14 +68,14 @@ impl UdpVirtualInterface { fn new_client_socket( source_peer_ip: IpAddr, client_port: VirtualPort, - ) -> anyhow::Result> { - let rx_meta = vec![UdpPacketMetadata::EMPTY; 10]; - let tx_meta = vec![UdpPacketMetadata::EMPTY; 10]; + ) -> anyhow::Result> { + let rx_meta = vec![udp::PacketMetadata::EMPTY; 10]; + let tx_meta = vec![udp::PacketMetadata::EMPTY; 10]; let rx_data = vec![0u8; MAX_PACKET]; let tx_data = vec![0u8; MAX_PACKET]; - let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); - let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + let udp_rx_buffer = udp::PacketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = udp::PacketBuffer::new(tx_meta, tx_data); + let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); socket .bind((IpAddress::from(source_peer_ip), client_port.num())) .with_context(|| "UDP virtual client failed to bind")?; @@ -89,20 +96,31 @@ impl UdpVirtualInterface { } #[async_trait] -impl VirtualInterfacePoll for UdpVirtualInterface { - async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { +impl<'a> VirtualInterfacePoll for UdpVirtualInterface<'a> { + async fn poll_loop(mut self, mut device: VirtualIpDevice) -> anyhow::Result<()> { // Create CIDR block for source peer IP + each port forward IP let addresses = self.addresses(); + let config = Config::new(HardwareAddress::Ip); // Create virtual interface (contains smoltcp state machine) - let mut iface = InterfaceBuilder::new(device, vec![]) - .ip_addrs(addresses) - .finalize(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + addresses.iter().for_each(|addr| { + ip_addrs.push(*addr).unwrap(); + }); + }); + iface.set_any_ip(true); + + // Maps virtual port to its client socket handle + let mut port_client_handle_map: HashMap = HashMap::new(); // Create virtual server for each port forward for port_forward in self.port_forwards.iter() { let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?; - iface.add_socket(server_socket); + let handle = self.sockets.add(server_socket); + let virtual_port = + VirtualPort::new(port_forward.destination.port(), port_forward.protocol); + port_client_handle_map.insert(virtual_port, handle); } // The next time to poll the interface. Can be None for instant poll. @@ -111,9 +129,6 @@ impl VirtualInterfacePoll for UdpVirtualInterface { // Bus endpoint to read events let mut endpoint = self.bus.new_endpoint(); - // Maps virtual port to its client socket handle - let mut port_client_handle_map: HashMap = HashMap::new(); - // Data packets to send from a virtual client let mut send_queue: HashMap> = HashMap::new(); @@ -127,16 +142,12 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } => { let loop_start = smoltcp::time::Instant::now(); - match iface.poll(loop_start) { - Ok(processed) if processed => { - trace!("UDP virtual interface polled some packets to be processed"); - } - Err(e) => error!("UDP virtual interface poll error: {:?}", e), - _ => {} + if iface.poll(loop_start, &mut device, &mut self.sockets) { + log::trace!("UDP virtual interface polled some packets to be processed"); } for (virtual_port, client_handle) in port_client_handle_map.iter() { - let client_socket = iface.get_socket::(*client_handle); + let client_socket = self.sockets.get_mut::(*client_handle); if client_socket.can_send() { if let Some(send_queue) = send_queue.get_mut(virtual_port) { let to_transfer = send_queue.pop_front(); @@ -144,7 +155,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { client_socket .send_slice( &data, - (IpAddress::from(port_forward.destination.ip()), port_forward.destination.port()).into(), + UdpMetadata::from(port_forward.destination), ) .unwrap_or_else(|e| { error!( @@ -172,7 +183,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) - next_poll = match iface.poll_delay(loop_start) { + next_poll = match iface.poll_delay(loop_start, &self.sockets) { Some(smoltcp::time::Duration::ZERO) => None, Some(delay) => { trace!("UDP Virtual interface delayed next poll by {}", delay); @@ -190,7 +201,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } else { // Client socket does not exist let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?; - let client_handle = iface.add_socket(client_socket); + let client_handle = self.sockets.add(client_socket); // Add handle to map port_client_handle_map.insert(virtual_port, client_handle); diff --git a/src/wg.rs b/src/wg.rs index 14efa23..e1edf01 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -237,7 +237,7 @@ impl WireGuardTunnel { .ok() // Only care if the packet is destined for this tunnel .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) - .and_then(|packet| match packet.protocol() { + .and_then(|packet| match packet.next_header() { IpProtocol::Tcp => Some(PortProtocol::Tcp), IpProtocol::Udp => Some(PortProtocol::Udp), // Unrecognized protocol, so we cannot determine where to route