This commit is contained in:
Aram 🍐 2021-10-14 21:31:09 -04:00
parent bf489900e6
commit ccb51fe5f8
3 changed files with 190 additions and 15 deletions

View file

@ -33,7 +33,7 @@ async fn main() -> anyhow::Result<()> {
let config = Config::from_args().with_context(|| "Failed to read config")?;
let port_pool = Arc::new(PortPool::new());
let wg = WireGuardTunnel::new(&config)
let wg = WireGuardTunnel::new(&config, port_pool.clone())
.await
.with_context(|| "Failed to initialize WireGuard tunnel")?;
let wg = Arc::new(wg);

View file

@ -7,23 +7,37 @@ const MAX_PORT: u16 = 60999;
const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
pub struct PortPool {
/// Remaining ports
inner: lockfree::queue::Queue<u16>,
/// Ports in use
taken: lockfree::set::Set<u16>,
}
impl PortPool {
pub fn new() -> Self {
let inner = lockfree::queue::Queue::default();
PORT_RANGE.for_each(|p| inner.push(p) as ());
Self { inner }
Self {
inner,
taken: lockfree::set::Set::new(),
}
}
pub fn next(&self) -> anyhow::Result<u16> {
self.inner
let port = self
.inner
.pop()
.with_context(|| "Virtual port pool is exhausted")
.with_context(|| "Virtual port pool is exhausted")?;
self.taken.insert(port);
Ok(port)
}
pub fn release(&self, port: u16) {
self.inner.push(port);
self.taken.remove(&port);
}
pub fn is_in_use(&self, port: u16) -> bool {
self.taken.contains(&port)
}
}

167
src/wg.rs
View file

@ -1,18 +1,26 @@
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use boringtun::noise::{Tunn, TunnResult};
use log::Level;
use smoltcp::phy::ChecksumCapabilities;
use smoltcp::wire::{
IpAddress, IpProtocol, IpRepr, IpVersion, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address,
Ipv6Packet, Ipv6Repr, TcpControl, TcpPacket, TcpRepr, TcpSeqNumber,
};
use tokio::net::UdpSocket;
use crate::config::Config;
use crate::port_pool::PortPool;
use crate::MAX_PACKET;
/// The capacity of the broadcast channel for received IP packets.
const BROADCAST_CAPACITY: usize = 1_000_000;
pub struct WireGuardTunnel {
source_peer_ip: IpAddr,
/// `boringtun` peer/tunnel implementation, used for crypto & WG protocol.
peer: Box<Tunn>,
/// The UDP socket for the public WireGuard endpoint to connect to.
@ -21,12 +29,16 @@ pub struct WireGuardTunnel {
endpoint: SocketAddr,
/// Broadcast sender for received IP packets.
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>,
/// Placeholder so that the broadcaster doesn't close.
ip_broadcast_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
/// Port pool.
port_pool: Arc<PortPool>,
}
impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config) -> anyhow::Result<Self> {
pub async fn new(config: &Config, port_pool: Arc<PortPool>) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(&config)?;
let udp = UdpSocket::bind("0.0.0.0:0")
.await
@ -36,11 +48,13 @@ impl WireGuardTunnel {
tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
Ok(Self {
source_peer_ip,
peer,
udp,
endpoint,
ip_broadcast_tx,
ip_broadcast_rx,
port_pool,
})
}
@ -175,10 +189,16 @@ impl WireGuardTunnel {
// For debugging purposes: parse packet
trace_ip_packet("Received IP packet", packet);
match self.route(packet) {
RouteResult::Broadcast => {
// Broadcast IP packet
if self.ip_broadcast_tx.receiver_count() > 1 {
match self.ip_broadcast_tx.send(packet.to_vec()) {
Ok(n) => {
trace!("Broadcasted received IP packet to {} recipients", n);
trace!(
"Broadcasted received IP packet to {} virtual interfaces",
n - 1
);
}
Err(e) => {
error!(
@ -188,6 +208,18 @@ impl WireGuardTunnel {
}
}
}
}
RouteResult::TcpReset(packet) => {
trace!("Resetting dead TCP connection after packet from WireGuard endpoint");
self.send_ip_packet(&packet)
.await
.unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e));
}
RouteResult::Drop => {
trace!("Dropped incoming IP packet from WireGuard endpoint");
}
}
}
_ => {}
}
}
@ -205,6 +237,126 @@ impl WireGuardTunnel {
.map_err(|s| anyhow::anyhow!("{}", s))
.with_context(|| "Failed to initialize boringtun Tunn")
}
fn route(&self, packet: &[u8]) -> RouteResult {
match IpVersion::of_packet(&packet) {
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
.ok()
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() {
IpProtocol::Tcp => TcpPacket::new_checked(packet.payload()).ok().map(|tcp| {
if self.port_pool.is_in_use(tcp.dst_port()) {
RouteResult::Broadcast
} else if tcp.rst() {
RouteResult::Drop
} else {
// Port is not in use, but it's a TCP packet so we'll craft a RST.
RouteResult::TcpReset(craft_tcp_rst_reply(
IpVersion::Ipv4,
packet.src_addr().into(),
tcp.src_port(),
packet.dst_addr().into(),
tcp.dst_port(),
tcp.ack_number(),
))
}
}),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
})
.flatten()
.unwrap_or(RouteResult::Drop),
// TODO: IPv6
_ => RouteResult::Drop,
}
}
}
fn craft_tcp_rst_reply(
ip_version: IpVersion,
source_addr: IpAddress,
source_port: u16,
dest_addr: IpAddress,
dest_port: u16,
ack_number: TcpSeqNumber,
) -> Vec<u8> {
let tcp_repr = TcpRepr {
src_port: dest_port,
dst_port: source_port,
control: TcpControl::Rst,
seq_number: ack_number,
ack_number: None,
window_len: 0,
window_scale: None,
max_seg_size: None,
sack_permitted: false,
sack_ranges: [None, None, None],
payload: &[],
};
let mut tcp_buffer = vec![0u8; 20];
let mut tcp_packet = &mut TcpPacket::new_unchecked(&mut tcp_buffer);
tcp_repr.emit(
&mut tcp_packet,
&dest_addr,
&source_addr,
&ChecksumCapabilities::default(),
);
let mut ip_buffer = vec![0u8; MAX_PACKET];
let (header_len, total_len) = match ip_version {
IpVersion::Ipv4 => {
let dest_addr = match dest_addr {
IpAddress::Ipv4(dest_addr) => dest_addr,
_ => panic!(),
};
let source_addr = match source_addr {
IpAddress::Ipv4(source_addr) => source_addr,
_ => panic!(),
};
let mut ip_packet = &mut Ipv4Packet::new_unchecked(&mut ip_buffer);
let ip_repr = Ipv4Repr {
src_addr: dest_addr,
dst_addr: source_addr,
protocol: IpProtocol::Tcp,
payload_len: tcp_buffer.len(),
hop_limit: 64,
};
ip_repr.emit(&mut ip_packet, &ChecksumCapabilities::default());
(
ip_packet.header_len() as usize,
ip_packet.total_len() as usize,
)
}
IpVersion::Ipv6 => {
let dest_addr = match dest_addr {
IpAddress::Ipv6(dest_addr) => dest_addr,
_ => panic!(),
};
let source_addr = match source_addr {
IpAddress::Ipv6(source_addr) => source_addr,
_ => panic!(),
};
let mut ip_packet = &mut Ipv6Packet::new_unchecked(&mut ip_buffer);
let ip_repr = Ipv6Repr {
src_addr: dest_addr,
dst_addr: source_addr,
next_header: IpProtocol::Tcp,
payload_len: tcp_buffer.len(),
hop_limit: 64,
};
ip_repr.emit(&mut ip_packet);
(ip_packet.header_len(), ip_packet.total_len())
}
_ => panic!(),
};
ip_buffer[header_len..total_len].copy_from_slice(&tcp_buffer);
let packet: &[u8] = &ip_buffer[..total_len];
packet.to_vec()
}
fn trace_ip_packet(message: &str, packet: &[u8]) {
@ -226,3 +378,12 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
}
}
}
enum RouteResult {
/// The packet can be broadcasted to the virtual interfaces
Broadcast,
/// The packet is not routable so it may be reset.
TcpReset(Vec<u8>),
/// The packet can be safely ignored.
Drop,
}