From ff7b5ec2ca0bc97318eba5eaf6a97eb733523583 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sat, 21 Oct 2023 11:12:18 +0800 Subject: [PATCH] smoltcp version 0.10 applied --- Cargo.toml | 2 +- src/virtual_device.rs | 37 ++++++------- src/virtual_iface/tcp.rs | 102 +++++++++++++++++++----------------- src/virtual_iface/udp.rs | 109 +++++++++++++++++++++------------------ src/wg.rs | 2 +- 5 files changed, 137 insertions(+), 115 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 583f3a7..5c0363a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ rand = "0.8" nom = "7" async-trait = "0.1" priority-queue = "1.3" -smoltcp = { version = "0.8.2", default-features = false, features = [ +smoltcp = { version = "0.10", default-features = false, features = [ "std", "log", "medium-ip", diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 0054690..7af4d27 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,13 +1,15 @@ -use std::collections::VecDeque; -use std::sync::{Arc, Mutex}; - -use bytes::{BufMut, Bytes, BytesMut}; -use smoltcp::phy::{Device, DeviceCapabilities, Medium}; -use smoltcp::time::Instant; - use crate::config::PortProtocol; use crate::events::{BusSender, Event}; use crate::Bus; +use bytes::{BufMut, Bytes, BytesMut}; +use smoltcp::{ + phy::{DeviceCapabilities, Medium}, + time::Instant, +}; +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; /// A virtual device that processes IP packets through smoltcp and WireGuard. pub struct VirtualIpDevice { @@ -52,11 +54,11 @@ impl VirtualIpDevice { } } -impl<'a> Device<'a> for VirtualIpDevice { - type RxToken = RxToken; - type TxToken = TxToken; +impl smoltcp::phy::Device for VirtualIpDevice { + type RxToken<'a> = RxToken where Self: 'a; + type TxToken<'a> = TxToken where Self: 'a; - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { let next = { let mut queue = self .process_queue @@ -81,7 +83,7 @@ impl<'a> Device<'a> for VirtualIpDevice { } } - fn transmit(&'a mut self) -> Option { + fn transmit(&mut self, _timestamp: Instant) -> Option> { Some(TxToken { sender: self.bus_sender.clone(), }) @@ -101,9 +103,9 @@ pub struct RxToken { } impl smoltcp::phy::RxToken for RxToken { - fn consume(mut self, _timestamp: Instant, f: F) -> smoltcp::Result + fn consume(mut self, f: F) -> R where - F: FnOnce(&mut [u8]) -> smoltcp::Result, + F: FnOnce(&mut [u8]) -> R, { f(&mut self.buffer) } @@ -115,12 +117,11 @@ pub struct TxToken { } impl smoltcp::phy::TxToken for TxToken { - fn consume(self, _timestamp: Instant, len: usize, f: F) -> smoltcp::Result + fn consume(self, len: usize, f: F) -> R where - F: FnOnce(&mut [u8]) -> smoltcp::Result, + F: FnOnce(&mut [u8]) -> R, { - let mut buffer = Vec::new(); - buffer.resize(len, 0); + let mut buffer = vec![0; len]; let result = f(&mut buffer); self.sender .send(Event::OutboundInternetPacket(buffer.into())); diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 32706a0..24957bd 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -1,30 +1,34 @@ -use std::collections::{HashMap, HashSet, VecDeque}; -use std::net::IpAddr; -use std::time::Duration; - -use anyhow::Context; -use async_trait::async_trait; -use bytes::Bytes; -use smoltcp::iface::{InterfaceBuilder, SocketHandle}; -use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState}; -use smoltcp::wire::{IpAddress, IpCidr}; - use crate::config::{PortForwardConfig, PortProtocol}; use crate::events::Event; use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::Bus; +use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; +use smoltcp::{ + iface::{Config, Interface, SocketHandle, SocketSet}, + socket::tcp, + time::Instant, + wire::{HardwareAddress, IpAddress, IpCidr}, +}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + net::IpAddr, + time::Duration, +}; const MAX_PACKET: usize = 65536; /// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa. -pub struct TcpVirtualInterface { +pub struct TcpVirtualInterface<'a> { source_peer_ip: IpAddr, port_forwards: Vec, bus: Bus, + sockets: SocketSet<'a>, } -impl TcpVirtualInterface { +impl<'a> TcpVirtualInterface<'a> { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { @@ -35,16 +39,17 @@ impl TcpVirtualInterface { .collect(), source_peer_ip, bus, + sockets: SocketSet::new([]), } } - fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { + fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { static mut TCP_SERVER_RX_DATA: [u8; 0] = []; static mut TCP_SERVER_TX_DATA: [u8; 0] = []; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); socket .listen(( @@ -56,12 +61,12 @@ impl TcpVirtualInterface { Ok(socket) } - fn new_client_socket() -> anyhow::Result> { + fn new_client_socket() -> anyhow::Result> { let rx_data = vec![0u8; MAX_PACKET]; let tx_data = vec![0u8; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); - let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); - let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + let tcp_rx_buffer = tcp::SocketBuffer::new(rx_data); + let tcp_tx_buffer = tcp::SocketBuffer::new(tx_data); + let socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); Ok(socket) } @@ -79,20 +84,31 @@ impl TcpVirtualInterface { } #[async_trait] -impl VirtualInterfacePoll for TcpVirtualInterface { - async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { +impl VirtualInterfacePoll for TcpVirtualInterface<'_> { + async fn poll_loop(mut self, mut device: VirtualIpDevice) -> anyhow::Result<()> { // Create CIDR block for source peer IP + each port forward IP let addresses = self.addresses(); + let config = Config::new(HardwareAddress::Ip); // Create virtual interface (contains smoltcp state machine) - let mut iface = InterfaceBuilder::new(device, vec![]) - .ip_addrs(addresses) - .finalize(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + addresses.iter().for_each(|addr| { + ip_addrs.push(*addr).unwrap(); + }); + }); + iface.set_any_ip(true); + + // Maps virtual port to its client socket handle + let mut port_client_handle_map: HashMap = HashMap::new(); // Create virtual server for each port forward for port_forward in self.port_forwards.iter() { let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?; - iface.add_socket(server_socket); + let handle = self.sockets.add(server_socket); + let virtual_port = + VirtualPort::new(port_forward.destination.port(), port_forward.protocol); + port_client_handle_map.insert(virtual_port, handle); } // The next time to poll the interface. Can be None for instant poll. @@ -101,9 +117,6 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Bus endpoint to read events let mut endpoint = self.bus.new_endpoint(); - // Maps virtual port to its client socket handle - let mut port_client_handle_map: HashMap = HashMap::new(); - // Data packets to send from a virtual client let mut send_queue: HashMap> = HashMap::new(); @@ -118,11 +131,11 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Find closed sockets port_client_handle_map.retain(|virtual_port, client_handle| { - let client_socket = iface.get_socket::(*client_handle); - if client_socket.state() == TcpState::Closed { + let client_socket = self.sockets.get_mut::(*client_handle); + if client_socket.state() == tcp::State::Closed { endpoint.send(Event::ClientConnectionDropped(*virtual_port)); send_queue.remove(virtual_port); - iface.remove_socket(*client_handle); + self.sockets.remove(*client_handle); false } else { // Not closed, retain @@ -130,16 +143,12 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } }); - match iface.poll(loop_start) { - Ok(processed) if processed => { - trace!("TCP virtual interface polled some packets to be processed"); - } - Err(e) => error!("TCP virtual interface poll error: {:?}", e), - _ => {} + if iface.poll(loop_start, &mut device, &mut self.sockets) { + log::trace!("TCP virtual interface polled some packets to be processed"); } for (virtual_port, client_handle) in port_client_handle_map.iter() { - let client_socket = iface.get_socket::(*client_handle); + let client_socket = self.sockets.get_mut::(*client_handle); if client_socket.can_send() { if let Some(send_queue) = send_queue.get_mut(virtual_port) { let to_transfer = send_queue.pop_front(); @@ -159,7 +168,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { ); } } - } else if client_socket.state() == TcpState::CloseWait { + } else if client_socket.state() == tcp::State::CloseWait { client_socket.close(); } } @@ -182,7 +191,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) - next_poll = match iface.poll_delay(loop_start) { + next_poll = match iface.poll_delay(loop_start, &self.sockets) { Some(smoltcp::time::Duration::ZERO) => None, Some(delay) => { trace!("TCP Virtual interface delayed next poll by {}", delay); @@ -195,13 +204,14 @@ impl VirtualInterfacePoll for TcpVirtualInterface { match event { Event::ClientConnectionInitiated(port_forward, virtual_port) => { let client_socket = TcpVirtualInterface::new_client_socket()?; - let client_handle = iface.add_socket(client_socket); + let client_handle = self.sockets.add(client_socket); // Add handle to map port_client_handle_map.insert(virtual_port, client_handle); send_queue.insert(virtual_port, VecDeque::new()); - let (client_socket, context) = iface.get_socket_and_context::(client_handle); + let client_socket = self.sockets.get_mut::(client_handle); + let context = iface.context(); client_socket .connect( @@ -218,7 +228,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } Event::ClientConnectionDropped(virtual_port) => { if let Some(client_handle) = port_client_handle_map.get(&virtual_port) { - let client_socket = iface.get_socket::(*client_handle); + let client_socket = self.sockets.get_mut::(*client_handle); client_socket.close(); next_poll = None; } @@ -229,7 +239,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { next_poll = None; } } - Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Tcp => { + Event::VirtualDeviceFed(PortProtocol::Tcp) => { next_poll = None; } _ => {} diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index be63071..8426e4a 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,29 +1,33 @@ -use std::collections::{HashMap, HashSet, VecDeque}; -use std::net::IpAddr; -use std::time::Duration; - -use anyhow::Context; -use async_trait::async_trait; -use bytes::Bytes; -use smoltcp::iface::{InterfaceBuilder, SocketHandle}; -use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; -use smoltcp::wire::{IpAddress, IpCidr}; - use crate::config::PortForwardConfig; use crate::events::Event; use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::{Bus, PortProtocol}; +use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; +use smoltcp::{ + iface::{Config, Interface, SocketHandle, SocketSet}, + socket::udp::{self, UdpMetadata}, + time::Instant, + wire::{HardwareAddress, IpAddress, IpCidr}, +}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + net::IpAddr, + time::Duration, +}; const MAX_PACKET: usize = 65536; -pub struct UdpVirtualInterface { +pub struct UdpVirtualInterface<'a> { source_peer_ip: IpAddr, port_forwards: Vec, bus: Bus, + sockets: SocketSet<'a>, } -impl UdpVirtualInterface { +impl<'a> UdpVirtualInterface<'a> { /// Initialize the parameters for a new virtual interface. /// Use the `poll_loop()` future to start the virtual interface poll loop. pub fn new(port_forwards: Vec, bus: Bus, source_peer_ip: IpAddr) -> Self { @@ -34,21 +38,24 @@ impl UdpVirtualInterface { .collect(), source_peer_ip, bus, + sockets: SocketSet::new([]), } } - fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { - static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; + fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result> { + static mut UDP_SERVER_RX_META: [udp::PacketMetadata; 0] = []; static mut UDP_SERVER_RX_DATA: [u8; 0] = []; - static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_TX_META: [udp::PacketMetadata; 0] = []; static mut UDP_SERVER_TX_DATA: [u8; 0] = []; - let udp_rx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { - &mut UDP_SERVER_RX_DATA[..] - }); - let udp_tx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { - &mut UDP_SERVER_TX_DATA[..] - }); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + let udp_rx_buffer = + udp::PacketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { + &mut UDP_SERVER_RX_DATA[..] + }); + let udp_tx_buffer = + udp::PacketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { + &mut UDP_SERVER_TX_DATA[..] + }); + let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); socket .bind(( IpAddress::from(port_forward.destination.ip()), @@ -61,14 +68,14 @@ impl UdpVirtualInterface { fn new_client_socket( source_peer_ip: IpAddr, client_port: VirtualPort, - ) -> anyhow::Result> { - let rx_meta = vec![UdpPacketMetadata::EMPTY; 10]; - let tx_meta = vec![UdpPacketMetadata::EMPTY; 10]; + ) -> anyhow::Result> { + let rx_meta = vec![udp::PacketMetadata::EMPTY; 10]; + let tx_meta = vec![udp::PacketMetadata::EMPTY; 10]; let rx_data = vec![0u8; MAX_PACKET]; let tx_data = vec![0u8; MAX_PACKET]; - let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); - let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + let udp_rx_buffer = udp::PacketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = udp::PacketBuffer::new(tx_meta, tx_data); + let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); socket .bind((IpAddress::from(source_peer_ip), client_port.num())) .with_context(|| "UDP virtual client failed to bind")?; @@ -89,20 +96,31 @@ impl UdpVirtualInterface { } #[async_trait] -impl VirtualInterfacePoll for UdpVirtualInterface { - async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { +impl<'a> VirtualInterfacePoll for UdpVirtualInterface<'a> { + async fn poll_loop(mut self, mut device: VirtualIpDevice) -> anyhow::Result<()> { // Create CIDR block for source peer IP + each port forward IP let addresses = self.addresses(); + let config = Config::new(HardwareAddress::Ip); // Create virtual interface (contains smoltcp state machine) - let mut iface = InterfaceBuilder::new(device, vec![]) - .ip_addrs(addresses) - .finalize(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + addresses.iter().for_each(|addr| { + ip_addrs.push(*addr).unwrap(); + }); + }); + iface.set_any_ip(true); + + // Maps virtual port to its client socket handle + let mut port_client_handle_map: HashMap = HashMap::new(); // Create virtual server for each port forward for port_forward in self.port_forwards.iter() { let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?; - iface.add_socket(server_socket); + let handle = self.sockets.add(server_socket); + let virtual_port = + VirtualPort::new(port_forward.destination.port(), port_forward.protocol); + port_client_handle_map.insert(virtual_port, handle); } // The next time to poll the interface. Can be None for instant poll. @@ -111,9 +129,6 @@ impl VirtualInterfacePoll for UdpVirtualInterface { // Bus endpoint to read events let mut endpoint = self.bus.new_endpoint(); - // Maps virtual port to its client socket handle - let mut port_client_handle_map: HashMap = HashMap::new(); - // Data packets to send from a virtual client let mut send_queue: HashMap> = HashMap::new(); @@ -127,16 +142,12 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } => { let loop_start = smoltcp::time::Instant::now(); - match iface.poll(loop_start) { - Ok(processed) if processed => { - trace!("UDP virtual interface polled some packets to be processed"); - } - Err(e) => error!("UDP virtual interface poll error: {:?}", e), - _ => {} + if iface.poll(loop_start, &mut device, &mut self.sockets) { + log::trace!("UDP virtual interface polled some packets to be processed"); } for (virtual_port, client_handle) in port_client_handle_map.iter() { - let client_socket = iface.get_socket::(*client_handle); + let client_socket = self.sockets.get_mut::(*client_handle); if client_socket.can_send() { if let Some(send_queue) = send_queue.get_mut(virtual_port) { let to_transfer = send_queue.pop_front(); @@ -144,7 +155,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { client_socket .send_slice( &data, - (IpAddress::from(port_forward.destination.ip()), port_forward.destination.port()).into(), + UdpMetadata::from(port_forward.destination), ) .unwrap_or_else(|e| { error!( @@ -172,7 +183,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) - next_poll = match iface.poll_delay(loop_start) { + next_poll = match iface.poll_delay(loop_start, &self.sockets) { Some(smoltcp::time::Duration::ZERO) => None, Some(delay) => { trace!("UDP Virtual interface delayed next poll by {}", delay); @@ -190,7 +201,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } else { // Client socket does not exist let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?; - let client_handle = iface.add_socket(client_socket); + let client_handle = self.sockets.add(client_socket); // Add handle to map port_client_handle_map.insert(virtual_port, client_handle); @@ -198,7 +209,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { } next_poll = None; } - Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Udp => { + Event::VirtualDeviceFed(PortProtocol::Udp) => { next_poll = None; } _ => {} diff --git a/src/wg.rs b/src/wg.rs index 14efa23..e1edf01 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -237,7 +237,7 @@ impl WireGuardTunnel { .ok() // Only care if the packet is destined for this tunnel .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) - .and_then(|packet| match packet.protocol() { + .and_then(|packet| match packet.next_header() { IpProtocol::Tcp => Some(PortProtocol::Tcp), IpProtocol::Udp => Some(PortProtocol::Udp), // Unrecognized protocol, so we cannot determine where to route