Merge pull request #16 from aramperes/1-udp

This commit is contained in:
Aram 🍐 2021-10-26 00:50:20 -05:00 committed by GitHub
commit 1ed555a98c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1663 additions and 533 deletions

87
Cargo.lock generated
View file

@ -38,6 +38,17 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e" checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e"
[[package]]
name = "async-trait"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44318e776df68115a881de9a8fd1b9e53368d7a4a5ce4cc48517da3393233a5e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "atty" name = "atty"
version = "0.2.14" version = "0.2.14"
@ -192,6 +203,16 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "dashmap"
version = "4.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e77a43b28d0668df09411cb0bc9a8c2adc40f9a048afe863e05fd43251e8e39c"
dependencies = [
"cfg-if",
"num_cpus",
]
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -353,6 +374,12 @@ version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7" checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7"
[[package]]
name = "hashbrown"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.1.19" version = "0.1.19"
@ -377,6 +404,16 @@ dependencies = [
"quick-error", "quick-error",
] ]
[[package]]
name = "indexmap"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5"
dependencies = [
"autocfg",
"hashbrown",
]
[[package]] [[package]]
name = "instant" name = "instant"
version = "0.1.11" version = "0.1.11"
@ -458,15 +495,6 @@ dependencies = [
"scopeguard", "scopeguard",
] ]
[[package]]
name = "lockfree"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74ee94b5ad113c7cb98c5a040f783d0952ee4fe100993881d1673c2cb002dd23"
dependencies = [
"owned-alloc",
]
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.14" version = "0.4.14"
@ -488,6 +516,12 @@ version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a"
[[package]]
name = "minimal-lexical"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c64630dcdd71f1a64c435f54885086a0de5d6a12d104d69b165fb7d5286d677"
[[package]] [[package]]
name = "miniz_oxide" name = "miniz_oxide"
version = "0.4.4" version = "0.4.4"
@ -520,6 +554,17 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "nom"
version = "7.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffd9d26838a953b4af82cbeb9f1592c6798916983959be223a7124e992742c1"
dependencies = [
"memchr",
"minimal-lexical",
"version_check",
]
[[package]] [[package]]
name = "ntapi" name = "ntapi"
version = "0.3.6" version = "0.3.6"
@ -578,23 +623,20 @@ name = "onetun"
version = "0.1.11" version = "0.1.11"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"boringtun", "boringtun",
"clap", "clap",
"dashmap",
"futures", "futures",
"lockfree",
"log", "log",
"nom",
"pretty_env_logger", "pretty_env_logger",
"priority-queue",
"rand", "rand",
"smoltcp", "smoltcp",
"tokio", "tokio",
] ]
[[package]]
name = "owned-alloc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30fceb411f9a12ff9222c5f824026be368ff15dc2f13468d850c7d3f502205d6"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.11.2" version = "0.11.2"
@ -649,6 +691,16 @@ dependencies = [
"log", "log",
] ]
[[package]]
name = "priority-queue"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf40e51ccefb72d42720609e1d3c518de8b5800d723a09358d4a6d6245e1f8ca"
dependencies = [
"autocfg",
"indexmap",
]
[[package]] [[package]]
name = "proc-macro-hack" name = "proc-macro-hack"
version = "0.5.19" version = "0.5.19"
@ -846,13 +898,14 @@ checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309"
[[package]] [[package]]
name = "smoltcp" name = "smoltcp"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/smoltcp-rs/smoltcp?branch=master#35e833e33dfd3e4efc3eb7d5de06bec17c54b011" source = "git+https://github.com/smoltcp-rs/smoltcp?branch=master#25c539bb7c96789270f032ede2a967cf0fe5cf57"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"byteorder", "byteorder",
"libc", "libc",
"log", "log",
"managed", "managed",
"rand_core",
] ]
[[package]] [[package]]

View file

@ -13,6 +13,9 @@ pretty_env_logger = "0.3"
anyhow = "1" anyhow = "1"
smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" } smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", branch = "master" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
lockfree = "0.5.1"
futures = "0.3.17" futures = "0.3.17"
rand = "0.8.4" rand = "0.8.4"
nom = "7"
async-trait = "0.1.51"
dashmap = "4.0.2"
priority-queue = "1.2.0"

View file

@ -10,8 +10,8 @@ A cross-platform, user-space WireGuard port-forwarder that requires no system ne
## Use-case ## Use-case
- You have an existing WireGuard endpoint (router), accessible using its UDP endpoint (typically port 51820); and - You have an existing WireGuard endpoint (router), accessible using its UDP endpoint (typically port 51820); and
- You have a peer on the WireGuard network, running a TCP server on a port accessible to the WireGuard network; and - You have a peer on the WireGuard network, running a TCP or UDP service on a port accessible to the WireGuard network; and
- You want to access this TCP service from a second computer, on which you can't install WireGuard because you - You want to access this TCP or UDP service from a second computer, on which you can't install WireGuard because you
can't (no root access) or don't want to (polluting OS configs). can't (no root access) or don't want to (polluting OS configs).
For example, this can be useful to forward a port from a Kubernetes cluster to a server behind WireGuard, For example, this can be useful to forward a port from a Kubernetes cluster to a server behind WireGuard,
@ -19,7 +19,7 @@ without needing to install WireGuard in a Pod.
## Usage ## Usage
**onetun** opens a TCP port on your local system, from which traffic is forwarded to a TCP port on a peer in your **onetun** opens a TCP or UDP port on your local system, from which traffic is forwarded to a TCP port on a peer in your
WireGuard network. It requires no changes to your operating system's network interfaces: you don't need to have `root` WireGuard network. It requires no changes to your operating system's network interfaces: you don't need to have `root`
access, or install any WireGuard tool on your local system for it to work. access, or install any WireGuard tool on your local system for it to work.
@ -27,7 +27,7 @@ The only prerequisite is to register a peer IP and public key on the remote Wire
the WireGuard endpoint to trust the onetun peer and for packets to be routed. the WireGuard endpoint to trust the onetun peer and for packets to be routed.
``` ```
./onetun <SOURCE_ADDR> <DESTINATION_ADDR> \ ./onetun [src_host:]<src_port>:<dst_host>:<dst_port>[:TCP,UDP,...] [...] \
--endpoint-addr <public WireGuard endpoint address> \ --endpoint-addr <public WireGuard endpoint address> \
--endpoint-public-key <the public key of the peer on the endpoint> \ --endpoint-public-key <the public key of the peer on the endpoint> \
--private-key <private key assigned to onetun> \ --private-key <private key assigned to onetun> \
@ -65,7 +65,7 @@ We want to access a web server on the friendly peer (`192.168.4.2`) on port `808
local port, say `127.0.0.1:8080`, that will tunnel through WireGuard to reach the peer web server: local port, say `127.0.0.1:8080`, that will tunnel through WireGuard to reach the peer web server:
```shell ```shell
./onetun 127.0.0.1:8080 192.168.4.2:8080 \ ./onetun 127.0.0.1:8080:192.168.4.2:8080 \
--endpoint-addr 140.30.3.182:51820 \ --endpoint-addr 140.30.3.182:51820 \
--endpoint-public-key 'PUB_****************************************' \ --endpoint-public-key 'PUB_****************************************' \
--private-key 'PRIV_BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB' \ --private-key 'PRIV_BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB' \
@ -76,7 +76,7 @@ local port, say `127.0.0.1:8080`, that will tunnel through WireGuard to reach th
You'll then see this log: You'll then see this log:
``` ```
INFO onetun > Tunnelling [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3) INFO onetun > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
``` ```
Which means you can now access the port locally! Which means you can now access the port locally!
@ -86,6 +86,53 @@ $ curl 127.0.0.1:8080
Hello world! Hello world!
``` ```
### Multiple tunnels in parallel
**onetun** supports running multiple tunnels in parallel. For example:
```
$ ./onetun 127.0.0.1:8080:192.168.4.2:8080 127.0.0.1:8081:192.168.4.4:8081
INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8081]->[192.168.4.4:8081] (via [140.30.3.182:51820] as peer 192.168.4.3)
```
... would open TCP ports 8080 and 8081 locally, which forward to their respective ports on the different peers.
### UDP Support
**onetun** supports UDP forwarding. You can add `:UDP` at the end of the port-forward configuration, or `UDP,TCP` to support
both protocols on the same port (note that this opens 2 separate tunnels, just on the same port)
```
$ ./onetun 127.0.0.1:8080:192.168.4.2:8080:UDP
INFO onetun::tunnel > Tunneling UDP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
$ ./onetun 127.0.0.1:8080:192.168.4.2:8080:UDP,TCP
INFO onetun::tunnel > Tunneling UDP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
```
Note: UDP support is totally experimental. You should read the UDP portion of the **Architecture** section before using
it in any production capacity.
### IPv6 Support
**onetun** supports both IPv4 and IPv6. In fact, you can use onetun to forward some IP version to another, e.g. 6-to-4:
```
$ ./onetun [::1]:8080:192.168.4.2:8080
INFO onetun::tunnel > Tunneling TCP [[::1]:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
```
Note that each tunnel can only support one "source" IP version and one "destination" IP version. If you want to support
both IPv4 and IPv6 on the same port, you should create a second port-forward:
```
$ ./onetun [::1]:8080:192.168.4.2:8080 127.0.0.1:8080:192.168.4.2:8080
INFO onetun::tunnel > Tunneling TCP [[::1]:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
INFO onetun::tunnel > Tunneling TCP [127.0.0.1:8080]->[192.168.4.2:8080] (via [140.30.3.182:51820] as peer 192.168.4.3)
```
## Download ## Download
Normally I would publish `onetun` to crates.io. However, it depends on some features Normally I would publish `onetun` to crates.io. However, it depends on some features
@ -152,6 +199,22 @@ the virtual client to read it. When the virtual client reads data, it simply pus
This work is all made possible by [smoltcp](https://github.com/smoltcp-rs/smoltcp) and [boringtun](https://github.com/cloudflare/boringtun), This work is all made possible by [smoltcp](https://github.com/smoltcp-rs/smoltcp) and [boringtun](https://github.com/cloudflare/boringtun),
so special thanks to the developers of those libraries. so special thanks to the developers of those libraries.
### UDP
UDP support is experimental. Since UDP messages are stateless, there is no perfect way for onetun to know when to release the
assigned virtual port back to the pool for a new peer to use. This would cause issues over time as running out of virtual ports
would mean new datagrams get dropped. To alleviate this, onetun will cap the amount of ports used by one peer IP address;
if another datagram comes in from a different port but with the same IP, the least recently used virtual port will be freed and assigned
to the new peer port. At that point, any datagram packets destined for the reused virtual port will be routed to the new peer,
and any datagrams received by the old peer will be dropped.
In addition, in cases where many IPs are exhausting the UDP virtual port pool in tandem, and a totally new peer IP sends data,
onetun will have to pick the least recently used virtual port from _any_ peer IP and reuse it. However, this is only allowed
if the least recently used port hasn't been used for a certain amount of time. If all virtual ports are truly "active"
(with at least one transmission within that time limit), the new datagram gets dropped due to exhaustion.
All in all, I would not recommend using UDP forwarding for public services, since it's most likely prone to simple DoS or DDoS.
## License ## License
MIT. See `LICENSE` for details. MIT. See `LICENSE` for details.

View file

@ -1,3 +1,6 @@
use std::collections::HashSet;
use std::convert::TryFrom;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::sync::Arc; use std::sync::Arc;
@ -7,8 +10,7 @@ use clap::{App, Arg};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Config { pub struct Config {
pub(crate) source_addr: SocketAddr, pub(crate) port_forwards: Vec<PortForwardConfig>,
pub(crate) dest_addr: SocketAddr,
pub(crate) private_key: Arc<X25519SecretKey>, pub(crate) private_key: Arc<X25519SecretKey>,
pub(crate) endpoint_public_key: Arc<X25519PublicKey>, pub(crate) endpoint_public_key: Arc<X25519PublicKey>,
pub(crate) endpoint_addr: SocketAddr, pub(crate) endpoint_addr: SocketAddr,
@ -23,16 +25,23 @@ impl Config {
.author("Aram Peres <aram.peres@gmail.com>") .author("Aram Peres <aram.peres@gmail.com>")
.version(env!("CARGO_PKG_VERSION")) .version(env!("CARGO_PKG_VERSION"))
.args(&[ .args(&[
Arg::with_name("SOURCE_ADDR") Arg::with_name("PORT_FORWARD")
.required(true) .required(false)
.multiple(true)
.takes_value(true) .takes_value(true)
.env("ONETUN_SOURCE_ADDR") .help("Port forward configurations. The format of each argument is [src_host:]<src_port>:<dst_host>:<dst_port>[:TCP,UDP,...], \
.help("The source address (IP + port) to forward from. Example: 127.0.0.1:2115"), where [src_host] is the local IP to listen on, <src_port> is the local port to listen on, <dst_host> is the remote peer IP to forward to, and <dst_port> is the remote port to forward to. \
Arg::with_name("DESTINATION_ADDR") Environment variables of the form 'ONETUN_PORT_FORWARD_[#]' are also accepted, where [#] starts at 1.\n\
.required(true) Examples:\n\
.takes_value(true) \t127.0.0.1:8080:192.168.4.1:8081:TCP,UDP\n\
.env("ONETUN_DESTINATION_ADDR") \t127.0.0.1:8080:192.168.4.1:8081:TCP\n\
.help("The destination address (IP + port) to forward to. The IP should be a peer registered in the Wireguard endpoint. Example: 192.168.4.2:2116"), \t0.0.0.0:8080:192.168.4.1:8081\n\
\t[::1]:8080:192.168.4.1:8081\n\
\t8080:192.168.4.1:8081\n\
\t8080:192.168.4.1:8081:TCP\n\
\tlocalhost:8080:192.168.4.1:8081:TCP\n\
\tlocalhost:8080:peer.intranet:8081:TCP\
"),
Arg::with_name("private-key") Arg::with_name("private-key")
.required(true) .required(true)
.takes_value(true) .takes_value(true)
@ -72,11 +81,37 @@ impl Config {
.help("Configures the log level and format.") .help("Configures the log level and format.")
]).get_matches(); ]).get_matches();
// Combine `PORT_FORWARD` arg and `ONETUN_PORT_FORWARD_#` envs
let mut port_forward_strings = HashSet::new();
if let Some(values) = matches.values_of("PORT_FORWARD") {
for value in values {
port_forward_strings.insert(value.to_owned());
}
}
for n in 1.. {
if let Ok(env) = std::env::var(format!("ONETUN_PORT_FORWARD_{}", n)) {
port_forward_strings.insert(env);
} else {
break;
}
}
if port_forward_strings.is_empty() {
return Err(anyhow::anyhow!("No port forward configurations given."));
}
// Parse `PORT_FORWARD` strings into `PortForwardConfig`
let port_forwards: anyhow::Result<Vec<Vec<PortForwardConfig>>> = port_forward_strings
.into_iter()
.map(|s| PortForwardConfig::from_notation(&s))
.collect();
let port_forwards: Vec<PortForwardConfig> = port_forwards
.with_context(|| "Failed to parse port forward config")?
.into_iter()
.flatten()
.collect();
Ok(Self { Ok(Self {
source_addr: parse_addr(matches.value_of("SOURCE_ADDR")) port_forwards,
.with_context(|| "Invalid source address")?,
dest_addr: parse_addr(matches.value_of("DESTINATION_ADDR"))
.with_context(|| "Invalid destination address")?,
private_key: Arc::new( private_key: Arc::new(
parse_private_key(matches.value_of("private-key")) parse_private_key(matches.value_of("private-key"))
.with_context(|| "Invalid private key")?, .with_context(|| "Invalid private key")?,
@ -137,3 +172,310 @@ fn parse_keep_alive(s: Option<&str>) -> anyhow::Result<Option<u16>> {
Ok(None) Ok(None)
} }
} }
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct PortForwardConfig {
/// The source IP and port where the local server will run.
pub source: SocketAddr,
/// The destination IP and port to which traffic will be forwarded.
pub destination: SocketAddr,
/// The transport protocol to use for the port (Layer 4).
pub protocol: PortProtocol,
}
impl PortForwardConfig {
/// Converts a string representation into `PortForwardConfig`.
///
/// Sample formats:
/// - `127.0.0.1:8080:192.168.4.1:8081:TCP,UDP`
/// - `127.0.0.1:8080:192.168.4.1:8081:TCP`
/// - `0.0.0.0:8080:192.168.4.1:8081`
/// - `[::1]:8080:192.168.4.1:8081`
/// - `8080:192.168.4.1:8081`
/// - `8080:192.168.4.1:8081:TCP`
/// - `localhost:8080:192.168.4.1:8081:TCP`
/// - `localhost:8080:peer.intranet:8081:TCP`
///
/// Implementation Notes:
/// - The format is formalized as `[src_host:]<src_port>:<dst_host>:<dst_port>[:PROTO1,PROTO2,...]`
/// - `src_host` is optional and defaults to `127.0.0.1`.
/// - `src_host` and `dst_host` may be specified as IPv4, IPv6, or a FQDN to be resolved by DNS.
/// - IPv6 addresses must be prefixed with `[` and suffixed with `]`. Example: `[::1]`.
/// - Any `u16` is accepted as `src_port` and `dst_port`
/// - Specifying protocols (`PROTO1,PROTO2,...`) is optional and defaults to `TCP`. Values must be separated by commas.
pub fn from_notation(s: &str) -> anyhow::Result<Vec<PortForwardConfig>> {
mod parsers {
use nom::branch::alt;
use nom::bytes::complete::is_not;
use nom::character::complete::{alpha1, char, digit1};
use nom::combinator::{complete, map, opt, success};
use nom::error::ErrorKind;
use nom::multi::separated_list1;
use nom::sequence::{delimited, preceded, separated_pair, tuple};
use nom::IResult;
fn ipv6(s: &str) -> IResult<&str, &str> {
delimited(char('['), is_not("]"), char(']'))(s)
}
fn ipv4_or_fqdn(s: &str) -> IResult<&str, &str> {
let s = is_not(":")(s)?;
if s.1.chars().all(|c| c.is_ascii_digit()) {
// If ipv4 or fqdn is all digits, it's not valid.
Err(nom::Err::Error(nom::error::ParseError::from_error_kind(
s.1,
ErrorKind::Fail,
)))
} else {
Ok(s)
}
}
fn port(s: &str) -> IResult<&str, &str> {
digit1(s)
}
fn ip_or_fqdn(s: &str) -> IResult<&str, &str> {
alt((ipv6, ipv4_or_fqdn))(s)
}
fn no_ip(s: &str) -> IResult<&str, Option<&str>> {
success(None)(s)
}
fn src_addr(s: &str) -> IResult<&str, (Option<&str>, &str)> {
let with_ip = separated_pair(map(ip_or_fqdn, Some), char(':'), port);
let without_ip = tuple((no_ip, port));
alt((with_ip, without_ip))(s)
}
fn dst_addr(s: &str) -> IResult<&str, (&str, &str)> {
separated_pair(ip_or_fqdn, char(':'), port)(s)
}
fn protocol(s: &str) -> IResult<&str, &str> {
alpha1(s)
}
fn protocols(s: &str) -> IResult<&str, Option<Vec<&str>>> {
opt(preceded(char(':'), separated_list1(char(','), protocol)))(s)
}
#[allow(clippy::type_complexity)]
pub fn port_forward(
s: &str,
) -> IResult<&str, ((Option<&str>, &str), (), (&str, &str), Option<Vec<&str>>)>
{
complete(tuple((
src_addr,
map(char(':'), |_| ()),
dst_addr,
protocols,
)))(s)
}
}
// TODO: Could improve error management with custom errors, so that the messages are more helpful.
let (src_addr, _, dst_addr, protocols) = parsers::port_forward(s)
.map_err(|e| anyhow::anyhow!("Invalid port-forward definition: {}", e))?
.1;
let source = (
src_addr.0.unwrap_or("127.0.0.1"),
src_addr
.1
.parse::<u16>()
.with_context(|| "Invalid source port")?,
)
.to_socket_addrs()
.with_context(|| "Invalid source address")?
.next()
.with_context(|| "Could not resolve source address")?;
let destination = (
dst_addr.0,
dst_addr
.1
.parse::<u16>()
.with_context(|| "Invalid source port")?,
)
.to_socket_addrs() // TODO: Pass this as given and use DNS config instead (issue #15)
.with_context(|| "Invalid destination address")?
.next()
.with_context(|| "Could not resolve destination address")?;
// Parse protocols
let protocols = if let Some(protocols) = protocols {
let protocols: anyhow::Result<Vec<PortProtocol>> =
protocols.into_iter().map(PortProtocol::try_from).collect();
protocols
} else {
Ok(vec![PortProtocol::Tcp])
}
.with_context(|| "Failed to parse protocols")?;
// Returns an config for each protocol
Ok(protocols
.into_iter()
.map(|protocol| Self {
source,
destination,
protocol,
})
.collect())
}
}
impl Display for PortForwardConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}:{}", self.source, self.destination, self.protocol)
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub enum PortProtocol {
Tcp,
Udp,
}
impl TryFrom<&str> for PortProtocol {
type Error = anyhow::Error;
fn try_from(value: &str) -> anyhow::Result<Self> {
match value.to_uppercase().as_str() {
"TCP" => Ok(Self::Tcp),
"UDP" => Ok(Self::Udp),
_ => Err(anyhow::anyhow!("Invalid protocol specifier: {}", value)),
}
}
}
impl Display for PortProtocol {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::Tcp => "TCP",
Self::Udp => "UDP",
}
)
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_1() {
assert_eq!(
PortForwardConfig::from_notation("192.168.0.1:8080:192.168.4.1:8081:TCP,UDP")
.expect("Failed to parse"),
vec![
PortForwardConfig {
source: SocketAddr::from_str("192.168.0.1:8080").unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Tcp
},
PortForwardConfig {
source: SocketAddr::from_str("192.168.0.1:8080").unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Udp
}
]
);
}
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_2() {
assert_eq!(
PortForwardConfig::from_notation("192.168.0.1:8080:192.168.4.1:8081:TCP")
.expect("Failed to parse"),
vec![PortForwardConfig {
source: SocketAddr::from_str("192.168.0.1:8080").unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Tcp
}]
);
}
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_3() {
assert_eq!(
PortForwardConfig::from_notation("0.0.0.0:8080:192.168.4.1:8081")
.expect("Failed to parse"),
vec![PortForwardConfig {
source: SocketAddr::from_str("0.0.0.0:8080").unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Tcp
}]
);
}
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_4() {
assert_eq!(
PortForwardConfig::from_notation("[::1]:8080:192.168.4.1:8081")
.expect("Failed to parse"),
vec![PortForwardConfig {
source: SocketAddr::from_str("[::1]:8080").unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Tcp
}]
);
}
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_5() {
assert_eq!(
PortForwardConfig::from_notation("8080:192.168.4.1:8081").expect("Failed to parse"),
vec![PortForwardConfig {
source: SocketAddr::from_str("127.0.0.1:8080").unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Tcp
}]
);
}
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_6() {
assert_eq!(
PortForwardConfig::from_notation("8080:192.168.4.1:8081:TCP").expect("Failed to parse"),
vec![PortForwardConfig {
source: SocketAddr::from_str("127.0.0.1:8080").unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Tcp
}]
);
}
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_7() {
assert_eq!(
PortForwardConfig::from_notation("localhost:8080:192.168.4.1:8081")
.expect("Failed to parse"),
vec![PortForwardConfig {
source: "localhost:8080".to_socket_addrs().unwrap().next().unwrap(),
destination: SocketAddr::from_str("192.168.4.1:8081").unwrap(),
protocol: PortProtocol::Tcp
}]
);
}
/// Tests the parsing of `PortForwardConfig`.
#[test]
fn test_parse_port_forward_config_8() {
assert_eq!(
PortForwardConfig::from_notation("localhost:8080:localhost:8081:TCP")
.expect("Failed to parse"),
vec![PortForwardConfig {
source: "localhost:8080".to_socket_addrs().unwrap().next().unwrap(),
destination: "localhost:8081".to_socket_addrs().unwrap().next().unwrap(),
protocol: PortProtocol::Tcp
}]
);
}
}

View file

@ -1,35 +1,30 @@
#[macro_use] #[macro_use]
extern crate log; extern crate log;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState};
use smoltcp::wire::{IpAddress, IpCidr};
use tokio::net::{TcpListener, TcpStream};
use crate::config::Config; use crate::config::Config;
use crate::port_pool::PortPool; use crate::tunnel::tcp::TcpPortPool;
use crate::virtual_device::VirtualIpDevice; use crate::tunnel::udp::UdpPortPool;
use crate::wg::WireGuardTunnel; use crate::wg::WireGuardTunnel;
pub mod config; pub mod config;
pub mod ip_sink; pub mod ip_sink;
pub mod port_pool; pub mod tunnel;
pub mod virtual_device; pub mod virtual_device;
pub mod virtual_iface;
pub mod wg; pub mod wg;
pub const MAX_PACKET: usize = 65536;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let config = Config::from_args().with_context(|| "Failed to read config")?; let config = Config::from_args().with_context(|| "Failed to read config")?;
init_logger(&config)?; init_logger(&config)?;
let port_pool = Arc::new(PortPool::new());
// Initialize the port pool for each protocol
let tcp_port_pool = TcpPortPool::new();
let udp_port_pool = UdpPortPool::new();
let wg = WireGuardTunnel::new(&config) let wg = WireGuardTunnel::new(&config)
.await .await
@ -54,395 +49,23 @@ async fn main() -> anyhow::Result<()> {
tokio::spawn(async move { ip_sink::run_ip_sink_interface(wg).await }); tokio::spawn(async move { ip_sink::run_ip_sink_interface(wg).await });
} }
info!(
"Tunnelling [{}]->[{}] (via [{}] as peer {})",
&config.source_addr, &config.dest_addr, &config.endpoint_addr, &config.source_peer_ip
);
tcp_proxy_server(
config.source_addr,
config.source_peer_ip,
config.dest_addr,
port_pool.clone(),
wg,
)
.await
}
/// Starts the server that listens on TCP connections.
async fn tcp_proxy_server(
listen_addr: SocketAddr,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
port_pool: Arc<PortPool>,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(listen_addr)
.await
.with_context(|| "Failed to listen on TCP proxy server")?;
loop {
let wg = wg.clone();
let port_pool = port_pool.clone();
let (socket, peer_addr) = listener
.accept()
.await
.with_context(|| "Failed to accept connection on TCP proxy server")?;
// Assign a 'virtual port': this is a unique port number used to route IP packets
// received from the WireGuard tunnel. It is the port number that the virtual client will
// listen on.
let virtual_port = match port_pool.next() {
Ok(port) => port,
Err(e) => {
error!(
"Failed to assign virtual port number for connection [{}]: {:?}",
peer_addr, e
);
continue;
}
};
info!("[{}] Incoming connection from {}", virtual_port, peer_addr);
tokio::spawn(async move {
let port_pool = Arc::clone(&port_pool);
let result = handle_tcp_proxy_connection(
socket,
virtual_port,
source_peer_ip,
dest_addr,
wg.clone(),
)
.await;
if let Err(e) = result {
error!(
"[{}] Connection dropped un-gracefully: {:?}",
virtual_port, e
);
} else {
info!("[{}] Connection closed by client", virtual_port);
}
// Release port when connection drops
wg.release_virtual_interface(virtual_port);
port_pool.release(virtual_port);
});
}
}
/// Handles a new TCP connection with its assigned virtual port.
async fn handle_tcp_proxy_connection(
socket: TcpStream,
virtual_port: u16,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
// Abort signal for stopping the Virtual Interface
let abort = Arc::new(AtomicBool::new(false));
// Signals that the Virtual Client is ready to send data
let (virtual_client_ready_tx, virtual_client_ready_rx) = tokio::sync::oneshot::channel::<()>();
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
// to the real client.
let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000);
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
// virtual interface (virtual server socket).
let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000);
// Spawn virtual interface
{ {
let abort = abort.clone(); let port_forwards = config.port_forwards;
let source_peer_ip = config.source_peer_ip;
port_forwards
.into_iter()
.map(|pf| (pf, wg.clone(), tcp_port_pool.clone(), udp_port_pool.clone()))
.for_each(move |(pf, wg, tcp_port_pool, udp_port_pool)| {
tokio::spawn(async move { tokio::spawn(async move {
virtual_tcp_interface( tunnel::port_forward(pf, source_peer_ip, tcp_port_pool, udp_port_pool, wg)
virtual_port,
source_peer_ip,
dest_addr,
wg,
abort,
data_to_real_client_tx,
data_to_virtual_server_rx,
virtual_client_ready_tx,
)
.await .await
.unwrap_or_else(|e| error!("Port-forward failed for {} : {}", pf, e))
});
}); });
} }
// Wait for virtual client to be ready. futures::future::pending().await
virtual_client_ready_rx
.await
.expect("failed to wait for virtual client to be ready");
trace!("[{}] Virtual client is ready to send data", virtual_port);
loop {
tokio::select! {
readable_result = socket.readable() => {
match readable_result {
Ok(_) => {
// Buffer for the individual TCP segment.
let mut buffer = Vec::with_capacity(MAX_PACKET);
match socket.try_read_buf(&mut buffer) {
Ok(size) if size > 0 => {
let data = &buffer[..size];
debug!(
"[{}] Read {} bytes of TCP data from real client",
virtual_port, size
);
if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await {
error!(
"[{}] Failed to dispatch data to virtual interface: {:?}",
virtual_port, e
);
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
}
Err(e) => {
error!(
"[{}] Failed to read from client TCP socket: {:?}",
virtual_port, e
);
break;
}
_ => {
break;
}
}
}
Err(e) => {
error!("[{}] Failed to check if readable: {:?}", virtual_port, e);
break;
}
}
}
data_recv_result = data_to_real_client_rx.recv() => {
match data_recv_result {
Some(data) => match socket.try_write(&data) {
Ok(size) => {
debug!(
"[{}] Wrote {} bytes of TCP data to real client",
virtual_port, size
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if abort.load(Ordering::Relaxed) {
break;
} else {
continue;
}
}
Err(e) => {
error!(
"[{}] Failed to write to client TCP socket: {:?}",
virtual_port, e
);
}
},
None => {
if abort.load(Ordering::Relaxed) {
break;
} else {
continue;
}
},
}
}
}
}
trace!("[{}] TCP socket handler task terminated", virtual_port);
abort.store(true, Ordering::Relaxed);
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn virtual_tcp_interface(
virtual_port: u16,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
wg: Arc<WireGuardTunnel>,
abort: Arc<AtomicBool>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
mut data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
) -> anyhow::Result<()> {
let mut virtual_client_ready_tx = Some(virtual_client_ready_tx);
// Create a device and interface to simulate IP packets
// In essence:
// * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client'
// * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel
// * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface'
// * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded)
// * The TCP data read by the 'virtual client' is sent to the 'real' TCP client
// Consumer for IP packets to send through the virtual interface
// Initialize the interface
let device = VirtualIpDevice::new(virtual_port, wg)
.with_context(|| "Failed to initialize VirtualIpDevice")?;
let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(dest_addr.ip()), 32),
])
.finalize();
// Server socket: this is a placeholder for the interface to route new connections to.
// TODO: Determine if we even need buffers here.
let server_socket: anyhow::Result<TcpSocket> = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.listen((IpAddress::from(dest_addr.ip()), dest_addr.port()))
.with_context(|| "Virtual server socket failed to listen")?;
Ok(socket)
};
let client_socket: anyhow::Result<TcpSocket> = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.connect(
(IpAddress::from(dest_addr.ip()), dest_addr.port()),
(IpAddress::from(source_peer_ip), virtual_port),
)
.with_context(|| "Virtual server socket failed to listen")?;
Ok(socket)
};
// Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server.
let mut socket_set_entries: [_; 2] = Default::default();
let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
let _server_handle = socket_set.add(server_socket?);
let client_handle = socket_set.add(client_socket?);
// Any data that wasn't sent because it was over the sending buffer limit
let mut tx_extra = Vec::new();
loop {
let loop_start = smoltcp::time::Instant::now();
// Shutdown occurs when the real client closes the connection,
// or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anmore.
// One last poll-loop iteration is executed so that the RST segment can be dispatched.
let shutdown = abort.load(Ordering::Relaxed);
if shutdown {
// Shutdown: sends a RST packet.
trace!("[{}] Shutting down virtual interface", virtual_port);
let mut client_socket = socket_set.get::<TcpSocket>(client_handle);
client_socket.abort();
}
match virtual_interface.poll(&mut socket_set, loop_start) {
Ok(processed) if processed => {
trace!(
"[{}] Virtual interface polled some packets to be processed",
virtual_port
);
}
Err(e) => {
error!("[{}] Virtual interface poll error: {:?}", virtual_port, e);
}
_ => {}
}
{
let mut client_socket = socket_set.get::<TcpSocket>(client_handle);
if client_socket.can_recv() {
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
Ok(data) => {
trace!(
"[{}] Virtual client received {} bytes of data",
virtual_port,
data.len()
);
// Send it to the real client
if let Err(e) = data_to_real_client_tx.send(data).await {
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e);
}
}
Err(e) => {
error!(
"[{}] Failed to read from virtual client socket: {:?}",
virtual_port, e
);
}
}
}
if client_socket.can_send() {
if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() {
virtual_client_ready_tx
.send(())
.expect("Failed to notify real client that virtual client is ready");
}
let mut to_transfer = None;
if tx_extra.is_empty() {
// The payload segment from the previous loop is complete,
// we can now read the next payload in the queue.
if let Ok(data) = data_to_virtual_server_rx.try_recv() {
to_transfer = Some(data);
} else if client_socket.state() == TcpState::CloseWait {
// No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN),
// the interface is shutdown.
trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", virtual_port);
abort.store(true, Ordering::Relaxed);
continue;
}
}
let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice();
if !to_transfer_slice.is_empty() {
let total = to_transfer_slice.len();
match client_socket.send_slice(to_transfer_slice) {
Ok(sent) => {
trace!(
"[{}] Sent {}/{} bytes via virtual client socket",
virtual_port,
sent,
total,
);
tx_extra = Vec::from(&to_transfer_slice[sent..total]);
}
Err(e) => {
error!(
"[{}] Failed to send slice via virtual client socket: {:?}",
virtual_port, e
);
}
}
}
}
}
if shutdown {
break;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
trace!("[{}] Virtual interface task terminated", virtual_port);
abort.store(true, Ordering::Relaxed);
Ok(())
} }
fn init_logger(config: &Config) -> anyhow::Result<()> { fn init_logger(config: &Config) -> anyhow::Result<()> {

View file

@ -1,62 +0,0 @@
use std::ops::Range;
use anyhow::Context;
use rand::seq::SliceRandom;
use rand::thread_rng;
const MIN_PORT: u16 = 32768;
const MAX_PORT: u16 = 60999;
const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
/// A pool of virtual ports available.
/// This structure is thread-safe and lock-free; you can use it safely in an `Arc`.
pub struct PortPool {
/// Remaining ports
inner: lockfree::queue::Queue<u16>,
/// Ports in use, with their associated IP channel sender.
taken: lockfree::set::Set<u16>,
}
impl Default for PortPool {
fn default() -> Self {
Self::new()
}
}
impl PortPool {
/// Initializes a new pool of virtual ports.
pub fn new() -> Self {
let inner = lockfree::queue::Queue::default();
let mut ports: Vec<u16> = PORT_RANGE.collect();
ports.shuffle(&mut thread_rng());
ports.into_iter().for_each(|p| inner.push(p) as ());
Self {
inner,
taken: lockfree::set::Set::new(),
}
}
/// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity).
pub fn next(&self) -> anyhow::Result<u16> {
let port = self
.inner
.pop()
.with_context(|| "Virtual port pool is exhausted")?;
self.taken
.insert(port)
.ok()
.with_context(|| "Failed to insert taken")?;
Ok(port)
}
/// Releases a port back into the pool.
pub fn release(&self, port: u16) {
self.inner.push(port);
self.taken.remove(&port);
}
/// Whether the given port is in use by a virtual interface.
pub fn is_in_use(&self, port: u16) -> bool {
self.taken.contains(&port)
}
}

33
src/tunnel/mod.rs Normal file
View file

@ -0,0 +1,33 @@
use std::net::IpAddr;
use std::sync::Arc;
use crate::config::{PortForwardConfig, PortProtocol};
use crate::tunnel::tcp::TcpPortPool;
use crate::tunnel::udp::UdpPortPool;
use crate::wg::WireGuardTunnel;
pub mod tcp;
#[allow(unused)]
pub mod udp;
pub async fn port_forward(
port_forward: PortForwardConfig,
source_peer_ip: IpAddr,
tcp_port_pool: TcpPortPool,
udp_port_pool: UdpPortPool,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
info!(
"Tunneling {} [{}]->[{}] (via [{}] as peer {})",
port_forward.protocol,
port_forward.source,
port_forward.destination,
&wg.endpoint,
source_peer_ip
);
match port_forward.protocol {
PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, wg).await,
PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, wg).await,
}
}

254
src/tunnel/tcp.rs Normal file
View file

@ -0,0 +1,254 @@
use crate::config::{PortForwardConfig, PortProtocol};
use crate::virtual_iface::tcp::TcpVirtualInterface;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel;
use anyhow::Context;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use std::ops::Range;
use rand::seq::SliceRandom;
use rand::thread_rng;
const MAX_PACKET: usize = 65536;
const MIN_PORT: u16 = 1000;
const MAX_PORT: u16 = 60999;
const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
/// Starts the server that listens on TCP connections.
pub async fn tcp_proxy_server(
port_forward: PortForwardConfig,
port_pool: TcpPortPool,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(port_forward.source)
.await
.with_context(|| "Failed to listen on TCP proxy server")?;
loop {
let wg = wg.clone();
let port_pool = port_pool.clone();
let (socket, peer_addr) = listener
.accept()
.await
.with_context(|| "Failed to accept connection on TCP proxy server")?;
// Assign a 'virtual port': this is a unique port number used to route IP packets
// received from the WireGuard tunnel. It is the port number that the virtual client will
// listen on.
let virtual_port = match port_pool.next().await {
Ok(port) => port,
Err(e) => {
error!(
"Failed to assign virtual port number for connection [{}]: {:?}",
peer_addr, e
);
continue;
}
};
info!("[{}] Incoming connection from {}", virtual_port, peer_addr);
tokio::spawn(async move {
let port_pool = port_pool.clone();
let result =
handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await;
if let Err(e) = result {
error!(
"[{}] Connection dropped un-gracefully: {:?}",
virtual_port, e
);
} else {
info!("[{}] Connection closed by client", virtual_port);
}
// Release port when connection drops
wg.release_virtual_interface(VirtualPort(virtual_port, PortProtocol::Tcp));
port_pool.release(virtual_port).await;
});
}
}
/// Handles a new TCP connection with its assigned virtual port.
async fn handle_tcp_proxy_connection(
socket: TcpStream,
virtual_port: u16,
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
// Abort signal for stopping the Virtual Interface
let abort = Arc::new(AtomicBool::new(false));
// Signals that the Virtual Client is ready to send data
let (virtual_client_ready_tx, virtual_client_ready_rx) = tokio::sync::oneshot::channel::<()>();
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
// to the real client.
let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000);
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
// virtual interface (virtual server socket).
let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000);
// Spawn virtual interface
{
let abort = abort.clone();
let virtual_interface = TcpVirtualInterface::new(
virtual_port,
port_forward,
wg,
abort.clone(),
data_to_real_client_tx,
data_to_virtual_server_rx,
virtual_client_ready_tx,
);
tokio::spawn(async move {
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
error!("Virtual interface poll loop failed unexpectedly: {}", e);
abort.store(true, Ordering::Relaxed);
})
});
}
// Wait for virtual client to be ready.
virtual_client_ready_rx
.await
.with_context(|| "Virtual client dropped before being ready.")?;
trace!("[{}] Virtual client is ready to send data", virtual_port);
loop {
tokio::select! {
readable_result = socket.readable() => {
match readable_result {
Ok(_) => {
// Buffer for the individual TCP segment.
let mut buffer = Vec::with_capacity(MAX_PACKET);
match socket.try_read_buf(&mut buffer) {
Ok(size) if size > 0 => {
let data = &buffer[..size];
debug!(
"[{}] Read {} bytes of TCP data from real client",
virtual_port, size
);
if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await {
error!(
"[{}] Failed to dispatch data to virtual interface: {:?}",
virtual_port, e
);
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
}
Err(e) => {
error!(
"[{}] Failed to read from client TCP socket: {:?}",
virtual_port, e
);
break;
}
_ => {
break;
}
}
}
Err(e) => {
error!("[{}] Failed to check if readable: {:?}", virtual_port, e);
break;
}
}
}
data_recv_result = data_to_real_client_rx.recv() => {
match data_recv_result {
Some(data) => match socket.try_write(&data) {
Ok(size) => {
debug!(
"[{}] Wrote {} bytes of TCP data to real client",
virtual_port, size
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if abort.load(Ordering::Relaxed) {
break;
} else {
continue;
}
}
Err(e) => {
error!(
"[{}] Failed to write to client TCP socket: {:?}",
virtual_port, e
);
}
},
None => {
if abort.load(Ordering::Relaxed) {
break;
} else {
continue;
}
},
}
}
}
}
trace!("[{}] TCP socket handler task terminated", virtual_port);
abort.store(true, Ordering::Relaxed);
Ok(())
}
/// A pool of virtual ports available for TCP connections.
#[derive(Clone)]
pub struct TcpPortPool {
inner: Arc<tokio::sync::RwLock<TcpPortPoolInner>>,
}
impl Default for TcpPortPool {
fn default() -> Self {
Self::new()
}
}
impl TcpPortPool {
/// Initializes a new pool of virtual ports.
pub fn new() -> Self {
let mut inner = TcpPortPoolInner::default();
let mut ports: Vec<u16> = PORT_RANGE.collect();
ports.shuffle(&mut thread_rng());
ports
.into_iter()
.for_each(|p| inner.queue.push_back(p) as ());
Self {
inner: Arc::new(tokio::sync::RwLock::new(inner)),
}
}
/// Requests a free port from the pool. An error is returned if none is available (exhaused max capacity).
pub async fn next(&self) -> anyhow::Result<u16> {
let mut inner = self.inner.write().await;
let port = inner
.queue
.pop_front()
.with_context(|| "TCP virtual port pool is exhausted")?;
Ok(port)
}
/// Releases a port back into the pool.
pub async fn release(&self, port: u16) {
let mut inner = self.inner.write().await;
inner.queue.push_back(port);
}
}
/// Non thread-safe inner logic for TCP port pool.
#[derive(Debug, Default)]
struct TcpPortPoolInner {
/// Remaining ports in the pool.
queue: VecDeque<u16>,
}

281
src/tunnel/udp.rs Normal file
View file

@ -0,0 +1,281 @@
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::net::{IpAddr, SocketAddr};
use std::ops::Range;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use anyhow::Context;
use priority_queue::double_priority_queue::DoublePriorityQueue;
use priority_queue::priority_queue::PriorityQueue;
use rand::seq::SliceRandom;
use rand::thread_rng;
use tokio::net::UdpSocket;
use crate::config::{PortForwardConfig, PortProtocol};
use crate::virtual_iface::udp::UdpVirtualInterface;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel;
const MAX_PACKET: usize = 65536;
const MIN_PORT: u16 = 1000;
const MAX_PORT: u16 = 60999;
const PORT_RANGE: Range<u16> = MIN_PORT..MAX_PORT;
/// How long to keep the UDP peer address assigned to its virtual specified port, in seconds.
/// TODO: Make this configurable by the CLI
const UDP_TIMEOUT_SECONDS: u64 = 60;
/// To prevent port-flooding, we set a limit on the amount of open ports per IP address.
/// TODO: Make this configurable by the CLI
const PORTS_PER_IP: usize = 100;
pub async fn udp_proxy_server(
port_forward: PortForwardConfig,
port_pool: UdpPortPool,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
// Abort signal
let abort = Arc::new(AtomicBool::new(false));
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
// to the real client.
let (data_to_real_client_tx, mut data_to_real_client_rx) =
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
// data_to_real_server_(tx/rx): This task sends the data received from the real client to the
// virtual interface (virtual server socket).
let (data_to_virtual_server_tx, data_to_virtual_server_rx) =
tokio::sync::mpsc::channel::<(VirtualPort, Vec<u8>)>(1_000);
{
// Spawn virtual interface
// Note: contrary to TCP, there is only one UDP virtual interface
let virtual_interface = UdpVirtualInterface::new(
port_forward,
wg,
data_to_real_client_tx,
data_to_virtual_server_rx,
);
let abort = abort.clone();
tokio::spawn(async move {
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
error!("Virtual interface poll loop failed unexpectedly: {}", e);
abort.store(true, Ordering::Relaxed);
});
});
}
let socket = UdpSocket::bind(port_forward.source)
.await
.with_context(|| "Failed to bind on UDP proxy address")?;
let mut buffer = [0u8; MAX_PACKET];
loop {
if abort.load(Ordering::Relaxed) {
break;
}
tokio::select! {
to_send_result = next_udp_datagram(&socket, &mut buffer, port_pool.clone()) => {
match to_send_result {
Ok(Some((port, data))) => {
data_to_virtual_server_tx.send((port, data)).await.unwrap_or_else(|e| {
error!(
"Failed to dispatch data to UDP virtual interface: {:?}",
e
);
});
}
Ok(None) => {
continue;
}
Err(e) => {
error!(
"Failed to read from client UDP socket: {:?}",
e
);
break;
}
}
}
data_recv_result = data_to_real_client_rx.recv() => {
if let Some((port, data)) = data_recv_result {
if let Some(peer_addr) = port_pool.get_peer_addr(port.0).await {
if let Err(e) = socket.send_to(&data, peer_addr).await {
error!(
"[{}] Failed to send UDP datagram to real client ({}): {:?}",
port,
peer_addr,
e,
);
}
port_pool.update_last_transmit(port.0).await;
}
}
}
}
}
Ok(())
}
async fn next_udp_datagram(
socket: &UdpSocket,
buffer: &mut [u8],
port_pool: UdpPortPool,
) -> anyhow::Result<Option<(VirtualPort, Vec<u8>)>> {
let (size, peer_addr) = socket
.recv_from(buffer)
.await
.with_context(|| "Failed to accept incoming UDP datagram")?;
// Assign a 'virtual port': this is a unique port number used to route IP packets
// received from the WireGuard tunnel. It is the port number that the virtual client will
// listen on.
let port = match port_pool.next(peer_addr).await {
Ok(port) => port,
Err(e) => {
error!(
"Failed to assign virtual port number for UDP datagram from [{}]: {:?}",
peer_addr, e
);
return Ok(None);
}
};
let port = VirtualPort(port, PortProtocol::Udp);
debug!(
"[{}] Received datagram of {} bytes from {}",
port, size, peer_addr
);
port_pool.update_last_transmit(port.0).await;
let data = buffer[..size].to_vec();
Ok(Some((port, data)))
}
/// A pool of virtual ports available for TCP connections.
#[derive(Clone)]
pub struct UdpPortPool {
inner: Arc<tokio::sync::RwLock<UdpPortPoolInner>>,
}
impl Default for UdpPortPool {
fn default() -> Self {
Self::new()
}
}
impl UdpPortPool {
/// Initializes a new pool of virtual ports.
pub fn new() -> Self {
let mut inner = UdpPortPoolInner::default();
let mut ports: Vec<u16> = PORT_RANGE.collect();
ports.shuffle(&mut thread_rng());
ports
.into_iter()
.for_each(|p| inner.queue.push_back(p) as ());
Self {
inner: Arc::new(tokio::sync::RwLock::new(inner)),
}
}
/// Requests a free port from the pool. An error is returned if none is available (exhausted max capacity).
pub async fn next(&self, peer_addr: SocketAddr) -> anyhow::Result<u16> {
// A port found to be reused. This is outside of the block because the read lock cannot be upgraded to a write lock.
let mut port_reuse: Option<u16> = None;
{
let inner = self.inner.read().await;
if let Some(port) = inner.port_by_peer_addr.get(&peer_addr) {
return Ok(*port);
}
// Count how many ports are being used by the peer IP
let peer_ip = peer_addr.ip();
let peer_port_count = inner
.peer_port_usage
.get(&peer_ip)
.map(|v| v.len())
.unwrap_or_default();
if peer_port_count >= PORTS_PER_IP {
// Return least recently used port in this IP's pool
port_reuse = Some(
*(inner
.peer_port_usage
.get(&peer_ip)
.unwrap()
.peek_min()
.unwrap()
.0),
);
warn!(
"Peer [{}] is re-using active virtual port {} due to self-exhaustion.",
peer_addr,
port_reuse.unwrap()
);
}
}
let mut inner = self.inner.write().await;
let port = port_reuse
.or_else(|| inner.queue.pop_front())
.or_else(|| {
// If there is no port to reuse, and the port pool is exhausted, take the last recently used port overall,
// as long as the last transmission exceeds the deadline
let last: (&u16, &Instant) = inner.port_usage.peek_min().unwrap();
if Instant::now().duration_since(*last.1).as_secs() > UDP_TIMEOUT_SECONDS {
warn!(
"Peer [{}] is re-using inactive virtual port {} due to global exhaustion.",
peer_addr, last.0
);
Some(*last.0)
} else {
None
}
})
.with_context(|| "virtual port pool is exhausted")?;
inner.port_by_peer_addr.insert(peer_addr, port);
inner.peer_addr_by_port.insert(port, peer_addr);
Ok(port)
}
/// Notify that the given virtual port has received or transmitted a UDP datagram.
pub async fn update_last_transmit(&self, port: u16) {
let mut inner = self.inner.write().await;
if let Some(peer) = inner.peer_addr_by_port.get(&port).copied() {
let mut pq: &mut DoublePriorityQueue<u16, Instant> = inner
.peer_port_usage
.entry(peer.ip())
.or_insert_with(Default::default);
pq.push(port, Instant::now());
}
let mut pq: &mut DoublePriorityQueue<u16, Instant> = &mut inner.port_usage;
pq.push(port, Instant::now());
}
pub async fn get_peer_addr(&self, port: u16) -> Option<SocketAddr> {
let inner = self.inner.read().await;
inner.peer_addr_by_port.get(&port).copied()
}
}
/// Non thread-safe inner logic for UDP port pool.
#[derive(Debug, Default)]
struct UdpPortPoolInner {
/// Remaining ports in the pool.
queue: VecDeque<u16>,
/// The port assigned by peer IP/port. This is used to lookup an existing virtual port
/// for an incoming UDP datagram.
port_by_peer_addr: HashMap<SocketAddr, u16>,
/// The socket address assigned to a peer IP/port. This is used to send a UDP datagram to
/// the real peer address, given the virtual port.
peer_addr_by_port: HashMap<u16, SocketAddr>,
/// Keeps an ordered map of the most recently used virtual ports by a peer (client) IP.
peer_port_usage: HashMap<IpAddr, DoublePriorityQueue<u16, Instant>>,
/// Keeps an ordered map of the most recently used virtual ports in general.
port_usage: DoublePriorityQueue<u16, Instant>,
}

View file

@ -1,9 +1,13 @@
use crate::wg::WireGuardTunnel; use crate::virtual_iface::VirtualPort;
use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
use anyhow::Context; use anyhow::Context;
use smoltcp::phy::{Device, DeviceCapabilities, Medium}; use smoltcp::phy::{Device, DeviceCapabilities, Medium};
use smoltcp::time::Instant; use smoltcp::time::Instant;
use std::sync::Arc; use std::sync::Arc;
/// The max transmission unit for WireGuard.
const WG_MTU: usize = 1420;
/// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint /// A virtual device that processes IP packets. IP packets received from the WireGuard endpoint
/// are made available to this device using a channel receiver. IP packets sent from this device /// are made available to this device using a channel receiver. IP packets sent from this device
/// are asynchronously sent out to the WireGuard tunnel. /// are asynchronously sent out to the WireGuard tunnel.
@ -15,9 +19,19 @@ pub struct VirtualIpDevice {
} }
impl VirtualIpDevice { impl VirtualIpDevice {
pub fn new(virtual_port: u16, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> { /// Initializes a new virtual IP device.
let ip_dispatch_rx = wg pub fn new(
.register_virtual_interface(virtual_port) wg: Arc<WireGuardTunnel>,
ip_dispatch_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
) -> Self {
Self { wg, ip_dispatch_rx }
}
/// Registers a virtual IP device for a single virtual client.
pub fn new_direct(virtual_port: VirtualPort, wg: Arc<WireGuardTunnel>) -> anyhow::Result<Self> {
let (ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
wg.register_virtual_interface(virtual_port, ip_dispatch_tx)
.with_context(|| "Failed to register IP dispatch for virtual interface")?; .with_context(|| "Failed to register IP dispatch for virtual interface")?;
Ok(Self { wg, ip_dispatch_rx }) Ok(Self { wg, ip_dispatch_rx })
@ -57,7 +71,7 @@ impl<'a> Device<'a> for VirtualIpDevice {
fn capabilities(&self) -> DeviceCapabilities { fn capabilities(&self) -> DeviceCapabilities {
let mut cap = DeviceCapabilities::default(); let mut cap = DeviceCapabilities::default();
cap.medium = Medium::Ip; cap.medium = Medium::Ip;
cap.max_transmission_unit = 1420; cap.max_transmission_unit = WG_MTU;
cap cap
} }
} }

23
src/virtual_iface/mod.rs Normal file
View file

@ -0,0 +1,23 @@
pub mod tcp;
pub mod udp;
use crate::config::PortProtocol;
use async_trait::async_trait;
use std::fmt::{Display, Formatter};
#[async_trait]
pub trait VirtualInterfacePoll {
/// Initializes the virtual interface and processes incoming data to be dispatched
/// to the WireGuard tunnel and to the real client.
async fn poll_loop(mut self) -> anyhow::Result<()>;
}
/// Virtual port.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct VirtualPort(pub u16, pub PortProtocol);
impl Display for VirtualPort {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}:{}]", self.0, self.1)
}
}

282
src/virtual_iface/tcp.rs Normal file
View file

@ -0,0 +1,282 @@
use crate::config::{PortForwardConfig, PortProtocol};
use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::WireGuardTunnel;
use anyhow::Context;
use async_trait::async_trait;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState};
use smoltcp::wire::{IpAddress, IpCidr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
const MAX_PACKET: usize = 65536;
/// A virtual interface for proxying Layer 7 data to Layer 3 packets, and vice-versa.
pub struct TcpVirtualInterface {
/// The virtual port assigned to the virtual client, used to
/// route Layer 4 segments/datagrams to and from the WireGuard tunnel.
virtual_port: u16,
/// The overall port-forward configuration: used for the destination address (on which
/// the virtual server listens) and the protocol in use.
port_forward: PortForwardConfig,
/// The WireGuard tunnel to send IP packets to.
wg: Arc<WireGuardTunnel>,
/// Abort signal to shutdown the virtual interface and its parent task.
abort: Arc<AtomicBool>,
/// Channel sender for pushing Layer 7 data back to the real client.
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
/// Channel receiver for processing Layer 7 data through the virtual interface.
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
/// One-shot sender to notify the parent task that the virtual client is ready to send Layer 7 data.
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
}
impl TcpVirtualInterface {
/// Initialize the parameters for a new virtual interface.
/// Use the `poll_loop()` future to start the virtual interface poll loop.
pub fn new(
virtual_port: u16,
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
abort: Arc<AtomicBool>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
) -> Self {
Self {
virtual_port,
port_forward,
wg,
abort,
data_to_real_client_tx,
data_to_virtual_server_rx,
virtual_client_ready_tx,
}
}
}
#[async_trait]
impl VirtualInterfacePoll for TcpVirtualInterface {
async fn poll_loop(self) -> anyhow::Result<()> {
let mut virtual_client_ready_tx = Some(self.virtual_client_ready_tx);
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
let source_peer_ip = self.wg.source_peer_ip;
// Create a device and interface to simulate IP packets
// In essence:
// * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client'
// * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel
// * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface'
// * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded)
// * The TCP data read by the 'virtual client' is sent to the 'real' TCP client
// Consumer for IP packets to send through the virtual interface
// Initialize the interface
let device =
VirtualIpDevice::new_direct(VirtualPort(self.virtual_port, PortProtocol::Tcp), self.wg)
.with_context(|| "Failed to initialize TCP VirtualIpDevice")?;
let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(self.port_forward.destination.ip()), 32),
])
.finalize();
// Server socket: this is a placeholder for the interface to route new connections to.
let server_socket: anyhow::Result<TcpSocket> = {
static mut TCP_SERVER_RX_DATA: [u8; 0] = [];
static mut TCP_SERVER_TX_DATA: [u8; 0] = [];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.listen((
IpAddress::from(self.port_forward.destination.ip()),
self.port_forward.destination.port(),
))
.with_context(|| "Virtual server socket failed to listen")?;
Ok(socket)
};
let client_socket: anyhow::Result<TcpSocket> = {
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data);
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data);
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
Ok(socket)
};
// Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server.
let mut socket_set_entries: [_; 2] = Default::default();
let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
let _server_handle = socket_set.add(server_socket?);
let client_handle = socket_set.add(client_socket?);
// Any data that wasn't sent because it was over the sending buffer limit
let mut tx_extra = Vec::new();
// Counts the connection attempts by the virtual client
let mut connection_attempts = 0;
// Whether the client has successfully connected before. Prevents the case of connecting again.
let mut has_connected = false;
loop {
let loop_start = smoltcp::time::Instant::now();
// Shutdown occurs when the real client closes the connection,
// or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore.
// One last poll-loop iteration is executed so that the RST segment can be dispatched.
let shutdown = self.abort.load(Ordering::Relaxed);
if shutdown {
// Shutdown: sends a RST packet.
trace!("[{}] Shutting down virtual interface", self.virtual_port);
let mut client_socket = socket_set.get::<TcpSocket>(client_handle);
client_socket.abort();
}
match virtual_interface.poll(&mut socket_set, loop_start) {
Ok(processed) if processed => {
trace!(
"[{}] Virtual interface polled some packets to be processed",
self.virtual_port
);
}
Err(e) => {
error!(
"[{}] Virtual interface poll error: {:?}",
self.virtual_port, e
);
}
_ => {}
}
{
let mut client_socket = socket_set.get::<TcpSocket>(client_handle);
if !shutdown && client_socket.state() == TcpState::Closed && !has_connected {
// Not shutting down, but the client socket is closed, and the client never successfully connected.
if connection_attempts < 10 {
// Try to connect
client_socket
.connect(
(
IpAddress::from(self.port_forward.destination.ip()),
self.port_forward.destination.port(),
),
(IpAddress::from(source_peer_ip), self.virtual_port),
)
.with_context(|| "Virtual server socket failed to listen")?;
if connection_attempts > 0 {
debug!(
"[{}] Virtual client retrying connection in 500ms",
self.virtual_port
);
// Not our first connection attempt, wait a little bit.
tokio::time::sleep(Duration::from_millis(500)).await;
}
} else {
// Too many connection attempts
self.abort.store(true, Ordering::Relaxed);
}
connection_attempts += 1;
continue;
}
if client_socket.state() == TcpState::Established {
// Prevent reconnection if the server later closes.
has_connected = true;
}
if client_socket.can_recv() {
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
Ok(data) => {
trace!(
"[{}] Virtual client received {} bytes of data",
self.virtual_port,
data.len()
);
// Send it to the real client
if let Err(e) = self.data_to_real_client_tx.send(data).await {
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", self.virtual_port, e);
}
}
Err(e) => {
error!(
"[{}] Failed to read from virtual client socket: {:?}",
self.virtual_port, e
);
}
}
}
if client_socket.can_send() {
if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() {
virtual_client_ready_tx
.send(())
.expect("Failed to notify real client that virtual client is ready");
}
let mut to_transfer = None;
if tx_extra.is_empty() {
// The payload segment from the previous loop is complete,
// we can now read the next payload in the queue.
if let Ok(data) = data_to_virtual_server_rx.try_recv() {
to_transfer = Some(data);
} else if client_socket.state() == TcpState::CloseWait {
// No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN),
// the interface is shutdown.
trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", self.virtual_port);
self.abort.store(true, Ordering::Relaxed);
continue;
}
}
let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice();
if !to_transfer_slice.is_empty() {
let total = to_transfer_slice.len();
match client_socket.send_slice(to_transfer_slice) {
Ok(sent) => {
trace!(
"[{}] Sent {}/{} bytes via virtual client socket",
self.virtual_port,
sent,
total,
);
tx_extra = Vec::from(&to_transfer_slice[sent..total]);
}
Err(e) => {
error!(
"[{}] Failed to send slice via virtual client socket: {:?}",
self.virtual_port, e
);
}
}
}
}
}
if shutdown {
break;
}
match virtual_interface.poll_delay(&socket_set, loop_start) {
Some(smoltcp::time::Duration::ZERO) => {
continue;
}
_ => {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
}
trace!("[{}] Virtual interface task terminated", self.virtual_port);
self.abort.store(true, Ordering::Relaxed);
Ok(())
}
}

201
src/virtual_iface/udp.rs Normal file
View file

@ -0,0 +1,201 @@
use anyhow::Context;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::socket::{SocketHandle, SocketSet, UdpPacketMetadata, UdpSocket, UdpSocketBuffer};
use smoltcp::wire::{IpAddress, IpCidr};
use crate::config::PortForwardConfig;
use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::{VirtualInterfacePoll, VirtualPort};
use crate::wg::{WireGuardTunnel, DISPATCH_CAPACITY};
const MAX_PACKET: usize = 65536;
pub struct UdpVirtualInterface {
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
}
impl UdpVirtualInterface {
pub fn new(
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<(VirtualPort, Vec<u8>)>,
data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<(VirtualPort, Vec<u8>)>,
) -> Self {
Self {
port_forward,
wg,
data_to_real_client_tx,
data_to_virtual_server_rx,
}
}
}
#[async_trait]
impl VirtualInterfacePoll for UdpVirtualInterface {
async fn poll_loop(self) -> anyhow::Result<()> {
// Data receiver to dispatch using virtual client sockets
let mut data_to_virtual_server_rx = self.data_to_virtual_server_rx;
// The IP to bind client sockets to
let source_peer_ip = self.wg.source_peer_ip;
// The IP/port to bind the server socket to
let destination = self.port_forward.destination;
// Initialize a channel for IP packets.
// The "base transmitted" is cloned so that each virtual port can register a sender in the tunnel.
// The receiver is given to the device so that the Virtual Interface can process incoming IP packets from the tunnel.
let (base_ip_dispatch_tx, ip_dispatch_rx) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
let device = VirtualIpDevice::new(self.wg.clone(), ip_dispatch_rx);
let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(source_peer_ip.into(), 32),
IpCidr::new(destination.ip().into(), 32),
])
.finalize();
// Server socket: this is a placeholder for the interface.
let server_socket: anyhow::Result<UdpSocket> = {
static mut UDP_SERVER_RX_META: [UdpPacketMetadata; 0] = [];
static mut UDP_SERVER_RX_DATA: [u8; 0] = [];
static mut UDP_SERVER_TX_META: [UdpPacketMetadata; 0] = [];
static mut UDP_SERVER_TX_DATA: [u8; 0] = [];
let udp_rx_buffer =
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_RX_META[..] }, unsafe {
&mut UDP_SERVER_RX_DATA[..]
});
let udp_tx_buffer =
UdpSocketBuffer::new(unsafe { &mut UDP_SERVER_TX_META[..] }, unsafe {
&mut UDP_SERVER_TX_DATA[..]
});
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
socket
.bind((IpAddress::from(destination.ip()), destination.port()))
.with_context(|| "UDP virtual server socket failed to listen")?;
Ok(socket)
};
let mut socket_set = SocketSet::new(vec![]);
let _server_handle = socket_set.add(server_socket?);
// A map of virtual port to client socket.
let mut client_sockets: HashMap<VirtualPort, SocketHandle> = HashMap::new();
// The next instant required to poll the virtual interface
// None means "immediate poll required".
let mut next_poll: Option<tokio::time::Instant> = None;
loop {
let wg = self.wg.clone();
tokio::select! {
// Wait the recommended amount of time by smoltcp, and poll again.
_ = match next_poll {
None => tokio::time::sleep(Duration::ZERO),
Some(until) => tokio::time::sleep_until(until)
} => {
let loop_start = smoltcp::time::Instant::now();
match virtual_interface.poll(&mut socket_set, loop_start) {
Ok(processed) if processed => {
trace!("UDP virtual interface polled some packets to be processed");
}
Err(e) => error!("UDP virtual interface poll error: {:?}", e),
_ => {}
}
// Loop through each client socket and check if there is any data to send back
// to the real client.
for (virtual_port, client_socket_handle) in client_sockets.iter() {
let mut client_socket = socket_set.get::<UdpSocket>(*client_socket_handle);
match client_socket.recv() {
Ok((data, _peer)) => {
// Send the data back to the real client using MPSC channel
self.data_to_real_client_tx
.send((*virtual_port, data.to_vec()))
.await
.unwrap_or_else(|e| {
error!(
"[{}] Failed to dispatch data from virtual client to real client: {:?}",
virtual_port, e
);
});
}
Err(smoltcp::Error::Exhausted) => {}
Err(e) => {
error!(
"[{}] Failed to read from virtual client socket: {:?}",
virtual_port, e
);
}
}
}
next_poll = match virtual_interface.poll_delay(&socket_set, loop_start) {
Some(smoltcp::time::Duration::ZERO) => None,
Some(delay) => Some(tokio::time::Instant::now() + Duration::from_millis(delay.millis())),
None => None,
}
}
// Wait for data to be received from the real client
data_recv_result = data_to_virtual_server_rx.recv() => {
if let Some((client_port, data)) = data_recv_result {
// Register the socket in WireGuard Tunnel (overrides any previous registration as well)
wg.register_virtual_interface(client_port, base_ip_dispatch_tx.clone())
.unwrap_or_else(|e| {
error!(
"[{}] Failed to register UDP socket in WireGuard tunnel: {:?}",
client_port, e
);
});
let client_socket_handle = client_sockets.entry(client_port).or_insert_with(|| {
let rx_meta = vec![UdpPacketMetadata::EMPTY; 10];
let tx_meta = vec![UdpPacketMetadata::EMPTY; 10];
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let udp_rx_buffer = UdpSocketBuffer::new(rx_meta, rx_data);
let udp_tx_buffer = UdpSocketBuffer::new(tx_meta, tx_data);
let mut socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
socket
.bind((IpAddress::from(wg.source_peer_ip), client_port.0))
.unwrap_or_else(|e| {
error!(
"[{}] UDP virtual client socket failed to bind: {:?}",
client_port, e
);
});
socket_set.add(socket)
});
let mut client_socket = socket_set.get::<UdpSocket>(*client_socket_handle);
client_socket
.send_slice(
&data,
(IpAddress::from(destination.ip()), destination.port()).into(),
)
.unwrap_or_else(|e| {
error!(
"[{}] Failed to send data to virtual server: {:?}",
client_port, e
);
});
}
}
}
}
}
}

View file

@ -4,29 +4,30 @@ use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use boringtun::noise::{Tunn, TunnResult}; use boringtun::noise::{Tunn, TunnResult};
use log::Level; use log::Level;
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket}; use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::config::Config; use crate::config::{Config, PortProtocol};
use crate::MAX_PACKET; use crate::virtual_iface::VirtualPort;
/// The capacity of the channel for received IP packets. /// The capacity of the channel for received IP packets.
const DISPATCH_CAPACITY: usize = 1_000; pub const DISPATCH_CAPACITY: usize = 1_000;
const MAX_PACKET: usize = 65536;
/// A WireGuard tunnel. Encapsulates and decapsulates IP packets /// A WireGuard tunnel. Encapsulates and decapsulates IP packets
/// to be sent to and received from a remote UDP endpoint. /// to be sent to and received from a remote UDP endpoint.
/// This tunnel supports at most 1 peer IP at a time, but supports simultaneous ports. /// This tunnel supports at most 1 peer IP at a time, but supports simultaneous ports.
pub struct WireGuardTunnel { pub struct WireGuardTunnel {
source_peer_ip: IpAddr, pub(crate) source_peer_ip: IpAddr,
/// `boringtun` peer/tunnel implementation, used for crypto & WG protocol. /// `boringtun` peer/tunnel implementation, used for crypto & WG protocol.
peer: Box<Tunn>, peer: Box<Tunn>,
/// The UDP socket for the public WireGuard endpoint to connect to. /// The UDP socket for the public WireGuard endpoint to connect to.
udp: UdpSocket, udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP). /// The address of the public WireGuard endpoint (UDP).
endpoint: SocketAddr, pub(crate) endpoint: SocketAddr,
/// Maps virtual ports to the corresponding IP packet dispatcher. /// Maps virtual ports to the corresponding IP packet dispatcher.
virtual_port_ip_tx: lockfree::map::Map<u16, tokio::sync::mpsc::Sender<Vec<u8>>>, virtual_port_ip_tx: dashmap::DashMap<VirtualPort, tokio::sync::mpsc::Sender<Vec<u8>>>,
/// IP packet dispatcher for unroutable packets. `None` if not initialized. /// IP packet dispatcher for unroutable packets. `None` if not initialized.
sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>, sink_ip_tx: RwLock<Option<tokio::sync::mpsc::Sender<Vec<u8>>>>,
} }
@ -40,7 +41,7 @@ impl WireGuardTunnel {
.await .await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?; .with_context(|| "Failed to create UDP socket for WireGuard connection")?;
let endpoint = config.endpoint_addr; let endpoint = config.endpoint_addr;
let virtual_port_ip_tx = lockfree::map::Map::new(); let virtual_port_ip_tx = Default::default();
Ok(Self { Ok(Self {
source_peer_ip, source_peer_ip,
@ -86,16 +87,11 @@ impl WireGuardTunnel {
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
pub fn register_virtual_interface( pub fn register_virtual_interface(
&self, &self,
virtual_port: u16, virtual_port: VirtualPort,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<Vec<u8>>> { sender: tokio::sync::mpsc::Sender<Vec<u8>>,
let existing = self.virtual_port_ip_tx.get(&virtual_port); ) -> anyhow::Result<()> {
if existing.is_some() {
Err(anyhow::anyhow!("Cannot register virtual interface with virtual port {} because it is already registered", virtual_port))
} else {
let (sender, receiver) = tokio::sync::mpsc::channel(DISPATCH_CAPACITY);
self.virtual_port_ip_tx.insert(virtual_port, sender); self.virtual_port_ip_tx.insert(virtual_port, sender);
Ok(receiver) Ok(())
}
} }
/// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`. /// Register a virtual interface (using its assigned virtual port) with the given IP packet `Sender`.
@ -111,7 +107,7 @@ impl WireGuardTunnel {
} }
/// Releases the virtual interface from IP dispatch. /// Releases the virtual interface from IP dispatch.
pub fn release_virtual_interface(&self, virtual_port: u16) { pub fn release_virtual_interface(&self, virtual_port: VirtualPort) {
self.virtual_port_ip_tx.remove(&virtual_port); self.virtual_port_ip_tx.remove(&virtual_port);
} }
@ -214,7 +210,7 @@ impl WireGuardTunnel {
RouteResult::Dispatch(port) => { RouteResult::Dispatch(port) => {
let sender = self.virtual_port_ip_tx.get(&port); let sender = self.virtual_port_ip_tx.get(&port);
if let Some(sender_guard) = sender { if let Some(sender_guard) = sender {
let sender = sender_guard.val(); let sender = sender_guard.value();
match sender.send(packet.to_vec()).await { match sender.send(packet.to_vec()).await {
Ok(_) => { Ok(_) => {
trace!( trace!(
@ -271,6 +267,7 @@ impl WireGuardTunnel {
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() { .map(|packet| match packet.protocol() {
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
// Unrecognized protocol, so we cannot determine where to route // Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop), _ => Some(RouteResult::Drop),
}) })
@ -282,6 +279,7 @@ impl WireGuardTunnel {
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.next_header() { .map(|packet| match packet.next_header() {
IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())), IpProtocol::Tcp => Some(self.route_tcp_segment(packet.payload())),
IpProtocol::Udp => Some(self.route_udp_datagram(packet.payload())),
// Unrecognized protocol, so we cannot determine where to route // Unrecognized protocol, so we cannot determine where to route
_ => Some(RouteResult::Drop), _ => Some(RouteResult::Drop),
}) })
@ -296,8 +294,12 @@ impl WireGuardTunnel {
TcpPacket::new_checked(segment) TcpPacket::new_checked(segment)
.ok() .ok()
.map(|tcp| { .map(|tcp| {
if self.virtual_port_ip_tx.get(&tcp.dst_port()).is_some() { if self
RouteResult::Dispatch(tcp.dst_port()) .virtual_port_ip_tx
.get(&VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
.is_some()
{
RouteResult::Dispatch(VirtualPort(tcp.dst_port(), PortProtocol::Tcp))
} else if tcp.rst() { } else if tcp.rst() {
RouteResult::Drop RouteResult::Drop
} else { } else {
@ -307,6 +309,24 @@ impl WireGuardTunnel {
.unwrap_or(RouteResult::Drop) .unwrap_or(RouteResult::Drop)
} }
/// Makes a decision on the handling of an incoming UDP datagram.
fn route_udp_datagram(&self, datagram: &[u8]) -> RouteResult {
UdpPacket::new_checked(datagram)
.ok()
.map(|udp| {
if self
.virtual_port_ip_tx
.get(&VirtualPort(udp.dst_port(), PortProtocol::Udp))
.is_some()
{
RouteResult::Dispatch(VirtualPort(udp.dst_port(), PortProtocol::Udp))
} else {
RouteResult::Drop
}
})
.unwrap_or(RouteResult::Drop)
}
/// Route a packet to the IP sink interface. /// Route a packet to the IP sink interface.
async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> { async fn route_ip_sink(&self, packet: &[u8]) -> anyhow::Result<()> {
let ip_sink_tx = self.sink_ip_tx.read().await; let ip_sink_tx = self.sink_ip_tx.read().await;
@ -347,7 +367,7 @@ fn trace_ip_packet(message: &str, packet: &[u8]) {
enum RouteResult { enum RouteResult {
/// Dispatch the packet to the virtual port. /// Dispatch the packet to the virtual port.
Dispatch(u16), Dispatch(VirtualPort),
/// The packet is not routable, and should be sent to the sink interface. /// The packet is not routable, and should be sent to the sink interface.
Sink, Sink,
/// The packet is not routable, and can be safely ignored. /// The packet is not routable, and can be safely ignored.