diff --git a/src/config.rs b/src/config.rs index eb48510..ccae551 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,6 +19,7 @@ pub struct Config { pub(crate) private_key: Arc, pub(crate) endpoint_public_key: Arc, pub(crate) endpoint_addr: SocketAddr, + pub(crate) endpoint_bind_addr: SocketAddr, pub(crate) source_peer_ip: IpAddr, pub(crate) keepalive_seconds: Option, pub(crate) max_transmission_unit: usize, @@ -76,6 +77,12 @@ impl Config { .long("endpoint-addr") .env("ONETUN_ENDPOINT_ADDR") .help("The address (IP + port) of the WireGuard endpoint (remote). Example: 1.2.3.4:51820"), + Arg::with_name("endpoint-bind-addr") + .required(false) + .takes_value(true) + .long("endpoint-bind-addr") + .env("ONETUN_ENDPOINT_BIND_ADDR") + .help("The address (IP + port) used to bind the local UDP socket for the WireGuard tunnel. Example: 1.2.3.4:30000. Defaults to 0.0.0.0:0 for IPv4 endpoints, or [::]:0 for IPv6 endpoints."), Arg::with_name("source-peer-ip") .required(true) .takes_value(true) @@ -225,6 +232,26 @@ impl Config { .with_context(|| "Missing private key") }?; + let endpoint_addr = parse_addr(matches.value_of("endpoint-addr")) + .with_context(|| "Invalid endpoint address")?; + + let endpoint_bind_addr = if let Some(addr) = matches.value_of("endpoint-bind-addr") { + let addr = parse_addr(Some(addr)).with_context(|| "Invalid bind address")?; + // Make sure the bind address and endpoint address are the same IP version + if addr.ip().is_ipv4() != endpoint_addr.ip().is_ipv4() { + return Err(anyhow::anyhow!( + "Endpoint and bind addresses must be the same IP version" + )); + } + addr + } else { + // Return the IP version of the endpoint address + match endpoint_addr { + SocketAddr::V4(_) => parse_addr(Some("0.0.0.0:0"))?, + SocketAddr::V6(_) => parse_addr(Some("[::]:0"))?, + } + }; + Ok(Self { port_forwards, remote_port_forwards, @@ -235,8 +262,8 @@ impl Config { parse_public_key(matches.value_of("endpoint-public-key")) .with_context(|| "Invalid endpoint public key")?, ), - endpoint_addr: parse_addr(matches.value_of("endpoint-addr")) - .with_context(|| "Invalid endpoint address")?, + endpoint_addr, + endpoint_bind_addr, source_peer_ip, keepalive_seconds: parse_keep_alive(matches.value_of("keep-alive")) .with_context(|| "Invalid keep-alive value")?, diff --git a/src/wg.rs b/src/wg.rs index 1d06444..2bacc22 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -36,12 +36,9 @@ impl WireGuardTunnel { let source_peer_ip = config.source_peer_ip; let peer = Self::create_tunnel(config)?; let endpoint = config.endpoint_addr; - let udp = UdpSocket::bind(match endpoint { - SocketAddr::V4(_) => "0.0.0.0:0", - SocketAddr::V6(_) => "[::]:0", - }) - .await - .with_context(|| "Failed to create UDP socket for WireGuard connection")?; + let udp = UdpSocket::bind(config.endpoint_bind_addr) + .await + .with_context(|| "Failed to create UDP socket for WireGuard connection")?; Ok(Self { source_peer_ip,