Getting there

This commit is contained in:
Aram 🍐 2021-10-13 21:52:06 -04:00
parent 8f5f8670af
commit f65a1e4e89
5 changed files with 299 additions and 42 deletions

View file

@ -1,18 +1,22 @@
#[macro_use]
extern crate log;
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
use smoltcp::wire::{IpAddress, IpCidr};
use tokio::io::Interest;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::error::TryRecvError;
use crate::config::Config;
use crate::port_pool::PortPool;
use crate::virtual_device::VirtualIpDevice;
use crate::wg::WireGuardTunnel;
pub mod client;
@ -51,12 +55,21 @@ async fn main() -> anyhow::Result<()> {
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
);
tcp_proxy_server(config.source_addr.clone(), port_pool.clone(), wg).await
tcp_proxy_server(
config.source_addr,
config.source_peer_ip,
config.dest_addr,
port_pool.clone(),
wg,
)
.await
}
/// Starts the server that listens on TCP connections.
async fn tcp_proxy_server(
listen_addr: SocketAddr,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
port_pool: Arc<PortPool>,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
@ -90,7 +103,9 @@ async fn tcp_proxy_server(
tokio::spawn(async move {
let port_pool = Arc::clone(&port_pool);
let result = handle_tcp_proxy_connection(socket, virtual_port, wg).await;
let result =
handle_tcp_proxy_connection(socket, virtual_port, source_peer_ip, dest_addr, wg)
.await;
if let Err(e) = result {
error!(
@ -111,6 +126,8 @@ async fn tcp_proxy_server(
async fn handle_tcp_proxy_connection(
socket: TcpStream,
virtual_port: u16,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
// Abort signal for stopping the Virtual Interface
@ -121,11 +138,22 @@ async fn handle_tcp_proxy_connection(
let (data_to_real_client_tx, mut data_to_real_client_rx) =
tokio::sync::mpsc::channel(1_000_000);
let (data_to_real_server_tx, data_to_real_server_rx) = tokio::sync::mpsc::channel(1_000_000);
// Spawn virtual interface
{
let abort = abort.clone();
tokio::spawn(async move {
virtual_tcp_interface(virtual_port, wg, abort, data_to_real_client_tx).await
virtual_tcp_interface(
virtual_port,
source_peer_ip,
dest_addr,
wg,
abort,
data_to_real_client_tx,
data_to_real_server_rx,
)
.await
});
}
@ -149,7 +177,15 @@ async fn handle_tcp_proxy_connection(
"[{}] Read {} bytes of TCP data from real client",
virtual_port, size
);
trace!("[{}] Read: {:?}", virtual_port, data);
match data_to_real_server_tx.send(data.to_vec()).await {
Err(e) => {
error!(
"[{}] Failed to dispatch data to virtual interface: {:?}",
virtual_port, e
);
}
_ => {}
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
@ -210,37 +246,135 @@ async fn handle_tcp_proxy_connection(
async fn virtual_tcp_interface(
virtual_port: u16,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
wg: Arc<WireGuardTunnel>,
abort: Arc<AtomicBool>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
mut data_to_real_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
) -> anyhow::Result<()> {
// Create a device and interface to simulate IP packets
// In essence:
// * TCP packets received from the 'real' client are 'sent' via the 'virtual client'
// * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client'
// * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel
// * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface'
// * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded)
// * The TCP data read by the 'virtual client' is sent to the 'real' TCP client
// Consumer for IP packets to send through the virtual interface
// Initialize the interface
let device = VirtualIpDevice::new(wg);
let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(dest_addr.ip()), 32),
])
.any_ip(true)
.finalize();
// Server socket: this is a placeholder for the interface to route new connections to.
// TODO: Determine if we even need buffers here.
let server_socket: anyhow::Result<TcpSocket> = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.listen((IpAddress::from(dest_addr.ip()), dest_addr.port()))
.with_context(|| "Virtual server socket failed to listen")?;
Ok(socket)
};
let client_socket: anyhow::Result<TcpSocket> = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.connect(
(IpAddress::from(dest_addr.ip()), dest_addr.port()),
(IpAddress::from(source_peer_ip), virtual_port),
)
.with_context(|| "Virtual server socket failed to listen")?;
Ok(socket)
};
// Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server.
let mut socket_set_entries: [_; 2] = Default::default();
let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
let _server_handle = socket_set.add(server_socket?);
let client_handle = socket_set.add(client_socket?);
loop {
let loop_start = smoltcp::time::Instant::now();
if abort.load(Ordering::Relaxed) {
break;
}
// Test START
tokio::time::sleep(Duration::from_millis(1000)).await;
match data_to_real_client_tx.send(b"pong".to_vec()).await {
Ok(_) => {
trace!("Wrote stuff in the data_to_real_client_tx")
}
Err(e) => {
match virtual_interface.poll(&mut socket_set, loop_start) {
Ok(processed) if processed => {
trace!(
"[{}] Virtual interface failed to dispatch data to parent task: {:?}",
virtual_port,
e
"[{}] Virtual interface polled some packets to be processed",
virtual_port
);
}
Err(e) => {
error!("[{}] Virtual interface poll error: {:?}", virtual_port, e);
}
_ => {}
}
// Test END
{
let mut client_socket = socket_set.get::<TcpSocket>(client_handle);
if client_socket.can_recv() {
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
Ok(data) => {
// Send it to the real client
match data_to_real_client_tx.send(data).await {
Err(e) => {
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e);
}
_ => {}
}
}
Err(e) => {
error!(
"[{}] Failed to read from virtual client socket: {:?}",
virtual_port, e
);
}
}
}
if client_socket.can_send() {
// Check if there is anything to send
match data_to_real_server_rx.try_recv() {
Ok(data) => match client_socket.send_slice(&data) {
Err(e) => {
error!(
"[{}] Failed to send slice via virtual client socket: {:?}",
virtual_port, e
);
}
_ => {}
},
Err(_) => {}
}
}
}
match virtual_interface.poll_delay(&socket_set, loop_start) {
None => tokio::time::sleep(Duration::from_millis(1)).await,
Some(smoltcp::time::Duration::ZERO) => {}
Some(delay) => tokio::time::sleep(Duration::from_millis(delay.millis())).await,
};
}
trace!("[{}] Virtual interface task terminated", virtual_port);
Ok(())

View file

@ -1,19 +1,17 @@
use crate::wg::WireGuardTunnel;
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant;
use std::sync::Arc;
#[derive(Clone)]
pub struct VirtualIpDevice {
/// Channel for packets sent by the interface.
ip_tx: crossbeam_channel::Sender<Vec<u8>>,
ip_rx: crossbeam_channel::Receiver<Vec<u8>>,
/// Tunnel to send IP packets to.
wg: Arc<WireGuardTunnel>,
}
impl VirtualIpDevice {
pub fn new(
ip_tx: crossbeam_channel::Sender<Vec<u8>>,
ip_rx: crossbeam_channel::Receiver<Vec<u8>>,
) -> Self {
Self { ip_tx, ip_rx }
pub fn new(wg: Arc<WireGuardTunnel>) -> Self {
Self { wg }
}
}
@ -22,22 +20,21 @@ impl<'a> Device<'a> for VirtualIpDevice {
type TxToken = TxToken;
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
if !self.ip_rx.is_empty() {
let buffer = self.ip_rx.recv().expect("failed to read ip_rx");
Some((
RxToken { buffer },
TxToken {
ip_tx: self.ip_tx.clone(),
let mut consumer = self.wg.subscribe();
match consumer.try_recv() {
Ok(buffer) => Some((
Self::RxToken { buffer },
Self::TxToken {
wg: self.wg.clone(),
},
))
} else {
None
)),
Err(_) => None,
}
}
fn transmit(&'a mut self) -> Option<Self::TxToken> {
Some(TxToken {
ip_tx: self.ip_tx.clone(),
wg: self.wg.clone(),
})
}
@ -65,10 +62,10 @@ impl smoltcp::phy::RxToken for RxToken {
#[doc(hidden)]
pub struct TxToken {
ip_tx: crossbeam_channel::Sender<Vec<u8>>,
wg: Arc<WireGuardTunnel>,
}
impl<'a> smoltcp::phy::TxToken for TxToken {
impl smoltcp::phy::TxToken for TxToken {
fn consume<R, F>(self, _timestamp: Instant, len: usize, f: F) -> smoltcp::Result<R>
where
F: FnOnce(&mut [u8]) -> smoltcp::Result<R>,
@ -76,9 +73,12 @@ impl<'a> smoltcp::phy::TxToken for TxToken {
let mut buffer = Vec::new();
buffer.resize(len, 0);
let result = f(&mut buffer);
self.ip_tx
.send(buffer.clone())
.expect("failed to send to ip_tx");
match futures::executor::block_on(self.wg.send_ip_packet(&buffer)) {
Ok(_) => {}
Err(e) => {
error!("Failed to send IP packet to WireGuard endpoint: {:?}", e);
}
}
result
}
}

View file

@ -21,6 +21,7 @@ pub struct WireGuardTunnel {
endpoint: SocketAddr,
/// Broadcast sender for received IP packets.
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>,
ip_broadcast_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
}
impl WireGuardTunnel {
@ -31,13 +32,15 @@ impl WireGuardTunnel {
.await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let endpoint = config.endpoint_addr;
let (ip_broadcast_tx, _) = tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
let (ip_broadcast_tx, ip_broadcast_rx) =
tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
Ok(Self {
peer,
udp,
endpoint,
ip_broadcast_tx,
ip_broadcast_rx,
})
}
@ -178,7 +181,7 @@ impl WireGuardTunnel {
}
Err(e) => {
error!(
"Failed to broadcast received IP packet to recipients: {:?}",
"Failed to broadcast received IP packet to recipients: {}",
e
);
}