Ipv6 support for RST packets. Clippy

This commit is contained in:
Aram 🍐 2021-10-15 18:52:27 -04:00
parent 3318e30d98
commit 0bb6c27d86
2 changed files with 79 additions and 51 deletions

View file

@ -132,10 +132,11 @@ async fn handle_tcp_proxy_connection(
// data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back // data_to_real_client_(tx/rx): This task reads the data from this mpsc channel to send back
// to the real client. // to the real client.
let (data_to_real_client_tx, mut data_to_real_client_rx) = let (data_to_real_client_tx, mut data_to_real_client_rx) = tokio::sync::mpsc::channel(1_000);
tokio::sync::mpsc::channel(1_000_000);
let (data_to_real_server_tx, data_to_real_server_rx) = tokio::sync::mpsc::channel(1_000_000); // data_to_real_server_(tx/rx): This task sends the data received from the real client to the
// virtual interface (virtual server socket).
let (data_to_virtual_server_tx, data_to_virtual_server_rx) = tokio::sync::mpsc::channel(1_000);
// Spawn virtual interface // Spawn virtual interface
{ {
@ -148,7 +149,7 @@ async fn handle_tcp_proxy_connection(
wg, wg,
abort, abort,
data_to_real_client_tx, data_to_real_client_tx,
data_to_real_server_rx, data_to_virtual_server_rx,
) )
.await .await
}); });
@ -171,15 +172,12 @@ async fn handle_tcp_proxy_connection(
"[{}] Read {} bytes of TCP data from real client", "[{}] Read {} bytes of TCP data from real client",
virtual_port, size virtual_port, size
); );
match data_to_real_server_tx.send(data.to_vec()).await { if let Err(e) = data_to_virtual_server_tx.send(data.to_vec()).await {
Err(e) => {
error!( error!(
"[{}] Failed to dispatch data to virtual interface: {:?}", "[{}] Failed to dispatch data to virtual interface: {:?}",
virtual_port, e virtual_port, e
); );
} }
_ => {}
}
} }
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue; continue;
@ -239,7 +237,7 @@ async fn virtual_tcp_interface(
wg: Arc<WireGuardTunnel>, wg: Arc<WireGuardTunnel>,
abort: Arc<AtomicBool>, abort: Arc<AtomicBool>,
data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>, data_to_real_client_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
mut data_to_real_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>, mut data_to_virtual_server_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Create a device and interface to simulate IP packets // Create a device and interface to simulate IP packets
// In essence: // In essence:
@ -326,12 +324,9 @@ async fn virtual_tcp_interface(
match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) {
Ok(data) => { Ok(data) => {
// Send it to the real client // Send it to the real client
match data_to_real_client_tx.send(data).await { if let Err(e) = data_to_real_client_tx.send(data).await {
Err(e) => {
error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e); error!("[{}] Failed to dispatch data from virtual client to real client: {:?}", virtual_port, e);
} }
_ => {}
}
} }
Err(e) => { Err(e) => {
error!( error!(
@ -343,17 +338,13 @@ async fn virtual_tcp_interface(
} }
if client_socket.can_send() { if client_socket.can_send() {
// Check if there is anything to send // Check if there is anything to send
match data_to_real_server_rx.try_recv() { if let Ok(data) = data_to_virtual_server_rx.try_recv() {
Ok(data) => match client_socket.send_slice(&data) { if let Err(e) = client_socket.send_slice(&data) {
Err(e) => {
error!( error!(
"[{}] Failed to send slice via virtual client socket: {:?}", "[{}] Failed to send slice via virtual client socket: {:?}",
virtual_port, e virtual_port, e
); );
} }
_ => {}
},
Err(_) => {}
} }
} }
} }

View file

@ -1,4 +1,4 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -17,7 +17,7 @@ use crate::port_pool::PortPool;
use crate::MAX_PACKET; use crate::MAX_PACKET;
/// The capacity of the broadcast channel for received IP packets. /// The capacity of the broadcast channel for received IP packets.
const BROADCAST_CAPACITY: usize = 1_000_000; const BROADCAST_CAPACITY: usize = 1_000;
pub struct WireGuardTunnel { pub struct WireGuardTunnel {
source_peer_ip: IpAddr, source_peer_ip: IpAddr,
@ -189,7 +189,7 @@ impl WireGuardTunnel {
// For debugging purposes: parse packet // For debugging purposes: parse packet
trace_ip_packet("Received IP packet", packet); trace_ip_packet("Received IP packet", packet);
match self.route(packet) { match self.route_ip_packet(packet) {
RouteResult::Broadcast => { RouteResult::Broadcast => {
// Broadcast IP packet // Broadcast IP packet
if self.ip_broadcast_tx.receiver_count() > 1 { if self.ip_broadcast_tx.receiver_count() > 1 {
@ -238,14 +238,56 @@ impl WireGuardTunnel {
.with_context(|| "Failed to initialize boringtun Tunn") .with_context(|| "Failed to initialize boringtun Tunn")
} }
fn route(&self, packet: &[u8]) -> RouteResult { /// Makes a decision on the handling of an incoming IP packet.
fn route_ip_packet(&self, packet: &[u8]) -> RouteResult {
match IpVersion::of_packet(packet) { match IpVersion::of_packet(packet) {
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet) Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
.ok() .ok()
// Only care if the packet is destined for this tunnel // Only care if the packet is destined for this tunnel
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| match packet.protocol() { .map(|packet| match packet.protocol() {
IpProtocol::Tcp => TcpPacket::new_checked(packet.payload()).ok().map(|tcp| { IpProtocol::Tcp => Some(
self.route_tcp_segment(
packet.src_addr().into(),
packet.dst_addr().into(),
packet.payload(),
)
// Note: Ipv4 drops invalid TCP packets when the specified protocol says that it should be TCP
.unwrap_or(RouteResult::Drop),
),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
})
.flatten()
.unwrap_or(RouteResult::Drop),
Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet)
.ok()
// Only care if the packet is destined for this tunnel
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.map(|packet| {
self.route_tcp_segment(
packet.src_addr().into(),
packet.dst_addr().into(),
packet.payload(),
)
// Note: Since Ipv6 doesn't inform us of the protocol at this layer,
// we should broadcast unrecognized packets.
.unwrap_or(RouteResult::Broadcast)
})
.unwrap_or(RouteResult::Drop),
_ => RouteResult::Drop,
}
}
/// Makes a decision on the handling of an incoming TCP segment.
/// When the given segment is an invalid TCP packet, it returns `None`.
fn route_tcp_segment(
&self,
src_addr: IpAddress,
dst_addr: IpAddress,
segment: &[u8],
) -> Option<RouteResult> {
TcpPacket::new_checked(segment).ok().map(|tcp| {
if self.port_pool.is_in_use(tcp.dst_port()) { if self.port_pool.is_in_use(tcp.dst_port()) {
RouteResult::Broadcast RouteResult::Broadcast
} else if tcp.rst() { } else if tcp.rst() {
@ -254,25 +296,20 @@ impl WireGuardTunnel {
// Port is not in use, but it's a TCP packet so we'll craft a RST. // Port is not in use, but it's a TCP packet so we'll craft a RST.
RouteResult::TcpReset(craft_tcp_rst_reply( RouteResult::TcpReset(craft_tcp_rst_reply(
IpVersion::Ipv4, IpVersion::Ipv4,
packet.src_addr().into(), src_addr,
tcp.src_port(), tcp.src_port(),
packet.dst_addr().into(), dst_addr,
tcp.dst_port(), tcp.dst_port(),
tcp.ack_number(), tcp.ack_number(),
)) ))
} }
}),
// Unrecognized protocol, so we'll allow it.
_ => Some(RouteResult::Broadcast),
}) })
.flatten()
.unwrap_or(RouteResult::Drop),
// TODO: IPv6
_ => RouteResult::Drop,
}
} }
} }
/// Craft an IP packet containing a TCP RST segment, given an IP version,
/// source address (the one to reply to), destination address (the one the reply comes from),
/// and the ACK number received in the initiating TCP segment.
fn craft_tcp_rst_reply( fn craft_tcp_rst_reply(
ip_version: IpVersion, ip_version: IpVersion,
source_addr: IpAddress, source_addr: IpAddress,