WIP with virtual sockets

This commit is contained in:
Aram 🍐 2021-10-11 23:24:56 -04:00
parent f3976be852
commit 95f929c00b
4 changed files with 534 additions and 290 deletions

View file

@ -1,9 +1,13 @@
#[macro_use]
extern crate log;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::{
IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket,
};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Barrier};
use std::sync::{Arc, Barrier, Mutex, RwLock};
use std::thread;
use std::time::Duration;
@ -12,13 +16,20 @@ use boringtun::crypto::{X25519PublicKey, X25519SecretKey};
use boringtun::device::peer::Peer;
use boringtun::noise::{Tunn, TunnResult};
use clap::{App, Arg};
use packet::ip::Protocol;
use packet::Builder;
use smoltcp::wire::Ipv4Packet;
use crossbeam_channel::{Receiver, RecvError, Sender};
use smoltcp::iface::InterfaceBuilder;
use smoltcp::phy::ChecksumCapabilities;
use smoltcp::socket::{SocketRef, SocketSet, TcpSocket, TcpSocketBuffer};
use smoltcp::time::Instant;
use smoltcp::wire::{
IpAddress, IpCidr, IpRepr, IpVersion, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr,
};
use crate::config::Config;
use crate::virtual_device::VirtualIpDevice;
mod config;
mod virtual_device;
const MAX_PACKET: usize = 65536;
@ -27,224 +38,420 @@ fn main() -> anyhow::Result<()> {
let config = Config::from_args().with_context(|| "Failed to read config")?;
debug!("Parsed arguments: {:?}", config);
let peer = Arc::new(
Tunn::new(
config.private_key.clone(),
config.endpoint_public_key.clone(),
None,
None,
0,
None,
)
.map_err(|s| anyhow::anyhow!("{}", s))
.with_context(|| "Failed to initialize peer")?,
);
let source_sock = Arc::new(
UdpSocket::bind(&config.source_addr).with_context(|| "Failed to bind source socket")?,
);
let endpoint_sock =
Arc::new(UdpSocket::bind("0.0.0.0:0").with_context(|| "Failed to bind endpoint socket")?);
let endpoint_addr = config.endpoint_addr;
let source_peer_addr = SocketAddr::new(config.source_peer_ip, 1234);
let destination_addr = config.dest_addr;
let close = Arc::new(AtomicBool::new(false));
let mut handles = Vec::with_capacity(3);
// thread 1: read from endpoint, forward to peer
{
let close = close.clone();
let peer = peer.clone();
let source_sock = source_sock.clone();
let endpoint_sock = endpoint_sock.clone();
handles.push(thread::spawn(move || loop {
// Listen on the network
let mut recv_buf = [0u8; MAX_PACKET];
let mut send_buf = [0u8; MAX_PACKET];
let n = match endpoint_sock.recv(&mut recv_buf) {
Ok(n) => n,
Err(_) => {
if close.load(Ordering::Relaxed) {
return;
}
continue;
}
};
let data = &recv_buf[..n];
match peer.decapsulate(None, data, &mut send_buf) {
TunnResult::WriteToNetwork(packet) => {
send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
loop {
let mut send_buf = [0u8; MAX_PACKET];
match peer.decapsulate(None, &[], &mut send_buf) {
TunnResult::WriteToNetwork(packet) => {
send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
}
_ => {
break;
}
}
}
}
TunnResult::WriteToTunnelV4(packet, _) => {
source_sock.send(packet).unwrap();
}
TunnResult::WriteToTunnelV6(packet, _) => {
source_sock.send(packet).unwrap();
}
_ => {}
}
}));
}
// thread 2: read from peer socket
{
let close = close.clone();
let peer = peer.clone();
let source_sock = source_sock.clone();
let endpoint_sock = endpoint_sock.clone();
handles.push(thread::spawn(move || loop {
let mut recv_buf = [0u8; MAX_PACKET];
let mut send_buf = [0u8; MAX_PACKET];
let n = match source_sock.recv(&mut recv_buf) {
Ok(n) => n,
Err(_) => {
if close.load(Ordering::Relaxed) {
return;
}
continue;
}
};
let data = &recv_buf[..n];
// TODO: Support TCP
let ip_packet =
wrap_data_packet(Protocol::Udp, data, source_peer_addr, destination_addr)
.expect("Failed to wrap data packet");
debug!("Crafted IP packet: {:#?}", ip_packet);
match peer.encapsulate(ip_packet.as_slice(), &mut send_buf) {
TunnResult::WriteToNetwork(packet) => {
send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
}
TunnResult::Err(e) => {
error!("Failed to encapsulate: {:?}", e);
}
other => {
error!("Unexpected TunnResult during encapsulation: {:?}", other);
}
}
}));
}
// thread 3: maintenance
{
let close = close.clone();
let peer = peer.clone();
let endpoint_sock = endpoint_sock.clone();
handles.push(thread::spawn(move || loop {
if close.load(Ordering::Relaxed) {
return;
}
let mut send_buf = [0u8; MAX_PACKET];
match peer.update_timers(&mut send_buf) {
TunnResult::WriteToNetwork(packet) => {
send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
}
_ => {}
}
thread::sleep(Duration::from_millis(200));
}));
}
info!(
"Tunnelling [{}]->[{}] (via [{}] as peer {})",
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
);
for handle in handles {
handle.join().expect("Failed to join thread")
let source_peer_ip = config.source_peer_ip;
let dest_addr_ip = config.dest_addr.ip();
let dest_addr_port = config.dest_addr.port();
// Initialize peer based on config
let peer = Tunn::new(
config.private_key.clone(),
config.endpoint_public_key.clone(),
None,
None,
0,
None,
)
.map_err(|s| anyhow::anyhow!("{}", s))
.with_context(|| "Failed to initialize peer")?;
let proxy_listener = TcpListener::bind(config.source_addr).unwrap();
for client_stream in proxy_listener.incoming() {
client_stream
.map(|client_stream| {
// Pick a port
// TODO: Pool
let port = 60000;
let client_addr = client_stream
.peer_addr()
.expect("client has no peer address");
info!("[{}] Incoming connection from {}", port, client_addr);
// tx/rx for data received from the client
// this data is received
let (client_received_tx, client_received_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
// tx/rx for packets received from the destination
// this data is received from the WG endpoint; the IP packets are routed using the port number
let (destination_sent_tx, destination_sent_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
// tx/rx for packets the virtual client sent and that should be sent to the wg tunnel
let (ip_tx, ip_rx) = crossbeam_channel::unbounded::<Vec<u8>>();
let stopped = Arc::new(AtomicBool::new(false));
let stopped_1 = Arc::clone(&stopped);
let stopped_2 = Arc::clone(&stopped);
// Reads data from the client
thread::spawn(move || {
let stopped = stopped_1.clone();
let mut client_stream = client_stream;
client_stream
.set_nonblocking(true)
.expect("failed to set nonblocking");
loop {
if stopped.load(Ordering::Relaxed) {
break;
}
let mut buffer = [0; MAX_PACKET];
let read = client_stream.read(&mut buffer);
match read {
Ok(size) if size == 0 => {
info!("[{}] Connection closed by client: {}", port, client_addr);
stopped.store(true, Ordering::Relaxed);
break;
}
Ok(size) => {
debug!("[{}] Data received from client: {} bytes", port, size);
let data = &buffer[..size];
client_received_tx
.send(data.to_vec())
.unwrap_or_else(|e| error!("[{}] failed to send data to client_received_tx channel as received from client: {}", port, e));
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// Ignore and continue
}
Err(e) => {
warn!("[{}] Connection error: {}", port, e);
stopped.store(true, Ordering::Relaxed);
break;
}
}
while !ip_rx.is_empty() {
let recv = ip_rx.recv().expect("failed to read ip_rx");
let src_addr: IpAddr = match IpVersion::of_packet(&recv) {
Ok(v) => {
match v {
IpVersion::Ipv4 => {
match Ipv4Repr::parse(&Ipv4Packet::new_unchecked(&recv), &ChecksumCapabilities::ignored()) {
Ok(packet) => Ipv4Addr::from(packet.src_addr).into(),
Err(e) => {
error!("[{}] Unable to determine source IPv4 from packet: {}", port, e);
continue;
}
}
}
IpVersion::Ipv6 => {
match Ipv6Repr::parse(&Ipv6Packet::new_unchecked(&recv)) {
Ok(packet) => Ipv6Addr::from(packet.src_addr).into(),
Err(e) => {
error!("[{}] Unable to determine source IPv6 from packet: {}", port, e);
continue;
}
}
}
_ => {
error!("[{}] Unable to determine IP version from packet: unspecified", port);
continue;
}
}
}
Err(e) => {
error!("[{}] Unable to determine IP version from packet: {}", port, e);
continue;
}
};
if src_addr == source_peer_ip {
// TODO: Encapsulate and send to WG
debug!("[{}] IP packet: {} bytes from {} to send to WG", port, recv.len(), src_addr);
}
}
while !destination_sent_rx.is_empty() {
let recv = destination_sent_rx.recv().expect("failed to read destination_sent_rx");
client_stream
.write(recv.as_slice())
.unwrap_or_else(|e| {
error!("[{}] failed to send write to client stream: {}", port, e);
0
});
}
}
});
// This thread simulates the IP-layer communication between the client and server.
// * When we get data from the 'real' client, we send it via the virtual client
// * When the virtual client sends data, it generates IP packets, which are captures via ip_rx/ip_tx
// * When the real destination sends data (via WG endpoint), we send it via the virtual server
thread::spawn(move || {
let stopped = Arc::clone(&stopped_2);
let server_socket = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer)
};
let client_socket = {
static mut TCP_CLIENT_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_CLIENT_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] });
TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer)
};
let mut socket_set_entries: [_; 2] = Default::default();
let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
let server_handle = socket_set.add(server_socket);
let client_handle = socket_set.add(client_socket);
// Virtual device
let device = VirtualIpDevice::new(ip_tx);
// Create a virtual interface to simulate TCP connection
let mut iface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(dest_addr_ip), 32),
])
.any_ip(true)
.finalize();
// keeps track of whether the virtual clients needs to be initialized
let mut started = false;
loop {
let loop_start = Instant::now();
if stopped.load(Ordering::Relaxed) {
debug!("[{}] Killing virtual thread", port);
break;
}
match iface.poll(&mut socket_set, loop_start) {
Ok(processed) => {
if processed {
debug!("[{}] virtual iface polled and processed some packets", port);
}
}
Err(e) => {
error!("[{}] virtual iface poll error: {:?}", port, e);
break;
}
}
// Server socket polling
{
let mut server_socket: SocketRef<TcpSocket> = socket_set.get(server_handle);
if !started {
// Open the virtual server socket
match server_socket.listen((IpAddress::from(dest_addr_ip), dest_addr_port)) {
Ok(_) => {
debug!("[{}] Virtual server listening: {}", port, server_socket.local_endpoint());
}
Err(e) => {
error!("[{}] Virtual server failed to listen: {}", port, e);
break;
}
}
}
if server_socket.can_recv() {
let buffer = server_socket
.recv(|buffer| { (buffer.len(), buffer.to_vec()) });
match buffer {
Ok(buffer) => {
debug!("[{}] Virtual server socket read: {} bytes", port, buffer.len());
}
Err(e) => {
error!("[{}] Virtual server failed to read: {}", port, e);
break;
}
}
}
if server_socket.can_send() {
// TODO: See if this is actually useful
}
}
// Virtual client
{
let mut client_socket: SocketRef<TcpSocket> = socket_set.get(client_handle);
if !started {
client_socket.connect(
(IpAddress::from(dest_addr_ip), dest_addr_port),
(IpAddress::from(source_peer_ip), port),
)
.expect("failed to connect virtual client");
debug!("[{}] Virtual client connected", port);
}
if client_socket.can_send() {
while !client_received_rx.is_empty() {
let to_send = client_received_rx.recv().expect("failed to read from client_received_rx channel");
client_socket.send_slice(to_send.as_slice()).expect("virtual client failed to send data from channel");
}
}
if client_socket.can_recv() {
// TODO: See if this is actually useful?
client_socket.recv(|b| (b.len(), 0)).expect("failed to recv");
}
if !client_socket.is_open() {
warn!("[{}] Client socket is no longer open", port);
break;
}
}
// After the first loop, the client and server have started
started = true;
match iface.poll_delay(&socket_set, loop_start) {
Some(smoltcp::time::Duration::ZERO) => {}
Some(delay) => std::thread::sleep(std::time::Duration::from_millis(delay.millis())),
_ => {}
}
}
// if this thread ends, end the other ones too
debug!("[{}] Virtual thread stopped", port);
stopped.store(true, Ordering::Relaxed);
});
})
.unwrap_or_else(|e| error!("{:?}", e));
}
// TCP thread
// let mut handles = Vec::with_capacity(3);
//
// let endpoint_sock =
// Arc::new(UdpSocket::bind("0.0.0.0:0").with_context(|| "Failed to bind endpoint socket")?);
//
// let endpoint_addr = config.endpoint_addr;
//
// let source_peer_addr = SocketAddr::new(config.source_peer_ip, 1234);
// let destination_addr = config.dest_addr;
//
// let close = Arc::new(AtomicBool::new(false));
//
//
// // thread 1: read from endpoint, forward to peer
// {
// let close = close.clone();
// let peer = peer.clone();
// let source_sock = source_sock.clone();
// let endpoint_sock = endpoint_sock.clone();
//
// handles.push(thread::spawn(move || loop {
// // Listen on the network
// let mut recv_buf = [0u8; MAX_PACKET];
// let mut send_buf = [0u8; MAX_PACKET];
//
// let n = match endpoint_sock.recv(&mut recv_buf) {
// Ok(n) => n,
// Err(_) => {
// if close.load(Ordering::Relaxed) {
// return;
// }
// continue;
// }
// };
//
// let data = &recv_buf[..n];
// match peer.decapsulate(None, data, &mut send_buf) {
// TunnResult::WriteToNetwork(packet) => {
// send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
// loop {
// let mut send_buf = [0u8; MAX_PACKET];
// match peer.decapsulate(None, &[], &mut send_buf) {
// TunnResult::WriteToNetwork(packet) => {
// send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
// }
// _ => {
// break;
// }
// }
// }
// }
// TunnResult::WriteToTunnelV4(packet, _) => {
// source_sock.send(packet).unwrap();
// }
// TunnResult::WriteToTunnelV6(packet, _) => {
// source_sock.send(packet).unwrap();
// }
// _ => {}
// }
// }));
// }
//
// // thread 2: read from peer socket
// {
// let close = close.clone();
// let peer = peer.clone();
// let source_sock = source_sock.clone();
// let endpoint_sock = endpoint_sock.clone();
//
// handles.push(thread::spawn(move || loop {
// let mut recv_buf = [0u8; MAX_PACKET];
// let mut send_buf = [0u8; MAX_PACKET];
//
// let n = match source_sock.recv(&mut recv_buf) {
// Ok(n) => n,
// Err(_) => {
// if close.load(Ordering::Relaxed) {
// return;
// }
// continue;
// }
// };
//
// let data = &recv_buf[..n];
//
// // TODO: Support TCP
// let ip_packet =
// wrap_data_packet(Protocol::Udp, data, source_peer_addr, destination_addr)
// .expect("Failed to wrap data packet");
//
// debug!("Crafted IP packet: {:#?}", ip_packet);
//
// match peer.encapsulate(ip_packet.as_slice(), &mut send_buf) {
// TunnResult::WriteToNetwork(packet) => {
// send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
// }
// TunnResult::Err(e) => {
// error!("Failed to encapsulate: {:?}", e);
// }
// other => {
// error!("Unexpected TunnResult during encapsulation: {:?}", other);
// }
// }
// }));
// }
//
// // thread 3: maintenance
// {
// let close = close.clone();
// let peer = peer.clone();
// let endpoint_sock = endpoint_sock.clone();
//
// handles.push(thread::spawn(move || loop {
// if close.load(Ordering::Relaxed) {
// return;
// }
//
// let mut send_buf = [0u8; MAX_PACKET];
// match peer.update_timers(&mut send_buf) {
// TunnResult::WriteToNetwork(packet) => {
// send_packet(packet, endpoint_sock.clone(), endpoint_addr).unwrap();
// }
// _ => {}
// }
//
// thread::sleep(Duration::from_millis(200));
// }));
// }
//
//
// for handle in handles {
// handle.join().expect("Failed to join thread")
// }
Ok(())
}
// wraps a UDP packet with an IP layer packet with the wanted source & destination addresses
fn wrap_data_packet(
proto: Protocol,
data: &[u8],
source: SocketAddr,
destination: SocketAddr,
) -> anyhow::Result<Vec<u8>> {
match source {
SocketAddr::V4(source) => {
let mut builder = packet::ip::v4::Builder::default();
builder = builder
.source(*source.ip())
.with_context(|| "Failed to set packet source")?;
builder = builder
.payload(data)
.with_context(|| "Failed to set packet payload")?;
builder = builder
.protocol(proto)
.with_context(|| "Failed to set packet protocol")?;
builder = builder
.dscp(0)
.with_context(|| "Failed to set packet dcsp")?;
builder = builder
.id(12345)
.with_context(|| "Failed to set packet ID")?;
builder = builder
.ttl(16)
.with_context(|| "Failed to set packet TTL")?;
match destination {
SocketAddr::V4(destination) => {
builder = builder
.destination(*destination.ip())
.with_context(|| "Failed to set packet destination")?;
}
SocketAddr::V6(_) => {
return Err(anyhow::anyhow!(
"cannot use ipv6 destination with ipv4 source"
));
}
}
builder
.build()
.with_context(|| "Failed to build ipv4 packet")
}
SocketAddr::V6(_) => {
todo!("ipv6 support")
}
}
}
fn send_packet(
packet: &[u8],
endpoint_socket: Arc<UdpSocket>,
endpoint_addr: SocketAddr,
) -> anyhow::Result<usize> {
// todo: replace addr with peer_addr
let size = endpoint_socket
.send_to(packet, endpoint_addr)
.with_context(|| "Failed to send packet")?;
Ok(size)
}

83
src/virtual_device.rs Normal file
View file

@ -0,0 +1,83 @@
use smoltcp::phy::{ChecksumCapabilities, Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant;
use smoltcp::wire::{Ipv4Packet, Ipv4Repr};
use std::collections::VecDeque;
pub struct VirtualIpDevice {
queue: VecDeque<Vec<u8>>,
/// Sends IP packets
ip_tx: crossbeam_channel::Sender<Vec<u8>>,
}
impl VirtualIpDevice {
pub fn new(ip_tx: crossbeam_channel::Sender<Vec<u8>>) -> Self {
Self {
queue: VecDeque::new(),
ip_tx,
}
}
}
impl<'a> Device<'a> for VirtualIpDevice {
type RxToken = RxToken;
type TxToken = TxToken<'a>;
fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> {
self.queue.pop_front().map(move |buffer| {
let rx = RxToken { buffer };
let tx = TxToken {
queue: &mut self.queue,
tx: Some(self.ip_tx.clone()),
};
(rx, tx)
})
}
fn transmit(&'a mut self) -> Option<Self::TxToken> {
Some(TxToken {
queue: &mut self.queue,
tx: Some(self.ip_tx.clone()),
})
}
fn capabilities(&self) -> DeviceCapabilities {
let mut cap = DeviceCapabilities::default();
cap.medium = Medium::Ip;
cap.max_transmission_unit = 65535;
cap
}
}
#[doc(hidden)]
pub struct RxToken {
buffer: Vec<u8>,
}
impl smoltcp::phy::RxToken for RxToken {
fn consume<R, F>(mut self, _timestamp: Instant, f: F) -> smoltcp::Result<R>
where
F: FnOnce(&mut [u8]) -> smoltcp::Result<R>,
{
f(&mut self.buffer)
}
}
#[doc(hidden)]
pub struct TxToken<'a> {
queue: &'a mut VecDeque<Vec<u8>>,
tx: Option<crossbeam_channel::Sender<Vec<u8>>>,
}
impl<'a> smoltcp::phy::TxToken for TxToken<'a> {
fn consume<R, F>(self, _timestamp: Instant, len: usize, f: F) -> smoltcp::Result<R>
where
F: FnOnce(&mut [u8]) -> smoltcp::Result<R>,
{
let mut buffer = Vec::new();
buffer.resize(len, 0);
let result = f(&mut buffer);
self.tx.map(|tx| tx.send(buffer.clone()));
self.queue.push_back(buffer);
result
}
}