diff --git a/src/virtual_iface/tcp.rs b/src/virtual_iface/tcp.rs index 190af1f..3005c3c 100644 --- a/src/virtual_iface/tcp.rs +++ b/src/virtual_iface/tcp.rs @@ -113,6 +113,20 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } => { let loop_start = smoltcp::time::Instant::now(); + // Find closed sockets + port_client_handle_map.retain(|virtual_port, client_handle| { + let client_socket = iface.get_socket::(*client_handle); + if client_socket.state() == TcpState::Closed { + endpoint.send(Event::ClientConnectionDropped(*virtual_port)); + send_queue.remove(virtual_port); + iface.remove_socket(*client_handle); + false + } else { + // Not closed, retain + true + } + }); + match iface.poll(loop_start) { Ok(processed) if processed => { trace!("TCP virtual interface polled some packets to be processed"); @@ -121,7 +135,6 @@ impl VirtualInterfacePoll for TcpVirtualInterface { _ => {} } - // Find client socket send data to for (virtual_port, client_handle) in port_client_handle_map.iter() { let client_socket = iface.get_socket::(*client_handle); if client_socket.can_send() { @@ -143,24 +156,16 @@ impl VirtualInterfacePoll for TcpVirtualInterface { ); } } - break; } else if client_socket.state() == TcpState::CloseWait { client_socket.close(); - break; } } } - } - - // Find client socket recv data from - for (virtual_port, client_handle) in port_client_handle_map.iter() { - let client_socket = iface.get_socket::(*client_handle); if client_socket.can_recv() { match client_socket.recv(|buffer| (buffer.len(), buffer.to_vec())) { Ok(data) => { if !data.is_empty() { endpoint.send(Event::RemoteData(*virtual_port, data)); - break; } } Err(e) => { @@ -172,20 +177,6 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } } - // Find closed sockets - port_client_handle_map.retain(|virtual_port, client_handle| { - let client_socket = iface.get_socket::(*client_handle); - if client_socket.state() == TcpState::Closed { - endpoint.send(Event::ClientConnectionDropped(*virtual_port)); - send_queue.remove(virtual_port); - iface.remove_socket(*client_handle); - false - } else { - // Not closed, retain - true - } - }); - // The virtual interface determines the next time to poll (this is to reduce unnecessary polls) next_poll = match iface.poll_delay(loop_start) { Some(smoltcp::time::Duration::ZERO) => None, @@ -223,11 +214,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface { } Event::ClientConnectionDropped(virtual_port) => { if let Some(client_handle) = port_client_handle_map.get(&virtual_port) { - let client_handle = *client_handle; - port_client_handle_map.remove(&virtual_port); - send_queue.remove(&virtual_port); - - let client_socket = iface.get_socket::(client_handle); + let client_socket = iface.get_socket::(*client_handle); client_socket.close(); next_poll = None; }