use bytes::Bytes; use std::fmt::{Display, Formatter}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use crate::config::PortForwardConfig; use crate::virtual_iface::VirtualPort; use crate::PortProtocol; /// Events that go on the bus between the local server, smoltcp, and WireGuard. #[derive(Debug, Clone)] pub enum Event { /// Dumb event with no data. Dumb, /// A new connection with the local server was initiated, and the given virtual port was assigned. ClientConnectionInitiated(PortForwardConfig, VirtualPort), /// A connection was dropped from the pool and should be closed in all interfaces. ClientConnectionDropped(VirtualPort), /// Data received by the local server that should be sent to the virtual server. LocalData(PortForwardConfig, VirtualPort, Bytes), /// Data received by the remote server that should be sent to the local client. RemoteData(VirtualPort, Bytes), /// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device. InboundInternetPacket(PortProtocol, Bytes), /// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device. OutboundInternetPacket(Bytes), /// Notifies that a virtual device read an IP packet. VirtualDeviceFed(PortProtocol), } impl Display for Event { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Event::Dumb => { write!(f, "Dumb{{}}") } Event::ClientConnectionInitiated(pf, vp) => { write!(f, "ClientConnectionInitiated{{ pf={} vp={} }}", pf, vp) } Event::ClientConnectionDropped(vp) => { write!(f, "ClientConnectionDropped{{ vp={} }}", vp) } Event::LocalData(pf, vp, data) => { let size = data.len(); write!(f, "LocalData{{ pf={} vp={} size={} }}", pf, vp, size) } Event::RemoteData(vp, data) => { let size = data.len(); write!(f, "RemoteData{{ vp={} size={} }}", vp, size) } Event::InboundInternetPacket(proto, data) => { let size = data.len(); write!( f, "InboundInternetPacket{{ proto={} size={} }}", proto, size ) } Event::OutboundInternetPacket(data) => { let size = data.len(); write!(f, "OutboundInternetPacket{{ size={} }}", size) } Event::VirtualDeviceFed(proto) => { write!(f, "VirtualDeviceFed{{ proto={} }}", proto) } } } } #[derive(Clone)] pub struct Bus { counter: Arc, bus: Arc>, } impl Bus { /// Creates a new event bus. pub fn new() -> Self { let (bus, _) = tokio::sync::broadcast::channel(1000); let bus = Arc::new(bus); let counter = Arc::new(AtomicU32::default()); Self { bus, counter } } /// Creates a new endpoint on the event bus. pub fn new_endpoint(&self) -> BusEndpoint { let id = self.counter.fetch_add(1, Ordering::Relaxed); let tx = (*self.bus).clone(); let rx = self.bus.subscribe(); let tx = BusSender { id, tx }; BusEndpoint { id, tx, rx } } } impl Default for Bus { fn default() -> Self { Self::new() } } pub struct BusEndpoint { id: u32, tx: BusSender, rx: tokio::sync::broadcast::Receiver<(u32, Event)>, } impl BusEndpoint { /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. pub fn send(&self, event: Event) { self.tx.send(event) } /// Returns the unique sequential ID of this endpoint. pub fn id(&self) -> u32 { self.id } /// Awaits the next `Event` on the bus to be read. pub async fn recv(&mut self) -> Event { loop { match self.rx.recv().await { Ok((id, event)) => { if id == self.id { // If the event was sent by this endpoint, it is skipped continue; } else { return event; } } Err(_) => { error!("Failed to read event bus from endpoint #{}", self.id); return futures::future::pending().await; } } } } /// Creates a new sender for this endpoint that can be cloned. pub fn sender(&self) -> BusSender { self.tx.clone() } } #[derive(Clone)] pub struct BusSender { id: u32, tx: tokio::sync::broadcast::Sender<(u32, Event)>, } impl BusSender { /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. pub fn send(&self, event: Event) { trace!("#{} -> {}", self.id, event); match self.tx.send((self.id, event)) { Ok(_) => {} Err(_) => error!("Failed to send event to bus from endpoint #{}", self.id), } } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_bus() { let bus = Bus::new(); let mut endpoint_1 = bus.new_endpoint(); let mut endpoint_2 = bus.new_endpoint(); let mut endpoint_3 = bus.new_endpoint(); assert_eq!(endpoint_1.id(), 0); assert_eq!(endpoint_2.id(), 1); assert_eq!(endpoint_3.id(), 2); endpoint_1.send(Event::Dumb); let recv_2 = endpoint_2.recv().await; let recv_3 = endpoint_3.recv().await; assert!(matches!(recv_2, Event::Dumb)); assert!(matches!(recv_3, Event::Dumb)); endpoint_2.send(Event::Dumb); let recv_1 = endpoint_1.recv().await; let recv_3 = endpoint_3.recv().await; assert!(matches!(recv_1, Event::Dumb)); assert!(matches!(recv_3, Event::Dumb)); } }