From cb09bb885743603e8a582eed0aabd41fec1fb09a Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Sun, 17 Oct 2021 21:57:41 -0400 Subject: [PATCH 01/17] WIP on UDP and multi-port-forward support --- Cargo.lock | 18 +++++++ Cargo.toml | 1 + src/config.rs | 143 ++++++++++++++++++++++++++++++++++++++++++++------ src/main.rs | 61 ++++++++++++++++----- src/wg.rs | 2 +- 5 files changed, 197 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 27fb027..edcdee8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -488,6 +488,12 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +[[package]] +name = "minimal-lexical" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c64630dcdd71f1a64c435f54885086a0de5d6a12d104d69b165fb7d5286d677" + [[package]] name = "miniz_oxide" version = "0.4.4" @@ -520,6 +526,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "nom" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffd9d26838a953b4af82cbeb9f1592c6798916983959be223a7124e992742c1" +dependencies = [ + "memchr", + "minimal-lexical", + "version_check", +] + [[package]] name = "ntapi" version = "0.3.6" @@ -583,6 +600,7 @@ dependencies = [ "futures", "lockfree", "log", + "nom", "pretty_env_logger", "rand", "smoltcp", diff --git a/Cargo.toml b/Cargo.toml index 2abd992..a6f927f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ tokio = { version = "1", features = ["full"] } lockfree = "0.5.1" futures = "0.3.17" rand = "0.8.4" +nom = "7" diff --git a/src/config.rs b/src/config.rs index 2c1a993..080232d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,6 @@ +use std::collections::HashSet; +use std::convert::TryFrom; +use std::fmt::{Display, Formatter}; use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; use std::sync::Arc; @@ -7,8 +10,7 @@ use clap::{App, Arg}; #[derive(Clone, Debug)] pub struct Config { - pub(crate) source_addr: SocketAddr, - pub(crate) dest_addr: SocketAddr, + pub(crate) port_forwards: Vec, pub(crate) private_key: Arc, pub(crate) endpoint_public_key: Arc, pub(crate) endpoint_addr: SocketAddr, @@ -23,16 +25,12 @@ impl Config { .author("Aram Peres ") .version(env!("CARGO_PKG_VERSION")) .args(&[ - Arg::with_name("SOURCE_ADDR") - .required(true) + Arg::with_name("PORT_FORWARD") + .required(false) + .multiple(true) .takes_value(true) - .env("ONETUN_SOURCE_ADDR") - .help("The source address (IP + port) to forward from. Example: 127.0.0.1:2115"), - Arg::with_name("DESTINATION_ADDR") - .required(true) - .takes_value(true) - .env("ONETUN_DESTINATION_ADDR") - .help("The destination address (IP + port) to forward to. The IP should be a peer registered in the Wireguard endpoint. Example: 192.168.4.2:2116"), + .help("Port forward configurations. The format of each argument is [src_host:]::[:TCP,UDP,...]. \ + Environment variables of the form 'ONETUN_PORT_FORWARD_[#]' are also accepted, where [#] starts at 1."), Arg::with_name("private-key") .required(true) .takes_value(true) @@ -72,11 +70,40 @@ impl Config { .help("Configures the log level and format.") ]).get_matches(); + // Combine `PORT_FORWARD` arg and `ONETUN_PORT_FORWARD_#` strings + let mut port_forward_strings = HashSet::new(); + matches.values_of("PORT_FORWARD").map(|values| { + values + .into_iter() + .map(|v| port_forward_strings.insert(v.to_string())) + .map(|_| ()) + }); + for n in 1.. { + if let Ok(env) = std::env::var(format!("ONETUN_PORT_FORWARD_{}", n)) { + port_forward_strings.insert(env); + } else { + break; + } + } + if port_forward_strings.is_empty() { + return Err(anyhow::anyhow!("No port forward configurations given.")); + } + + // Parse `PORT_FORWARD` strings into `PortForwardConfig` + let port_forwards: Vec>> = port_forward_strings + .into_iter() + .map(|s| PortForwardConfig::from_str(&s)) + .collect(); + let port_forwards: anyhow::Result>> = + port_forwards.into_iter().collect(); + let port_forwards: Vec = port_forwards + .with_context(|| "Failed to parse port forward config")? + .into_iter() + .flatten() + .collect(); + Ok(Self { - source_addr: parse_addr(matches.value_of("SOURCE_ADDR")) - .with_context(|| "Invalid source address")?, - dest_addr: parse_addr(matches.value_of("DESTINATION_ADDR")) - .with_context(|| "Invalid destination address")?, + port_forwards, private_key: Arc::new( parse_private_key(matches.value_of("private-key")) .with_context(|| "Invalid private key")?, @@ -137,3 +164,89 @@ fn parse_keep_alive(s: Option<&str>) -> anyhow::Result> { Ok(None) } } + +#[derive(Debug, Clone, Copy)] +pub struct PortForwardConfig { + /// The source IP and port where the local server will run. + pub source: SocketAddr, + /// The destination IP and port to which traffic will be forwarded. + pub destination: SocketAddr, + /// The transport protocol to use for the port (Layer 4). + pub protocol: PortProtocol, +} + +impl PortForwardConfig { + /// Converts a string representation into `PortForwardConfig`. + /// + /// Sample formats: + /// - `127.0.0.1:8080:192.168.4.1:8081:TCP,UDP` + /// - `127.0.0.1:8080:192.168.4.1:8081:TCP` + /// - `0.0.0.0:8080:192.168.4.1:8081` + /// - `[::1]:8080:192.168.4.1:8081` + /// - `8080:192.168.4.1:8081` + /// - `8080:192.168.4.1:8081:TCP` + /// + /// Implementation Notes: + /// - The format is formalized as `[src_host:]::[:PROTO1,PROTO2,...]` + /// - `src_host` is optional and defaults to `127.0.0.1`. + /// - `src_host` and `dst_host` may be specified as IPv4, IPv6, or a FQDN to be resolved by DNS. + /// - IPv6 addresses must be prefixed with `[` and suffixed with `]`. Example: `[::1]`. + /// - Any `u16` is accepted as `src_port` and `dst_port` + /// - Specifying protocols (`PROTO1,PROTO2,...`) is optional and defaults to `TCP`. Values must be separated by commas. + pub fn from_str<'a>(s: &'a str) -> anyhow::Result> { + use nom::branch::alt; + use nom::bytes::complete::{is_not, take_until, take_while}; + use nom::character::complete::char; + use nom::combinator::opt; + use nom::multi::separated_list0; + use nom::sequence::{delimited, terminated}; + use nom::IResult; + + Err(anyhow::anyhow!("TODO")) + } +} + +impl Display for PortForwardConfig { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}:{}", self.source, self.destination, self.protocol) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum PortProtocol { + Tcp, + Udp, +} + +impl TryFrom<&str> for PortProtocol { + type Error = anyhow::Error; + + fn try_from(value: &str) -> anyhow::Result { + match value.to_uppercase().as_str() { + "TCP" => Ok(Self::Tcp), + "UDP" => Ok(Self::Udp), + _ => Err(anyhow::anyhow!("Invalid protocol specifier: {}", value)), + } + } +} + +impl Display for PortProtocol { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Tcp => "TCP", + Self::Udp => "UDP", + } + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Tests the parsing of `PortForwardConfig`. + fn test_parse_port_forward_config() {} +} diff --git a/src/main.rs b/src/main.rs index 097ea9c..bb3dc34 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState}; use smoltcp::wire::{IpAddress, IpCidr}; use tokio::net::{TcpListener, TcpStream}; -use crate::config::Config; +use crate::config::{Config, PortForwardConfig, PortProtocol}; use crate::port_pool::PortPool; use crate::virtual_device::VirtualIpDevice; use crate::wg::WireGuardTunnel; @@ -54,26 +54,63 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { ip_sink::run_ip_sink_interface(wg).await }); } + { + let port_forwards = config.port_forwards; + let source_peer_ip = config.source_peer_ip; + + futures::future::try_join_all( + port_forwards + .into_iter() + .map(|pf| (pf, wg.clone(), port_pool.clone())) + .map(|(pf, wg, port_pool)| { + tokio::spawn(async move { + port_forward(pf, source_peer_ip, port_pool, wg) + .await + .with_context(|| format!("Port-forward failed: {})", pf)) + }) + }), + ) + .await + .with_context(|| "A port-forward instance failed.") + .map(|_| ()) + } +} + +async fn port_forward( + port_forward: PortForwardConfig, + source_peer_ip: IpAddr, + port_pool: Arc, + wg: Arc, +) -> anyhow::Result<()> { info!( - "Tunnelling [{}]->[{}] (via [{}] as peer {})", - &config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip + "Tunnelling {} [{}]->[{}] (via [{}] as peer {})", + port_forward.protocol, + port_forward.source, + port_forward.destination, + &wg.endpoint, + source_peer_ip ); - tcp_proxy_server( - config.source_addr, - config.source_peer_ip, - config.dest_addr, - port_pool.clone(), - wg, - ) - .await + match port_forward.protocol { + PortProtocol::Tcp => { + tcp_proxy_server( + port_forward.source, + port_forward.destination, + source_peer_ip, + port_pool, + wg, + ) + .await + } + PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")), + } } /// Starts the server that listens on TCP connections. async fn tcp_proxy_server( listen_addr: SocketAddr, - source_peer_ip: IpAddr, dest_addr: SocketAddr, + source_peer_ip: IpAddr, port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { diff --git a/src/wg.rs b/src/wg.rs index 7fc5559..e2740e8 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -24,7 +24,7 @@ pub struct WireGuardTunnel { /// The UDP socket for the public WireGuard endpoint to connect to. udp: UdpSocket, /// The address of the public WireGuard endpoint (UDP). - endpoint: SocketAddr, + pub(crate) endpoint: SocketAddr, /// Maps virtual ports to the corresponding IP packet dispatcher. virtual_port_ip_tx: lockfree::map::Map>>, /// IP packet dispatcher for unroutable packets. `None` if not initialized. From 651ddaec494084129bb63f0a033f0f29e8ec5cea Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Mon, 18 Oct 2021 01:36:40 -0400 Subject: [PATCH 02/17] Implement port-forwarder configuration parsing --- src/config.rs | 279 +++++++++++++++++++++++++++++++++++++++++++++----- src/main.rs | 2 +- 2 files changed, 255 insertions(+), 26 deletions(-) diff --git a/src/config.rs b/src/config.rs index 080232d..3a46881 100644 --- a/src/config.rs +++ b/src/config.rs @@ -29,8 +29,19 @@ impl Config { .required(false) .multiple(true) .takes_value(true) - .help("Port forward configurations. The format of each argument is [src_host:]::[:TCP,UDP,...]. \ - Environment variables of the form 'ONETUN_PORT_FORWARD_[#]' are also accepted, where [#] starts at 1."), + .help("Port forward configurations. The format of each argument is [src_host:]::[:TCP,UDP,...], \ + where [src_host] is the local IP to listen on, is the local port to listen on, is the remote peer IP to forward to, and is the remote port to forward to. \ + Environment variables of the form 'ONETUN_PORT_FORWARD_[#]' are also accepted, where [#] starts at 1.\n\ + Examples:\n\ + \t127.0.0.1:8080:192.168.4.1:8081:TCP,UDP\n\ + \t127.0.0.1:8080:192.168.4.1:8081:TCP\n\ + \t0.0.0.0:8080:192.168.4.1:8081\n\ + \t[::1]:8080:192.168.4.1:8081\n\ + \t8080:192.168.4.1:8081\n\ + \t8080:192.168.4.1:8081:TCP\n\ + \tlocalhost:8080:192.168.4.1:8081:TCP\n\ + \tlocalhost:8080:peer.intranet:8081:TCP\ + "), Arg::with_name("private-key") .required(true) .takes_value(true) @@ -70,14 +81,13 @@ impl Config { .help("Configures the log level and format.") ]).get_matches(); - // Combine `PORT_FORWARD` arg and `ONETUN_PORT_FORWARD_#` strings + // Combine `PORT_FORWARD` arg and `ONETUN_PORT_FORWARD_#` envs let mut port_forward_strings = HashSet::new(); - matches.values_of("PORT_FORWARD").map(|values| { - values - .into_iter() - .map(|v| port_forward_strings.insert(v.to_string())) - .map(|_| ()) - }); + if let Some(values) = matches.values_of("PORT_FORWARD") { + for value in values { + port_forward_strings.insert(value.to_owned()); + } + } for n in 1.. { if let Ok(env) = std::env::var(format!("ONETUN_PORT_FORWARD_{}", n)) { port_forward_strings.insert(env); @@ -90,12 +100,10 @@ impl Config { } // Parse `PORT_FORWARD` strings into `PortForwardConfig` - let port_forwards: Vec>> = port_forward_strings + let port_forwards: anyhow::Result>> = port_forward_strings .into_iter() - .map(|s| PortForwardConfig::from_str(&s)) + .map(|s| PortForwardConfig::from_notation(&s)) .collect(); - let port_forwards: anyhow::Result>> = - port_forwards.into_iter().collect(); let port_forwards: Vec = port_forwards .with_context(|| "Failed to parse port forward config")? .into_iter() @@ -165,7 +173,7 @@ fn parse_keep_alive(s: Option<&str>) -> anyhow::Result> { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct PortForwardConfig { /// The source IP and port where the local server will run. pub source: SocketAddr, @@ -185,6 +193,8 @@ impl PortForwardConfig { /// - `[::1]:8080:192.168.4.1:8081` /// - `8080:192.168.4.1:8081` /// - `8080:192.168.4.1:8081:TCP` + /// - `localhost:8080:192.168.4.1:8081:TCP` + /// - `localhost:8080:peer.intranet:8081:TCP` /// /// Implementation Notes: /// - The format is formalized as `[src_host:]::[:PROTO1,PROTO2,...]` @@ -193,16 +203,126 @@ impl PortForwardConfig { /// - IPv6 addresses must be prefixed with `[` and suffixed with `]`. Example: `[::1]`. /// - Any `u16` is accepted as `src_port` and `dst_port` /// - Specifying protocols (`PROTO1,PROTO2,...`) is optional and defaults to `TCP`. Values must be separated by commas. - pub fn from_str<'a>(s: &'a str) -> anyhow::Result> { - use nom::branch::alt; - use nom::bytes::complete::{is_not, take_until, take_while}; - use nom::character::complete::char; - use nom::combinator::opt; - use nom::multi::separated_list0; - use nom::sequence::{delimited, terminated}; - use nom::IResult; + pub fn from_notation(s: &str) -> anyhow::Result> { + mod parsers { + use nom::branch::alt; + use nom::bytes::complete::is_not; + use nom::character::complete::{alpha1, char, digit1}; + use nom::combinator::{complete, map, opt, success}; + use nom::error::ErrorKind; + use nom::multi::separated_list1; + use nom::sequence::{delimited, preceded, separated_pair, tuple}; + use nom::IResult; - Err(anyhow::anyhow!("TODO")) + fn ipv6(s: &str) -> IResult<&str, &str> { + delimited(char('['), is_not("]"), char(']'))(s) + } + + fn ipv4_or_fqdn(s: &str) -> IResult<&str, &str> { + let s = is_not(":")(s)?; + if s.1.chars().all(|c| c.is_ascii_digit()) { + // If ipv4 or fqdn is all digits, it's not valid. + Err(nom::Err::Error(nom::error::ParseError::from_error_kind( + s.1, + ErrorKind::Fail, + ))) + } else { + Ok(s) + } + } + + fn port(s: &str) -> IResult<&str, &str> { + digit1(s) + } + + fn ip_or_fqdn(s: &str) -> IResult<&str, &str> { + alt((ipv6, ipv4_or_fqdn))(s) + } + + fn no_ip(s: &str) -> IResult<&str, Option<&str>> { + success(None)(s) + } + + fn src_addr(s: &str) -> IResult<&str, (Option<&str>, &str)> { + let with_ip = separated_pair(map(ip_or_fqdn, Some), char(':'), port); + let without_ip = tuple((no_ip, port)); + alt((with_ip, without_ip))(s) + } + + fn dst_addr(s: &str) -> IResult<&str, (&str, &str)> { + separated_pair(ip_or_fqdn, char(':'), port)(s) + } + + fn protocol(s: &str) -> IResult<&str, &str> { + alpha1(s) + } + + fn protocols(s: &str) -> IResult<&str, Option>> { + opt(preceded(char(':'), separated_list1(char(','), protocol)))(s) + } + + #[allow(clippy::type_complexity)] + pub fn port_forward( + s: &str, + ) -> IResult<&str, ((Option<&str>, &str), (), (&str, &str), Option>)> + { + complete(tuple(( + src_addr, + map(char(':'), |_| ()), + dst_addr, + protocols, + )))(s) + } + } + + // TODO: Could improve error management with custom errors, so that the messages are more helpful. + let (src_addr, _, dst_addr, protocols) = parsers::port_forward(s) + .map_err(|e| anyhow::anyhow!("Invalid port-forward definition: {}", e))? + .1; + + let source = ( + src_addr.0.unwrap_or("127.0.0.1"), + src_addr + .1 + .parse::() + .with_context(|| "Invalid source port")?, + ) + .to_socket_addrs() + .with_context(|| "Invalid source address")? + .next() + .with_context(|| "Could not resolve source address")?; + + let destination = ( + dst_addr.0, + dst_addr + .1 + .parse::() + .with_context(|| "Invalid source port")?, + ) + .to_socket_addrs() // TODO: Pass this as given and use DNS config instead (issue #15) + .with_context(|| "Invalid destination address")? + .next() + .with_context(|| "Could not resolve destination address")?; + + // Parse protocols + let protocols = if let Some(protocols) = protocols { + let protocols: anyhow::Result> = + protocols.into_iter().map(PortProtocol::try_from).collect(); + protocols + } else { + Ok(vec![PortProtocol::Tcp]) + } + .with_context(|| "Failed to parse protocols")?; + + // Returns an config for each protocol + Ok(protocols + .into_iter() + .map(|protocol| Self { + source, + destination, + protocol, + }) + .collect()) } } @@ -212,7 +332,7 @@ impl Display for PortForwardConfig { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum PortProtocol { Tcp, Udp, @@ -245,8 +365,117 @@ impl Display for PortProtocol { #[cfg(test)] mod tests { + use std::str::FromStr; + use super::*; /// Tests the parsing of `PortForwardConfig`. - fn test_parse_port_forward_config() {} + #[test] + fn test_parse_port_forward_config_1() { + assert_eq!( + PortForwardConfig::from_notation("192.168.0.1:8080:192.168.4.1:8081:TCP,UDP") + .expect("Failed to parse"), + vec![ + PortForwardConfig { + source: SocketAddr::from_str("192.168.0.1:8080").unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Tcp + }, + PortForwardConfig { + source: SocketAddr::from_str("192.168.0.1:8080").unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Udp + } + ] + ); + } + /// Tests the parsing of `PortForwardConfig`. + #[test] + fn test_parse_port_forward_config_2() { + assert_eq!( + PortForwardConfig::from_notation("192.168.0.1:8080:192.168.4.1:8081:TCP") + .expect("Failed to parse"), + vec![PortForwardConfig { + source: SocketAddr::from_str("192.168.0.1:8080").unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Tcp + }] + ); + } + /// Tests the parsing of `PortForwardConfig`. + #[test] + fn test_parse_port_forward_config_3() { + assert_eq!( + PortForwardConfig::from_notation("0.0.0.0:8080:192.168.4.1:8081") + .expect("Failed to parse"), + vec![PortForwardConfig { + source: SocketAddr::from_str("0.0.0.0:8080").unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Tcp + }] + ); + } + /// Tests the parsing of `PortForwardConfig`. + #[test] + fn test_parse_port_forward_config_4() { + assert_eq!( + PortForwardConfig::from_notation("[::1]:8080:192.168.4.1:8081") + .expect("Failed to parse"), + vec![PortForwardConfig { + source: SocketAddr::from_str("[::1]:8080").unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Tcp + }] + ); + } + /// Tests the parsing of `PortForwardConfig`. + #[test] + fn test_parse_port_forward_config_5() { + assert_eq!( + PortForwardConfig::from_notation("8080:192.168.4.1:8081").expect("Failed to parse"), + vec![PortForwardConfig { + source: SocketAddr::from_str("127.0.0.1:8080").unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Tcp + }] + ); + } + /// Tests the parsing of `PortForwardConfig`. + #[test] + fn test_parse_port_forward_config_6() { + assert_eq!( + PortForwardConfig::from_notation("8080:192.168.4.1:8081:TCP").expect("Failed to parse"), + vec![PortForwardConfig { + source: SocketAddr::from_str("127.0.0.1:8080").unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Tcp + }] + ); + } + /// Tests the parsing of `PortForwardConfig`. + #[test] + fn test_parse_port_forward_config_7() { + assert_eq!( + PortForwardConfig::from_notation("localhost:8080:192.168.4.1:8081") + .expect("Failed to parse"), + vec![PortForwardConfig { + source: "localhost:8080".to_socket_addrs().unwrap().next().unwrap(), + destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(), + protocol: PortProtocol::Tcp + }] + ); + } + /// Tests the parsing of `PortForwardConfig`. + #[test] + fn test_parse_port_forward_config_8() { + assert_eq!( + PortForwardConfig::from_notation("localhost:8080:localhost:8081:TCP") + .expect("Failed to parse"), + vec![PortForwardConfig { + source: "localhost:8080".to_socket_addrs().unwrap().next().unwrap(), + destination: "localhost:8081".to_socket_addrs().unwrap().next().unwrap(), + protocol: PortProtocol::Tcp + }] + ); + } } diff --git a/src/main.rs b/src/main.rs index bb3dc34..0a8cd1d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -66,7 +66,7 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { port_forward(pf, source_peer_ip, port_pool, wg) .await - .with_context(|| format!("Port-forward failed: {})", pf)) + .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }) }), ) From ed835c47d303af68fd02765d867eec3b4965bc2d Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Mon, 18 Oct 2021 03:54:13 -0400 Subject: [PATCH 03/17] Spawn tunnels in entirely separate threads --- src/main.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/main.rs b/src/main.rs index 0a8cd1d..95eb37a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -58,22 +58,22 @@ async fn main() -> anyhow::Result<()> { let port_forwards = config.port_forwards; let source_peer_ip = config.source_peer_ip; - futures::future::try_join_all( - port_forwards - .into_iter() - .map(|pf| (pf, wg.clone(), port_pool.clone())) - .map(|(pf, wg, port_pool)| { - tokio::spawn(async move { + port_forwards + .into_iter() + .map(|pf| (pf, wg.clone(), port_pool.clone())) + .for_each(move |(pf, wg, port_pool)| { + std::thread::spawn(move || { + let cpu_pool = tokio::runtime::Runtime::new().unwrap(); + cpu_pool.block_on(async move { port_forward(pf, source_peer_ip, port_pool, wg) .await .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) - }) - }), - ) - .await - .with_context(|| "A port-forward instance failed.") - .map(|_| ()) + }); + }); + }); } + + futures::future::pending().await } async fn port_forward( From dbced52070e35a48b0298a01ba41c6b4d0264acd Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Mon, 18 Oct 2021 06:03:54 -0400 Subject: [PATCH 04/17] Attempt reconnection in virtual client --- src/main.rs | 50 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/src/main.rs b/src/main.rs index 95eb37a..bd28f96 100644 --- a/src/main.rs +++ b/src/main.rs @@ -212,7 +212,7 @@ async fn handle_tcp_proxy_connection( // Wait for virtual client to be ready. virtual_client_ready_rx .await - .expect("failed to wait for virtual client to be ready"); + .with_context(|| "Virtual client dropped before being ready.")?; trace!("[{}] Virtual client is ready to send data", virtual_port); loop { @@ -351,15 +351,7 @@ async fn virtual_tcp_interface( static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - - socket - .connect( - (IpAddress::from(dest_addr.ip()), dest_addr.port()), - (IpAddress::from(source_peer_ip), virtual_port), - ) - .with_context(|| "Virtual server socket failed to listen")?; - + let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); Ok(socket) }; @@ -372,11 +364,16 @@ async fn virtual_tcp_interface( // Any data that wasn't sent because it was over the sending buffer limit let mut tx_extra = Vec::new(); + // Counts the connection attempts by the virtual client + let mut connection_attempts = 0; + // Whether the client has successfully connected before. Prevents the case of connecting again. + let mut has_connected = false; + loop { let loop_start = smoltcp::time::Instant::now(); // Shutdown occurs when the real client closes the connection, - // or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anmore. + // or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore. // One last poll-loop iteration is executed so that the RST segment can be dispatched. let shutdown = abort.load(Ordering::Relaxed); @@ -403,6 +400,37 @@ async fn virtual_tcp_interface( { let mut client_socket = socket_set.get::(client_handle); + if !shutdown && client_socket.state() == TcpState::Closed && !has_connected { + // Not shutting down, but the client socket is closed, and the client never successfully connected. + if connection_attempts < 10 { + // Try to connect + client_socket + .connect( + (IpAddress::from(dest_addr.ip()), dest_addr.port()), + (IpAddress::from(source_peer_ip), virtual_port), + ) + .with_context(|| "Virtual server socket failed to listen")?; + if connection_attempts > 0 { + debug!( + "[{}] Virtual client retrying connection in 500ms", + virtual_port + ); + // Not our first connection attempt, wait a little bit. + tokio::time::sleep(Duration::from_millis(500)).await; + } + } else { + // Too many connection attempts + abort.store(true, Ordering::Relaxed); + } + connection_attempts += 1; + continue; + } + + if client_socket.state() == TcpState::Established { + // Prevent reconnection if the server later closes. + has_connected = true; + } + if client_socket.can_recv() { match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { Ok(data) => { From 070c0f516242f805232ece780b08846163d324fa Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Mon, 18 Oct 2021 22:13:13 -0400 Subject: [PATCH 05/17] Use Vec instead of static mut for socket storage. Update smoltcp to fix #17 --- Cargo.lock | 3 ++- src/main.rs | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index edcdee8..e87a015 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -864,13 +864,14 @@ checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" [[package]] name = "smoltcp" version = "0.8.0" -source = "git+https://github.com/smoltcp-rs/smoltcp?branch=master#35e833e33dfd3e4efc3eb7d5de06bec17c54b011" +source = "git+https://github.com/smoltcp-rs/smoltcp?branch=master#25c539bb7c96789270f032ede2a967cf0fe5cf57" dependencies = [ "bitflags", "byteorder", "libc", "log", "managed", + "rand_core", ] [[package]] diff --git a/src/main.rs b/src/main.rs index bd28f96..a117a99 100644 --- a/src/main.rs +++ b/src/main.rs @@ -347,10 +347,10 @@ async fn virtual_tcp_interface( }; let client_socket: anyhow::Result = { - static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); + let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); Ok(socket) }; From c2d0b9719a2264ae6231dd52c53b030c7a5c1076 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 19 Oct 2021 00:43:59 -0400 Subject: [PATCH 06/17] Refactor TCP virtual interface code out of main. Removed unused server socket buffer. --- Cargo.lock | 12 ++ Cargo.toml | 1 + src/main.rs | 278 ++++----------------------------------- src/virtual_iface/mod.rs | 10 ++ src/virtual_iface/tcp.rs | 274 ++++++++++++++++++++++++++++++++++++++ src/wg.rs | 2 +- 6 files changed, 322 insertions(+), 255 deletions(-) create mode 100644 src/virtual_iface/mod.rs create mode 100644 src/virtual_iface/tcp.rs diff --git a/Cargo.lock b/Cargo.lock index e87a015..f3ba00d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,17 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e" +[[package]] +name = "async-trait" +version = "0.1.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44318e776df68115a881de9a8fd1b9e53368d7a4a5ce4cc48517da3393233a5e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atty" version = "0.2.14" @@ -595,6 +606,7 @@ name = "onetun" version = "0.1.11" dependencies = [ "anyhow", + "async-trait", "boringtun", "clap", "futures", diff --git a/Cargo.toml b/Cargo.toml index a6f927f..d552c2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,4 @@ lockfree = "0.5.1" futures = "0.3.17" rand = "0.8.4" nom = "7" +async-trait = "0.1.51" diff --git a/src/main.rs b/src/main.rs index a117a99..28846fa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,26 +1,24 @@ #[macro_use] extern crate log; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::time::Duration; use anyhow::Context; -use smoltcp::iface::InterfaceBuilder; -use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState}; -use smoltcp::wire::{IpAddress, IpCidr}; use tokio::net::{TcpListener, TcpStream}; use crate::config::{Config, PortForwardConfig, PortProtocol}; use crate::port_pool::PortPool; -use crate::virtual_device::VirtualIpDevice; +use crate::virtual_iface::tcp::TcpVirtualInterface; +use crate::virtual_iface::VirtualInterfacePoll; use crate::wg::WireGuardTunnel; pub mod config; pub mod ip_sink; pub mod port_pool; pub mod virtual_device; +pub mod virtual_iface; pub mod wg; pub const MAX_PACKET: usize = 65536; @@ -92,29 +90,18 @@ async fn port_forward( ); match port_forward.protocol { - PortProtocol::Tcp => { - tcp_proxy_server( - port_forward.source, - port_forward.destination, - source_peer_ip, - port_pool, - wg, - ) - .await - } + PortProtocol::Tcp => tcp_proxy_server(port_forward, port_pool, wg).await, PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")), } } /// Starts the server that listens on TCP connections. async fn tcp_proxy_server( - listen_addr: SocketAddr, - dest_addr: SocketAddr, - source_peer_ip: IpAddr, + port_forward: PortForwardConfig, port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { - let listener = TcpListener::bind(listen_addr) + let listener = TcpListener::bind(port_forward.source) .await .with_context(|| "Failed to listen on TCP proxy server")?; @@ -144,14 +131,8 @@ async fn tcp_proxy_server( tokio::spawn(async move { let port_pool = Arc::clone(&port_pool); - let result = handle_tcp_proxy_connection( - socket, - virtual_port, - source_peer_ip, - dest_addr, - wg.clone(), - ) - .await; + let result = + handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await; if let Err(e) = result { error!( @@ -173,8 +154,7 @@ async fn tcp_proxy_server( async fn handle_tcp_proxy_connection( socket: TcpStream, virtual_port: u16, - source_peer_ip: IpAddr, - dest_addr: SocketAddr, + port_forward: PortForwardConfig, wg: Arc, ) -> anyhow::Result<()> { // Abort signal for stopping the Virtual Interface @@ -194,18 +174,21 @@ async fn handle_tcp_proxy_connection( // Spawn virtual interface { let abort = abort.clone(); + let virtual_interface = TcpVirtualInterface::new( + virtual_port, + port_forward, + wg, + abort.clone(), + data_to_real_client_tx, + data_to_virtual_server_rx, + virtual_client_ready_tx, + ); + tokio::spawn(async move { - virtual_tcp_interface( - virtual_port, - source_peer_ip, - dest_addr, - wg, - abort, - data_to_real_client_tx, - data_to_virtual_server_rx, - virtual_client_ready_tx, - ) - .await + virtual_interface.poll_loop().await.unwrap_or_else(|e| { + error!("Virtual interface poll loop failed unexpectedly: {}", e); + abort.store(true, Ordering::Relaxed); + }) }); } @@ -297,219 +280,6 @@ async fn handle_tcp_proxy_connection( Ok(()) } -#[allow(clippy::too_many_arguments)] -async fn virtual_tcp_interface( - virtual_port: u16, - source_peer_ip: IpAddr, - dest_addr: SocketAddr, - wg: Arc, - abort: Arc, - data_to_real_client_tx: tokio::sync::mpsc::Sender>, - mut data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, - virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, -) -> anyhow::Result<()> { - let mut virtual_client_ready_tx = Some(virtual_client_ready_tx); - - // Create a device and interface to simulate IP packets - // In essence: - // * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client' - // * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel - // * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface' - // * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded) - // * The TCP data read by the 'virtual client' is sent to the 'real' TCP client - - // Consumer for IP packets to send through the virtual interface - // Initialize the interface - let device = VirtualIpDevice::new(virtual_port, wg) - .with_context(|| "Failed to initialize VirtualIpDevice")?; - let mut virtual_interface = InterfaceBuilder::new(device) - .ip_addrs([ - // Interface handles IP packets for the sender and recipient - IpCidr::new(IpAddress::from(source_peer_ip), 32), - IpCidr::new(IpAddress::from(dest_addr.ip()), 32), - ]) - .finalize(); - - // Server socket: this is a placeholder for the interface to route new connections to. - // TODO: Determine if we even need buffers here. - let server_socket: anyhow::Result = { - static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - - socket - .listen((IpAddress::from(dest_addr.ip()), dest_addr.port())) - .with_context(|| "Virtual server socket failed to listen")?; - - Ok(socket) - }; - - let client_socket: anyhow::Result = { - let rx_data = vec![0u8; MAX_PACKET]; - let tx_data = vec![0u8; MAX_PACKET]; - let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); - let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); - let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - Ok(socket) - }; - - // Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server. - let mut socket_set_entries: [_; 2] = Default::default(); - let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); - let _server_handle = socket_set.add(server_socket?); - let client_handle = socket_set.add(client_socket?); - - // Any data that wasn't sent because it was over the sending buffer limit - let mut tx_extra = Vec::new(); - - // Counts the connection attempts by the virtual client - let mut connection_attempts = 0; - // Whether the client has successfully connected before. Prevents the case of connecting again. - let mut has_connected = false; - - loop { - let loop_start = smoltcp::time::Instant::now(); - - // Shutdown occurs when the real client closes the connection, - // or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore. - // One last poll-loop iteration is executed so that the RST segment can be dispatched. - let shutdown = abort.load(Ordering::Relaxed); - - if shutdown { - // Shutdown: sends a RST packet. - trace!("[{}] Shutting down virtual interface", virtual_port); - let mut client_socket = socket_set.get::(client_handle); - client_socket.abort(); - } - - match virtual_interface.poll(&mut socket_set, loop_start) { - Ok(processed) if processed => { - trace!( - "[{}] Virtual interface polled some packets to be processed", - virtual_port - ); - } - Err(e) => { - error!("[{}] Virtual interface poll error: {:?}", virtual_port, e); - } - _ => {} - } - - { - let mut client_socket = socket_set.get::(client_handle); - - if !shutdown && client_socket.state() == TcpState::Closed && !has_connected { - // Not shutting down, but the client socket is closed, and the client never successfully connected. - if connection_attempts < 10 { - // Try to connect - client_socket - .connect( - (IpAddress::from(dest_addr.ip()), dest_addr.port()), - (IpAddress::from(source_peer_ip), virtual_port), - ) - .with_context(|| "Virtual server socket failed to listen")?; - if connection_attempts > 0 { - debug!( - "[{}] Virtual client retrying connection in 500ms", - virtual_port - ); - // Not our first connection attempt, wait a little bit. - tokio::time::sleep(Duration::from_millis(500)).await; - } - } else { - // Too many connection attempts - abort.store(true, Ordering::Relaxed); - } - connection_attempts += 1; - continue; - } - - if client_socket.state() == TcpState::Established { - // Prevent reconnection if the server later closes. - has_connected = true; - } - - if client_socket.can_recv() { - match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { - Ok(data) => { - trace!( - "[{}] Virtual client received {} bytes of data", - virtual_port, - data.len() - ); - // Send it to the real client - if let Err(e) = data_to_real_client_tx.send(data).await { - error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e); - } - } - Err(e) => { - error!( - "[{}] Failed to read from virtual client socket: {:?}", - virtual_port, e - ); - } - } - } - if client_socket.can_send() { - if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() { - virtual_client_ready_tx - .send(()) - .expect("Failed to notify real client that virtual client is ready"); - } - - let mut to_transfer = None; - - if tx_extra.is_empty() { - // The payload segment from the previous loop is complete, - // we can now read the next payload in the queue. - if let Ok(data) = data_to_virtual_server_rx.try_recv() { - to_transfer = Some(data); - } else if client_socket.state() == TcpState::CloseWait { - // No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN), - // the interface is shutdown. - trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", virtual_port); - abort.store(true, Ordering::Relaxed); - continue; - } - } - - let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice(); - if !to_transfer_slice.is_empty() { - let total = to_transfer_slice.len(); - match client_socket.send_slice(to_transfer_slice) { - Ok(sent) => { - trace!( - "[{}] Sent {}/{} bytes via virtual client socket", - virtual_port, - sent, - total, - ); - tx_extra = Vec::from(&to_transfer_slice[sent..total]); - } - Err(e) => { - error!( - "[{}] Failed to send slice via virtual client socket: {:?}", - virtual_port, e - ); - } - } - } - } - } - - if shutdown { - break; - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - trace!("[{}] Virtual interface task terminated", virtual_port); - abort.store(true, Ordering::Relaxed); - Ok(()) -} - fn init_logger(config: &Config) -> anyhow::Result<()> { let mut builder = pretty_env_logger::formatted_builder(); builder.parse_filters(&config.log); diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs new file mode 100644 index 0000000..b9d3354 --- /dev/null +++ b/src/virtual_iface/mod.rs @@ -0,0 +1,10 @@ +pub mod tcp; + +use async_trait::async_trait; + +#[async_trait] +pub trait VirtualInterfacePoll { + /// Initializes the virtual interface and processes incoming data to be dispatched + /// to the WireGuard tunnel and to the real client. + async fn poll_loop(mut self) -> anyhow::Result<()>; +} diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs new file mode 100644 index 0000000..c9fee95 --- /dev/null +++ b/src/virtual_iface/tcp.rs @@ -0,0 +1,274 @@ +use crate::config::PortForwardConfig; +use crate::virtual_device::VirtualIpDevice; +use crate::virtual_iface::VirtualInterfacePoll; +use crate::wg::WireGuardTunnel; +use anyhow::Context; +use async_trait::async_trait; +use smoltcp::iface::InterfaceBuilder; +use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState}; +use smoltcp::wire::{IpAddress, IpCidr}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +const MAX_PACKET: usize = 65536; + +/// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa. +pub struct TcpVirtualInterface { + /// The virtual port assigned to the virtual client, used to + /// route Layer 4 segments/datagrams to and from the WireGuard tunnel. + virtual_port: u16, + /// The overall port-forward configuration: used for the destination address (on which + /// the virtual server listens) and the protocol in use. + port_forward: PortForwardConfig, + /// The WireGuard tunnel to send IP packets to. + wg: Arc, + /// Abort signal to shutdown the virtual interface and its parent task. + abort: Arc, + /// Channel sender for pushing Layer 7 data back to the real client. + data_to_real_client_tx: tokio::sync::mpsc::Sender>, + /// Channel receiver for processing Layer 7 data through the virtual interface. + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, + /// One-shot sender to notify the parent task that the virtual client is ready to send Layer 7 data. + virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, +} + +impl TcpVirtualInterface { + /// Initialize the parameters for a new virtual interface. + /// Use the `poll_loop()` future to start the virtual interface poll loop. + pub fn new( + virtual_port: u16, + port_forward: PortForwardConfig, + wg: Arc, + abort: Arc, + data_to_real_client_tx: tokio::sync::mpsc::Sender>, + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver>, + virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>, + ) -> Self { + Self { + virtual_port, + port_forward, + wg, + abort, + data_to_real_client_tx, + data_to_virtual_server_rx, + virtual_client_ready_tx, + } + } +} + +#[async_trait] +impl VirtualInterfacePoll for TcpVirtualInterface { + async fn poll_loop(self) -> anyhow::Result<()> { + let mut virtual_client_ready_tx = Some(self.virtual_client_ready_tx); + let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; + let source_peer_ip = self.wg.source_peer_ip; + + // Create a device and interface to simulate IP packets + // In essence: + // * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client' + // * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel + // * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface' + // * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded) + // * The TCP data read by the 'virtual client' is sent to the 'real' TCP client + + // Consumer for IP packets to send through the virtual interface + // Initialize the interface + let device = VirtualIpDevice::new(self.virtual_port, self.wg) + .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; + let mut virtual_interface = InterfaceBuilder::new(device) + .ip_addrs([ + // Interface handles IP packets for the sender and recipient + IpCidr::new(IpAddress::from(source_peer_ip), 32), + IpCidr::new(IpAddress::from(self.port_forward.destination.ip()), 32), + ]) + .finalize(); + + // Server socket: this is a placeholder for the interface to route new connections to. + let server_socket: anyhow::Result = { + static mut TCP_SERVER_RX_DATA: [u8; 0] = []; + static mut TCP_SERVER_TX_DATA: [u8; 0] = []; + let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + + socket + .listen(( + IpAddress::from(self.port_forward.destination.ip()), + self.port_forward.destination.port(), + )) + .with_context(|| "Virtual server socket failed to listen")?; + + Ok(socket) + }; + + let client_socket: anyhow::Result = { + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let tcp_rx_buffer = TcpSocketBuffer::new(rx_data); + let tcp_tx_buffer = TcpSocketBuffer::new(tx_data); + let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + Ok(socket) + }; + + // Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server. + let mut socket_set_entries: [_; 2] = Default::default(); + let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); + let _server_handle = socket_set.add(server_socket?); + let client_handle = socket_set.add(client_socket?); + + // Any data that wasn't sent because it was over the sending buffer limit + let mut tx_extra = Vec::new(); + + // Counts the connection attempts by the virtual client + let mut connection_attempts = 0; + // Whether the client has successfully connected before. Prevents the case of connecting again. + let mut has_connected = false; + + loop { + let loop_start = smoltcp::time::Instant::now(); + + // Shutdown occurs when the real client closes the connection, + // or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore. + // One last poll-loop iteration is executed so that the RST segment can be dispatched. + let shutdown = self.abort.load(Ordering::Relaxed); + + if shutdown { + // Shutdown: sends a RST packet. + trace!("[{}] Shutting down virtual interface", self.virtual_port); + let mut client_socket = socket_set.get::(client_handle); + client_socket.abort(); + } + + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + trace!( + "[{}] Virtual interface polled some packets to be processed", + self.virtual_port + ); + } + Err(e) => { + error!( + "[{}] Virtual interface poll error: {:?}", + self.virtual_port, e + ); + } + _ => {} + } + + { + let mut client_socket = socket_set.get::(client_handle); + + if !shutdown && client_socket.state() == TcpState::Closed && !has_connected { + // Not shutting down, but the client socket is closed, and the client never successfully connected. + if connection_attempts < 10 { + // Try to connect + client_socket + .connect( + ( + IpAddress::from(self.port_forward.destination.ip()), + self.port_forward.destination.port(), + ), + (IpAddress::from(source_peer_ip), self.virtual_port), + ) + .with_context(|| "Virtual server socket failed to listen")?; + if connection_attempts > 0 { + debug!( + "[{}] Virtual client retrying connection in 500ms", + self.virtual_port + ); + // Not our first connection attempt, wait a little bit. + tokio::time::sleep(Duration::from_millis(500)).await; + } + } else { + // Too many connection attempts + self.abort.store(true, Ordering::Relaxed); + } + connection_attempts += 1; + continue; + } + + if client_socket.state() == TcpState::Established { + // Prevent reconnection if the server later closes. + has_connected = true; + } + + if client_socket.can_recv() { + match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { + Ok(data) => { + trace!( + "[{}] Virtual client received {} bytes of data", + self.virtual_port, + data.len() + ); + // Send it to the real client + if let Err(e) = self.data_to_real_client_tx.send(data).await { + error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", self.virtual_port, e); + } + } + Err(e) => { + error!( + "[{}] Failed to read from virtual client socket: {:?}", + self.virtual_port, e + ); + } + } + } + if client_socket.can_send() { + if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() { + virtual_client_ready_tx + .send(()) + .expect("Failed to notify real client that virtual client is ready"); + } + + let mut to_transfer = None; + + if tx_extra.is_empty() { + // The payload segment from the previous loop is complete, + // we can now read the next payload in the queue. + if let Ok(data) = data_to_virtual_server_rx.try_recv() { + to_transfer = Some(data); + } else if client_socket.state() == TcpState::CloseWait { + // No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN), + // the interface is shutdown. + trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", self.virtual_port); + self.abort.store(true, Ordering::Relaxed); + continue; + } + } + + let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice(); + if !to_transfer_slice.is_empty() { + let total = to_transfer_slice.len(); + match client_socket.send_slice(to_transfer_slice) { + Ok(sent) => { + trace!( + "[{}] Sent {}/{} bytes via virtual client socket", + self.virtual_port, + sent, + total, + ); + tx_extra = Vec::from(&to_transfer_slice[sent..total]); + } + Err(e) => { + error!( + "[{}] Failed to send slice via virtual client socket: {:?}", + self.virtual_port, e + ); + } + } + } + } + } + + if shutdown { + break; + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + trace!("[{}] Virtual interface task terminated", self.virtual_port); + self.abort.store(true, Ordering::Relaxed); + Ok(()) + } +} diff --git a/src/wg.rs b/src/wg.rs index e2740e8..3d3cc5f 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -18,7 +18,7 @@ const DISPATCH_CAPACITY: usize = 1_000; /// to be sent to and received from a remote UDP endpoint. /// This tunnel supports at most 1 peer IP at a time, but supports simultaneous ports. pub struct WireGuardTunnel { - source_peer_ip: IpAddr, + pub(crate) source_peer_ip: IpAddr, /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. peer: Box, /// The UDP socket for the public WireGuard endpoint to connect to. From 703f2613449c6817bac00eadae61d7519e1802ca Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 19 Oct 2021 01:00:05 -0400 Subject: [PATCH 07/17] Move TCP tunneling code to separate module --- README.md | 2 +- src/main.rs | 218 +--------------------------------------------- src/tunnel/mod.rs | 29 ++++++ src/tunnel/tcp.rs | 196 +++++++++++++++++++++++++++++++++++++++++ src/wg.rs | 2 +- 5 files changed, 230 insertions(+), 217 deletions(-) create mode 100644 src/tunnel/mod.rs create mode 100644 src/tunnel/tcp.rs diff --git a/README.md b/README.md index f652574..1defff3 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ local port, say `127.0.0.1:8080`, that will tunnel through WireGuard to reach th You'll then see this log: ``` -INFO onetun > Tunnelling [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +INFO onetun > Tunneling [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) ``` Which means you can now access the port locally! diff --git a/src/main.rs b/src/main.rs index 28846fa..b5da3ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,28 +1,22 @@ #[macro_use] extern crate log; -use std::net::IpAddr; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use anyhow::Context; -use tokio::net::{TcpListener, TcpStream}; -use crate::config::{Config, PortForwardConfig, PortProtocol}; +use crate::config::Config; use crate::port_pool::PortPool; -use crate::virtual_iface::tcp::TcpVirtualInterface; -use crate::virtual_iface::VirtualInterfacePoll; use crate::wg::WireGuardTunnel; pub mod config; pub mod ip_sink; pub mod port_pool; +pub mod tunnel; pub mod virtual_device; pub mod virtual_iface; pub mod wg; -pub const MAX_PACKET: usize = 65536; - #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Config::from_args().with_context(|| "Failed to read config")?; @@ -63,7 +57,7 @@ async fn main() -> anyhow::Result<()> { std::thread::spawn(move || { let cpu_pool = tokio::runtime::Runtime::new().unwrap(); cpu_pool.block_on(async move { - port_forward(pf, source_peer_ip, port_pool, wg) + tunnel::port_forward(pf, source_peer_ip, port_pool, wg) .await .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }); @@ -74,212 +68,6 @@ async fn main() -> anyhow::Result<()> { futures::future::pending().await } -async fn port_forward( - port_forward: PortForwardConfig, - source_peer_ip: IpAddr, - port_pool: Arc, - wg: Arc, -) -> anyhow::Result<()> { - info!( - "Tunnelling {} [{}]->[{}] (via [{}] as peer {})", - port_forward.protocol, - port_forward.source, - port_forward.destination, - &wg.endpoint, - source_peer_ip - ); - - match port_forward.protocol { - PortProtocol::Tcp => tcp_proxy_server(port_forward, port_pool, wg).await, - PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")), - } -} - -/// Starts the server that listens on TCP connections. -async fn tcp_proxy_server( - port_forward: PortForwardConfig, - port_pool: Arc, - wg: Arc, -) -> anyhow::Result<()> { - let listener = TcpListener::bind(port_forward.source) - .await - .with_context(|| "Failed to listen on TCP proxy server")?; - - loop { - let wg = wg.clone(); - let port_pool = port_pool.clone(); - let (socket, peer_addr) = listener - .accept() - .await - .with_context(|| "Failed to accept connection on TCP proxy server")?; - - // Assign a 'virtual port': this is a unique port number used to route IP packets - // received from the WireGuard tunnel. It is the port number that the virtual client will - // listen on. - let virtual_port = match port_pool.next() { - Ok(port) => port, - Err(e) => { - error!( - "Failed to assign virtual port number for connection [{}]: {:?}", - peer_addr, e - ); - continue; - } - }; - - info!("[{}] Incoming connection from {}", virtual_port, peer_addr); - - tokio::spawn(async move { - let port_pool = Arc::clone(&port_pool); - let result = - handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await; - - if let Err(e) = result { - error!( - "[{}] Connection dropped un-gracefully: {:?}", - virtual_port, e - ); - } else { - info!("[{}] Connection closed by client", virtual_port); - } - - // Release port when connection drops - wg.release_virtual_interface(virtual_port); - port_pool.release(virtual_port); - }); - } -} - -/// Handles a new TCP connection with its assigned virtual port. -async fn handle_tcp_proxy_connection( - socket: TcpStream, - virtual_port: u16, - port_forward: PortForwardConfig, - wg: Arc, -) -> anyhow::Result<()> { - // Abort signal for stopping the Virtual Interface - let abort = Arc::new(AtomicBool::new(false)); - - // Signals that the Virtual Client is ready to send data - let (virtual_client_ready_tx, virtual_client_ready_rx) = tokio::sync::oneshot::channel::<()>(); - - // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back - // to the real client. - let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000); - - // data_to_real_server_(tx/rx): This task sends the data received from the real client to the - // virtual interface (virtual server socket). - let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000); - - // Spawn virtual interface - { - let abort = abort.clone(); - let virtual_interface = TcpVirtualInterface::new( - virtual_port, - port_forward, - wg, - abort.clone(), - data_to_real_client_tx, - data_to_virtual_server_rx, - virtual_client_ready_tx, - ); - - tokio::spawn(async move { - virtual_interface.poll_loop().await.unwrap_or_else(|e| { - error!("Virtual interface poll loop failed unexpectedly: {}", e); - abort.store(true, Ordering::Relaxed); - }) - }); - } - - // Wait for virtual client to be ready. - virtual_client_ready_rx - .await - .with_context(|| "Virtual client dropped before being ready.")?; - trace!("[{}] Virtual client is ready to send data", virtual_port); - - loop { - tokio::select! { - readable_result = socket.readable() => { - match readable_result { - Ok(_) => { - // Buffer for the individual TCP segment. - let mut buffer = Vec::with_capacity(MAX_PACKET); - match socket.try_read_buf(&mut buffer) { - Ok(size) if size > 0 => { - let data = &buffer[..size]; - debug!( - "[{}] Read {} bytes of TCP data from real client", - virtual_port, size - ); - if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await { - error!( - "[{}] Failed to dispatch data to virtual interface: {:?}", - virtual_port, e - ); - } - } - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - continue; - } - Err(e) => { - error!( - "[{}] Failed to read from client TCP socket: {:?}", - virtual_port, e - ); - break; - } - _ => { - break; - } - } - } - Err(e) => { - error!("[{}] Failed to check if readable: {:?}", virtual_port, e); - break; - } - } - } - data_recv_result = data_to_real_client_rx.recv() => { - match data_recv_result { - Some(data) => match socket.try_write(&data) { - Ok(size) => { - debug!( - "[{}] Wrote {} bytes of TCP data to real client", - virtual_port, size - ); - } - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - if abort.load(Ordering::Relaxed) { - break; - } else { - continue; - } - } - Err(e) => { - error!( - "[{}] Failed to write to client TCP socket: {:?}", - virtual_port, e - ); - } - }, - None => { - if abort.load(Ordering::Relaxed) { - break; - } else { - continue; - } - }, - } - } - } - } - - trace!("[{}] TCP socket handler task terminated", virtual_port); - abort.store(true, Ordering::Relaxed); - Ok(()) -} - fn init_logger(config: &Config) -> anyhow::Result<()> { let mut builder = pretty_env_logger::formatted_builder(); builder.parse_filters(&config.log); diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs new file mode 100644 index 0000000..7eab856 --- /dev/null +++ b/src/tunnel/mod.rs @@ -0,0 +1,29 @@ +use std::net::IpAddr; +use std::sync::Arc; + +use crate::config::{PortForwardConfig, PortProtocol}; +use crate::port_pool::PortPool; +use crate::wg::WireGuardTunnel; + +mod tcp; + +pub async fn port_forward( + port_forward: PortForwardConfig, + source_peer_ip: IpAddr, + port_pool: Arc, + wg: Arc, +) -> anyhow::Result<()> { + info!( + "Tunneling {} [{}]->[{}] (via [{}] as peer {})", + port_forward.protocol, + port_forward.source, + port_forward.destination, + &wg.endpoint, + source_peer_ip + ); + + match port_forward.protocol { + PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, port_pool, wg).await, + PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")), + } +} diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs new file mode 100644 index 0000000..987aa5e --- /dev/null +++ b/src/tunnel/tcp.rs @@ -0,0 +1,196 @@ +use crate::config::PortForwardConfig; +use crate::port_pool::PortPool; +use crate::virtual_iface::tcp::TcpVirtualInterface; +use crate::virtual_iface::VirtualInterfacePoll; +use crate::wg::WireGuardTunnel; +use anyhow::Context; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use tokio::net::{TcpListener, TcpStream}; + +const MAX_PACKET: usize = 65536; + +/// Starts the server that listens on TCP connections. +pub async fn tcp_proxy_server( + port_forward: PortForwardConfig, + port_pool: Arc, + wg: Arc, +) -> anyhow::Result<()> { + let listener = TcpListener::bind(port_forward.source) + .await + .with_context(|| "Failed to listen on TCP proxy server")?; + + loop { + let wg = wg.clone(); + let port_pool = port_pool.clone(); + let (socket, peer_addr) = listener + .accept() + .await + .with_context(|| "Failed to accept connection on TCP proxy server")?; + + // Assign a 'virtual port': this is a unique port number used to route IP packets + // received from the WireGuard tunnel. It is the port number that the virtual client will + // listen on. + let virtual_port = match port_pool.next() { + Ok(port) => port, + Err(e) => { + error!( + "Failed to assign virtual port number for connection [{}]: {:?}", + peer_addr, e + ); + continue; + } + }; + + info!("[{}] Incoming connection from {}", virtual_port, peer_addr); + + tokio::spawn(async move { + let port_pool = Arc::clone(&port_pool); + let result = + handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await; + + if let Err(e) = result { + error!( + "[{}] Connection dropped un-gracefully: {:?}", + virtual_port, e + ); + } else { + info!("[{}] Connection closed by client", virtual_port); + } + + // Release port when connection drops + wg.release_virtual_interface(virtual_port); + port_pool.release(virtual_port); + }); + } +} + +/// Handles a new TCP connection with its assigned virtual port. +async fn handle_tcp_proxy_connection( + socket: TcpStream, + virtual_port: u16, + port_forward: PortForwardConfig, + wg: Arc, +) -> anyhow::Result<()> { + // Abort signal for stopping the Virtual Interface + let abort = Arc::new(AtomicBool::new(false)); + + // Signals that the Virtual Client is ready to send data + let (virtual_client_ready_tx, virtual_client_ready_rx) = tokio::sync::oneshot::channel::<()>(); + + // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back + // to the real client. + let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000); + + // data_to_real_server_(tx/rx): This task sends the data received from the real client to the + // virtual interface (virtual server socket). + let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000); + + // Spawn virtual interface + { + let abort = abort.clone(); + let virtual_interface = TcpVirtualInterface::new( + virtual_port, + port_forward, + wg, + abort.clone(), + data_to_real_client_tx, + data_to_virtual_server_rx, + virtual_client_ready_tx, + ); + + tokio::spawn(async move { + virtual_interface.poll_loop().await.unwrap_or_else(|e| { + error!("Virtual interface poll loop failed unexpectedly: {}", e); + abort.store(true, Ordering::Relaxed); + }) + }); + } + + // Wait for virtual client to be ready. + virtual_client_ready_rx + .await + .with_context(|| "Virtual client dropped before being ready.")?; + trace!("[{}] Virtual client is ready to send data", virtual_port); + + loop { + tokio::select! { + readable_result = socket.readable() => { + match readable_result { + Ok(_) => { + // Buffer for the individual TCP segment. + let mut buffer = Vec::with_capacity(MAX_PACKET); + match socket.try_read_buf(&mut buffer) { + Ok(size) if size > 0 => { + let data = &buffer[..size]; + debug!( + "[{}] Read {} bytes of TCP data from real client", + virtual_port, size + ); + if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await { + error!( + "[{}] Failed to dispatch data to virtual interface: {:?}", + virtual_port, e + ); + } + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + error!( + "[{}] Failed to read from client TCP socket: {:?}", + virtual_port, e + ); + break; + } + _ => { + break; + } + } + } + Err(e) => { + error!("[{}] Failed to check if readable: {:?}", virtual_port, e); + break; + } + } + } + data_recv_result = data_to_real_client_rx.recv() => { + match data_recv_result { + Some(data) => match socket.try_write(&data) { + Ok(size) => { + debug!( + "[{}] Wrote {} bytes of TCP data to real client", + virtual_port, size + ); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + if abort.load(Ordering::Relaxed) { + break; + } else { + continue; + } + } + Err(e) => { + error!( + "[{}] Failed to write to client TCP socket: {:?}", + virtual_port, e + ); + } + }, + None => { + if abort.load(Ordering::Relaxed) { + break; + } else { + continue; + } + }, + } + } + } + } + + trace!("[{}] TCP socket handler task terminated", virtual_port); + abort.store(true, Ordering::Relaxed); + Ok(()) +} diff --git a/src/wg.rs b/src/wg.rs index 3d3cc5f..498fb98 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -9,10 +9,10 @@ use tokio::net::UdpSocket; use tokio::sync::RwLock; use crate::config::Config; -use crate::MAX_PACKET; /// The capacity of the channel for received IP packets. const DISPATCH_CAPACITY: usize = 1_000; +const MAX_PACKET: usize = 65536; /// A WireGuard tunnel. Encapsulates and decapsulates IP packets /// to be sent to and received from a remote UDP endpoint. From 5cec6d4943557541fa5eff0afef4022e87987b3b Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 19 Oct 2021 01:55:04 -0400 Subject: [PATCH 08/17] Index ports with protocol in WG. Start writing UDP tunnel code with plans. --- src/config.rs | 2 +- src/main.rs | 14 ++++---- src/port_pool.rs | 62 ----------------------------------- src/tunnel/mod.rs | 12 ++++--- src/tunnel/tcp.rs | 70 +++++++++++++++++++++++++++++++++++++--- src/tunnel/udp.rs | 44 +++++++++++++++++++++++++ src/virtual_device.rs | 3 +- src/virtual_iface/mod.rs | 12 +++++++ src/virtual_iface/tcp.rs | 9 +++--- src/wg.rs | 19 +++++++---- 10 files changed, 156 insertions(+), 91 deletions(-) delete mode 100644 src/port_pool.rs create mode 100644 src/tunnel/udp.rs diff --git a/src/config.rs b/src/config.rs index 3a46881..0e3a727 100644 --- a/src/config.rs +++ b/src/config.rs @@ -332,7 +332,7 @@ impl Display for PortForwardConfig { } } -#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)] pub enum PortProtocol { Tcp, Udp, diff --git a/src/main.rs b/src/main.rs index b5da3ad..0d301d0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,12 +6,11 @@ use std::sync::Arc; use anyhow::Context; use crate::config::Config; -use crate::port_pool::PortPool; +use crate::tunnel::tcp::TcpPortPool; use crate::wg::WireGuardTunnel; pub mod config; pub mod ip_sink; -pub mod port_pool; pub mod tunnel; pub mod virtual_device; pub mod virtual_iface; @@ -21,7 +20,10 @@ pub mod wg; async fn main() -> anyhow::Result<()> { let config = Config::from_args().with_context(|| "Failed to read config")?; init_logger(&config)?; - let port_pool = Arc::new(PortPool::new()); + + // Initialize the port pool for each protocol + let tcp_port_pool = Arc::new(TcpPortPool::new()); + // TODO: udp_port_pool let wg = WireGuardTunnel::new(&config) .await @@ -52,12 +54,12 @@ async fn main() -> anyhow::Result<()> { port_forwards .into_iter() - .map(|pf| (pf, wg.clone(), port_pool.clone())) - .for_each(move |(pf, wg, port_pool)| { + .map(|pf| (pf, wg.clone(), tcp_port_pool.clone())) + .for_each(move |(pf, wg, tcp_port_pool)| { std::thread::spawn(move || { let cpu_pool = tokio::runtime::Runtime::new().unwrap(); cpu_pool.block_on(async move { - tunnel::port_forward(pf, source_peer_ip, port_pool, wg) + tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, wg) .await .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }); diff --git a/src/port_pool.rs b/src/port_pool.rs deleted file mode 100644 index 7bff712..0000000 --- a/src/port_pool.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::ops::Range; - -use anyhow::Context; -use rand::seq::SliceRandom; -use rand::thread_rng; - -const MIN_PORT: u16 = 32768; -const MAX_PORT: u16 = 60999; -const PORT_RANGE: Range = MIN_PORT..MAX_PORT; - -/// A pool of virtual ports available. -/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`. -pub struct PortPool { - /// Remaining ports - inner: lockfree::queue::Queue, - /// Ports in use, with their associated IP channel sender. - taken: lockfree::set::Set, -} - -impl Default for PortPool { - fn default() -> Self { - Self::new() - } -} - -impl PortPool { - /// Initializes a new pool of virtual ports. - pub fn new() -> Self { - let inner = lockfree::queue::Queue::default(); - let mut ports: Vec = PORT_RANGE.collect(); - ports.shuffle(&mut thread_rng()); - ports.into_iter().for_each(|p| inner.push(p) as ()); - Self { - inner, - taken: lockfree::set::Set::new(), - } - } - - /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). - pub fn next(&self) -> anyhow::Result { - let port = self - .inner - .pop() - .with_context(|| "Virtual port pool is exhausted")?; - self.taken - .insert(port) - .ok() - .with_context(|| "Failed to insert taken")?; - Ok(port) - } - - /// Releases a port back into the pool. - pub fn release(&self, port: u16) { - self.inner.push(port); - self.taken.remove(&port); - } - - /// Whether the given port is in use by a virtual interface. - pub fn is_in_use(&self, port: u16) -> bool { - self.taken.contains(&port) - } -} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 7eab856..00fe4e8 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -2,15 +2,17 @@ use std::net::IpAddr; use std::sync::Arc; use crate::config::{PortForwardConfig, PortProtocol}; -use crate::port_pool::PortPool; +use crate::tunnel::tcp::TcpPortPool; use crate::wg::WireGuardTunnel; -mod tcp; +pub mod tcp; +#[allow(unused)] +pub mod udp; pub async fn port_forward( port_forward: PortForwardConfig, source_peer_ip: IpAddr, - port_pool: Arc, + tcp_port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { info!( @@ -23,7 +25,7 @@ pub async fn port_forward( ); match port_forward.protocol { - PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, port_pool, wg).await, - PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")), + PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await, + PortProtocol::Udp => udp::udp_proxy_server(port_forward, /* udp_port_pool, */ wg).await, } } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index 987aa5e..1f01642 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -1,19 +1,26 @@ -use crate::config::PortForwardConfig; -use crate::port_pool::PortPool; +use crate::config::{PortForwardConfig, PortProtocol}; use crate::virtual_iface::tcp::TcpVirtualInterface; -use crate::virtual_iface::VirtualInterfacePoll; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; +use std::ops::Range; + +use rand::seq::SliceRandom; +use rand::thread_rng; + const MAX_PACKET: usize = 65536; +const MIN_PORT: u16 = 1000; +const MAX_PORT: u16 = 60999; +const PORT_RANGE: Range = MIN_PORT..MAX_PORT; /// Starts the server that listens on TCP connections. pub async fn tcp_proxy_server( port_forward: PortForwardConfig, - port_pool: Arc, + port_pool: Arc, wg: Arc, ) -> anyhow::Result<()> { let listener = TcpListener::bind(port_forward.source) @@ -59,7 +66,7 @@ pub async fn tcp_proxy_server( } // Release port when connection drops - wg.release_virtual_interface(virtual_port); + wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp)); port_pool.release(virtual_port); }); } @@ -194,3 +201,56 @@ async fn handle_tcp_proxy_connection( abort.store(true, Ordering::Relaxed); Ok(()) } + +/// A pool of virtual ports available for TCP connections. +/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`. +pub struct TcpPortPool { + /// Remaining ports + inner: lockfree::queue::Queue, + /// Ports in use, with their associated IP channel sender. + taken: lockfree::set::Set, +} + +impl Default for TcpPortPool { + fn default() -> Self { + Self::new() + } +} + +impl TcpPortPool { + /// Initializes a new pool of virtual ports. + pub fn new() -> Self { + let inner = lockfree::queue::Queue::default(); + let mut ports: Vec = PORT_RANGE.collect(); + ports.shuffle(&mut thread_rng()); + ports.into_iter().for_each(|p| inner.push(p) as ()); + Self { + inner, + taken: lockfree::set::Set::new(), + } + } + + /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). + pub fn next(&self) -> anyhow::Result { + let port = self + .inner + .pop() + .with_context(|| "Virtual port pool is exhausted")?; + self.taken + .insert(port) + .ok() + .with_context(|| "Failed to insert taken")?; + Ok(port) + } + + /// Releases a port back into the pool. + pub fn release(&self, port: u16) { + self.inner.push(port); + self.taken.remove(&port); + } + + /// Whether the given port is in use by a virtual interface. + pub fn is_in_use(&self, port: u16) -> bool { + self.taken.contains(&port) + } +} diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs new file mode 100644 index 0000000..7f9fda4 --- /dev/null +++ b/src/tunnel/udp.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use anyhow::Context; +use tokio::net::UdpSocket; + +use crate::config::PortForwardConfig; +use crate::wg::WireGuardTunnel; + +const MAX_PACKET: usize = 65536; + +/// How long to keep the UDP peer address assigned to its virtual specified port, in seconds. +const UDP_TIMEOUT_SECONDS: u64 = 60; + +/// To prevent port-flooding, we set a limit on the amount of open ports per IP address. +const PORTS_PER_IP: usize = 100; + +pub async fn udp_proxy_server( + port_forward: PortForwardConfig, + wg: Arc, +) -> anyhow::Result<()> { + let socket = UdpSocket::bind(port_forward.source) + .await + .with_context(|| "Failed to bind on UDP proxy address")?; + + let mut buffer = [0u8; MAX_PACKET]; + loop { + let (size, peer_addr) = socket + .recv_from(&mut buffer) + .await + .with_context(|| "Failed to accept incoming UDP datagram")?; + + let _wg = wg.clone(); + let _data = &buffer[..size].to_vec(); + debug!("Received datagram of {} bytes from {}", size, peer_addr); + + // Assign a 'virtual port': this is a unique port number used to route IP packets + // received from the WireGuard tunnel. It is the port number that the virtual client will + // listen on. + // Since UDP is connection-less, the port is assigned to the source SocketAddr for up to `UDP_TIMEOUT_SECONDS`; + // every datagram resets the timer for that SocketAddr. Each IP address also has a limit of active connections, + // discarding the LRU ports. + // TODO: UDP Port Pool + } +} diff --git a/src/virtual_device.rs b/src/virtual_device.rs index b8a61ce..992e7d0 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,3 +1,4 @@ +use crate::virtual_iface::VirtualPort; use crate::wg::WireGuardTunnel; use anyhow::Context; use smoltcp::phy::{Device, DeviceCapabilities, Medium}; @@ -15,7 +16,7 @@ pub struct VirtualIpDevice { } impl VirtualIpDevice { - pub fn new(virtual_port: u16, wg: Arc) -> anyhow::Result { + pub fn new(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { let ip_dispatch_rx = wg .register_virtual_interface(virtual_port) .with_context(|| "Failed to register IP dispatch for virtual interface")?; diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs index b9d3354..d2ceb53 100644 --- a/src/virtual_iface/mod.rs +++ b/src/virtual_iface/mod.rs @@ -1,6 +1,8 @@ pub mod tcp; +use crate::config::PortProtocol; use async_trait::async_trait; +use std::fmt::{Display, Formatter}; #[async_trait] pub trait VirtualInterfacePoll { @@ -8,3 +10,13 @@ pub trait VirtualInterfacePoll { /// to the WireGuard tunnel and to the real client. async fn poll_loop(mut self) -> anyhow::Result<()>; } + +/// Virtual port. +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub struct VirtualPort(pub u16, pub PortProtocol); + +impl Display for VirtualPort { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}:{}]", self.0, self.1) + } +} diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index c9fee95..fc8348c 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -1,6 +1,6 @@ -use crate::config::PortForwardConfig; +use crate::config::{PortForwardConfig, PortProtocol}; use crate::virtual_device::VirtualIpDevice; -use crate::virtual_iface::VirtualInterfacePoll; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; use async_trait::async_trait; @@ -74,8 +74,9 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Consumer for IP packets to send through the virtual interface // Initialize the interface - let device = VirtualIpDevice::new(self.virtual_port, self.wg) - .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; + let device = + VirtualIpDevice::new(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) + .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; let mut virtual_interface = InterfaceBuilder::new(device) .ip_addrs([ // Interface handles IP packets for the sender and recipient diff --git a/src/wg.rs b/src/wg.rs index 498fb98..d10ad50 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -8,7 +8,8 @@ use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket}; use tokio::net::UdpSocket; use tokio::sync::RwLock; -use crate::config::Config; +use crate::config::{Config, PortProtocol}; +use crate::virtual_iface::VirtualPort; /// The capacity of the channel for received IP packets. const DISPATCH_CAPACITY: usize = 1_000; @@ -26,7 +27,7 @@ pub struct WireGuardTunnel { /// The address of the public WireGuard endpoint (UDP). pub(crate) endpoint: SocketAddr, /// Maps virtual ports to the corresponding IP packet dispatcher. - virtual_port_ip_tx: lockfree::map::Map>>, + virtual_port_ip_tx: lockfree::map::Map>>, /// IP packet dispatcher for unroutable packets. `None` if not initialized. sink_ip_tx: RwLock>>>, } @@ -86,7 +87,7 @@ impl WireGuardTunnel { /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. pub fn register_virtual_interface( &self, - virtual_port: u16, + virtual_port: VirtualPort, ) -> anyhow::Result>> { let existing = self.virtual_port_ip_tx.get(&virtual_port); if existing.is_some() { @@ -111,7 +112,7 @@ impl WireGuardTunnel { } /// Releases the virtual interface from IP dispatch. - pub fn release_virtual_interface(&self, virtual_port: u16) { + pub fn release_virtual_interface(&self, virtual_port: VirtualPort) { self.virtual_port_ip_tx.remove(&virtual_port); } @@ -296,8 +297,12 @@ impl WireGuardTunnel { TcpPacket::new_checked(segment) .ok() .map(|tcp| { - if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() { - RouteResult::Dispatch(tcp.dst_port()) + if self + .virtual_port_ip_tx + .get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp)) + .is_some() + { + RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp)) } else if tcp.rst() { RouteResult::Drop } else { @@ -347,7 +352,7 @@ fn trace_ip_packet(message: &str, packet: &[u8]) { enum RouteResult { /// Dispatch the packet to the virtual port. - Dispatch(u16), + Dispatch(VirtualPort), /// The packet is not routable, and should be sent to the sink interface. Sink, /// The packet is not routable, and can be safely ignored. From 11c5ec99fd5aef913f210558b03490a5fb6b88da Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 20 Oct 2021 16:05:04 -0400 Subject: [PATCH 09/17] Replace lockfree with tokio::sync --- Cargo.lock | 27 +++++++++------------ Cargo.toml | 2 +- src/main.rs | 2 +- src/tunnel/mod.rs | 2 +- src/tunnel/tcp.rs | 60 ++++++++++++++++++++++++----------------------- src/tunnel/udp.rs | 2 ++ src/wg.rs | 10 ++++---- 7 files changed, 52 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f3ba00d..885b05b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -203,6 +203,16 @@ dependencies = [ "libc", ] +[[package]] +name = "dashmap" +version = "4.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e77a43b28d0668df09411cb0bc9a8c2adc40f9a048afe863e05fd43251e8e39c" +dependencies = [ + "cfg-if", + "num_cpus", +] + [[package]] name = "dirs-next" version = "2.0.0" @@ -469,15 +479,6 @@ dependencies = [ "scopeguard", ] -[[package]] -name = "lockfree" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74ee94b5ad113c7cb98c5a040f783d0952ee4fe100993881d1673c2cb002dd23" -dependencies = [ - "owned-alloc", -] - [[package]] name = "log" version = "0.4.14" @@ -609,8 +610,8 @@ dependencies = [ "async-trait", "boringtun", "clap", + "dashmap", "futures", - "lockfree", "log", "nom", "pretty_env_logger", @@ -619,12 +620,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "owned-alloc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30fceb411f9a12ff9222c5f824026be368ff15dc2f13468d850c7d3f502205d6" - [[package]] name = "parking_lot" version = "0.11.2" diff --git a/Cargo.toml b/Cargo.toml index d552c2b..905ac32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,8 @@ pretty_env_logger = "0.3" anyhow = "1" smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" } tokio = { version = "1", features = ["full"] } -lockfree = "0.5.1" futures = "0.3.17" rand = "0.8.4" nom = "7" async-trait = "0.1.51" +dashmap = "4.0.2" diff --git a/src/main.rs b/src/main.rs index 0d301d0..b2f1f75 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> { init_logger(&config)?; // Initialize the port pool for each protocol - let tcp_port_pool = Arc::new(TcpPortPool::new()); + let tcp_port_pool = TcpPortPool::new(); // TODO: udp_port_pool let wg = WireGuardTunnel::new(&config) diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 00fe4e8..a4042c9 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -12,7 +12,7 @@ pub mod udp; pub async fn port_forward( port_forward: PortForwardConfig, source_peer_ip: IpAddr, - tcp_port_pool: Arc, + tcp_port_pool: TcpPortPool, wg: Arc, ) -> anyhow::Result<()> { info!( diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index 1f01642..fbbdbc2 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -3,6 +3,7 @@ use crate::virtual_iface::tcp::TcpVirtualInterface; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; +use std::collections::{HashSet, VecDeque}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; @@ -20,7 +21,7 @@ const PORT_RANGE: Range = MIN_PORT..MAX_PORT; /// Starts the server that listens on TCP connections. pub async fn tcp_proxy_server( port_forward: PortForwardConfig, - port_pool: Arc, + port_pool: TcpPortPool, wg: Arc, ) -> anyhow::Result<()> { let listener = TcpListener::bind(port_forward.source) @@ -38,7 +39,7 @@ pub async fn tcp_proxy_server( // Assign a 'virtual port': this is a unique port number used to route IP packets // received from the WireGuard tunnel. It is the port number that the virtual client will // listen on. - let virtual_port = match port_pool.next() { + let virtual_port = match port_pool.next().await { Ok(port) => port, Err(e) => { error!( @@ -52,7 +53,7 @@ pub async fn tcp_proxy_server( info!("[{}] Incoming connection from {}", virtual_port, peer_addr); tokio::spawn(async move { - let port_pool = Arc::clone(&port_pool); + let port_pool = port_pool.clone(); let result = handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await; @@ -67,7 +68,7 @@ pub async fn tcp_proxy_server( // Release port when connection drops wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp)); - port_pool.release(virtual_port); + port_pool.release(virtual_port).await; }); } } @@ -203,12 +204,9 @@ async fn handle_tcp_proxy_connection( } /// A pool of virtual ports available for TCP connections. -/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`. +#[derive(Clone)] pub struct TcpPortPool { - /// Remaining ports - inner: lockfree::queue::Queue, - /// Ports in use, with their associated IP channel sender. - taken: lockfree::set::Set, + inner: Arc>, } impl Default for TcpPortPool { @@ -220,37 +218,41 @@ impl Default for TcpPortPool { impl TcpPortPool { /// Initializes a new pool of virtual ports. pub fn new() -> Self { - let inner = lockfree::queue::Queue::default(); + let mut inner = TcpPortPoolInner::default(); let mut ports: Vec = PORT_RANGE.collect(); ports.shuffle(&mut thread_rng()); - ports.into_iter().for_each(|p| inner.push(p) as ()); + ports + .into_iter() + .for_each(|p| inner.queue.push_back(p) as ()); Self { - inner, - taken: lockfree::set::Set::new(), + inner: Arc::new(tokio::sync::RwLock::new(inner)), } } /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). - pub fn next(&self) -> anyhow::Result { - let port = self - .inner - .pop() + pub async fn next(&self) -> anyhow::Result { + let mut inner = self.inner.write().await; + let port = inner + .queue + .pop_front() .with_context(|| "Virtual port pool is exhausted")?; - self.taken - .insert(port) - .ok() - .with_context(|| "Failed to insert taken")?; + inner.taken.insert(port); Ok(port) } /// Releases a port back into the pool. - pub fn release(&self, port: u16) { - self.inner.push(port); - self.taken.remove(&port); - } - - /// Whether the given port is in use by a virtual interface. - pub fn is_in_use(&self, port: u16) -> bool { - self.taken.contains(&port) + pub async fn release(&self, port: u16) { + let mut inner = self.inner.write().await; + inner.queue.push_back(port); + inner.taken.remove(&port); } } + +/// Non thread-safe inner logic for TCP port pool. +#[derive(Debug, Clone, Default)] +struct TcpPortPoolInner { + /// Remaining ports in the pool. + queue: VecDeque, + /// Ports taken out of the pool. + taken: HashSet, +} diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 7f9fda4..2326351 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -9,9 +9,11 @@ use crate::wg::WireGuardTunnel; const MAX_PACKET: usize = 65536; /// How long to keep the UDP peer address assigned to its virtual specified port, in seconds. +/// TODO: Make this configurable by the CLI const UDP_TIMEOUT_SECONDS: u64 = 60; /// To prevent port-flooding, we set a limit on the amount of open ports per IP address. +/// TODO: Make this configurable by the CLI const PORTS_PER_IP: usize = 100; pub async fn udp_proxy_server( diff --git a/src/wg.rs b/src/wg.rs index d10ad50..653568e 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -27,7 +27,7 @@ pub struct WireGuardTunnel { /// The address of the public WireGuard endpoint (UDP). pub(crate) endpoint: SocketAddr, /// Maps virtual ports to the corresponding IP packet dispatcher. - virtual_port_ip_tx: lockfree::map::Map>>, + virtual_port_ip_tx: dashmap::DashMap>>, /// IP packet dispatcher for unroutable packets. `None` if not initialized. sink_ip_tx: RwLock>>>, } @@ -41,7 +41,7 @@ impl WireGuardTunnel { .await .with_context(|| "Failed to create UDP socket for WireGuard connection")?; let endpoint = config.endpoint_addr; - let virtual_port_ip_tx = lockfree::map::Map::new(); + let virtual_port_ip_tx = Default::default(); Ok(Self { source_peer_ip, @@ -89,8 +89,8 @@ impl WireGuardTunnel { &self, virtual_port: VirtualPort, ) -> anyhow::Result>> { - let existing = self.virtual_port_ip_tx.get(&virtual_port); - if existing.is_some() { + let existing = self.virtual_port_ip_tx.contains_key(&virtual_port); + if existing { Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) } else { let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); @@ -215,7 +215,7 @@ impl WireGuardTunnel { RouteResult::Dispatch(port) => { let sender = self.virtual_port_ip_tx.get(&port); if let Some(sender_guard) = sender { - let sender = sender_guard.val(); + let sender = sender_guard.value(); match sender.send(packet.to_vec()).await { Ok(_) => { trace!( From cc91cce169ff1fcbe883dc3dfaccfa356654b1a7 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 20 Oct 2021 16:49:24 -0400 Subject: [PATCH 10/17] Basic UDP port pool --- src/main.rs | 9 ++--- src/tunnel/mod.rs | 4 ++- src/tunnel/tcp.rs | 10 ++---- src/tunnel/udp.rs | 88 +++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 93 insertions(+), 18 deletions(-) diff --git a/src/main.rs b/src/main.rs index b2f1f75..d20138d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use anyhow::Context; use crate::config::Config; use crate::tunnel::tcp::TcpPortPool; +use crate::tunnel::udp::UdpPortPool; use crate::wg::WireGuardTunnel; pub mod config; @@ -23,7 +24,7 @@ async fn main() -> anyhow::Result<()> { // Initialize the port pool for each protocol let tcp_port_pool = TcpPortPool::new(); - // TODO: udp_port_pool + let udp_port_pool = UdpPortPool::new(); let wg = WireGuardTunnel::new(&config) .await @@ -54,12 +55,12 @@ async fn main() -> anyhow::Result<()> { port_forwards .into_iter() - .map(|pf| (pf, wg.clone(), tcp_port_pool.clone())) - .for_each(move |(pf, wg, tcp_port_pool)| { + .map(|pf| (pf, wg.clone(), tcp_port_pool.clone(), udp_port_pool.clone())) + .for_each(move |(pf, wg, tcp_port_pool, udp_port_pool)| { std::thread::spawn(move || { let cpu_pool = tokio::runtime::Runtime::new().unwrap(); cpu_pool.block_on(async move { - tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, wg) + tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg) .await .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }); diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index a4042c9..c7f2a67 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::config::{PortForwardConfig, PortProtocol}; use crate::tunnel::tcp::TcpPortPool; +use crate::tunnel::udp::UdpPortPool; use crate::wg::WireGuardTunnel; pub mod tcp; @@ -13,6 +14,7 @@ pub async fn port_forward( port_forward: PortForwardConfig, source_peer_ip: IpAddr, tcp_port_pool: TcpPortPool, + udp_port_pool: UdpPortPool, wg: Arc, ) -> anyhow::Result<()> { info!( @@ -26,6 +28,6 @@ pub async fn port_forward( match port_forward.protocol { PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await, - PortProtocol::Udp => udp::udp_proxy_server(port_forward, /* udp_port_pool, */ wg).await, + PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, wg).await, } } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index fbbdbc2..f49aa7c 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -3,7 +3,7 @@ use crate::virtual_iface::tcp::TcpVirtualInterface; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; use anyhow::Context; -use std::collections::{HashSet, VecDeque}; +use std::collections::VecDeque; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; @@ -235,8 +235,7 @@ impl TcpPortPool { let port = inner .queue .pop_front() - .with_context(|| "Virtual port pool is exhausted")?; - inner.taken.insert(port); + .with_context(|| "TCP virtual port pool is exhausted")?; Ok(port) } @@ -244,15 +243,12 @@ impl TcpPortPool { pub async fn release(&self, port: u16) { let mut inner = self.inner.write().await; inner.queue.push_back(port); - inner.taken.remove(&port); } } /// Non thread-safe inner logic for TCP port pool. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] struct TcpPortPoolInner { /// Remaining ports in the pool. queue: VecDeque, - /// Ports taken out of the pool. - taken: HashSet, } diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 2326351..a92bd5e 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -1,12 +1,22 @@ +use std::collections::{HashMap, VecDeque}; +use std::net::{IpAddr, SocketAddr}; +use std::ops::Range; use std::sync::Arc; +use std::time::Instant; use anyhow::Context; +use rand::seq::SliceRandom; +use rand::thread_rng; use tokio::net::UdpSocket; -use crate::config::PortForwardConfig; +use crate::config::{PortForwardConfig, PortProtocol}; +use crate::virtual_iface::VirtualPort; use crate::wg::WireGuardTunnel; const MAX_PACKET: usize = 65536; +const MIN_PORT: u16 = 1000; +const MAX_PORT: u16 = 60999; +const PORT_RANGE: Range = MIN_PORT..MAX_PORT; /// How long to keep the UDP peer address assigned to its virtual specified port, in seconds. /// TODO: Make this configurable by the CLI @@ -18,6 +28,7 @@ const PORTS_PER_IP: usize = 100; pub async fn udp_proxy_server( port_forward: PortForwardConfig, + port_pool: UdpPortPool, wg: Arc, ) -> anyhow::Result<()> { let socket = UdpSocket::bind(port_forward.source) @@ -33,14 +44,79 @@ pub async fn udp_proxy_server( let _wg = wg.clone(); let _data = &buffer[..size].to_vec(); - debug!("Received datagram of {} bytes from {}", size, peer_addr); // Assign a 'virtual port': this is a unique port number used to route IP packets // received from the WireGuard tunnel. It is the port number that the virtual client will // listen on. - // Since UDP is connection-less, the port is assigned to the source SocketAddr for up to `UDP_TIMEOUT_SECONDS`; - // every datagram resets the timer for that SocketAddr. Each IP address also has a limit of active connections, - // discarding the LRU ports. - // TODO: UDP Port Pool + let port = match port_pool.next(peer_addr).await { + Ok(port) => port, + Err(e) => { + error!( + "Failed to assign virtual port number for UDP datagram from [{}]: {:?}", + peer_addr, e + ); + continue; + } + }; + + let port = VirtualPort(port, PortProtocol::Udp); + debug!( + "[{}] Received datagram of {} bytes from {}", + port, size, peer_addr + ); } } + +/// A pool of virtual ports available for TCP connections. +#[derive(Clone)] +pub struct UdpPortPool { + inner: Arc>, +} + +impl Default for UdpPortPool { + fn default() -> Self { + Self::new() + } +} + +impl UdpPortPool { + /// Initializes a new pool of virtual ports. + pub fn new() -> Self { + let mut inner = UdpPortPoolInner::default(); + let mut ports: Vec = PORT_RANGE.collect(); + ports.shuffle(&mut thread_rng()); + ports + .into_iter() + .for_each(|p| inner.queue.push_back(p) as ()); + Self { + inner: Arc::new(tokio::sync::RwLock::new(inner)), + } + } + + /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). + pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result { + { + let inner = self.inner.read().await; + if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) { + return Ok(*port); + } + } + + let mut inner = self.inner.write().await; + let port = inner + .queue + .pop_front() + .with_context(|| "UDP virtual port pool is exhausted")?; + inner.port_by_peer_addr.insert(peer_addr, port); + Ok(port) + } +} + +/// Non thread-safe inner logic for UDP port pool. +#[derive(Debug, Default)] +struct UdpPortPoolInner { + /// Remaining ports in the pool. + queue: VecDeque, + /// The port assigned by peer IP/port. + port_by_peer_addr: HashMap, +} From fb50ee7113c138ac536dbda87492c028c82d79d0 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 20 Oct 2021 18:06:35 -0400 Subject: [PATCH 11/17] UDP virtual interface skeleton --- src/tunnel/udp.rs | 144 +++++++++++++++++++++++++++++++-------- src/virtual_device.rs | 8 ++- src/virtual_iface/mod.rs | 1 + src/virtual_iface/tcp.rs | 9 ++- src/virtual_iface/udp.rs | 68 ++++++++++++++++++ src/wg.rs | 8 +-- 6 files changed, 203 insertions(+), 35 deletions(-) create mode 100644 src/virtual_iface/udp.rs diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index a92bd5e..db232cf 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, VecDeque}; use std::net::{IpAddr, SocketAddr}; use std::ops::Range; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Instant; @@ -10,7 +11,8 @@ use rand::thread_rng; use tokio::net::UdpSocket; use crate::config::{PortForwardConfig, PortProtocol}; -use crate::virtual_iface::VirtualPort; +use crate::virtual_iface::udp::UdpVirtualInterface; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::WireGuardTunnel; const MAX_PACKET: usize = 65536; @@ -31,40 +33,120 @@ pub async fn udp_proxy_server( port_pool: UdpPortPool, wg: Arc, ) -> anyhow::Result<()> { + // Abort signal + let abort = Arc::new(AtomicBool::new(false)); + + // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back + // to the real client. + let (data_to_real_client_tx, mut data_to_real_client_rx) = + tokio::sync::mpsc::channel::<(VirtualPort, Vec)>(1_000); + + // data_to_real_server_(tx/rx): This task sends the data received from the real client to the + // virtual interface (virtual server socket). + let (data_to_virtual_server_tx, data_to_virtual_server_rx) = + tokio::sync::mpsc::channel::<(VirtualPort, Vec)>(1_000); + + { + // Spawn virtual interface + // Note: contrary to TCP, there is only one UDP virtual interface + let virtual_interface = UdpVirtualInterface::new( + port_forward, + wg, + data_to_real_client_tx, + data_to_virtual_server_rx, + ); + let abort = abort.clone(); + tokio::spawn(async move { + virtual_interface.poll_loop().await.unwrap_or_else(|e| { + error!("Virtual interface poll loop failed unexpectedly: {}", e); + abort.store(true, Ordering::Relaxed); + }); + }); + } + let socket = UdpSocket::bind(port_forward.source) .await .with_context(|| "Failed to bind on UDP proxy address")?; let mut buffer = [0u8; MAX_PACKET]; loop { - let (size, peer_addr) = socket - .recv_from(&mut buffer) - .await - .with_context(|| "Failed to accept incoming UDP datagram")?; - - let _wg = wg.clone(); - let _data = &buffer[..size].to_vec(); - - // Assign a 'virtual port': this is a unique port number used to route IP packets - // received from the WireGuard tunnel. It is the port number that the virtual client will - // listen on. - let port = match port_pool.next(peer_addr).await { - Ok(port) => port, - Err(e) => { - error!( - "Failed to assign virtual port number for UDP datagram from [{}]: {:?}", - peer_addr, e - ); - continue; + if abort.load(Ordering::Relaxed) { + break; + } + tokio::select! { + to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => { + match to_send_result { + Ok(Some((port, data))) => { + data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| { + error!( + "Failed to dispatch data to UDP virtual interface: {:?}", + e + ); + }); + } + Ok(None) => { + continue; + } + Err(e) => { + error!( + "Failed to read from client UDP socket: {:?}", + e + ); + break; + } + } } - }; - - let port = VirtualPort(port, PortProtocol::Udp); - debug!( - "[{}] Received datagram of {} bytes from {}", - port, size, peer_addr - ); + data_recv_result = data_to_real_client_rx.recv() => { + if let Some((port, data)) = data_recv_result { + if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await { + if let Err(e) = socket.send_to(&data, peer_addr).await { + error!( + "[{}] Failed to send UDP datagram to real client ({}): {:?}", + port, + peer_addr, + e, + ); + } + } + } + } + } } + Ok(()) +} + +async fn next_udp_datagram( + socket: &UdpSocket, + buffer: &mut [u8], + port_pool: UdpPortPool, +) -> anyhow::Result)>> { + let (size, peer_addr) = socket + .recv_from(buffer) + .await + .with_context(|| "Failed to accept incoming UDP datagram")?; + + // Assign a 'virtual port': this is a unique port number used to route IP packets + // received from the WireGuard tunnel. It is the port number that the virtual client will + // listen on. + let port = match port_pool.next(peer_addr).await { + Ok(port) => port, + Err(e) => { + error!( + "Failed to assign virtual port number for UDP datagram from [{}]: {:?}", + peer_addr, e + ); + return Ok(None); + } + }; + let port = VirtualPort(port, PortProtocol::Udp); + + debug!( + "[{}] Received datagram of {} bytes from {}", + port, size, peer_addr + ); + + let data = buffer[..size].to_vec(); + Ok(Some((port, data))) } /// A pool of virtual ports available for TCP connections. @@ -108,8 +190,14 @@ impl UdpPortPool { .pop_front() .with_context(|| "UDP virtual port pool is exhausted")?; inner.port_by_peer_addr.insert(peer_addr, port); + inner.peer_addr_by_port.insert(port, peer_addr); Ok(port) } + + pub async fn get_peer_addr(&self, port: u16) -> Option { + let inner = self.inner.read().await; + inner.peer_addr_by_port.get(&port).copied() + } } /// Non thread-safe inner logic for UDP port pool. @@ -119,4 +207,6 @@ struct UdpPortPoolInner { queue: VecDeque, /// The port assigned by peer IP/port. port_by_peer_addr: HashMap, + /// The socket address assigned to a peer IP/port. + peer_addr_by_port: HashMap, } diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 992e7d0..8cd0cdf 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -1,5 +1,5 @@ use crate::virtual_iface::VirtualPort; -use crate::wg::WireGuardTunnel; +use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; use anyhow::Context; use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; @@ -16,9 +16,11 @@ pub struct VirtualIpDevice { } impl VirtualIpDevice { + /// Registers a virtual IP device for a single virtual client. pub fn new(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { - let ip_dispatch_rx = wg - .register_virtual_interface(virtual_port) + let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); + + wg.register_virtual_interface(virtual_port, ip_dispatch_tx) .with_context(|| "Failed to register IP dispatch for virtual interface")?; Ok(Self { wg, ip_dispatch_rx }) diff --git a/src/virtual_iface/mod.rs b/src/virtual_iface/mod.rs index d2ceb53..f3796bd 100644 --- a/src/virtual_iface/mod.rs +++ b/src/virtual_iface/mod.rs @@ -1,4 +1,5 @@ pub mod tcp; +pub mod udp; use crate::config::PortProtocol; use async_trait::async_trait; diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index fc8348c..baa92a7 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -266,7 +266,14 @@ impl VirtualInterfacePoll for TcpVirtualInterface { break; } - tokio::time::sleep(Duration::from_millis(1)).await; + match virtual_interface.poll_delay(&socket_set, loop_start) { + Some(smoltcp::time::Duration::ZERO) => { + continue; + } + _ => { + tokio::time::sleep(Duration::from_millis(1)).await; + } + } } trace!("[{}] Virtual interface task terminated", self.virtual_port); self.abort.store(true, Ordering::Relaxed); diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs new file mode 100644 index 0000000..88f13b6 --- /dev/null +++ b/src/virtual_iface/udp.rs @@ -0,0 +1,68 @@ +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; + +use crate::config::PortForwardConfig; +use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; +use crate::wg::WireGuardTunnel; + +pub struct UdpVirtualInterface { + port_forward: PortForwardConfig, + wg: Arc, + data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec)>, + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec)>, +} + +impl UdpVirtualInterface { + pub fn new( + port_forward: PortForwardConfig, + wg: Arc, + data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec)>, + data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec)>, + ) -> Self { + Self { + port_forward, + wg, + data_to_real_client_tx, + data_to_virtual_server_rx, + } + } +} + +#[async_trait] +impl VirtualInterfacePoll for UdpVirtualInterface { + async fn poll_loop(self) -> anyhow::Result<()> { + // Data receiver to dispatch using virtual client sockets + let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; + + // The IP to bind client sockets to + let _source_peer_ip = self.wg.source_peer_ip; + + // The IP/port to bind the server socket to + let _destination = self.port_forward.destination; + + loop { + let _loop_start = smoltcp::time::Instant::now(); + // TODO: smoltcp UDP + + if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { + // TODO: Find the matching client socket and send + // Echo for now + self.data_to_real_client_tx + .send((client_port, data)) + .await + .unwrap_or_else(|e| { + error!( + "[{}] Failed to dispatch data from virtual client to real client: {:?}", + client_port, e + ); + }); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + // Ok(()) + } +} diff --git a/src/wg.rs b/src/wg.rs index 653568e..673da1c 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -12,7 +12,7 @@ use crate::config::{Config, PortProtocol}; use crate::virtual_iface::VirtualPort; /// The capacity of the channel for received IP packets. -const DISPATCH_CAPACITY: usize = 1_000; +pub const DISPATCH_CAPACITY: usize = 1_000; const MAX_PACKET: usize = 65536; /// A WireGuard tunnel. Encapsulates and decapsulates IP packets @@ -88,14 +88,14 @@ impl WireGuardTunnel { pub fn register_virtual_interface( &self, virtual_port: VirtualPort, - ) -> anyhow::Result>> { + sender: tokio::sync::mpsc::Sender>, + ) -> anyhow::Result<()> { let existing = self.virtual_port_ip_tx.contains_key(&virtual_port); if existing { Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) } else { - let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); self.virtual_port_ip_tx.insert(virtual_port, sender); - Ok(receiver) + Ok(()) } } From 282d4f48eb0341d24fb2f5090d8771a0cfec3014 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Wed, 20 Oct 2021 19:04:56 -0400 Subject: [PATCH 12/17] Checkpoint --- src/virtual_device.rs | 15 ++++++++-- src/virtual_iface/tcp.rs | 2 +- src/virtual_iface/udp.rs | 65 ++++++++++++++++++++++++++++++++++++++-- src/wg.rs | 6 +++- 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/virtual_device.rs b/src/virtual_device.rs index 8cd0cdf..02d60f9 100644 --- a/src/virtual_device.rs +++ b/src/virtual_device.rs @@ -5,6 +5,9 @@ use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::time::Instant; use std::sync::Arc; +/// The max transmission unit for WireGuard. +const WG_MTU: usize = 1420; + /// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint /// are made available to this device using a channel receiver. IP packets sent from this device /// are asynchronously sent out to the WireGuard tunnel. @@ -16,8 +19,16 @@ pub struct VirtualIpDevice { } impl VirtualIpDevice { + /// Initializes a new virtual IP device. + pub fn new( + wg: Arc, + ip_dispatch_rx: tokio::sync::mpsc::Receiver>, + ) -> Self { + Self { wg, ip_dispatch_rx } + } + /// Registers a virtual IP device for a single virtual client. - pub fn new(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { + pub fn new_direct(virtual_port: VirtualPort, wg: Arc) -> anyhow::Result { let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); wg.register_virtual_interface(virtual_port, ip_dispatch_tx) @@ -60,7 +71,7 @@ impl<'a> Device<'a> for VirtualIpDevice { fn capabilities(&self) -> DeviceCapabilities { let mut cap = DeviceCapabilities::default(); cap.medium = Medium::Ip; - cap.max_transmission_unit = 1420; + cap.max_transmission_unit = WG_MTU; cap } } diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index baa92a7..a632792 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -75,7 +75,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { // Consumer for IP packets to send through the virtual interface // Initialize the interface let device = - VirtualIpDevice::new(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) + VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg) .with_context(|| "Failed to initialize TCP VirtualIpDevice")?; let mut virtual_interface = InterfaceBuilder::new(device) .ip_addrs([ diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 88f13b6..139610d 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -1,11 +1,18 @@ +use anyhow::Context; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; +use dashmap::DashMap; +use smoltcp::iface::InterfaceBuilder; +use smoltcp::socket::{SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; +use smoltcp::wire::{IpAddress, IpCidr}; use crate::config::PortForwardConfig; +use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; -use crate::wg::WireGuardTunnel; +use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; pub struct UdpVirtualInterface { port_forward: PortForwardConfig, @@ -37,16 +44,68 @@ impl VirtualInterfacePoll for UdpVirtualInterface { let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx; // The IP to bind client sockets to - let _source_peer_ip = self.wg.source_peer_ip; + let source_peer_ip = self.wg.source_peer_ip; // The IP/port to bind the server socket to - let _destination = self.port_forward.destination; + let destination = self.port_forward.destination; + + // Initialize a channel for IP packets. + // The "base transmitted" is cloned so that each virtual port can register a sender in the tunnel. + // The receiver is given to the device so that the Virtual Interface can process incoming IP packets from the tunnel. + let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY); + + let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx); + let mut virtual_interface = InterfaceBuilder::new(device) + .ip_addrs([ + // Interface handles IP packets for the sender and recipient + IpCidr::new(source_peer_ip.into(), 32), + IpCidr::new(destination.ip().into(), 32), + ]) + .finalize(); + + // Server socket: this is a placeholder for the interface. + let server_socket: anyhow::Result = { + static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_RX_DATA: [u8; 0] = []; + static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = []; + static mut UDP_SERVER_TX_DATA: [u8; 0] = []; + let udp_rx_buffer = + UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe { + &mut UDP_SERVER_RX_DATA[..] + }); + let udp_tx_buffer = + UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe { + &mut UDP_SERVER_TX_DATA[..] + }); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + + socket + .bind((IpAddress::from(destination.ip()), destination.port())) + .with_context(|| "UDP virtual server socket failed to listen")?; + + Ok(socket) + }; + + let mut socket_set = SocketSet::new(vec![]); + let _server_handle = socket_set.add(server_socket?); loop { let _loop_start = smoltcp::time::Instant::now(); + let wg = self.wg.clone(); // TODO: smoltcp UDP if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { + // Register the socket in WireGuard Tunnel if not already + if !wg.is_registered(client_port) { + wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to register UDP socket in WireGuard tunnel", + client_port + ); + }); + } + // TODO: Find the matching client socket and send // Echo for now self.data_to_real_client_tx diff --git a/src/wg.rs b/src/wg.rs index 673da1c..80c7272 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -90,7 +90,7 @@ impl WireGuardTunnel { virtual_port: VirtualPort, sender: tokio::sync::mpsc::Sender>, ) -> anyhow::Result<()> { - let existing = self.virtual_port_ip_tx.contains_key(&virtual_port); + let existing = self.is_registered(virtual_port); if existing { Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) } else { @@ -99,6 +99,10 @@ impl WireGuardTunnel { } } + pub fn is_registered(&self, virtual_port: VirtualPort) -> bool { + self.virtual_port_ip_tx.contains_key(&virtual_port) + } + /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. pub async fn register_sink_interface( &self, From d975efefaf787934bfe7f6f817ce3230f246c589 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Mon, 25 Oct 2021 19:05:40 -0400 Subject: [PATCH 13/17] End-to-end UDP implementation Port re-use still needs to be implemented to prevent exhaustion over time, and flooding. --- src/tunnel/udp.rs | 5 +- src/virtual_iface/udp.rs | 102 ++++++++++++++++++++++++++++++--------- src/wg.rs | 35 +++++++++----- 3 files changed, 106 insertions(+), 36 deletions(-) diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index db232cf..6eb7e02 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -175,7 +175,7 @@ impl UdpPortPool { } } - /// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity). + /// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity). pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result { { let inner = self.inner.read().await; @@ -184,6 +184,9 @@ impl UdpPortPool { } } + // TODO: When the port pool is exhausted, it should re-queue the least recently used port. + // TODO: Limit number of ports in use by peer IP + let mut inner = self.inner.write().await; let port = inner .queue diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 139610d..382d7c3 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -4,9 +4,8 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use dashmap::DashMap; use smoltcp::iface::InterfaceBuilder; -use smoltcp::socket::{SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; +use smoltcp::socket::{SocketHandle, SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer}; use smoltcp::wire::{IpAddress, IpCidr}; use crate::config::PortForwardConfig; @@ -14,6 +13,8 @@ use crate::virtual_device::VirtualIpDevice; use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort}; use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY}; +const MAX_PACKET: usize = 65536; + pub struct UdpVirtualInterface { port_forward: PortForwardConfig, wg: Arc, @@ -89,31 +90,88 @@ impl VirtualInterfacePoll for UdpVirtualInterface { let mut socket_set = SocketSet::new(vec![]); let _server_handle = socket_set.add(server_socket?); + // A map of virtual port to client socket. + let mut client_sockets: HashMap = HashMap::new(); + loop { - let _loop_start = smoltcp::time::Instant::now(); + let loop_start = smoltcp::time::Instant::now(); let wg = self.wg.clone(); - // TODO: smoltcp UDP + + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + trace!("UDP virtual interface polled some packets to be processed"); + } + Err(e) => error!("UDP virtual interface poll error: {:?}", e), + _ => {} + } + + // Loop through each client socket and check if there is any data to send back + // to the real client. + for (virtual_port, client_socket_handle) in client_sockets.iter() { + let mut client_socket = socket_set.get::(*client_socket_handle); + match client_socket.recv() { + Ok((data, _peer)) => { + // Send the data back to the real client using MPSC channel + self.data_to_real_client_tx + .send((*virtual_port, data.to_vec())) + .await + .unwrap_or_else(|e| { + error!( + "[{}] Failed to dispatch data from virtual client to real client: {:?}", + virtual_port, e + ); + }); + } + Err(smoltcp::Error::Exhausted) => {} + Err(e) => { + error!( + "[{}] Failed to read from virtual client socket: {:?}", + virtual_port, e + ); + } + } + } if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { - // Register the socket in WireGuard Tunnel if not already - if !wg.is_registered(client_port) { - wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to register UDP socket in WireGuard tunnel", - client_port - ); - }); - } - - // TODO: Find the matching client socket and send - // Echo for now - self.data_to_real_client_tx - .send((client_port, data)) - .await + // Register the socket in WireGuard Tunnel (overrides any previous registration as well) + wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) .unwrap_or_else(|e| { error!( - "[{}] Failed to dispatch data from virtual client to real client: {:?}", + "[{}] Failed to register UDP socket in WireGuard tunnel: {:?}", + client_port, e + ); + }); + + let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { + let rx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let tx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + + socket + .bind((IpAddress::from(wg.source_peer_ip), client_port.0)) + .unwrap_or_else(|e| { + error!( + "[{}] UDP virtual client socket failed to bind: {:?}", + client_port, e + ); + }); + + socket_set.add(socket) + }); + + let mut client_socket = socket_set.get::(*client_socket_handle); + client_socket + .send_slice( + &data, + (IpAddress::from(destination.ip()), destination.port()).into(), + ) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to send data to virtual server: {:?}", client_port, e ); }); @@ -121,7 +179,5 @@ impl VirtualInterfacePoll for UdpVirtualInterface { tokio::time::sleep(Duration::from_millis(1)).await; } - - // Ok(()) } } diff --git a/src/wg.rs b/src/wg.rs index 80c7272..2dc20d3 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -4,7 +4,7 @@ use std::time::Duration; use anyhow::Context; use boringtun::noise::{Tunn, TunnResult}; use log::Level; -use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket}; +use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket}; use tokio::net::UdpSocket; use tokio::sync::RwLock; @@ -90,17 +90,8 @@ impl WireGuardTunnel { virtual_port: VirtualPort, sender: tokio::sync::mpsc::Sender>, ) -> anyhow::Result<()> { - let existing = self.is_registered(virtual_port); - if existing { - Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port)) - } else { - self.virtual_port_ip_tx.insert(virtual_port, sender); - Ok(()) - } - } - - pub fn is_registered(&self, virtual_port: VirtualPort) -> bool { - self.virtual_port_ip_tx.contains_key(&virtual_port) + self.virtual_port_ip_tx.insert(virtual_port, sender); + Ok(()) } /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. @@ -276,6 +267,7 @@ impl WireGuardTunnel { .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.protocol() { IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), + IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())), // Unrecognized protocol, so we cannot determine where to route _ => Some(RouteResult::Drop), }) @@ -287,6 +279,7 @@ impl WireGuardTunnel { .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) .map(|packet| match packet.next_header() { IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), + IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())), // Unrecognized protocol, so we cannot determine where to route _ => Some(RouteResult::Drop), }) @@ -316,6 +309,24 @@ impl WireGuardTunnel { .unwrap_or(RouteResult::Drop) } + /// Makes a decision on the handling of an incoming UDP datagram. + fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult { + UdpPacket::new_checked(datagram) + .ok() + .map(|udp| { + if self + .virtual_port_ip_tx + .get(&VirtualPort(udp.dst_port(), PortProtocol::Udp)) + .is_some() + { + RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp)) + } else { + RouteResult::Drop + } + }) + .unwrap_or(RouteResult::Drop) + } + /// Route a packet to the IP sink interface. async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> { let ip_sink_tx = self.sink_ip_tx.read().await; From faf157cfeb759d20130895e10feb77b01ca0fa13 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 26 Oct 2021 00:03:44 -0400 Subject: [PATCH 14/17] UDP port re-use during flooding --- Cargo.lock | 27 +++++++++++++++ Cargo.toml | 1 + src/tunnel/udp.rs | 86 +++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 104 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 885b05b..851e8c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -374,6 +374,12 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7" +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" + [[package]] name = "hermit-abi" version = "0.1.19" @@ -398,6 +404,16 @@ dependencies = [ "quick-error", ] +[[package]] +name = "indexmap" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" +dependencies = [ + "autocfg", + "hashbrown", +] + [[package]] name = "instant" version = "0.1.11" @@ -615,6 +631,7 @@ dependencies = [ "log", "nom", "pretty_env_logger", + "priority-queue", "rand", "smoltcp", "tokio", @@ -674,6 +691,16 @@ dependencies = [ "log", ] +[[package]] +name = "priority-queue" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf40e51ccefb72d42720609e1d3c518de8b5800d723a09358d4a6d6245e1f8ca" +dependencies = [ + "autocfg", + "indexmap", +] + [[package]] name = "proc-macro-hack" version = "0.5.19" diff --git a/Cargo.toml b/Cargo.toml index 905ac32..9549a80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ rand = "0.8.4" nom = "7" async-trait = "0.1.51" dashmap = "4.0.2" +priority-queue = "1.2.0" diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 6eb7e02..bedcdf7 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::{BTreeMap, HashMap, VecDeque}; use std::net::{IpAddr, SocketAddr}; use std::ops::Range; use std::sync::atomic::{AtomicBool, Ordering}; @@ -6,6 +6,8 @@ use std::sync::Arc; use std::time::Instant; use anyhow::Context; +use priority_queue::double_priority_queue::DoublePriorityQueue; +use priority_queue::priority_queue::PriorityQueue; use rand::seq::SliceRandom; use rand::thread_rng; use tokio::net::UdpSocket; @@ -107,6 +109,7 @@ pub async fn udp_proxy_server( e, ); } + port_pool.update_last_transmit(port.0).await; } } } @@ -145,6 +148,8 @@ async fn next_udp_datagram( port, size, peer_addr ); + port_pool.update_last_transmit(port.0).await; + let data = buffer[..size].to_vec(); Ok(Some((port, data))) } @@ -177,26 +182,81 @@ impl UdpPortPool { /// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity). pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result { + // A port found to be reused. This is outside of the block because the read lock cannot be upgraded to a write lock. + let mut port_reuse: Option = None; + { let inner = self.inner.read().await; if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) { return Ok(*port); } + + // Count how many ports are being used by the peer IP + let peer_ip = peer_addr.ip(); + let peer_port_count = inner + .peer_port_usage + .get(&peer_ip) + .map(|v| v.len()) + .unwrap_or_default(); + + if peer_port_count >= PORTS_PER_IP { + // Return least recently used port in this IP's pool + port_reuse = Some( + *(inner + .peer_port_usage + .get(&peer_ip) + .unwrap() + .peek_min() + .unwrap() + .0), + ); + warn!( + "Peer [{}] is re-using active virtual port {} due to self-exhaustion.", + peer_addr, + port_reuse.unwrap() + ); + } } - // TODO: When the port pool is exhausted, it should re-queue the least recently used port. - // TODO: Limit number of ports in use by peer IP - let mut inner = self.inner.write().await; - let port = inner - .queue - .pop_front() - .with_context(|| "UDP virtual port pool is exhausted")?; + + let port = port_reuse + .or_else(|| inner.queue.pop_front()) + .or_else(|| { + // If there is no port to reuse, and the port pool is exhausted, take the last recently used port overall, + // as long as the last transmission exceeds the deadline + let last: (&u16, &Instant) = inner.port_usage.peek_min().unwrap(); + if Instant::now().duration_since(*last.1).as_secs() > UDP_TIMEOUT_SECONDS { + warn!( + "Peer [{}] is re-using inactive virtual port {} due to global exhaustion.", + peer_addr, last.0 + ); + Some(*last.0) + } else { + None + } + }) + .with_context(|| "virtual port pool is exhausted")?; + inner.port_by_peer_addr.insert(peer_addr, port); inner.peer_addr_by_port.insert(port, peer_addr); Ok(port) } + /// Notify that the given virtual port has received or transmitted a UDP datagram. + pub async fn update_last_transmit(&self, port: u16) { + let mut inner = self.inner.write().await; + if let Some(peer) = inner.peer_addr_by_port.get(&port).copied() { + let mut pq: &mut DoublePriorityQueue = inner + .peer_port_usage + .entry(peer.ip()) + .or_insert_with(Default::default); + pq.push(port, Instant::now()); + } + let mut pq: &mut DoublePriorityQueue = &mut inner.port_usage; + pq.push(port, Instant::now()); + } + pub async fn get_peer_addr(&self, port: u16) -> Option { let inner = self.inner.read().await; inner.peer_addr_by_port.get(&port).copied() @@ -208,8 +268,14 @@ impl UdpPortPool { struct UdpPortPoolInner { /// Remaining ports in the pool. queue: VecDeque, - /// The port assigned by peer IP/port. + /// The port assigned by peer IP/port. This is used to lookup an existing virtual port + /// for an incoming UDP datagram. port_by_peer_addr: HashMap, - /// The socket address assigned to a peer IP/port. + /// The socket address assigned to a peer IP/port. This is used to send a UDP datagram to + /// the real peer address, given the virtual port. peer_addr_by_port: HashMap, + /// Keeps an ordered map of the most recently used virtual ports by a peer (client) IP. + peer_port_usage: HashMap>, + /// Keeps an ordered map of the most recently used virtual ports in general. + port_usage: DoublePriorityQueue, } From 0da6fa51de35ef1d323673e4273ea9eeb545d707 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 26 Oct 2021 00:38:22 -0400 Subject: [PATCH 15/17] udp: use tokio select instead of 1ms loop --- src/virtual_iface/udp.rs | 166 ++++++++++++++++++++++----------------- 1 file changed, 92 insertions(+), 74 deletions(-) diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 382d7c3..5b275de 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -93,91 +93,109 @@ impl VirtualInterfacePoll for UdpVirtualInterface { // A map of virtual port to client socket. let mut client_sockets: HashMap = HashMap::new(); + // The next instant required to poll the virtual interface + // None means "immediate poll required". + let mut next_poll: Option = None; + loop { - let loop_start = smoltcp::time::Instant::now(); let wg = self.wg.clone(); + tokio::select! { + // Wait the recommended amount of time by smoltcp, and poll again. + _ = match next_poll { + None => tokio::time::sleep(Duration::ZERO), + Some(until) => tokio::time::sleep_until(until) + } => { + let loop_start = smoltcp::time::Instant::now(); - match virtual_interface.poll(&mut socket_set, loop_start) { - Ok(processed) if processed => { - trace!("UDP virtual interface polled some packets to be processed"); + match virtual_interface.poll(&mut socket_set, loop_start) { + Ok(processed) if processed => { + trace!("UDP virtual interface polled some packets to be processed"); + } + Err(e) => error!("UDP virtual interface poll error: {:?}", e), + _ => {} + } + + // Loop through each client socket and check if there is any data to send back + // to the real client. + for (virtual_port, client_socket_handle) in client_sockets.iter() { + let mut client_socket = socket_set.get::(*client_socket_handle); + match client_socket.recv() { + Ok((data, _peer)) => { + // Send the data back to the real client using MPSC channel + self.data_to_real_client_tx + .send((*virtual_port, data.to_vec())) + .await + .unwrap_or_else(|e| { + error!( + "[{}] Failed to dispatch data from virtual client to real client: {:?}", + virtual_port, e + ); + }); + } + Err(smoltcp::Error::Exhausted) => {} + Err(e) => { + error!( + "[{}] Failed to read from virtual client socket: {:?}", + virtual_port, e + ); + } + } + } + + next_poll = match virtual_interface.poll_delay(&socket_set, loop_start) { + Some(smoltcp::time::Duration::ZERO) => None, + Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())), + None => None, + } } - Err(e) => error!("UDP virtual interface poll error: {:?}", e), - _ => {} - } - - // Loop through each client socket and check if there is any data to send back - // to the real client. - for (virtual_port, client_socket_handle) in client_sockets.iter() { - let mut client_socket = socket_set.get::(*client_socket_handle); - match client_socket.recv() { - Ok((data, _peer)) => { - // Send the data back to the real client using MPSC channel - self.data_to_real_client_tx - .send((*virtual_port, data.to_vec())) - .await + // Wait for data to be received from the real client + data_recv_result = data_to_virtual_server_rx.recv() => { + if let Some((client_port, data)) = data_recv_result { + // Register the socket in WireGuard Tunnel (overrides any previous registration as well) + wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) .unwrap_or_else(|e| { error!( - "[{}] Failed to dispatch data from virtual client to real client: {:?}", - virtual_port, e + "[{}] Failed to register UDP socket in WireGuard tunnel: {:?}", + client_port, e + ); + }); + + let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { + let rx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let tx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let rx_data = vec![0u8; MAX_PACKET]; + let tx_data = vec![0u8; MAX_PACKET]; + let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); + let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); + let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + + socket + .bind((IpAddress::from(wg.source_peer_ip), client_port.0)) + .unwrap_or_else(|e| { + error!( + "[{}] UDP virtual client socket failed to bind: {:?}", + client_port, e + ); + }); + + socket_set.add(socket) + }); + + let mut client_socket = socket_set.get::(*client_socket_handle); + client_socket + .send_slice( + &data, + (IpAddress::from(destination.ip()), destination.port()).into(), + ) + .unwrap_or_else(|e| { + error!( + "[{}] Failed to send data to virtual server: {:?}", + client_port, e ); }); } - Err(smoltcp::Error::Exhausted) => {} - Err(e) => { - error!( - "[{}] Failed to read from virtual client socket: {:?}", - virtual_port, e - ); - } } } - - if let Ok((client_port, data)) = data_to_virtual_server_rx.try_recv() { - // Register the socket in WireGuard Tunnel (overrides any previous registration as well) - wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone()) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to register UDP socket in WireGuard tunnel: {:?}", - client_port, e - ); - }); - - let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { - let rx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; - let tx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; - let rx_data = vec![0u8; MAX_PACKET]; - let tx_data = vec![0u8; MAX_PACKET]; - let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); - let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data); - let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - - socket - .bind((IpAddress::from(wg.source_peer_ip), client_port.0)) - .unwrap_or_else(|e| { - error!( - "[{}] UDP virtual client socket failed to bind: {:?}", - client_port, e - ); - }); - - socket_set.add(socket) - }); - - let mut client_socket = socket_set.get::(*client_socket_handle); - client_socket - .send_slice( - &data, - (IpAddress::from(destination.ip()), destination.port()).into(), - ) - .unwrap_or_else(|e| { - error!( - "[{}] Failed to send data to virtual server: {:?}", - client_port, e - ); - }); - } - - tokio::time::sleep(Duration::from_millis(1)).await; } } } From 1493feb1844fd93bbf6419d095fbd3905625b630 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 26 Oct 2021 01:20:02 -0400 Subject: [PATCH 16/17] Reduce udp client socket meta buffer --- src/main.rs | 11 ++++------- src/virtual_iface/udp.rs | 4 ++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/main.rs b/src/main.rs index d20138d..146a65b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -57,13 +57,10 @@ async fn main() -> anyhow::Result<()> { .into_iter() .map(|pf| (pf, wg.clone(), tcp_port_pool.clone(), udp_port_pool.clone())) .for_each(move |(pf, wg, tcp_port_pool, udp_port_pool)| { - std::thread::spawn(move || { - let cpu_pool = tokio::runtime::Runtime::new().unwrap(); - cpu_pool.block_on(async move { - tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg) - .await - .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) - }); + tokio::spawn(async move { + tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg) + .await + .unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e)) }); }); } diff --git a/src/virtual_iface/udp.rs b/src/virtual_iface/udp.rs index 5b275de..212cb8d 100644 --- a/src/virtual_iface/udp.rs +++ b/src/virtual_iface/udp.rs @@ -161,8 +161,8 @@ impl VirtualInterfacePoll for UdpVirtualInterface { }); let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| { - let rx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; - let tx_meta = vec![UdpPacketMetadata::EMPTY; MAX_PACKET]; + let rx_meta = vec![UdpPacketMetadata::EMPTY; 10]; + let tx_meta = vec![UdpPacketMetadata::EMPTY; 10]; let rx_data = vec![0u8; MAX_PACKET]; let tx_data = vec![0u8; MAX_PACKET]; let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data); From 4ecf16bc3fc35927e66222b27654960881193342 Mon Sep 17 00:00:00 2001 From: Aram Peres Date: Tue, 26 Oct 2021 01:47:48 -0400 Subject: [PATCH 17/17] Update readme --- README.md | 85 ++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 1defff3..eaade12 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ A cross-platform, user-space WireGuard port-forwarder that requires no system ne ## Use-case - You have an existing WireGuard endpoint (router), accessible using its UDP endpoint (typically port 51820); and -- You have a peer on the WireGuard network, running a TCP server on a port accessible to the WireGuard network; and -- You want to access this TCP service from a second computer, on which you can't install WireGuard because you +- You have a peer on the WireGuard network, running a TCP or UDP service on a port accessible to the WireGuard network; and +- You want to access this TCP or UDP service from a second computer, on which you can't install WireGuard because you can't (no root access) or don't want to (polluting OS configs). For example, this can be useful to forward a port from a Kubernetes cluster to a server behind WireGuard, @@ -17,7 +17,7 @@ without needing to install WireGuard in a Pod. ## Usage -**onetun** opens a TCP port on your local system, from which traffic is forwarded to a TCP port on a peer in your +**onetun** opens a TCP or UDP port on your local system, from which traffic is forwarded to a TCP port on a peer in your WireGuard network. It requires no changes to your operating system's network interfaces: you don't need to have `root` access, or install any WireGuard tool on your local system for it to work. @@ -25,12 +25,12 @@ The only prerequisite is to register a peer IP and public key on the remote Wire the WireGuard endpoint to trust the onetun peer and for packets to be routed. ``` -./onetun \ - --endpoint-addr \ - --endpoint-public-key \ - --private-key \ - --source-peer-ip \ - --keep-alive \ +./onetun [src_host:]::[:TCP,UDP,...] [...] \ + --endpoint-addr \ + --endpoint-public-key \ + --private-key \ + --source-peer-ip \ + --keep-alive \ --log Tunneling [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +INFO onetun > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) ``` Which means you can now access the port locally! @@ -84,6 +84,53 @@ $ curl 127.0.0.1:8080 Hello world! ``` +### Multiple tunnels in parallel + +**onetun** supports running multiple tunnels in parallel. For example: + +``` +$ ./onetun 127.0.0.1:8080:192.168.4.2:8080 127.0.0.1:8081:192.168.4.4:8081 +INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8081]->[192.168.4.4:8081] (via [140.30.3.182:51820] as peer 192.168.4.3) +``` + +... would open TCP ports 8080 and 8081 locally, which forward to their respective ports on the different peers. + +### UDP Support + +**onetun** supports UDP forwarding. You can add `:UDP` at the end of the port-forward configuration, or `UDP,TCP` to support +both protocols on the same port (note that this opens 2 separate tunnels, just on the same port) + +``` +$ ./onetun 127.0.0.1:8080:192.168.4.2:8080:UDP +INFO onetun::tunnel > Tunneling UDP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) + +$ ./onetun 127.0.0.1:8080:192.168.4.2:8080:UDP,TCP +INFO onetun::tunnel > Tunneling UDP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +``` + +Note: UDP support is totally experimental. You should read the UDP portion of the **Architecture** section before using +it in any production capacity. + +### IPv6 Support + +**onetun** supports both IPv4 and IPv6. In fact, you can use onetun to forward some IP version to another, e.g. 6-to-4: + +``` +$ ./onetun [::1]:8080:192.168.4.2:8080 +INFO onetun::tunnel > Tunneling TCP [[::1]:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +``` + +Note that each tunnel can only support one "source" IP version and one "destination" IP version. If you want to support +both IPv4 and IPv6 on the same port, you should create a second port-forward: + +``` +$ ./onetun [::1]:8080:192.168.4.2:8080 127.0.0.1:8080:192.168.4.2:8080 +INFO onetun::tunnel > Tunneling TCP [[::1]:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) +``` + ## Download Normally I would publish `onetun` to crates.io. However, it depends on some features @@ -150,6 +197,22 @@ the virtual client to read it. When the virtual client reads data, it simply pus This work is all made possible by [smoltcp](https://github.com/smoltcp-rs/smoltcp) and [boringtun](https://github.com/cloudflare/boringtun), so special thanks to the developers of those libraries. +### UDP + +UDP support is experimental. Since UDP messages are stateless, there is no perfect way for onetun to know when to release the +assigned virtual port back to the pool for a new peer to use. This would cause issues over time as running out of virtual ports +would mean new datagrams get dropped. To alleviate this, onetun will cap the amount of ports used by one peer IP address; +if another datagram comes in from a different port but with the same IP, the least recently used virtual port will be freed and assigned +to the new peer port. At that point, any datagram packets destined for the reused virtual port will be routed to the new peer, +and any datagrams received by the old peer will be dropped. + +In addition, in cases where many IPs are exhausting the UDP virtual port pool in tandem, and a totally new peer IP sends data, +onetun will have to pick the least recently used virtual port from _any_ peer IP and reuse it. However, this is only allowed +if the least recently used port hasn't been used for a certain amount of time. If all virtual ports are truly "active" +(with at least one transmission within that time limit), the new datagram gets dropped due to exhaustion. + +All in all, I would not recommend using UDP forwarding for public services, since it's most likely prone to simple DoS or DDoS. + ## License MIT. See `LICENSE` for details.