mirror of
https://github.com/aramperes/onetun.git
synced 2025-09-09 12:18:31 -04:00
Consumption task, and keep-alive
This commit is contained in:
parent
99a0d4370e
commit
492875c392
3 changed files with 183 additions and 11 deletions
|
@ -1,7 +1,7 @@
|
||||||
use anyhow::Context;
|
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
use boringtun::crypto::{X25519PublicKey, X25519SecretKey};
|
use boringtun::crypto::{X25519PublicKey, X25519SecretKey};
|
||||||
use clap::{App, Arg};
|
use clap::{App, Arg};
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ pub struct Config {
|
||||||
pub(crate) endpoint_public_key: Arc<X25519PublicKey>,
|
pub(crate) endpoint_public_key: Arc<X25519PublicKey>,
|
||||||
pub(crate) endpoint_addr: SocketAddr,
|
pub(crate) endpoint_addr: SocketAddr,
|
||||||
pub(crate) source_peer_ip: IpAddr,
|
pub(crate) source_peer_ip: IpAddr,
|
||||||
|
pub(crate) keepalive_seconds: Option<u16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
@ -54,7 +55,13 @@ impl Config {
|
||||||
.takes_value(true)
|
.takes_value(true)
|
||||||
.long("source-peer-ip")
|
.long("source-peer-ip")
|
||||||
.env("ONETUN_SOURCE_PEER_IP")
|
.env("ONETUN_SOURCE_PEER_IP")
|
||||||
.help("The source IP to identify this peer as (local). Example: 192.168.4.3")
|
.help("The source IP to identify this peer as (local). Example: 192.168.4.3"),
|
||||||
|
Arg::with_name("keep-alive")
|
||||||
|
.required(false)
|
||||||
|
.takes_value(true)
|
||||||
|
.long("keep-alive")
|
||||||
|
.env("ONETUN_KEEP_ALIVE")
|
||||||
|
.help("Configures a persistent keep-alive for the WireGuard tunnel, in seconds.")
|
||||||
]).get_matches();
|
]).get_matches();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -74,6 +81,8 @@ impl Config {
|
||||||
.with_context(|| "Invalid endpoint address")?,
|
.with_context(|| "Invalid endpoint address")?,
|
||||||
source_peer_ip: parse_ip(matches.value_of("source-peer-ip"))
|
source_peer_ip: parse_ip(matches.value_of("source-peer-ip"))
|
||||||
.with_context(|| "Invalid source peer IP")?,
|
.with_context(|| "Invalid source peer IP")?,
|
||||||
|
keepalive_seconds: parse_keep_alive(matches.value_of("keep-alive"))
|
||||||
|
.with_context(|| "Invalid keep-alive value")?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,3 +112,17 @@ fn parse_public_key(s: Option<&str>) -> anyhow::Result<X25519PublicKey> {
|
||||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||||
.with_context(|| "Invalid public key")
|
.with_context(|| "Invalid public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_keep_alive(s: Option<&str>) -> anyhow::Result<Option<u16>> {
|
||||||
|
if let Some(s) = s {
|
||||||
|
let parsed: u16 = s.parse().with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Keep-alive must be a number between 0 and {} seconds",
|
||||||
|
u16::MAX
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(Some(parsed))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
32
src/main.rs
32
src/main.rs
|
@ -6,7 +6,7 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use tokio::io::Interest;
|
use tokio::io::Interest;
|
||||||
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::port_pool::PortPool;
|
use crate::port_pool::PortPool;
|
||||||
|
@ -31,24 +31,38 @@ async fn main() -> anyhow::Result<()> {
|
||||||
.with_context(|| "Failed to initialize WireGuard tunnel")?;
|
.with_context(|| "Failed to initialize WireGuard tunnel")?;
|
||||||
let wg = Arc::new(wg);
|
let wg = Arc::new(wg);
|
||||||
|
|
||||||
// Start routine task for WireGuard
|
{
|
||||||
tokio::spawn(async move { Arc::clone(&wg).routine_task().await });
|
// Start routine task for WireGuard
|
||||||
|
let wg = wg.clone();
|
||||||
|
tokio::spawn(async move { wg.routine_task().await });
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Start consumption task for WireGuard
|
||||||
|
let wg = wg.clone();
|
||||||
|
tokio::spawn(async move { wg.consume_task().await });
|
||||||
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Tunnelling [{}]->[{}] (via [{}] as peer {})",
|
"Tunnelling [{}]->[{}] (via [{}] as peer {})",
|
||||||
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
|
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
|
||||||
);
|
);
|
||||||
|
|
||||||
tcp_proxy_server(config.source_addr.clone(), port_pool.clone()).await
|
tcp_proxy_server(config.source_addr.clone(), port_pool.clone(), wg).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Starts the server that listens on TCP connections.
|
/// Starts the server that listens on TCP connections.
|
||||||
async fn tcp_proxy_server(listen_addr: SocketAddr, port_pool: Arc<PortPool>) -> anyhow::Result<()> {
|
async fn tcp_proxy_server(
|
||||||
|
listen_addr: SocketAddr,
|
||||||
|
port_pool: Arc<PortPool>,
|
||||||
|
wg: Arc<WireGuardTunnel>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let listener = TcpListener::bind(listen_addr)
|
let listener = TcpListener::bind(listen_addr)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to listen on TCP proxy server")?;
|
.with_context(|| "Failed to listen on TCP proxy server")?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
let wg = wg.clone();
|
||||||
let port_pool = port_pool.clone();
|
let port_pool = port_pool.clone();
|
||||||
let (socket, peer_addr) = listener
|
let (socket, peer_addr) = listener
|
||||||
.accept()
|
.accept()
|
||||||
|
@ -73,7 +87,7 @@ async fn tcp_proxy_server(listen_addr: SocketAddr, port_pool: Arc<PortPool>) ->
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let port_pool = Arc::clone(&port_pool);
|
let port_pool = Arc::clone(&port_pool);
|
||||||
let result = handle_tcp_proxy_connection(socket, virtual_port).await;
|
let result = handle_tcp_proxy_connection(socket, virtual_port, wg).await;
|
||||||
|
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
error!(
|
error!(
|
||||||
|
@ -91,7 +105,11 @@ async fn tcp_proxy_server(listen_addr: SocketAddr, port_pool: Arc<PortPool>) ->
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handles a new TCP connection with its assigned virtual port.
|
/// Handles a new TCP connection with its assigned virtual port.
|
||||||
async fn handle_tcp_proxy_connection(socket: TcpStream, virtual_port: u16) -> anyhow::Result<()> {
|
async fn handle_tcp_proxy_connection(
|
||||||
|
socket: TcpStream,
|
||||||
|
virtual_port: u16,
|
||||||
|
wg: Arc<WireGuardTunnel>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
loop {
|
loop {
|
||||||
let ready = socket
|
let ready = socket
|
||||||
.ready(Interest::READABLE | Interest::WRITABLE)
|
.ready(Interest::READABLE | Interest::WRITABLE)
|
||||||
|
|
135
src/wg.rs
135
src/wg.rs
|
@ -1,5 +1,4 @@
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
@ -9,6 +8,9 @@ use tokio::net::UdpSocket;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::MAX_PACKET;
|
use crate::MAX_PACKET;
|
||||||
|
|
||||||
|
/// The capacity of the broadcast channel for received IP packets.
|
||||||
|
const BROADCAST_CAPACITY: usize = 1_000_000;
|
||||||
|
|
||||||
pub struct WireGuardTunnel {
|
pub struct WireGuardTunnel {
|
||||||
/// `boringtun` peer/tunnel implementation, used for crypto & WG protocol.
|
/// `boringtun` peer/tunnel implementation, used for crypto & WG protocol.
|
||||||
peer: Box<Tunn>,
|
peer: Box<Tunn>,
|
||||||
|
@ -16,6 +18,8 @@ pub struct WireGuardTunnel {
|
||||||
udp: UdpSocket,
|
udp: UdpSocket,
|
||||||
/// The address of the public WireGuard endpoint (UDP).
|
/// The address of the public WireGuard endpoint (UDP).
|
||||||
endpoint: SocketAddr,
|
endpoint: SocketAddr,
|
||||||
|
/// Broadcast sender for received IP packets.
|
||||||
|
ip_broadcast_tx: tokio::sync::broadcast::Sender<Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WireGuardTunnel {
|
impl WireGuardTunnel {
|
||||||
|
@ -26,14 +30,51 @@ impl WireGuardTunnel {
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
|
.with_context(|| "Failed to create UDP socket for WireGuard connection")?;
|
||||||
let endpoint = config.endpoint_addr;
|
let endpoint = config.endpoint_addr;
|
||||||
|
let (ip_broadcast_tx, _) = tokio::sync::broadcast::channel(BROADCAST_CAPACITY);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
peer,
|
peer,
|
||||||
udp,
|
udp,
|
||||||
endpoint,
|
endpoint,
|
||||||
|
ip_broadcast_tx,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Encapsulates and sends an IP packet through to the WireGuard endpoint.
|
||||||
|
pub async fn send_ip_packet(&self, packet: &[u8]) -> anyhow::Result<()> {
|
||||||
|
let mut send_buf = [0u8; MAX_PACKET];
|
||||||
|
match self.peer.encapsulate(packet, &mut send_buf) {
|
||||||
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
|
self.udp
|
||||||
|
.send_to(packet, self.endpoint)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to send encrypted IP packet to WireGuard endpoint.")?;
|
||||||
|
debug!(
|
||||||
|
"Sent {} bytes to WireGuard endpoint (encrypted IP packet)",
|
||||||
|
packet.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
TunnResult::Err(e) => {
|
||||||
|
error!("Failed to encapsulate IP packet: {:?}", e);
|
||||||
|
}
|
||||||
|
TunnResult::Done => {
|
||||||
|
// Ignored
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
error!(
|
||||||
|
"Unexpected WireGuard state during encapsulation: {:?}",
|
||||||
|
other
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new receiver for broadcasted IP packets, received from the WireGuard endpoint.
|
||||||
|
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Vec<u8>> {
|
||||||
|
self.ip_broadcast_tx.subscribe()
|
||||||
|
}
|
||||||
|
|
||||||
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||||
pub async fn routine_task(&self) -> ! {
|
pub async fn routine_task(&self) -> ! {
|
||||||
trace!("Starting WireGuard routine task");
|
trace!("Starting WireGuard routine task");
|
||||||
|
@ -73,12 +114,86 @@ impl WireGuardTunnel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
|
||||||
|
/// decapsulates them, and broadcasts newly received IP packets.
|
||||||
|
pub async fn consume_task(&self) -> ! {
|
||||||
|
trace!("Starting WireGuard consumption task");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let mut recv_buf = [0u8; MAX_PACKET];
|
||||||
|
let mut send_buf = [0u8; MAX_PACKET];
|
||||||
|
|
||||||
|
let size = match self.udp.recv(&mut recv_buf).await {
|
||||||
|
Ok(size) => size,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to read from WireGuard endpoint: {:?}", e);
|
||||||
|
// Sleep a little bit and try again
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let data = &recv_buf[..size];
|
||||||
|
match self.peer.decapsulate(None, data, &mut send_buf) {
|
||||||
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
|
match self.udp.send_to(packet, self.endpoint).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
loop {
|
||||||
|
let mut send_buf = [0u8; MAX_PACKET];
|
||||||
|
match self.peer.decapsulate(None, &[], &mut send_buf) {
|
||||||
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
|
match self.udp.send_to(packet, self.endpoint).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => {
|
||||||
|
debug!(
|
||||||
|
"WireGuard endpoint sent IP packet of {} bytes",
|
||||||
|
packet.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// For debugging purposes: parse packet
|
||||||
|
trace_ip_packet(packet);
|
||||||
|
|
||||||
|
// Broadcast IP packet
|
||||||
|
match self.ip_broadcast_tx.send(packet.to_vec()) {
|
||||||
|
Ok(n) => {
|
||||||
|
trace!("Broadcasted received IP packet to {} recipients", n);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"Failed to broadcast received IP packet to recipients: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> {
|
fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> {
|
||||||
Tunn::new(
|
Tunn::new(
|
||||||
config.private_key.clone(),
|
config.private_key.clone(),
|
||||||
config.endpoint_public_key.clone(),
|
config.endpoint_public_key.clone(),
|
||||||
None,
|
None,
|
||||||
None,
|
config.keepalive_seconds,
|
||||||
0,
|
0,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
@ -86,3 +201,19 @@ impl WireGuardTunnel {
|
||||||
.with_context(|| "Failed to initialize boringtun Tunn")
|
.with_context(|| "Failed to initialize boringtun Tunn")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn trace_ip_packet(packet: &[u8]) {
|
||||||
|
use smoltcp::wire::*;
|
||||||
|
|
||||||
|
match IpVersion::of_packet(&packet) {
|
||||||
|
Ok(IpVersion::Ipv4) => trace!(
|
||||||
|
"IPv4 packet received: {}",
|
||||||
|
PrettyPrinter::<Ipv4Packet<&mut [u8]>>::new("", &packet)
|
||||||
|
),
|
||||||
|
Ok(IpVersion::Ipv6) => trace!(
|
||||||
|
"IPv6 packet received: {}",
|
||||||
|
PrettyPrinter::<Ipv6Packet<&mut [u8]>>::new("", &packet)
|
||||||
|
),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue