diff --git a/src/events.rs b/src/events.rs index d6582ce..23a62e7 100644 --- a/src/events.rs +++ b/src/events.rs @@ -1,5 +1,6 @@ use bytes::Bytes; use std::fmt::{Display, Formatter}; +use std::net::SocketAddr; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; @@ -20,12 +21,16 @@ pub enum Event { LocalData(PortForwardConfig, VirtualPort, Bytes), /// Data received by the remote server that should be sent to the local client. RemoteData(VirtualPort, Bytes), + /// Data received from a remote client to send to a local server. + RemoteClientData(SocketAddr, PortForwardConfig, Bytes), /// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device. InboundInternetPacket(PortProtocol, Bytes), /// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device. OutboundInternetPacket(Bytes), /// Notifies that a virtual device read an IP packet. VirtualDeviceFed(PortProtocol), + /// A new remote connection is trying to open to local. + RemoteConnectionRequest(PortForwardConfig), } impl Display for Event { @@ -48,6 +53,14 @@ impl Display for Event { let size = data.len(); write!(f, "RemoteData{{ vp={} size={} }}", vp, size) } + Event::RemoteClientData(remote, pf, data) => { + let size = data.len(); + write!( + f, + "RemoteClientData{{ remote={}, pf={} size={} }}", + remote, pf, size + ) + } Event::InboundInternetPacket(proto, data) => { let size = data.len(); write!( @@ -63,6 +76,9 @@ impl Display for Event { Event::VirtualDeviceFed(proto) => { write!(f, "VirtualDeviceFed{{ proto={} }}", proto) } + Event::RemoteConnectionRequest(pf) => { + write!(f, "RemoteConnectionRequest{{ pf={} }}", pf) + } } } } diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index eadf8b0..38c5f59 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -28,7 +28,19 @@ pub async fn port_forward( ); match port_forward.protocol { - PortProtocol::Tcp => tcp::tcp_proxy_server(port_forward, tcp_port_pool, bus).await, - PortProtocol::Udp => udp::udp_proxy_server(port_forward, udp_port_pool, bus).await, + PortProtocol::Tcp => { + if port_forward.remote { + tcp::tcp_remote_dispatcher(port_forward, bus).await + } else { + tcp::tcp_proxy_server(port_forward, tcp_port_pool, bus).await + } + } + PortProtocol::Udp => { + if port_forward.remote { + udp::udp_remote_dispatcher(port_forward, udp_port_pool, bus).await + } else { + udp::udp_proxy_server(port_forward, udp_port_pool, bus).await + } + } } } diff --git a/src/tunnel/tcp.rs b/src/tunnel/tcp.rs index b5e1ec5..ed63b71 100644 --- a/src/tunnel/tcp.rs +++ b/src/tunnel/tcp.rs @@ -160,6 +160,23 @@ async fn handle_tcp_proxy_connection( Ok(()) } +/// Listens for incoming remote connections and creates new tasks to handle them. +pub async fn tcp_remote_dispatcher( + port_forward: PortForwardConfig, + bus: Bus, +) -> anyhow::Result<()> { + let mut endpoint = bus.new_endpoint(); + + loop { + match endpoint.recv().await { + Event::RemoteConnectionRequest(pf) if pf == port_forward => { + info!("New remote connection: {}", pf); + } + _ => continue, + } + } +} + /// A pool of virtual ports available for TCP connections. #[derive(Clone)] pub struct TcpPortPool { diff --git a/src/tunnel/udp.rs b/src/tunnel/udp.rs index 32fef15..428ce7c 100644 --- a/src/tunnel/udp.rs +++ b/src/tunnel/udp.rs @@ -1,3 +1,4 @@ +use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; use std::net::{IpAddr, SocketAddr}; use std::ops::Range; @@ -130,6 +131,40 @@ async fn next_udp_datagram( Ok(Some((port, data.into()))) } +pub async fn udp_remote_dispatcher( + port_forward: PortForwardConfig, + port_pool: UdpPortPool, + bus: Bus, +) -> anyhow::Result<()> { + let mut endpoint = bus.new_endpoint(); + let mut socket_map: HashMap = HashMap::default(); + loop { + match endpoint.recv().await { + Event::RemoteClientData(remote, pf, data) if pf == port_forward => { + // TODO: IPv6 supprt + let socket = match socket_map.entry(remote) { + Entry::Occupied(o) => o.into_mut(), + Entry::Vacant(v) => { + let udp = UdpSocket::bind(( + if pf.destination.is_ipv4() { + "0.0.0.0" + } else { + "::" + }, + port_pool.next(remote).await.unwrap().num(), + )) + .await + .with_context(|| "Failed to create UDP socket for local client"); + v.insert(udp.unwrap()) + } + }; + socket.send_to(&data, remote).await.unwrap(); + } + _ => continue, + } + } +} + /// A pool of virtual ports available for TCP connections. #[derive(Clone)] pub struct UdpPortPool {