diff --git a/src/config.rs b/src/config.rs index 5629b02..499ea3b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,7 +19,7 @@ pub struct Config { pub(crate) private_key: Arc, pub(crate) endpoint_public_key: Arc, pub(crate) endpoint_addr: SocketAddr, - pub(crate) host_addr: Option, + pub(crate) host_addr: SocketAddr, pub(crate) source_peer_ip: IpAddr, pub(crate) keepalive_seconds: Option, pub(crate) max_transmission_unit: usize, @@ -232,6 +232,28 @@ impl Config { .with_context(|| "Missing private key") }?; + let endpoint_addr = parse_addr(matches.value_of("endpoint-addr")) + .with_context(|| "Invalid endpoint address")?; + + let host_addr = if let Some(addr) = matches.value_of("host-addr") { + let addr = parse_addr(Some(addr)).with_context(|| "Invalid host address")?; + // Make sure the host address and endpoint address are the same IP version + if addr.ip().is_ipv6() && endpoint_addr.ip().is_ipv6() + || (addr.ip().is_ipv4() && endpoint_addr.ip().is_ipv4()) + { + return Err(anyhow::anyhow!( + "Host address and endpoint address must be the same IP version" + )); + } + addr + } else { + // Return the IP version of the endpoint address + match endpoint_addr { + SocketAddr::V4(_) => parse_addr(Some("0.0.0.0:0"))?, + SocketAddr::V6(_) => parse_addr(Some("[::]:0"))?, + } + }; + Ok(Self { port_forwards, remote_port_forwards, @@ -242,14 +264,8 @@ impl Config { parse_public_key(matches.value_of("endpoint-public-key")) .with_context(|| "Invalid endpoint public key")?, ), - endpoint_addr: parse_addr(matches.value_of("endpoint-addr")) - .with_context(|| "Invalid endpoint address")?, - host_addr: match matches.value_of("host-addr") { - Some(host_addr) => { - Some(parse_addr(Some(host_addr)).with_context(|| "Invalid host address")?) - } - None => None, - }, + endpoint_addr, + host_addr, source_peer_ip, keepalive_seconds: parse_keep_alive(matches.value_of("keep-alive")) .with_context(|| "Invalid keep-alive value")?, diff --git a/src/wg.rs b/src/wg.rs index a55fd99..3ee8ca7 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -36,16 +36,9 @@ impl WireGuardTunnel { let source_peer_ip = config.source_peer_ip; let peer = Self::create_tunnel(config)?; let endpoint = config.endpoint_addr; - let udp = if let Some(host) = config.host_addr { - UdpSocket::bind(host).await - } else { - UdpSocket::bind(match endpoint { - SocketAddr::V4(_) => "0.0.0.0:0", - SocketAddr::V6(_) => "[::]:0", - }) + let udp = UdpSocket::bind(config.host_addr) .await - } - .with_context(|| "Failed to create UDP socket for WireGuard connection")?; + .with_context(|| "Failed to create UDP socket for WireGuard connection")?; Ok(Self { source_peer_ip,