mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 06:38:32 -04:00
yippee
This commit is contained in:
parent
bf489900e6
commit
ccb51fe5f8
3 changed files with 190 additions and 15 deletions
|
@ -33,7 +33,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let config = Config::from_args().with_context(|| "Failed to read config")?;
|
||||
let port_pool = Arc::new(PortPool::new());
|
||||
|
||||
let wg = WireGuardTunnel::new(&config)
|
||||
let wg = WireGuardTunnel::new(&config, port_pool.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to initialize WireGuard tunnel")?;
|
||||
let wg = Arc::new(wg);
|
||||
|
|
|
@ -7,23 +7,37 @@ const MAX_PORT: u16 = 60999;
|
|||
const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
|
||||
|
||||
pub struct PortPool {
|
||||
/// Remaining ports
|
||||
inner: lockfree::queue::Queue<u16>,
|
||||
/// Ports in use
|
||||
taken: lockfree::set::Set<u16>,
|
||||
}
|
||||
|
||||
impl PortPool {
|
||||
pub fn new() -> Self {
|
||||
let inner = lockfree::queue::Queue::default();
|
||||
PORT_RANGE.for_each(|p| inner.push(p) as ());
|
||||
Self { inner }
|
||||
Self {
|
||||
inner,
|
||||
taken: lockfree::set::Set::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next(&self) -> anyhow::Result<u16> {
|
||||
self.inner
|
||||
let port = self
|
||||
.inner
|
||||
.pop()
|
||||
.with_context(|| "Virtual port pool is exhausted")
|
||||
.with_context(|| "Virtual port pool is exhausted")?;
|
||||
self.taken.insert(port);
|
||||
Ok(port)
|
||||
}
|
||||
|
||||
pub fn release(&self, port: u16) {
|
||||
self.inner.push(port);
|
||||
self.taken.remove(&port);
|
||||
}
|
||||
|
||||
pub fn is_in_use(&self, port: u16) -> bool {
|
||||
self.taken.contains(&port)
|
||||
}
|
||||
}
|
||||
|
|
167
src/wg.rs
167
src/wg.rs
|
@ -1,18 +1,26 @@
|
|||
use std::net::SocketAddr;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use log::Level;
|
||||
use smoltcp::phy::ChecksumCapabilities;
|
||||
use smoltcp::wire::{
|
||||
IpAddress, IpProtocol, IpRepr, IpVersion, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address,
|
||||
Ipv6Packet, Ipv6Repr, TcpControl, TcpPacket, TcpRepr, TcpSeqNumber,
|
||||
};
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::port_pool::PortPool;
|
||||
use crate::MAX_PACKET;
|
||||
|
||||
/// The capacity of the broadcast channel for received IP packets.
|
||||
const BROADCAST_CAPACITY: usize = 1_000_000;
|
||||
|
||||
pub struct WireGuardTunnel {
|
||||
source_peer_ip: IpAddr,
|
||||
/// `boringtun` peer/tunnel implementation, used for crypto & WG protocol.
|
||||
peer: Box<Tunn>,
|
||||
/// The UDP socket for the public WireGuard endpoint to connect to.
|
||||
|
@ -21,12 +29,16 @@ pub struct WireGuardTunnel {
|
|||
endpoint: SocketAddr,
|
||||
/// Broadcast sender for received IP packets.
|
||||
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>,
|
||||
/// Placeholder so that the broadcaster doesn't close.
|
||||
ip_broadcast_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
/// Port pool.
|
||||
port_pool: Arc<PortPool>,
|
||||
}
|
||||
|
||||
impl WireGuardTunnel {
|
||||
/// Initialize a new WireGuard tunnel.
|
||||
pub async fn new(config: &Config) -> anyhow::Result<Self> {
|
||||
pub async fn new(config: &Config, port_pool: Arc<PortPool>) -> anyhow::Result<Self> {
|
||||
let source_peer_ip = config.source_peer_ip;
|
||||
let peer = Self::create_tunnel(&config)?;
|
||||
let udp = UdpSocket::bind("0.0.0.0:0")
|
||||
.await
|
||||
|
@ -36,11 +48,13 @@ impl WireGuardTunnel {
|
|||
tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
|
||||
|
||||
Ok(Self {
|
||||
source_peer_ip,
|
||||
peer,
|
||||
udp,
|
||||
endpoint,
|
||||
ip_broadcast_tx,
|
||||
ip_broadcast_rx,
|
||||
port_pool,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -175,10 +189,16 @@ impl WireGuardTunnel {
|
|||
// For debugging purposes: parse packet
|
||||
trace_ip_packet("Received IP packet", packet);
|
||||
|
||||
match self.route(packet) {
|
||||
RouteResult::Broadcast => {
|
||||
// Broadcast IP packet
|
||||
if self.ip_broadcast_tx.receiver_count() > 1 {
|
||||
match self.ip_broadcast_tx.send(packet.to_vec()) {
|
||||
Ok(n) => {
|
||||
trace!("Broadcasted received IP packet to {} recipients", n);
|
||||
trace!(
|
||||
"Broadcasted received IP packet to {} virtual interfaces",
|
||||
n - 1
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
|
@ -188,6 +208,18 @@ impl WireGuardTunnel {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
RouteResult::TcpReset(packet) => {
|
||||
trace!("Resetting dead TCP connection after packet from WireGuard endpoint");
|
||||
self.send_ip_packet(&packet)
|
||||
.await
|
||||
.unwrap_or_else(|e| error!("Failed to sent TCP reset: {:?}", e));
|
||||
}
|
||||
RouteResult::Drop => {
|
||||
trace!("Dropped incoming IP packet from WireGuard endpoint");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
@ -205,6 +237,126 @@ impl WireGuardTunnel {
|
|||
.map_err(|s| anyhow::anyhow!("{}", s))
|
||||
.with_context(|| "Failed to initialize boringtun Tunn")
|
||||
}
|
||||
|
||||
fn route(&self, packet: &[u8]) -> RouteResult {
|
||||
match IpVersion::of_packet(&packet) {
|
||||
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
|
||||
.ok()
|
||||
// Only care if the packet is destined for this tunnel
|
||||
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
|
||||
.map(|packet| match packet.protocol() {
|
||||
IpProtocol::Tcp => TcpPacket::new_checked(packet.payload()).ok().map(|tcp| {
|
||||
if self.port_pool.is_in_use(tcp.dst_port()) {
|
||||
RouteResult::Broadcast
|
||||
} else if tcp.rst() {
|
||||
RouteResult::Drop
|
||||
} else {
|
||||
// Port is not in use, but it's a TCP packet so we'll craft a RST.
|
||||
RouteResult::TcpReset(craft_tcp_rst_reply(
|
||||
IpVersion::Ipv4,
|
||||
packet.src_addr().into(),
|
||||
tcp.src_port(),
|
||||
packet.dst_addr().into(),
|
||||
tcp.dst_port(),
|
||||
tcp.ack_number(),
|
||||
))
|
||||
}
|
||||
}),
|
||||
// Unrecognized protocol, so we'll allow it.
|
||||
_ => Some(RouteResult::Broadcast),
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or(RouteResult::Drop),
|
||||
// TODO: IPv6
|
||||
_ => RouteResult::Drop,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn craft_tcp_rst_reply(
|
||||
ip_version: IpVersion,
|
||||
source_addr: IpAddress,
|
||||
source_port: u16,
|
||||
dest_addr: IpAddress,
|
||||
dest_port: u16,
|
||||
ack_number: TcpSeqNumber,
|
||||
) -> Vec<u8> {
|
||||
let tcp_repr = TcpRepr {
|
||||
src_port: dest_port,
|
||||
dst_port: source_port,
|
||||
control: TcpControl::Rst,
|
||||
seq_number: ack_number,
|
||||
ack_number: None,
|
||||
window_len: 0,
|
||||
window_scale: None,
|
||||
max_seg_size: None,
|
||||
sack_permitted: false,
|
||||
sack_ranges: [None, None, None],
|
||||
payload: &[],
|
||||
};
|
||||
|
||||
let mut tcp_buffer = vec![0u8; 20];
|
||||
let mut tcp_packet = &mut TcpPacket::new_unchecked(&mut tcp_buffer);
|
||||
tcp_repr.emit(
|
||||
&mut tcp_packet,
|
||||
&dest_addr,
|
||||
&source_addr,
|
||||
&ChecksumCapabilities::default(),
|
||||
);
|
||||
|
||||
let mut ip_buffer = vec![0u8; MAX_PACKET];
|
||||
|
||||
let (header_len, total_len) = match ip_version {
|
||||
IpVersion::Ipv4 => {
|
||||
let dest_addr = match dest_addr {
|
||||
IpAddress::Ipv4(dest_addr) => dest_addr,
|
||||
_ => panic!(),
|
||||
};
|
||||
let source_addr = match source_addr {
|
||||
IpAddress::Ipv4(source_addr) => source_addr,
|
||||
_ => panic!(),
|
||||
};
|
||||
|
||||
let mut ip_packet = &mut Ipv4Packet::new_unchecked(&mut ip_buffer);
|
||||
let ip_repr = Ipv4Repr {
|
||||
src_addr: dest_addr,
|
||||
dst_addr: source_addr,
|
||||
protocol: IpProtocol::Tcp,
|
||||
payload_len: tcp_buffer.len(),
|
||||
hop_limit: 64,
|
||||
};
|
||||
ip_repr.emit(&mut ip_packet, &ChecksumCapabilities::default());
|
||||
(
|
||||
ip_packet.header_len() as usize,
|
||||
ip_packet.total_len() as usize,
|
||||
)
|
||||
}
|
||||
IpVersion::Ipv6 => {
|
||||
let dest_addr = match dest_addr {
|
||||
IpAddress::Ipv6(dest_addr) => dest_addr,
|
||||
_ => panic!(),
|
||||
};
|
||||
let source_addr = match source_addr {
|
||||
IpAddress::Ipv6(source_addr) => source_addr,
|
||||
_ => panic!(),
|
||||
};
|
||||
let mut ip_packet = &mut Ipv6Packet::new_unchecked(&mut ip_buffer);
|
||||
let ip_repr = Ipv6Repr {
|
||||
src_addr: dest_addr,
|
||||
dst_addr: source_addr,
|
||||
next_header: IpProtocol::Tcp,
|
||||
payload_len: tcp_buffer.len(),
|
||||
hop_limit: 64,
|
||||
};
|
||||
ip_repr.emit(&mut ip_packet);
|
||||
(ip_packet.header_len(), ip_packet.total_len())
|
||||
}
|
||||
_ => panic!(),
|
||||
};
|
||||
|
||||
ip_buffer[header_len..total_len].copy_from_slice(&tcp_buffer);
|
||||
let packet: &[u8] = &ip_buffer[..total_len];
|
||||
packet.to_vec()
|
||||
}
|
||||
|
||||
fn trace_ip_packet(message: &str, packet: &[u8]) {
|
||||
|
@ -226,3 +378,12 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum RouteResult {
|
||||
/// The packet can be broadcasted to the virtual interfaces
|
||||
Broadcast,
|
||||
/// The packet is not routable so it may be reset.
|
||||
TcpReset(Vec<u8>),
|
||||
/// The packet can be safely ignored.
|
||||
Drop,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue