Refactor TCP virtual interface code out of main. Removed unused server socket buffer.

This commit is contained in:
Aram 🍐 2021-10-19 00:43:59 -04:00
parent 070c0f5162
commit c2d0b9719a
6 changed files with 322 additions and 255 deletions

View file

@ -1,26 +1,24 @@
#[macro_use]
extern crate log;
use std::net::{IpAddr, SocketAddr};
use std::net::IpAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use smoltcp::iface::InterfaceBuilder;
use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer, TcpState};
use smoltcp::wire::{IpAddress, IpCidr};
use tokio::net::{TcpListener, TcpStream};
use crate::config::{Config, PortForwardConfig, PortProtocol};
use crate::port_pool::PortPool;
use crate::virtual_device::VirtualIpDevice;
use crate::virtual_iface::tcp::TcpVirtualInterface;
use crate::virtual_iface::VirtualInterfacePoll;
use crate::wg::WireGuardTunnel;
pub mod config;
pub mod ip_sink;
pub mod port_pool;
pub mod virtual_device;
pub mod virtual_iface;
pub mod wg;
pub const MAX_PACKET: usize = 65536;
@ -92,29 +90,18 @@ async fn port_forward(
);
match port_forward.protocol {
PortProtocol::Tcp => {
tcp_proxy_server(
port_forward.source,
port_forward.destination,
source_peer_ip,
port_pool,
wg,
)
.await
}
PortProtocol::Tcp => tcp_proxy_server(port_forward, port_pool, wg).await,
PortProtocol::Udp => Err(anyhow::anyhow!("UDP isn't supported just yet.")),
}
}
/// Starts the server that listens on TCP connections.
async fn tcp_proxy_server(
listen_addr: SocketAddr,
dest_addr: SocketAddr,
source_peer_ip: IpAddr,
port_forward: PortForwardConfig,
port_pool: Arc<PortPool>,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(listen_addr)
let listener = TcpListener::bind(port_forward.source)
.await
.with_context(|| "Failed to listen on TCP proxy server")?;
@ -144,14 +131,8 @@ async fn tcp_proxy_server(
tokio::spawn(async move {
let port_pool = Arc::clone(&port_pool);
let result = handle_tcp_proxy_connection(
socket,
virtual_port,
source_peer_ip,
dest_addr,
wg.clone(),
)
.await;
let result =
handle_tcp_proxy_connection(socket, virtual_port, port_forward, wg.clone()).await;
if let Err(e) = result {
error!(
@ -173,8 +154,7 @@ async fn tcp_proxy_server(
async fn handle_tcp_proxy_connection(
socket: TcpStream,
virtual_port: u16,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
port_forward: PortForwardConfig,
wg: Arc<WireGuardTunnel>,
) -> anyhow::Result<()> {
// Abort signal for stopping the Virtual Interface
@ -194,18 +174,21 @@ async fn handle_tcp_proxy_connection(
// Spawn virtual interface
{
let abort = abort.clone();
let virtual_interface = TcpVirtualInterface::new(
virtual_port,
port_forward,
wg,
abort.clone(),
data_to_real_client_tx,
data_to_virtual_server_rx,
virtual_client_ready_tx,
);
tokio::spawn(async move {
virtual_tcp_interface(
virtual_port,
source_peer_ip,
dest_addr,
wg,
abort,
data_to_real_client_tx,
data_to_virtual_server_rx,
virtual_client_ready_tx,
)
.await
virtual_interface.poll_loop().await.unwrap_or_else(|e| {
error!("Virtual interface poll loop failed unexpectedly: {}", e);
abort.store(true, Ordering::Relaxed);
})
});
}
@ -297,219 +280,6 @@ async fn handle_tcp_proxy_connection(
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn virtual_tcp_interface(
virtual_port: u16,
source_peer_ip: IpAddr,
dest_addr: SocketAddr,
wg: Arc<WireGuardTunnel>,
abort: Arc<AtomicBool>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
mut data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
virtual_client_ready_tx: tokio::sync::oneshot::Sender<()>,
) -> anyhow::Result<()> {
let mut virtual_client_ready_tx = Some(virtual_client_ready_tx);
// Create a device and interface to simulate IP packets
// In essence:
// * TCP packets received from the 'real' client are 'sent' to the 'virtual server' via the 'virtual client'
// * Those TCP packets generate IP packets, which are captured from the interface and sent to the WireGuardTunnel
// * IP packets received by the WireGuardTunnel (from the endpoint) are fed into this 'virtual interface'
// * The interface processes those IP packets and routes them to the 'virtual client' (the rest is discarded)
// * The TCP data read by the 'virtual client' is sent to the 'real' TCP client
// Consumer for IP packets to send through the virtual interface
// Initialize the interface
let device = VirtualIpDevice::new(virtual_port, wg)
.with_context(|| "Failed to initialize VirtualIpDevice")?;
let mut virtual_interface = InterfaceBuilder::new(device)
.ip_addrs([
// Interface handles IP packets for the sender and recipient
IpCidr::new(IpAddress::from(source_peer_ip), 32),
IpCidr::new(IpAddress::from(dest_addr.ip()), 32),
])
.finalize();
// Server socket: this is a placeholder for the interface to route new connections to.
// TODO: Determine if we even need buffers here.
let server_socket: anyhow::Result<TcpSocket> = {
static mut TCP_SERVER_RX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
static mut TCP_SERVER_TX_DATA: [u8; MAX_PACKET] = [0; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] });
let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] });
let mut socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
socket
.listen((IpAddress::from(dest_addr.ip()), dest_addr.port()))
.with_context(|| "Virtual server socket failed to listen")?;
Ok(socket)
};
let client_socket: anyhow::Result<TcpSocket> = {
let rx_data = vec![0u8; MAX_PACKET];
let tx_data = vec![0u8; MAX_PACKET];
let tcp_rx_buffer = TcpSocketBuffer::new(rx_data);
let tcp_tx_buffer = TcpSocketBuffer::new(tx_data);
let socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
Ok(socket)
};
// Socket set: there are always 2 sockets: 1 virtual client and 1 virtual server.
let mut socket_set_entries: [_; 2] = Default::default();
let mut socket_set = SocketSet::new(&mut socket_set_entries[..]);
let _server_handle = socket_set.add(server_socket?);
let client_handle = socket_set.add(client_socket?);
// Any data that wasn't sent because it was over the sending buffer limit
let mut tx_extra = Vec::new();
// Counts the connection attempts by the virtual client
let mut connection_attempts = 0;
// Whether the client has successfully connected before. Prevents the case of connecting again.
let mut has_connected = false;
loop {
let loop_start = smoltcp::time::Instant::now();
// Shutdown occurs when the real client closes the connection,
// or if the client was in a CLOSE-WAIT state (after a server FIN) and had no data to send anymore.
// One last poll-loop iteration is executed so that the RST segment can be dispatched.
let shutdown = abort.load(Ordering::Relaxed);
if shutdown {
// Shutdown: sends a RST packet.
trace!("[{}] Shutting down virtual interface", virtual_port);
let mut client_socket = socket_set.get::<TcpSocket>(client_handle);
client_socket.abort();
}
match virtual_interface.poll(&mut socket_set, loop_start) {
Ok(processed) if processed => {
trace!(
"[{}] Virtual interface polled some packets to be processed",
virtual_port
);
}
Err(e) => {
error!("[{}] Virtual interface poll error: {:?}", virtual_port, e);
}
_ => {}
}
{
let mut client_socket = socket_set.get::<TcpSocket>(client_handle);
if !shutdown && client_socket.state() == TcpState::Closed && !has_connected {
// Not shutting down, but the client socket is closed, and the client never successfully connected.
if connection_attempts < 10 {
// Try to connect
client_socket
.connect(
(IpAddress::from(dest_addr.ip()), dest_addr.port()),
(IpAddress::from(source_peer_ip), virtual_port),
)
.with_context(|| "Virtual server socket failed to listen")?;
if connection_attempts > 0 {
debug!(
"[{}] Virtual client retrying connection in 500ms",
virtual_port
);
// Not our first connection attempt, wait a little bit.
tokio::time::sleep(Duration::from_millis(500)).await;
}
} else {
// Too many connection attempts
abort.store(true, Ordering::Relaxed);
}
connection_attempts += 1;
continue;
}
if client_socket.state() == TcpState::Established {
// Prevent reconnection if the server later closes.
has_connected = true;
}
if client_socket.can_recv() {
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
Ok(data) => {
trace!(
"[{}] Virtual client received {} bytes of data",
virtual_port,
data.len()
);
// Send it to the real client
if let Err(e) = data_to_real_client_tx.send(data).await {
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e);
}
}
Err(e) => {
error!(
"[{}] Failed to read from virtual client socket: {:?}",
virtual_port, e
);
}
}
}
if client_socket.can_send() {
if let Some(virtual_client_ready_tx) = virtual_client_ready_tx.take() {
virtual_client_ready_tx
.send(())
.expect("Failed to notify real client that virtual client is ready");
}
let mut to_transfer = None;
if tx_extra.is_empty() {
// The payload segment from the previous loop is complete,
// we can now read the next payload in the queue.
if let Ok(data) = data_to_virtual_server_rx.try_recv() {
to_transfer = Some(data);
} else if client_socket.state() == TcpState::CloseWait {
// No data to be sent in this loop. If the client state is CLOSE-WAIT (because of a server FIN),
// the interface is shutdown.
trace!("[{}] Shutting down virtual interface because client sent no more data, and server sent FIN (CLOSE-WAIT)", virtual_port);
abort.store(true, Ordering::Relaxed);
continue;
}
}
let to_transfer_slice = to_transfer.as_ref().unwrap_or(&tx_extra).as_slice();
if !to_transfer_slice.is_empty() {
let total = to_transfer_slice.len();
match client_socket.send_slice(to_transfer_slice) {
Ok(sent) => {
trace!(
"[{}] Sent {}/{} bytes via virtual client socket",
virtual_port,
sent,
total,
);
tx_extra = Vec::from(&to_transfer_slice[sent..total]);
}
Err(e) => {
error!(
"[{}] Failed to send slice via virtual client socket: {:?}",
virtual_port, e
);
}
}
}
}
}
if shutdown {
break;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
trace!("[{}] Virtual interface task terminated", virtual_port);
abort.store(true, Ordering::Relaxed);
Ok(())
}
fn init_logger(config: &Config) -> anyhow::Result<()> {
let mut builder = pretty_env_logger::formatted_builder();
builder.parse_filters(&config.log);