smoltcp version 0.10 applied

This commit is contained in:
ssrlive 2023-10-21 11:12:18 +08:00
parent f212e85c41
commit ff7b5ec2ca
5 changed files with 137 additions and 115 deletions

View file

@ -28,7 +28,7 @@ rand = "0.8"
nom = "7" nom = "7"
async-trait = "0.1" async-trait = "0.1"
priority-queue = "1.3" priority-queue = "1.3"
smoltcp = { version = "0.8.2", default-features = false, features = [ smoltcp = { version = "0.10", default-features = false, features = [
"std", "std",
"log", "log",
"medium-ip", "medium-ip",

View file

@ -1,13 +1,15 @@
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use bytes::{BufMut, Bytes, BytesMut};
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant;
use crate::config::PortProtocol; use crate::config::PortProtocol;
use crate::events::{BusSender, Event}; use crate::events::{BusSender, Event};
use crate::Bus; use crate::Bus;
use bytes::{BufMut, Bytes, BytesMut};
use smoltcp::{
phy::{DeviceCapabilities, Medium},
time::Instant,
};
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
/// A virtual device that processes IP packets through smoltcp and WireGuard. /// A virtual device that processes IP packets through smoltcp and WireGuard.
pub struct VirtualIpDevice { pub struct VirtualIpDevice {
@ -52,11 +54,11 @@ impl VirtualIpDevice {
} }
} }
impl<'a> Device<'a> for VirtualIpDevice { impl smoltcp::phy::Device for VirtualIpDevice {
type RxToken = RxToken; type RxToken<'a> = RxToken where Self: 'a;
type TxToken = TxToken; type TxToken<'a> = TxToken where Self: 'a;
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
let next = { let next = {
let mut queue = self let mut queue = self
.process_queue .process_queue
@ -81,7 +83,7 @@ impl<'a> Device<'a> for VirtualIpDevice {
} }
} }
fn transmit(&'a mut self) -> Option<Self::TxToken> { fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
Some(TxToken { Some(TxToken {
sender: self.bus_sender.clone(), sender: self.bus_sender.clone(),
}) })
@ -101,9 +103,9 @@ pub struct RxToken {
} }
impl smoltcp::phy::RxToken for RxToken { impl smoltcp::phy::RxToken for RxToken {
fn consume<R, F>(mut self, _timestamp: Instant, f: F) -> smoltcp::Result<R> fn consume<R, F>(mut self, f: F) -> R
where where
F: FnOnce(&mut [u8]) -> smoltcp::Result<R>, F: FnOnce(&mut [u8]) -> R,
{ {
f(&mut self.buffer) f(&mut self.buffer)
} }
@ -115,12 +117,11 @@ pub struct TxToken {
} }
impl smoltcp::phy::TxToken for TxToken { impl smoltcp::phy::TxToken for TxToken {
fn consume<R, F>(self, _timestamp: Instant, len: usize, f: F) -> smoltcp::Result<R> fn consume<R, F>(self, len: usize, f: F) -> R
where where
F: FnOnce(&mut [u8]) -> smoltcp::Result<R>, F: FnOnce(&mut [u8]) -> R,
{ {
let mut buffer = Vec::new(); let mut buffer = vec![0; len];
buffer.resize(len, 0);
let result = f(&mut buffer); let result = f(&mut buffer);
self.sender self.sender
.send(Event::OutboundInternetPacket(buffer.into())); .send(Event::OutboundInternetPacket(buffer.into()));

View file

@ -1,30 +1,34 @@
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::IpAddr;
use std::time::Duration;
use anyhow::Context;
use async_trait::async_trait;
use bytes::Bytes;
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState};
use smoltcp::wire::{IpAddress, IpCidr};
use crate::config::{PortForwardConfig, PortProtocol}; use crate::config::{PortForwardConfig, PortProtocol};
use crate::events::Event; use crate::events::Event;
use crate::virtual_device::VirtualIpDevice; use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::Bus; use crate::Bus;
use anyhow::Context;
use async_trait::async_trait;
use bytes::Bytes;
use smoltcp::{
iface::{Config, Interface, SocketHandle, SocketSet},
socket::tcp,
time::Instant,
wire::{HardwareAddress, IpAddress, IpCidr},
};
use std::{
collections::{HashMap, HashSet, VecDeque},
net::IpAddr,
time::Duration,
};
const MAX_PACKET: usize = 65536; const MAX_PACKET: usize = 65536;
/// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa. /// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa.
pub struct TcpVirtualInterface { pub struct TcpVirtualInterface<'a> {
source_peer_ip: IpAddr, source_peer_ip: IpAddr,
port_forwards: Vec<PortForwardConfig>, port_forwards: Vec<PortForwardConfig>,
bus: Bus, bus: Bus,
sockets: SocketSet<'a>,
} }
impl TcpVirtualInterface { impl<'a> TcpVirtualInterface<'a> {
/// Initialize the parameters for a new virtual interface. /// Initialize the parameters for a new virtual interface.
/// Use the `poll_loop()` future to start the virtual interface poll loop. /// Use the `poll_loop()` future to start the virtual interface poll loop.
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self { pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
@ -35,16 +39,17 @@ impl TcpVirtualInterface {
.collect(), .collect(),
source_peer_ip, source_peer_ip,
bus, bus,
sockets: SocketSet::new([]),
} }
} }
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<TcpSocket<'static>> { fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<tcp::Socket<'static>> {
static mut TCP_SERVER_RX_DATA: [u8; 0] = []; static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
static mut TCP_SERVER_TX_DATA: [u8; 0] = []; static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); let mut socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer);
socket socket
.listen(( .listen((
@ -56,12 +61,12 @@ impl TcpVirtualInterface {
Ok(socket) Ok(socket)
} }
fn new_client_socket() -> anyhow::Result<TcpSocket<'static>> { fn new_client_socket() -> anyhow::Result<tcp::Socket<'static>> {
let rx_data = vec![0u8; MAX_PACKET]; let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET]; let tx_data = vec![0u8; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); let tcp_rx_buffer = tcp::SocketBuffer::new(rx_data);
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); let tcp_tx_buffer = tcp::SocketBuffer::new(tx_data);
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); let socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer);
Ok(socket) Ok(socket)
} }
@ -79,20 +84,31 @@ impl TcpVirtualInterface {
} }
#[async_trait] #[async_trait]
impl VirtualInterfacePoll for TcpVirtualInterface { impl VirtualInterfacePoll for TcpVirtualInterface<'_> {
async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { async fn poll_loop(mut self, mut device: VirtualIpDevice) -> anyhow::Result<()> {
// Create CIDR block for source peer IP + each port forward IP // Create CIDR block for source peer IP + each port forward IP
let addresses = self.addresses(); let addresses = self.addresses();
let config = Config::new(HardwareAddress::Ip);
// Create virtual interface (contains smoltcp state machine) // Create virtual interface (contains smoltcp state machine)
let mut iface = InterfaceBuilder::new(device, vec![]) let mut iface = Interface::new(config, &mut device, Instant::now());
.ip_addrs(addresses) iface.update_ip_addrs(|ip_addrs| {
.finalize(); addresses.iter().for_each(|addr| {
ip_addrs.push(*addr).unwrap();
});
});
iface.set_any_ip(true);
// Maps virtual port to its client socket handle
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
// Create virtual server for each port forward // Create virtual server for each port forward
for port_forward in self.port_forwards.iter() { for port_forward in self.port_forwards.iter() {
let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?; let server_socket = TcpVirtualInterface::new_server_socket(*port_forward)?;
iface.add_socket(server_socket); let handle = self.sockets.add(server_socket);
let virtual_port =
VirtualPort::new(port_forward.destination.port(), port_forward.protocol);
port_client_handle_map.insert(virtual_port, handle);
} }
// The next time to poll the interface. Can be None for instant poll. // The next time to poll the interface. Can be None for instant poll.
@ -101,9 +117,6 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
// Bus endpoint to read events // Bus endpoint to read events
let mut endpoint = self.bus.new_endpoint(); let mut endpoint = self.bus.new_endpoint();
// Maps virtual port to its client socket handle
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
// Data packets to send from a virtual client // Data packets to send from a virtual client
let mut send_queue: HashMap<VirtualPort, VecDeque<Bytes>> = HashMap::new(); let mut send_queue: HashMap<VirtualPort, VecDeque<Bytes>> = HashMap::new();
@ -118,11 +131,11 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
// Find closed sockets // Find closed sockets
port_client_handle_map.retain(|virtual_port, client_handle| { port_client_handle_map.retain(|virtual_port, client_handle| {
let client_socket = iface.get_socket::<TcpSocket>(*client_handle); let client_socket = self.sockets.get_mut::<tcp::Socket>(*client_handle);
if client_socket.state() == TcpState::Closed { if client_socket.state() == tcp::State::Closed {
endpoint.send(Event::ClientConnectionDropped(*virtual_port)); endpoint.send(Event::ClientConnectionDropped(*virtual_port));
send_queue.remove(virtual_port); send_queue.remove(virtual_port);
iface.remove_socket(*client_handle); self.sockets.remove(*client_handle);
false false
} else { } else {
// Not closed, retain // Not closed, retain
@ -130,16 +143,12 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
} }
}); });
match iface.poll(loop_start) { if iface.poll(loop_start, &mut device, &mut self.sockets) {
Ok(processed) if processed => { log::trace!("TCP virtual interface polled some packets to be processed");
trace!("TCP virtual interface polled some packets to be processed");
}
Err(e) => error!("TCP virtual interface poll error: {:?}", e),
_ => {}
} }
for (virtual_port, client_handle) in port_client_handle_map.iter() { for (virtual_port, client_handle) in port_client_handle_map.iter() {
let client_socket = iface.get_socket::<TcpSocket>(*client_handle); let client_socket = self.sockets.get_mut::<tcp::Socket>(*client_handle);
if client_socket.can_send() { if client_socket.can_send() {
if let Some(send_queue) = send_queue.get_mut(virtual_port) { if let Some(send_queue) = send_queue.get_mut(virtual_port) {
let to_transfer = send_queue.pop_front(); let to_transfer = send_queue.pop_front();
@ -159,7 +168,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
); );
} }
} }
} else if client_socket.state() == TcpState::CloseWait { } else if client_socket.state() == tcp::State::CloseWait {
client_socket.close(); client_socket.close();
} }
} }
@ -182,7 +191,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
} }
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls) // The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
next_poll = match iface.poll_delay(loop_start) { next_poll = match iface.poll_delay(loop_start, &self.sockets) {
Some(smoltcp::time::Duration::ZERO) => None, Some(smoltcp::time::Duration::ZERO) => None,
Some(delay) => { Some(delay) => {
trace!("TCP Virtual interface delayed next poll by {}", delay); trace!("TCP Virtual interface delayed next poll by {}", delay);
@ -195,13 +204,14 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
match event { match event {
Event::ClientConnectionInitiated(port_forward, virtual_port) => { Event::ClientConnectionInitiated(port_forward, virtual_port) => {
let client_socket = TcpVirtualInterface::new_client_socket()?; let client_socket = TcpVirtualInterface::new_client_socket()?;
let client_handle = iface.add_socket(client_socket); let client_handle = self.sockets.add(client_socket);
// Add handle to map // Add handle to map
port_client_handle_map.insert(virtual_port, client_handle); port_client_handle_map.insert(virtual_port, client_handle);
send_queue.insert(virtual_port, VecDeque::new()); send_queue.insert(virtual_port, VecDeque::new());
let (client_socket, context) = iface.get_socket_and_context::<TcpSocket>(client_handle); let client_socket = self.sockets.get_mut::<tcp::Socket>(client_handle);
let context = iface.context();
client_socket client_socket
.connect( .connect(
@ -218,7 +228,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
} }
Event::ClientConnectionDropped(virtual_port) => { Event::ClientConnectionDropped(virtual_port) => {
if let Some(client_handle) = port_client_handle_map.get(&virtual_port) { if let Some(client_handle) = port_client_handle_map.get(&virtual_port) {
let client_socket = iface.get_socket::<TcpSocket>(*client_handle); let client_socket = self.sockets.get_mut::<tcp::Socket>(*client_handle);
client_socket.close(); client_socket.close();
next_poll = None; next_poll = None;
} }
@ -229,7 +239,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
next_poll = None; next_poll = None;
} }
} }
Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Tcp => { Event::VirtualDeviceFed(PortProtocol::Tcp) => {
next_poll = None; next_poll = None;
} }
_ => {} _ => {}

View file

@ -1,29 +1,33 @@
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::IpAddr;
use std::time::Duration;
use anyhow::Context;
use async_trait::async_trait;
use bytes::Bytes;
use smoltcp::iface::{InterfaceBuilder, SocketHandle};
use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer};
use smoltcp::wire::{IpAddress, IpCidr};
use crate::config::PortForwardConfig; use crate::config::PortForwardConfig;
use crate::events::Event; use crate::events::Event;
use crate::virtual_device::VirtualIpDevice; use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::{Bus, PortProtocol}; use crate::{Bus, PortProtocol};
use anyhow::Context;
use async_trait::async_trait;
use bytes::Bytes;
use smoltcp::{
iface::{Config, Interface, SocketHandle, SocketSet},
socket::udp::{self, UdpMetadata},
time::Instant,
wire::{HardwareAddress, IpAddress, IpCidr},
};
use std::{
collections::{HashMap, HashSet, VecDeque},
net::IpAddr,
time::Duration,
};
const MAX_PACKET: usize = 65536; const MAX_PACKET: usize = 65536;
pub struct UdpVirtualInterface { pub struct UdpVirtualInterface<'a> {
source_peer_ip: IpAddr, source_peer_ip: IpAddr,
port_forwards: Vec<PortForwardConfig>, port_forwards: Vec<PortForwardConfig>,
bus: Bus, bus: Bus,
sockets: SocketSet<'a>,
} }
impl UdpVirtualInterface { impl<'a> UdpVirtualInterface<'a> {
/// Initialize the parameters for a new virtual interface. /// Initialize the parameters for a new virtual interface.
/// Use the `poll_loop()` future to start the virtual interface poll loop. /// Use the `poll_loop()` future to start the virtual interface poll loop.
pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self { pub fn new(port_forwards: Vec<PortForwardConfig>, bus: Bus, source_peer_ip: IpAddr) -> Self {
@ -34,21 +38,24 @@ impl UdpVirtualInterface {
.collect(), .collect(),
source_peer_ip, source_peer_ip,
bus, bus,
sockets: SocketSet::new([]),
} }
} }
fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<UdpSocket<'static>> { fn new_server_socket(port_forward: PortForwardConfig) -> anyhow::Result<udp::Socket<'static>> {
static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; static mut UDP_SERVER_RX_META: [udp::PacketMetadata; 0] = [];
static mut UDP_SERVER_RX_DATA: [u8; 0] = []; static mut UDP_SERVER_RX_DATA: [u8; 0] = [];
static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; static mut UDP_SERVER_TX_META: [udp::PacketMetadata; 0] = [];
static mut UDP_SERVER_TX_DATA: [u8; 0] = []; static mut UDP_SERVER_TX_DATA: [u8; 0] = [];
let udp_rx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { let udp_rx_buffer =
udp::PacketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe {
&mut UDP_SERVER_RX_DATA[..] &mut UDP_SERVER_RX_DATA[..]
}); });
let udp_tx_buffer = UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { let udp_tx_buffer =
udp::PacketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe {
&mut UDP_SERVER_TX_DATA[..] &mut UDP_SERVER_TX_DATA[..]
}); });
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer);
socket socket
.bind(( .bind((
IpAddress::from(port_forward.destination.ip()), IpAddress::from(port_forward.destination.ip()),
@ -61,14 +68,14 @@ impl UdpVirtualInterface {
fn new_client_socket( fn new_client_socket(
source_peer_ip: IpAddr, source_peer_ip: IpAddr,
client_port: VirtualPort, client_port: VirtualPort,
) -> anyhow::Result<UdpSocket<'static>> { ) -> anyhow::Result<udp::Socket<'static>> {
let rx_meta = vec![UdpPacketMetadata::EMPTY; 10]; let rx_meta = vec![udp::PacketMetadata::EMPTY; 10];
let tx_meta = vec![UdpPacketMetadata::EMPTY; 10]; let tx_meta = vec![udp::PacketMetadata::EMPTY; 10];
let rx_data = vec![0u8; MAX_PACKET]; let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET]; let tx_data = vec![0u8; MAX_PACKET];
let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); let udp_rx_buffer = udp::PacketBuffer::new(rx_meta, rx_data);
let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); let udp_tx_buffer = udp::PacketBuffer::new(tx_meta, tx_data);
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer);
socket socket
.bind((IpAddress::from(source_peer_ip), client_port.num())) .bind((IpAddress::from(source_peer_ip), client_port.num()))
.with_context(|| "UDP virtual client failed to bind")?; .with_context(|| "UDP virtual client failed to bind")?;
@ -89,20 +96,31 @@ impl UdpVirtualInterface {
} }
#[async_trait] #[async_trait]
impl VirtualInterfacePoll for UdpVirtualInterface { impl<'a> VirtualInterfacePoll for UdpVirtualInterface<'a> {
async fn poll_loop(self, device: VirtualIpDevice) -> anyhow::Result<()> { async fn poll_loop(mut self, mut device: VirtualIpDevice) -> anyhow::Result<()> {
// Create CIDR block for source peer IP + each port forward IP // Create CIDR block for source peer IP + each port forward IP
let addresses = self.addresses(); let addresses = self.addresses();
let config = Config::new(HardwareAddress::Ip);
// Create virtual interface (contains smoltcp state machine) // Create virtual interface (contains smoltcp state machine)
let mut iface = InterfaceBuilder::new(device, vec![]) let mut iface = Interface::new(config, &mut device, Instant::now());
.ip_addrs(addresses) iface.update_ip_addrs(|ip_addrs| {
.finalize(); addresses.iter().for_each(|addr| {
ip_addrs.push(*addr).unwrap();
});
});
iface.set_any_ip(true);
// Maps virtual port to its client socket handle
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
// Create virtual server for each port forward // Create virtual server for each port forward
for port_forward in self.port_forwards.iter() { for port_forward in self.port_forwards.iter() {
let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?; let server_socket = UdpVirtualInterface::new_server_socket(*port_forward)?;
iface.add_socket(server_socket); let handle = self.sockets.add(server_socket);
let virtual_port =
VirtualPort::new(port_forward.destination.port(), port_forward.protocol);
port_client_handle_map.insert(virtual_port, handle);
} }
// The next time to poll the interface. Can be None for instant poll. // The next time to poll the interface. Can be None for instant poll.
@ -111,9 +129,6 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
// Bus endpoint to read events // Bus endpoint to read events
let mut endpoint = self.bus.new_endpoint(); let mut endpoint = self.bus.new_endpoint();
// Maps virtual port to its client socket handle
let mut port_client_handle_map: HashMap<VirtualPort, SocketHandle> = HashMap::new();
// Data packets to send from a virtual client // Data packets to send from a virtual client
let mut send_queue: HashMap<VirtualPort, VecDeque<(PortForwardConfig, Bytes)>> = let mut send_queue: HashMap<VirtualPort, VecDeque<(PortForwardConfig, Bytes)>> =
HashMap::new(); HashMap::new();
@ -127,16 +142,12 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
} => { } => {
let loop_start = smoltcp::time::Instant::now(); let loop_start = smoltcp::time::Instant::now();
match iface.poll(loop_start) { if iface.poll(loop_start, &mut device, &mut self.sockets) {
Ok(processed) if processed => { log::trace!("UDP virtual interface polled some packets to be processed");
trace!("UDP virtual interface polled some packets to be processed");
}
Err(e) => error!("UDP virtual interface poll error: {:?}", e),
_ => {}
} }
for (virtual_port, client_handle) in port_client_handle_map.iter() { for (virtual_port, client_handle) in port_client_handle_map.iter() {
let client_socket = iface.get_socket::<UdpSocket>(*client_handle); let client_socket = self.sockets.get_mut::<udp::Socket>(*client_handle);
if client_socket.can_send() { if client_socket.can_send() {
if let Some(send_queue) = send_queue.get_mut(virtual_port) { if let Some(send_queue) = send_queue.get_mut(virtual_port) {
let to_transfer = send_queue.pop_front(); let to_transfer = send_queue.pop_front();
@ -144,7 +155,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
client_socket client_socket
.send_slice( .send_slice(
&data, &data,
(IpAddress::from(port_forward.destination.ip()), port_forward.destination.port()).into(), UdpMetadata::from(port_forward.destination),
) )
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
error!( error!(
@ -172,7 +183,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
} }
// The virtual interface determines the next time to poll (this is to reduce unnecessary polls) // The virtual interface determines the next time to poll (this is to reduce unnecessary polls)
next_poll = match iface.poll_delay(loop_start) { next_poll = match iface.poll_delay(loop_start, &self.sockets) {
Some(smoltcp::time::Duration::ZERO) => None, Some(smoltcp::time::Duration::ZERO) => None,
Some(delay) => { Some(delay) => {
trace!("UDP Virtual interface delayed next poll by {}", delay); trace!("UDP Virtual interface delayed next poll by {}", delay);
@ -190,7 +201,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
} else { } else {
// Client socket does not exist // Client socket does not exist
let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?; let client_socket = UdpVirtualInterface::new_client_socket(self.source_peer_ip, virtual_port)?;
let client_handle = iface.add_socket(client_socket); let client_handle = self.sockets.add(client_socket);
// Add handle to map // Add handle to map
port_client_handle_map.insert(virtual_port, client_handle); port_client_handle_map.insert(virtual_port, client_handle);
@ -198,7 +209,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface {
} }
next_poll = None; next_poll = None;
} }
Event::VirtualDeviceFed(protocol) if protocol == PortProtocol::Udp => { Event::VirtualDeviceFed(PortProtocol::Udp) => {
next_poll = None; next_poll = None;
} }
_ => {} _ => {}

View file

@ -237,7 +237,7 @@ impl WireGuardTunnel {
.ok() .ok()
// Only care if the packet is destined for this tunnel // Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.and_then(|packet| match packet.protocol() { .and_then(|packet| match packet.next_header() {
IpProtocol::Tcp => Some(PortProtocol::Tcp), IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp), IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route // Unrecognized protocol, so we cannot determine where to route