From 76b6a6e346cbc858ca77fad6b7618ebad173b523 Mon Sep 17 00:00:00 2001 From: Aram Peres <6775216+aramperes@users.noreply.github.com> Date: Thu, 12 Jan 2023 01:40:04 -0500 Subject: [PATCH] Use bytes --- Cargo.lock | 1 + Cargo.toml | 1 + src/events.rs | 9 +++++---- src/tunnel/tcp.rs | 19 ++++++++++--------- src/tunnel/udp.rs | 7 ++++--- src/virtual_device.rs | 26 ++++++++++++++++++-------- src/virtual_iface/tcp.rs | 25 ++++++++++++++----------- src/virtual_iface/udp.rs | 13 +++++++------ src/wg.rs | 2 +- 9 files changed, 61 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b9b7c16..d795033 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -437,6 +437,7 @@ dependencies = [ "async-recursion", "async-trait", "boringtun", + "bytes", "clap", "futures", "log", diff --git a/Cargo.toml b/Cargo.toml index d38288f..40f2276 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ priority-queue = "1.2.0" smoltcp = { version = "0.8.0", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-udp", "socket-tcp"] } # forward boringtuns tracing events to log tracing = { version = "0.1.36", default-features = false, features = ["log"] } +bytes = "1" # bin-only dependencies clap = { version = "2.33", default-features = false, features = ["suggestions"], optional = true } diff --git a/src/events.rs b/src/events.rs index cd76a49..d6582ce 100644 --- a/src/events.rs +++ b/src/events.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use std::fmt::{Display, Formatter}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; @@ -16,13 +17,13 @@ pub enum Event { /// A connection was dropped from the pool and should be closed in all interfaces. ClientConnectionDropped(VirtualPort), /// Data received by the local server that should be sent to the virtual server. - LocalData(PortForwardConfig, VirtualPort, Vec), + LocalData(PortForwardConfig, VirtualPort, Bytes), /// Data received by the remote server that should be sent to the local client. - RemoteData(VirtualPort, Vec), + RemoteData(VirtualPort, Bytes), /// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device. - InboundInternetPacket(PortProtocol, Vec), + InboundInternetPacket(PortProtocol, Bytes), /// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device. - OutboundInternetPacket(Vec), + OutboundInternetPacket(Bytes), /// Notifies that a virtual device read an IP packet. VirtualDeviceFed(PortProtocol), } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index 557f5ad..b5e1ec5 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -1,17 +1,18 @@ -use crate::config::{PortForwardConfig, PortProtocol}; -use crate::virtual_iface::VirtualPort; -use anyhow::Context; use std::collections::VecDeque; -use std::sync::Arc; -use tokio::net::{TcpListener, TcpStream}; - use std::ops::Range; +use std::sync::Arc; use std::time::Duration; -use crate::events::{Bus, Event}; +use anyhow::Context; +use bytes::BytesMut; use rand::seq::SliceRandom; use rand::thread_rng; use tokio::io::AsyncWriteExt; +use tokio::net::{TcpListener, TcpStream}; + +use crate::config::{PortForwardConfig, PortProtocol}; +use crate::events::{Bus, Event}; +use crate::virtual_iface::VirtualPort; const MAX_PACKET: usize = 65536; const MIN_PORT: u16 = 1000; @@ -81,7 +82,7 @@ async fn handle_tcp_proxy_connection( let mut endpoint = bus.new_endpoint(); endpoint.send(Event::ClientConnectionInitiated(port_forward, virtual_port)); - let mut buffer = Vec::with_capacity(MAX_PACKET); + let mut buffer = BytesMut::with_capacity(MAX_PACKET); loop { tokio::select! { readable_result = socket.readable() => { @@ -90,7 +91,7 @@ async fn handle_tcp_proxy_connection( match socket.try_read_buf(&mut buffer) { Ok(size) if size > 0 => { let data = Vec::from(&buffer[..size]); - endpoint.send(Event::LocalData(port_forward, virtual_port, data)); + endpoint.send(Event::LocalData(port_forward, virtual_port, data.into())); // Reset buffer buffer.clear(); } diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 1d25914..32fef15 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -4,14 +4,15 @@ use std::ops::Range; use std::sync::Arc; use std::time::Instant; -use crate::events::{Bus, Event}; use anyhow::Context; +use bytes::Bytes; use priority_queue::double_priority_queue::DoublePriorityQueue; use rand::seq::SliceRandom; use rand::thread_rng; use tokio::net::UdpSocket; use crate::config::{PortForwardConfig, PortProtocol}; +use crate::events::{Bus, Event}; use crate::virtual_iface::VirtualPort; const MAX_PACKET: usize = 65536; @@ -98,7 +99,7 @@ async fn next_udp_datagram( socket: &UdpSocket, buffer: &mut [u8], port_pool: UdpPortPool, -) -> anyhow::Result)>> { +) -> anyhow::Result> { let (size, peer_addr) = socket .recv_from(buffer) .await @@ -126,7 +127,7 @@ async fn next_udp_datagram( port_pool.update_last_transmit(port).await; let data = buffer[..size].to_vec(); - Ok(Some((port, data))) + Ok(Some((port, data.into()))) } /// A pool of virtual ports available for TCP connections. diff --git a/src/virtual_device.rs b/src/virtual_device.rs index e0e7e4d..0054690 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,10 +1,13 @@ +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; + +use bytes::{BufMut, Bytes, BytesMut}; +use smoltcp::phy::{Device, DeviceCapabilities, Medium}; +use smoltcp::time::Instant; + use crate::config::PortProtocol; use crate::events::{BusSender, Event}; use crate::Bus; -use smoltcp::phy::{Device, DeviceCapabilities, Medium}; -use smoltcp::time::Instant; -use std::collections::VecDeque; -use std::sync::{Arc, Mutex}; /// A virtual device that processes IP packets through smoltcp and WireGuard. pub struct VirtualIpDevice { @@ -13,7 +16,7 @@ pub struct VirtualIpDevice { /// Channel receiver for received IP packets. bus_sender: BusSender, /// Local queue for packets received from the bus that need to go through the smoltcp interface. - process_queue: Arc>>>, + process_queue: Arc>>, } impl VirtualIpDevice { @@ -63,7 +66,13 @@ impl<'a> Device<'a> for VirtualIpDevice { }; match next { Some(buffer) => Some(( - Self::RxToken { buffer }, + Self::RxToken { + buffer: { + let mut buf = BytesMut::new(); + buf.put(buffer); + buf + }, + }, Self::TxToken { sender: self.bus_sender.clone(), }, @@ -88,7 +97,7 @@ impl<'a> Device<'a> for VirtualIpDevice { #[doc(hidden)] pub struct RxToken { - buffer: Vec, + buffer: BytesMut, } impl smoltcp::phy::RxToken for RxToken { @@ -113,7 +122,8 @@ impl smoltcp::phy::TxToken for TxToken { let mut buffer = Vec::new(); buffer.resize(len, 0); let result = f(&mut buffer); - self.sender.send(Event::OutboundInternetPacket(buffer)); + self.sender + .send(Event::OutboundInternetPacket(buffer.into())); result } } diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 28eacac..32706a0 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -1,16 +1,19 @@ +use std::collections::{HashMap, HashSet, VecDeque}; +use std::net::IpAddr; +use std::time::Duration; + +use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; +use smoltcp::iface::{InterfaceBuilder, SocketHandle}; +use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState}; +use smoltcp::wire::{IpAddress, IpCidr}; + use crate::config::{PortForwardConfig, PortProtocol}; use crate::events::Event; use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::Bus; -use anyhow::Context; -use async_trait::async_trait; -use smoltcp::iface::{InterfaceBuilder, SocketHandle}; -use smoltcp::socket::{TcpSocket, TcpSocketBuffer, TcpState}; -use smoltcp::wire::{IpAddress, IpCidr}; -use std::collections::{HashMap, HashSet, VecDeque}; -use std::net::IpAddr; -use std::time::Duration; const MAX_PACKET: usize = 65536; @@ -102,7 +105,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { let mut port_client_handle_map: HashMap = HashMap::new(); // Data packets to send from a virtual client - let mut send_queue: HashMap>> = HashMap::new(); + let mut send_queue: HashMap> = HashMap::new(); loop { tokio::select! { @@ -147,7 +150,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { if sent < total { // Sometimes only a subset is sent, so the rest needs to be sent on the next poll let tx_extra = Vec::from(&to_transfer_slice[sent..total]); - send_queue.push_front(tx_extra); + send_queue.push_front(tx_extra.into()); } } Err(e) => { @@ -162,7 +165,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } } if client_socket.can_recv() { - match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { + match client_socket.recv(|buffer| (buffer.len(), Bytes::from(buffer.to_vec()))) { Ok(data) => { debug!("[{}] Received {} bytes from virtual server", virtual_port, data.len()); if !data.is_empty() { diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index d939132..be63071 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,18 +1,19 @@ -use anyhow::Context; use std::collections::{HashMap, HashSet, VecDeque}; use std::net::IpAddr; +use std::time::Duration; -use crate::events::Event; -use crate::{Bus, PortProtocol}; +use anyhow::Context; use async_trait::async_trait; +use bytes::Bytes; use smoltcp::iface::{InterfaceBuilder, SocketHandle}; use smoltcp::socket::{UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; use smoltcp::wire::{IpAddress, IpCidr}; -use std::time::Duration; use crate::config::PortForwardConfig; +use crate::events::Event; use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; +use crate::{Bus, PortProtocol}; const MAX_PACKET: usize = 65536; @@ -114,7 +115,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { let mut port_client_handle_map: HashMap = HashMap::new(); // Data packets to send from a virtual client - let mut send_queue: HashMap)>> = + let mut send_queue: HashMap> = HashMap::new(); loop { @@ -158,7 +159,7 @@ impl VirtualInterfacePoll for UdpVirtualInterface { match client_socket.recv() { Ok((data, _peer)) => { if !data.is_empty() { - endpoint.send(Event::RemoteData(*virtual_port, data.to_vec())); + endpoint.send(Event::RemoteData(*virtual_port, data.to_vec().into())); } } Err(e) => { diff --git a/src/wg.rs b/src/wg.rs index 18ef7d5..0646606 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -209,7 +209,7 @@ impl WireGuardTunnel { trace_ip_packet("Received IP packet", packet); if let Some(proto) = self.route_protocol(packet) { - endpoint.send(Event::InboundInternetPacket(proto, packet.into())); + endpoint.send(Event::InboundInternetPacket(proto, packet.to_vec().into())); } } _ => {}