Move host address resolution logic to config

This commit is contained in:
Jackson Coxson 2022-06-23 22:59:19 -06:00
parent b108b5f404
commit 3ab108ad04
2 changed files with 27 additions and 18 deletions

View file

@ -19,7 +19,7 @@ pub struct Config {
pub(crate) private_key: Arc<X25519SecretKey>, pub(crate) private_key: Arc<X25519SecretKey>,
pub(crate) endpoint_public_key: Arc<X25519PublicKey>, pub(crate) endpoint_public_key: Arc<X25519PublicKey>,
pub(crate) endpoint_addr: SocketAddr, pub(crate) endpoint_addr: SocketAddr,
pub(crate) host_addr: Option<SocketAddr>, pub(crate) host_addr: SocketAddr,
pub(crate) source_peer_ip: IpAddr, pub(crate) source_peer_ip: IpAddr,
pub(crate) keepalive_seconds: Option<u16>, pub(crate) keepalive_seconds: Option<u16>,
pub(crate) max_transmission_unit: usize, pub(crate) max_transmission_unit: usize,
@ -232,6 +232,28 @@ impl Config {
.with_context(|| "Missing private key") .with_context(|| "Missing private key")
}?; }?;
let endpoint_addr = parse_addr(matches.value_of("endpoint-addr"))
.with_context(|| "Invalid endpoint address")?;
let host_addr = if let Some(addr) = matches.value_of("host-addr") {
let addr = parse_addr(Some(addr)).with_context(|| "Invalid host address")?;
// Make sure the host address and endpoint address are the same IP version
if addr.ip().is_ipv6() && endpoint_addr.ip().is_ipv6()
|| (addr.ip().is_ipv4() && endpoint_addr.ip().is_ipv4())
{
return Err(anyhow::anyhow!(
"Host address and endpoint address must be the same IP version"
));
}
addr
} else {
// Return the IP version of the endpoint address
match endpoint_addr {
SocketAddr::V4(_) => parse_addr(Some("0.0.0.0:0"))?,
SocketAddr::V6(_) => parse_addr(Some("[::]:0"))?,
}
};
Ok(Self { Ok(Self {
port_forwards, port_forwards,
remote_port_forwards, remote_port_forwards,
@ -242,14 +264,8 @@ impl Config {
parse_public_key(matches.value_of("endpoint-public-key")) parse_public_key(matches.value_of("endpoint-public-key"))
.with_context(|| "Invalid endpoint public key")?, .with_context(|| "Invalid endpoint public key")?,
), ),
endpoint_addr: parse_addr(matches.value_of("endpoint-addr")) endpoint_addr,
.with_context(|| "Invalid endpoint address")?, host_addr,
host_addr: match matches.value_of("host-addr") {
Some(host_addr) => {
Some(parse_addr(Some(host_addr)).with_context(|| "Invalid host address")?)
}
None => None,
},
source_peer_ip, source_peer_ip,
keepalive_seconds: parse_keep_alive(matches.value_of("keep-alive")) keepalive_seconds: parse_keep_alive(matches.value_of("keep-alive"))
.with_context(|| "Invalid keep-alive value")?, .with_context(|| "Invalid keep-alive value")?,

View file

@ -36,15 +36,8 @@ impl WireGuardTunnel {
let source_peer_ip = config.source_peer_ip; let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?; let peer = Self::create_tunnel(config)?;
let endpoint = config.endpoint_addr; let endpoint = config.endpoint_addr;
let udp = if let Some(host) = config.host_addr { let udp = UdpSocket::bind(config.host_addr)
UdpSocket::bind(host).await
} else {
UdpSocket::bind(match endpoint {
SocketAddr::V4(_) => "0.0.0.0:0",
SocketAddr::V6(_) => "[::]:0",
})
.await .await
}
.with_context(|| "Failed to create UDP socket for WireGuard connection")?; .with_context(|| "Failed to create UDP socket for WireGuard connection")?;
Ok(Self { Ok(Self {