One interface/server for all. Doesn't quite work for repeats.

This commit is contained in:
Aram 🍐 2021-10-13 03:19:56 -04:00
parent da8b216fb0
commit f2e6ec1e3f
5 changed files with 335 additions and 203 deletions

9
src/client.rs Normal file
View file

@ -0,0 +1,9 @@
#[derive(Clone)]
pub struct ProxyClient {
/// Unique identifier for this client (used as a port number in the virtual interface).
pub virtual_port: u16,
/// For sending data to the client.
pub data_tx: crossbeam_channel::Sender<Vec<u8>>,
/// For receiving data from the client.
pub data_rx: crossbeam_channel::Receiver<Vec<u8>>,
}

View file

@ -11,23 +11,24 @@ use std::sync::{Arc, Barrier, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use crate::client::ProxyClient;
use anyhow::Context;
use boringtun::crypto::{X25519PublicKey, X25519SecretKey};
use boringtun::device::peer::Peer;
use boringtun::noise::{Tunn, TunnResult};
use clap::{App, Arg};
use crossbeam_channel::{Receiver, RecvError, Sender};
use dashmap::DashMap;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::phy::ChecksumCapabilities;
use smoltcp::socket::{SocketRef, SocketSet, TcpSocket, TcpSocketBuffer};
use smoltcp::time::Instant;
use smoltcp::wire::{
IpAddress, IpCidr, IpRepr, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr,
IpAddress, IpCidr, IpRepr, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, PrettyPrinter,
};
use crate::config::Config;
use crate::virtual_device::VirtualIpDevice;
mod client;
mod config;
mod virtual_device;
@ -73,6 +74,258 @@ fn main() -> anyhow::Result<()> {
let endpoint_socket =
Arc::new(UdpSocket::bind("0.0.0.0:0").with_context(|| "Failed to bind endpoint socket")?);
let (new_client_tx, new_client_rx) = crossbeam_channel::unbounded::<ProxyClient>();
let (dead_client_tx, dead_client_rx) = crossbeam_channel::unbounded::<u16>();
// tx/rx for IP packets the interface exchanged that should be filtered/routed
let (send_to_ip_filter_tx, send_to_ip_filter_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
// Virtual interface thread
{
thread::spawn(move || {
// Virtual device: generated IP packets will be send to ip_tx, and IP packets that should be polled should be sent to given ip_rx.
let virtual_device =
VirtualIpDevice::new(send_to_ip_filter_tx, send_to_virtual_interface_rx.clone());
// Create a virtual interface that will generate the IP packets for us
let mut virtual_interface = InterfaceBuilder::new(virtual_device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(dest_addr_ip), 32),
])
.any_ip(true)
.finalize();
// Server socket: this is a placeholder for the interface to route new connections to.
// TODO: Determine if we even need buffers here.
let server_socket = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.listen((IpAddress::from(dest_addr_ip), dest_addr_port))
.expect("Virtual server socket failed to listen");
socket
};
// Socket set: there is always 1 TCP socket for the server, and the rest are client sockets created over time.
let socket_set_entries = Vec::new();
let mut socket_set = SocketSet::new(socket_set_entries);
socket_set.add(server_socket);
// Gate socket_set behind RwLock so we can add clients in the background
let socket_set = Arc::new(RwLock::new(socket_set));
let socket_set_1 = socket_set.clone();
let socket_set_2 = socket_set.clone();
let client_port_to_handle = Arc::new(DashMap::new());
let client_port_to_handle_1 = client_port_to_handle.clone();
let client_handle_to_client = Arc::new(DashMap::new());
let client_handle_to_client_1 = client_handle_to_client.clone();
let client_handle_to_client_2 = client_handle_to_client.clone();
// Checks if there are new clients to initialize, and adds them to the socket_set
thread::spawn(move || {
let socket_set = socket_set_1;
let client_handle_to_client = client_handle_to_client_1;
loop {
let client = new_client_rx.recv().expect("failed to read new_client_rx");
// Create a virtual client socket for the client
let client_socket = {
static mut TCP_CLIENT_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_CLIENT_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer =
TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] });
let tcp_tx_buffer =
TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.connect(
(IpAddress::from(dest_addr_ip), dest_addr_port),
(IpAddress::from(source_peer_ip), client.virtual_port),
)
.expect("failed to connect virtual client");
socket
};
// Add to socket set: this makes the ip/port combination routable in the interface, so that IP packets
// received from WG actually go somewhere.
let mut socket_set = socket_set
.write()
.expect("failed to acquire lock on socket_set to add new client");
let client_handle = socket_set.add(client_socket);
// Map the client handle by port so we can look it up later
client_port_to_handle.insert(client.virtual_port, client_handle);
client_handle_to_client.insert(client_handle, client);
}
});
// Checks if there are clients that disconnected, and removes them from the socket_set
thread::spawn(move || {
let dead_client_rx = dead_client_rx.clone();
let socket_set = socket_set_2;
let client_handle_to_client = client_handle_to_client_2;
let client_port_to_handle = client_port_to_handle_1;
loop {
let client_port = dead_client_rx
.recv()
.expect("failed to read dead_client_rx");
// Get handle, if any
let handle = client_port_to_handle.remove(&client_port);
// Remove handle from socket set and from map (handle -> client def)
if let Some((_, handle)) = handle {
// Remove socket set
let mut socket_set = socket_set
.write()
.expect("failed to acquire lock on socket_set to add new client");
socket_set.remove(handle);
client_handle_to_client.remove(&handle);
debug!("Removed client from socket set: vport={}", client_port);
}
}
});
loop {
let loop_start = Instant::now();
// Poll virtual interface
// Note: minimize lock time on socket set so new clients can fit in
{
let mut socket_set = socket_set
.write()
.expect("failed to acquire lock on socket_set to poll interface");
match virtual_interface.poll(&mut socket_set, loop_start) {
Ok(processed) if processed => {
debug!("Virtual interface polled and processed some packets");
}
Err(e) => {
error!("Virtual interface poll error: {}", e);
}
_ => {}
}
}
{
let mut socket_set = socket_set
.write()
.expect("failed to acquire lock on socket_set to get client socket");
// Process packets for each client
for x in client_handle_to_client.iter() {
let client_handle = x.key();
let client_def = x.value();
let mut client_socket: SocketRef<TcpSocket> =
socket_set.get(*client_handle);
// Send data received from the real client as the virtual client
if client_socket.can_send() {
while !client_def.data_rx.is_empty() {
let to_send = client_def
.data_rx
.recv()
.expect("failed to read from client data_rx channel");
client_socket.send_slice(&to_send).expect("virtual client failed to send data as received from data_rx channel");
}
}
// Send data received by the virtual client to the real client
if client_socket.can_recv() {
let data = client_socket
.recv(|b| (b.len(), b.to_vec()))
.expect("virtual client failed to recv");
client_def
.data_tx
.send(data)
.expect("failed to send data to client data_tx channel");
}
}
}
// Use poll_delay to know when is the next time to poll.
{
let socket_set = socket_set
.read()
.expect("failed to acquire read lock on socket_set to poll_delay");
match virtual_interface.poll_delay(&socket_set, loop_start) {
Some(smoltcp::time::Duration::ZERO) => {}
Some(delay) => {
thread::sleep(std::time::Duration::from_millis(delay.millis()))
}
_ => thread::sleep(std::time::Duration::from_millis(1)),
}
}
}
});
}
// Packet routing thread
// Filters packets sent by the virtual interface, so that only the ones that should be sent
// to the real server are.
thread::spawn(move || {
loop {
let recv = send_to_ip_filter_rx
.recv()
.expect("failed to read send_to_ip_filter_rx channel");
let src_addr: IpAddr = match IpVersion::of_packet(&recv) {
Ok(v) => match v {
IpVersion::Ipv4 => {
match Ipv4Repr::parse(
&Ipv4Packet::new_unchecked(&recv),
&ChecksumCapabilities::ignored(),
) {
Ok(packet) => Ipv4Addr::from(packet.src_addr).into(),
Err(e) => {
error!("Unable to determine source IPv4 from packet: {}", e);
continue;
}
}
}
IpVersion::Ipv6 => match Ipv6Repr::parse(&Ipv6Packet::new_unchecked(&recv)) {
Ok(packet) => Ipv6Addr::from(packet.src_addr).into(),
Err(e) => {
error!("Unable to determine source IPv6 from packet: {}", e);
continue;
}
},
_ => {
error!("Unable to determine IP version from packet: unspecified",);
continue;
}
},
Err(e) => {
error!("Unable to determine IP version from packet: {}", e);
continue;
}
};
if src_addr == source_peer_ip {
debug!(
"IP packet: {} bytes from {} to send to WG",
recv.len(),
src_addr
);
// Add to queue to be encapsulated and sent by other thread
send_to_real_server_tx
.send(recv)
.expect("failed to write to send_to_real_server_tx channel");
}
}
});
// Thread that encapsulates and sends WG packets
{
let peer = peer.clone();
@ -88,6 +341,7 @@ fn main() -> anyhow::Result<()> {
endpoint_socket
.send_to(packet, endpoint_addr)
.expect("failed to send packet to wg endpoint");
debug!("Sent {} bytes through WG (encryped)", packet.len());
}
TunnResult::Err(e) => {
error!("Failed to encapsulate: {:?}", e);
@ -109,6 +363,8 @@ fn main() -> anyhow::Result<()> {
}
});
}
// Thread that decapulates WG IP packets and feeds them to the interface
{
let peer = peer.clone();
let endpoint_socket = endpoint_socket.clone();
@ -147,13 +403,35 @@ fn main() -> anyhow::Result<()> {
}
}
TunnResult::WriteToTunnelV4(packet, _) => {
debug!("Got {} bytes to send back to client", packet.len());
debug!(
"Got {} bytes to send back to virtual interface",
packet.len()
);
// For debugging purposes: parse packet
{
match IpVersion::of_packet(&packet) {
Ok(IpVersion::Ipv4) => trace!(
"IPv4 packet received: {}",
PrettyPrinter::<Ipv4Packet<&mut [u8]>>::new("", &packet)
),
Ok(IpVersion::Ipv6) => trace!(
"IPv6 packet received: {}",
PrettyPrinter::<Ipv6Packet<&mut [u8]>>::new("", &packet)
),
_ => {}
}
}
send_to_virtual_interface_tx
.send(packet.to_vec())
.expect("failed to queue received wg packet");
}
TunnResult::WriteToTunnelV6(packet, _) => {
debug!("Got {} bytes to send back to client", packet.len());
debug!(
"Got {} bytes to send back to virtual interface",
packet.len()
);
send_to_virtual_interface_tx
.send(packet.to_vec())
.expect("failed to queue received wg packet");
@ -188,115 +466,62 @@ fn main() -> anyhow::Result<()> {
for client_stream in proxy_listener.incoming() {
client_stream
.map(|client_stream| {
let send_to_real_server_tx = send_to_real_server_tx.clone();
let send_to_virtual_interface_rx = send_to_virtual_interface_rx.clone();
let dead_client_tx = dead_client_tx.clone();
// Pick a port
// TODO: Pool
let port = 60000;
let (data_to_read_tx, data_to_read_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
let (data_to_send_tx, data_to_send_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
let client_addr = client_stream
.peer_addr()
.expect("client has no peer address");
info!("[{}] Incoming connection from {}", port, client_addr);
// tx/rx for data received from the client
// this data is received
let (send_to_virtual_client_tx, send_to_virtual_client_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
let client = ProxyClient {
virtual_port: port,
data_tx: data_to_send_tx,
data_rx: data_to_read_rx,
};
// tx/rx for packets received from the destination
// this data is received from the WG endpoint; the IP packets are routed using the port number
let (send_to_real_client_tx, send_to_real_client_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
// tx/rx for IP packets the interface exchanged that should be filtered/routed
let (send_to_ip_filter_tx, send_to_ip_filter_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
let stopped = Arc::new(AtomicBool::new(false));
let stopped_1 = Arc::clone(&stopped);
let stopped_2 = Arc::clone(&stopped);
// Register the new client with the virtual interface
new_client_tx.send(client.clone()).expect("failed to notify virtual interface of new client");
// Reads data from the client
thread::spawn(move || {
let stopped = stopped_1.clone();
let mut client_stream = client_stream;
// todo: change this to tokio?
client_stream
.set_nonblocking(true)
.expect("failed to set nonblocking");
loop {
if stopped.load(Ordering::Relaxed) {
break;
}
loop {
let mut buffer = [0; MAX_PACKET];
let read = client_stream.read(&mut buffer);
match read {
Ok(size) if size == 0 => {
info!("[{}] Connection closed by client: {}", port, client_addr);
stopped.store(true, Ordering::Relaxed);
break;
}
Ok(size) => {
debug!("[{}] Data received from client: {} bytes", port, size);
let data = &buffer[..size];
send_to_virtual_client_tx
.send(data.to_vec())
.unwrap_or_else(|e| error!("[{}] failed to send data to client_received_tx channel as received from client: {}", port, e));
data_to_read_tx.send(data.to_vec())
.unwrap_or_else(|e| error!("[{}] failed to send data to data_to_read_tx channel as received from client: {}", port, e));
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// Ignore and continue
}
Err(e) => {
warn!("[{}] Connection error: {}", port, e);
stopped.store(true, Ordering::Relaxed);
break;
}
}
while !send_to_ip_filter_rx.is_empty() {
let recv = send_to_ip_filter_rx.recv().expect("failed to read send_to_ip_filter_rx");
let src_addr: IpAddr = match IpVersion::of_packet(&recv) {
Ok(v) => {
match v {
IpVersion::Ipv4 => {
match Ipv4Repr::parse(&Ipv4Packet::new_unchecked(&recv), &ChecksumCapabilities::ignored()) {
Ok(packet) => Ipv4Addr::from(packet.src_addr).into(),
Err(e) => {
error!("[{}] Unable to determine source IPv4 from packet: {}", port, e);
continue;
}
}
}
IpVersion::Ipv6 => {
match Ipv6Repr::parse(&Ipv6Packet::new_unchecked(&recv)) {
Ok(packet) => Ipv6Addr::from(packet.src_addr).into(),
Err(e) => {
error!("[{}] Unable to determine source IPv6 from packet: {}", port, e);
continue;
}
}
}
_ => {
error!("[{}] Unable to determine IP version from packet: unspecified", port);
continue;
}
}
}
Err(e) => {
error!("[{}] Unable to determine IP version from packet: {}", port, e);
continue;
}
};
if src_addr == source_peer_ip {
debug!("[{}] IP packet: {} bytes from {} to send to WG", port, recv.len(), src_addr);
// Add to queue to be encapsulated and sent by other thread
send_to_real_server_tx.send(recv).expect("failed to write to send_to_real_server_tx channel");
}
}
while !send_to_real_client_rx.is_empty() {
let recv = send_to_real_client_rx.recv().expect("failed to read destination_sent_rx");
while !data_to_send_rx.is_empty() {
let recv = data_to_send_rx.recv().expect("failed to read data_to_send_rx");
client_stream
.write(recv.as_slice())
.unwrap_or_else(|e| {
@ -305,131 +530,9 @@ fn main() -> anyhow::Result<()> {
});
}
}
dead_client_tx.send(port).expect("failed to send to dead_client_tx channel");
});
// This thread simulates the IP-layer communication between the client and server.
// * When we get data from the 'real' client, we send it via the virtual client
// * When the virtual client sends data, it generates IP packets, which are captures via ip_rx/ip_tx
thread::spawn(move || {
let stopped = Arc::clone(&stopped_2);
let server_socket = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer)
};
let client_socket = {
static mut TCP_CLIENT_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_CLIENT_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] });
TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer)
};
let mut socket_set_entries: [_; 2] = Default::default();
let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
let server_handle = socket_set.add(server_socket);
let client_handle = socket_set.add(client_socket);
// Virtual device
let device = VirtualIpDevice::new(send_to_ip_filter_tx, send_to_virtual_interface_rx.clone());
// Create a virtual interface to simulate TCP connection
let mut iface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(dest_addr_ip), 32),
])
.any_ip(true)
.finalize();
// keeps track of whether the virtual clients needs to be initialized
let mut started = false;
loop {
let loop_start = Instant::now();
if stopped.load(Ordering::Relaxed) {
debug!("[{}] Killing virtual thread", port);
break;
}
match iface.poll(&mut socket_set, loop_start) {
Ok(processed) => {
if processed {
debug!("[{}] virtual iface polled and processed some packets", port);
}
}
Err(e) => {
error!("[{}] virtual iface poll error: {:?}", port, e);
break;
}
}
// Spawn a server socket so the virtual interface allows routing
// Note: the server socket is never read, since the IP packets are intercepted
// at the interface level.
{
let mut server_socket: SocketRef<TcpSocket> = socket_set.get(server_handle);
if !started {
// Open the virtual server socket
match server_socket.listen((IpAddress::from(dest_addr_ip), dest_addr_port)) {
Ok(_) => {
debug!("[{}] Virtual server listening: {}", port, server_socket.local_endpoint());
}
Err(e) => {
error!("[{}] Virtual server failed to listen: {}", port, e);
break;
}
}
}
}
// Virtual client
{
let mut client_socket: SocketRef<TcpSocket> = socket_set.get(client_handle);
if !started {
client_socket.connect(
(IpAddress::from(dest_addr_ip), dest_addr_port),
(IpAddress::from(source_peer_ip), port),
)
.expect("failed to connect virtual client");
debug!("[{}] Virtual client connected", port);
}
if client_socket.can_send() {
while !send_to_virtual_client_rx.is_empty() {
let to_send = send_to_virtual_client_rx.recv().expect("failed to read from client_received_rx channel");
client_socket.send_slice(to_send.as_slice()).expect("virtual client failed to send data from channel");
}
}
if client_socket.can_recv() {
let data = client_socket.recv(|b| (b.len(), b.to_vec())).expect("failed to recv");
send_to_real_client_tx.send(data).expect("failed to send to channel send_to_real_client_tx");
}
if !client_socket.is_open() {
warn!("[{}] Client socket is no longer open", port);
break;
}
}
// After the first loop, the client and server have started
started = true;
match iface.poll_delay(&socket_set, loop_start) {
Some(smoltcp::time::Duration::ZERO) => {}
Some(delay) => std::thread::sleep(std::time::Duration::from_millis(delay.millis())),
_ => {}
}
}
// if this thread ends, end the other ones too
debug!("[{}] Virtual thread stopped", port);
stopped.store(true, Ordering::Relaxed);
});
// * When the real destination sends IP packets (via WG endpoint), we send it via the device/interface
})
.unwrap_or_else(|e| error!("{:?}", e));
}

View file

@ -1,9 +1,7 @@
use std::collections::VecDeque;
use smoltcp::phy::{ChecksumCapabilities, Device, DeviceCapabilities, Medium};
use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant;
use smoltcp::wire::{Ipv4Packet, Ipv4Repr};
#[derive(Clone)]
pub struct VirtualIpDevice {
/// Channel for packets sent by the interface.
ip_tx: crossbeam_channel::Sender<Vec<u8>>,