use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use crate::config::PortForwardConfig; use crate::virtual_iface::VirtualPort; use crate::PortProtocol; /// Events that go on the bus between the local server, smoltcp, and WireGuard. #[derive(Debug, Clone)] pub enum Event { /// Dumb event with no data. Dumb, /// A new connection with the local server was initiated, and the given virtual port was assigned. ClientConnectionInitiated(PortForwardConfig, VirtualPort), /// A connection was dropped from the pool and should be closed in all interfaces. ClientConnectionDropped(VirtualPort), /// Data received by the local server that should be sent to the virtual server. LocalData(VirtualPort, Vec), /// Data received by the remote server that should be sent to the local client. RemoteData(VirtualPort, Vec), /// IP packet received from the WireGuard tunnel that should be passed through the corresponding virtual device. InboundInternetPacket(PortProtocol, Vec), /// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device. OutboundInternetPacket(Vec), /// Notifies that a virtual device read an IP packet. VirtualDeviceFed(PortProtocol), } #[derive(Clone)] pub struct Bus { counter: Arc, bus: Arc>, } impl Bus { /// Creates a new event bus. pub fn new() -> Self { let (bus, _) = tokio::sync::broadcast::channel(1000); let bus = Arc::new(bus); let counter = Arc::new(AtomicU32::default()); Self { bus, counter } } /// Creates a new endpoint on the event bus. pub fn new_endpoint(&self) -> BusEndpoint { let id = self.counter.fetch_add(1, Ordering::Relaxed); let tx = (*self.bus).clone(); let rx = self.bus.subscribe(); let tx = BusSender { id, tx }; BusEndpoint { id, tx, rx } } } impl Default for Bus { fn default() -> Self { Self::new() } } pub struct BusEndpoint { id: u32, tx: BusSender, rx: tokio::sync::broadcast::Receiver<(u32, Event)>, } impl BusEndpoint { /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. pub fn send(&self, event: Event) { self.tx.send(event) } /// Returns the unique sequential ID of this endpoint. pub fn id(&self) -> u32 { self.id } /// Awaits the next `Event` on the bus to be read. pub async fn recv(&mut self) -> Event { loop { match self.rx.recv().await { Ok((id, event)) => { if id == self.id { // If the event was sent by this endpoint, it is skipped continue; } else { trace!("#{} <- {:?}", self.id, event); return event; } } Err(_) => { error!("Failed to read event bus from endpoint #{}", self.id); return futures::future::pending().await; } } } } /// Creates a new sender for this endpoint that can be cloned. pub fn sender(&self) -> BusSender { self.tx.clone() } } #[derive(Clone)] pub struct BusSender { id: u32, tx: tokio::sync::broadcast::Sender<(u32, Event)>, } impl BusSender { /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. pub fn send(&self, event: Event) { trace!("#{} -> {:?}", self.id, event); match self.tx.send((self.id, event)) { Ok(_) => {} Err(_) => error!("Failed to send event to bus from endpoint #{}", self.id), } } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_bus() { let bus = Bus::new(); let mut endpoint_1 = bus.new_endpoint(); let mut endpoint_2 = bus.new_endpoint(); let mut endpoint_3 = bus.new_endpoint(); assert_eq!(endpoint_1.id(), 0); assert_eq!(endpoint_2.id(), 1); assert_eq!(endpoint_3.id(), 2); endpoint_1.send(Event::Dumb); let recv_2 = endpoint_2.recv().await; let recv_3 = endpoint_3.recv().await; assert!(matches!(recv_2, Event::Dumb)); assert!(matches!(recv_3, Event::Dumb)); endpoint_2.send(Event::Dumb); let recv_1 = endpoint_1.recv().await; let recv_3 = endpoint_3.recv().await; assert!(matches!(recv_1, Event::Dumb)); assert!(matches!(recv_3, Event::Dumb)); } }