From 0da6fa51de35ef1d323673e4273ea9eeb545d707 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 26 Oct 2021 00:38:22 -0400 Subject: [PATCH] udp: use tokio select instead of 1ms loop --- src/virtual_iface/udp.rs | 166 ++++++++++++++++++++++----------------- 1 file changed, 92 insertions(+), 74 deletions(-) diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 382d7c3..5b275de 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -93,91 +93,109 @@ impl VirtualInterfacePoll for UdpVirtualInterface { // A map of virtual port to client socket. let mut client_sockets: HashMap = HashMap::new(); + // The next instant required to poll the virtual interface + // None means "immediate poll required". + let mut next_poll: Option = None; + loop { - let loop_start = smoltcp::time::Instant::now(); let wg = self.wg.clone(); + tokio::select! { + // Wait the recommended amount of time by smoltcp, and poll again. + _ = match next_poll { + None => tokio::time::sleep(Duration::ZERO), + Some(until) => tokio::time::sleep_until(until) + } => { + let loop_start = smoltcp::time::Instant::now(); - match virtual_interface.poll(&mut socket_set, loop_start) { - Ok(processed) if processed => { - trace!("UDP virtual interface polled some packets to be processed"); + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + trace!("UDP virtual interface polled some packets to be processed"); + } + Err(e) => error!("UDP virtual interface poll error: {:?}", e), + _ => {} + } + + // Loop through each client socket and check if there is any data to send back + // to the real client. + for (virtual_port, client_socket_handle) in client_sockets.iter() { + let mut client_socket = socket_set.get::(*client_socket_handle); + match client_socket.recv() { + Ok((data, _peer)) => { + // Send the data back to the real client using MPSC channel + self.data_to_real_client_tx + .send((*virtual_port, data.to_vec())) + .await + .unwrap_or_else(|e| { + error!( + "[{}] Failed to dispatch data from virtual client to real client: {:?}", + virtual_port, e + ); + }); + } + Err(smoltcp::Error::Exhausted) => {} + Err(e) => { + error!( + "[{}] Failed to read from virtual client socket: {:?}", + virtual_port, e + ); + } + } + } + + next_poll = match virtual_interface.poll_delay(&socket_set, loop_start) { + Some(smoltcp::time::Duration::ZERO) => None, + Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())), + None => None, + } } - Err(e) => error!("UDP virtual interface poll error: {:?}", e), - _ => {} - } - - // Loop through each client socket and check if there is any data to send back - // to the real client. - for (virtual_port, client_socket_handle) in client_sockets.iter() { - let mut client_socket = socket_set.get::(*client_socket_handle); - match client_socket.recv() { - Ok((data, _peer)) => { - // Send the data back to the real client using MPSC channel - self.data_to_real_client_tx - .send((*virtual_port, data.to_vec())) - .await + // Wait for data to be received from the real client + data_recv_result = data_to_virtual_server_rx.recv() => { + if let Some((client_port, data)) = data_recv_result { + // Register the socket in WireGuard Tunnel (overrides any previous registration as well) + wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) .unwrap_or_else(|e| { error!( - "[{}] Failed to dispatch data from virtual client to real client: {:?}", - virtual_port, e + "[{}] Failed to register UDP socket in WireGuard tunnel: {:?}", + client_port, e + ); + }); + + let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { + let rx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let tx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + + socket + .bind((IpAddress::from(wg.source_peer_ip), client_port.0)) + .unwrap_or_else(|e| { + error!( + "[{}] UDP virtual client socket failed to bind: {:?}", + client_port, e + ); + }); + + socket_set.add(socket) + }); + + let mut client_socket = socket_set.get::(*client_socket_handle); + client_socket + .send_slice( + &data, + (IpAddress::from(destination.ip()), destination.port()).into(), + ) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to send data to virtual server: {:?}", + client_port, e ); }); } - Err(smoltcp::Error::Exhausted) => {} - Err(e) => { - error!( - "[{}] Failed to read from virtual client socket: {:?}", - virtual_port, e - ); - } } } - - if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { - // Register the socket in WireGuard Tunnel (overrides any previous registration as well) - wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to register UDP socket in WireGuard tunnel: {:?}", - client_port, e - ); - }); - - let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { - let rx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; - let tx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; - let rx_data = vec![0u8; MAX_PACKET]; - let tx_data = vec![0u8; MAX_PACKET]; - let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); - let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - - socket - .bind((IpAddress::from(wg.source_peer_ip), client_port.0)) - .unwrap_or_else(|e| { - error!( - "[{}] UDP virtual client socket failed to bind: {:?}", - client_port, e - ); - }); - - socket_set.add(socket) - }); - - let mut client_socket = socket_set.get::(*client_socket_handle); - client_socket - .send_slice( - &data, - (IpAddress::from(destination.ip()), destination.port()).into(), - ) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to send data to virtual server: {:?}", - client_port, e - ); - }); - } - - tokio::time::sleep(Duration::from_millis(1)).await; } } }