Cleanup usage of anyhow with_context

This commit is contained in:
Aram 🍐 2023-12-24 15:05:33 -05:00
parent 3ccd000ea8
commit ce40f85efa
9 changed files with 55 additions and 69 deletions

View file

@ -161,14 +161,14 @@ impl Config {
.map(|s| PortForwardConfig::from_notation(&s, DEFAULT_PORT_FORWARD_SOURCE)) .map(|s| PortForwardConfig::from_notation(&s, DEFAULT_PORT_FORWARD_SOURCE))
.collect(); .collect();
let port_forwards: Vec<PortForwardConfig> = port_forwards let port_forwards: Vec<PortForwardConfig> = port_forwards
.with_context(|| "Failed to parse port forward config")? .context("Failed to parse port forward config")?
.into_iter() .into_iter()
.flatten() .flatten()
.collect(); .collect();
// Read source-peer-ip // Read source-peer-ip
let source_peer_ip = parse_ip(matches.get_one::<String>("source-peer-ip")) let source_peer_ip = parse_ip(matches.get_one::<String>("source-peer-ip"))
.with_context(|| "Invalid source peer IP")?; .context("Invalid source peer IP")?;
// Combined `remote` arg and `ONETUN_REMOTE_PORT_FORWARD_#` envs // Combined `remote` arg and `ONETUN_REMOTE_PORT_FORWARD_#` envs
let mut port_forward_strings = HashSet::new(); let mut port_forward_strings = HashSet::new();
@ -196,7 +196,7 @@ impl Config {
}) })
.collect(); .collect();
let mut remote_port_forwards: Vec<PortForwardConfig> = remote_port_forwards let mut remote_port_forwards: Vec<PortForwardConfig> = remote_port_forwards
.with_context(|| "Failed to parse remote port forward config")? .context("Failed to parse remote port forward config")?
.into_iter() .into_iter()
.flatten() .flatten()
.collect(); .collect();
@ -229,7 +229,7 @@ impl Config {
{ {
read_to_string(private_key_file) read_to_string(private_key_file)
.map(|s| s.trim().to_string()) .map(|s| s.trim().to_string())
.with_context(|| "Failed to read private key file") .context("Failed to read private key file")
} else { } else {
if std::env::var("ONETUN_PRIVATE_KEY").is_err() { if std::env::var("ONETUN_PRIVATE_KEY").is_err() {
warnings.push("Private key was passed using CLI. This is insecure. \ warnings.push("Private key was passed using CLI. This is insecure. \
@ -238,20 +238,18 @@ impl Config {
matches matches
.get_one::<String>("private-key") .get_one::<String>("private-key")
.cloned() .cloned()
.with_context(|| "Missing private key") .context("Missing private key")
}?; }?;
let endpoint_addr = parse_addr(matches.get_one::<String>("endpoint-addr")) let endpoint_addr = parse_addr(matches.get_one::<String>("endpoint-addr"))
.with_context(|| "Invalid endpoint address")?; .context("Invalid endpoint address")?;
let endpoint_bind_addr = if let Some(addr) = matches.get_one::<String>("endpoint-bind-addr") let endpoint_bind_addr = if let Some(addr) = matches.get_one::<String>("endpoint-bind-addr")
{ {
let addr = parse_addr(Some(addr)).with_context(|| "Invalid bind address")?; let addr = parse_addr(Some(addr)).context("Invalid bind address")?;
// Make sure the bind address and endpoint address are the same IP version // Make sure the bind address and endpoint address are the same IP version
if addr.ip().is_ipv4() != endpoint_addr.ip().is_ipv4() { if addr.ip().is_ipv4() != endpoint_addr.ip().is_ipv4() {
return Err(anyhow::anyhow!( bail!("Endpoint and bind addresses must be the same IP version");
"Endpoint and bind addresses must be the same IP version"
));
} }
addr addr
} else { } else {
@ -265,21 +263,19 @@ impl Config {
Ok(Self { Ok(Self {
port_forwards, port_forwards,
remote_port_forwards, remote_port_forwards,
private_key: Arc::new( private_key: Arc::new(parse_private_key(&private_key).context("Invalid private key")?),
parse_private_key(&private_key).with_context(|| "Invalid private key")?,
),
endpoint_public_key: Arc::new( endpoint_public_key: Arc::new(
parse_public_key(matches.get_one::<String>("endpoint-public-key")) parse_public_key(matches.get_one::<String>("endpoint-public-key"))
.with_context(|| "Invalid endpoint public key")?, .context("Invalid endpoint public key")?,
), ),
preshared_key: parse_preshared_key(matches.get_one::<String>("preshared-key"))?, preshared_key: parse_preshared_key(matches.get_one::<String>("preshared-key"))?,
endpoint_addr, endpoint_addr,
endpoint_bind_addr, endpoint_bind_addr,
source_peer_ip, source_peer_ip,
keepalive_seconds: parse_keep_alive(matches.get_one::<String>("keep-alive")) keepalive_seconds: parse_keep_alive(matches.get_one::<String>("keep-alive"))
.with_context(|| "Invalid keep-alive value")?, .context("Invalid keep-alive value")?,
max_transmission_unit: parse_mtu(matches.get_one::<String>("max-transmission-unit")) max_transmission_unit: parse_mtu(matches.get_one::<String>("max-transmission-unit"))
.with_context(|| "Invalid max-transmission-unit value")?, .context("Invalid max-transmission-unit value")?,
log: matches log: matches
.get_one::<String>("log") .get_one::<String>("log")
.cloned() .cloned()
@ -291,22 +287,22 @@ impl Config {
} }
fn parse_addr<T: AsRef<str>>(s: Option<T>) -> anyhow::Result<SocketAddr> { fn parse_addr<T: AsRef<str>>(s: Option<T>) -> anyhow::Result<SocketAddr> {
s.with_context(|| "Missing address")? s.context("Missing address")?
.as_ref() .as_ref()
.to_socket_addrs() .to_socket_addrs()
.with_context(|| "Invalid address")? .context("Invalid address")?
.next() .next()
.with_context(|| "Could not lookup address") .context("Could not lookup address")
} }
fn parse_ip(s: Option<&String>) -> anyhow::Result<IpAddr> { fn parse_ip(s: Option<&String>) -> anyhow::Result<IpAddr> {
s.with_context(|| "Missing IP")? s.context("Missing IP address")?
.parse::<IpAddr>() .parse::<IpAddr>()
.with_context(|| "Invalid IP address") .context("Invalid IP address")
} }
fn parse_private_key(s: &str) -> anyhow::Result<StaticSecret> { fn parse_private_key(s: &str) -> anyhow::Result<StaticSecret> {
let decoded = base64::decode(s).with_context(|| "Failed to decode private key")?; let decoded = base64::decode(s).context("Failed to decode private key")?;
if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() { if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() {
Ok(StaticSecret::from(bytes)) Ok(StaticSecret::from(bytes))
} else { } else {
@ -315,8 +311,8 @@ fn parse_private_key(s: &str) -> anyhow::Result<StaticSecret> {
} }
fn parse_public_key(s: Option<&String>) -> anyhow::Result<PublicKey> { fn parse_public_key(s: Option<&String>) -> anyhow::Result<PublicKey> {
let encoded = s.with_context(|| "Missing public key")?; let encoded = s.context("Missing public key")?;
let decoded = base64::decode(encoded).with_context(|| "Failed to decode public key")?; let decoded = base64::decode(encoded).context("Failed to decode public key")?;
if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() { if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() {
Ok(PublicKey::from(bytes)) Ok(PublicKey::from(bytes))
} else { } else {
@ -326,7 +322,7 @@ fn parse_public_key(s: Option<&String>) -> anyhow::Result<PublicKey> {
fn parse_preshared_key(s: Option<&String>) -> anyhow::Result<Option<[u8; 32]>> { fn parse_preshared_key(s: Option<&String>) -> anyhow::Result<Option<[u8; 32]>> {
if let Some(s) = s { if let Some(s) = s {
let decoded = base64::decode(s).with_context(|| "Failed to decode preshared key")?; let decoded = base64::decode(s).context("Failed to decode preshared key")?;
if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() { if let Ok::<[u8; 32], _>(bytes) = decoded.try_into() {
Ok(Some(bytes)) Ok(Some(bytes))
} else { } else {
@ -352,9 +348,7 @@ fn parse_keep_alive(s: Option<&String>) -> anyhow::Result<Option<u16>> {
} }
fn parse_mtu(s: Option<&String>) -> anyhow::Result<usize> { fn parse_mtu(s: Option<&String>) -> anyhow::Result<usize> {
s.with_context(|| "Missing MTU")? s.context("Missing MTU")?.parse().context("Invalid MTU")
.parse()
.with_context(|| "Invalid MTU")
} }
#[cfg(unix)] #[cfg(unix)]
@ -483,27 +477,21 @@ impl PortForwardConfig {
let source = ( let source = (
src_addr.0.unwrap_or(default_source), src_addr.0.unwrap_or(default_source),
src_addr src_addr.1.parse::<u16>().context("Invalid source port")?,
.1
.parse::<u16>()
.with_context(|| "Invalid source port")?,
) )
.to_socket_addrs() .to_socket_addrs()
.with_context(|| "Invalid source address")? .context("Invalid source address")?
.next() .next()
.with_context(|| "Could not resolve source address")?; .context("Could not resolve source address")?;
let destination = ( let destination = (
dst_addr.0, dst_addr.0,
dst_addr dst_addr.1.parse::<u16>().context("Invalid source port")?,
.1
.parse::<u16>()
.with_context(|| "Invalid source port")?,
) )
.to_socket_addrs() // TODO: Pass this as given and use DNS config instead (issue #15) .to_socket_addrs() // TODO: Pass this as given and use DNS config instead (issue #15)
.with_context(|| "Invalid destination address")? .context("Invalid destination address")?
.next() .next()
.with_context(|| "Could not resolve destination address")?; .context("Could not resolve destination address")?;
// Parse protocols // Parse protocols
let protocols = if let Some(protocols) = protocols { let protocols = if let Some(protocols) = protocols {
@ -513,7 +501,7 @@ impl PortForwardConfig {
} else { } else {
Ok(vec![PortProtocol::Tcp]) Ok(vec![PortProtocol::Tcp])
} }
.with_context(|| "Failed to parse protocols")?; .context("Failed to parse protocols")?;
// Returns an config for each protocol // Returns an config for each protocol
Ok(protocols Ok(protocols

View file

@ -41,7 +41,7 @@ pub async fn start_tunnels(config: Config, bus: Bus) -> anyhow::Result<()> {
let wg = WireGuardTunnel::new(&config, bus.clone()) let wg = WireGuardTunnel::new(&config, bus.clone())
.await .await
.with_context(|| "Failed to initialize WireGuard tunnel")?; .context("Failed to initialize WireGuard tunnel")?;
let wg = Arc::new(wg); let wg = Arc::new(wg);
{ {

View file

@ -8,7 +8,7 @@ async fn main() -> anyhow::Result<()> {
use anyhow::Context; use anyhow::Context;
use onetun::{config::Config, events::Bus}; use onetun::{config::Config, events::Bus};
let config = Config::from_args().with_context(|| "Failed to read config")?; let config = Config::from_args().context("Configuration has errors")?;
init_logger(&config)?; init_logger(&config)?;
for warning in &config.warnings { for warning in &config.warnings {
@ -32,7 +32,5 @@ fn init_logger(config: &onetun::config::Config) -> anyhow::Result<()> {
let mut builder = pretty_env_logger::formatted_timed_builder(); let mut builder = pretty_env_logger::formatted_timed_builder();
builder.parse_filters(&config.log); builder.parse_filters(&config.log);
builder builder.try_init().context("Failed to initialize logger")
.try_init()
.with_context(|| "Failed to initialize logger")
} }

View file

@ -16,7 +16,7 @@ impl Pcap {
self.writer self.writer
.flush() .flush()
.await .await
.with_context(|| "Failed to flush pcap writer") .context("Failed to flush pcap writer")
} }
async fn write(&mut self, data: &[u8]) -> anyhow::Result<usize> { async fn write(&mut self, data: &[u8]) -> anyhow::Result<usize> {
@ -30,14 +30,14 @@ impl Pcap {
self.writer self.writer
.write_u16(value) .write_u16(value)
.await .await
.with_context(|| "Failed to write u16 to pcap writer") .context("Failed to write u16 to pcap writer")
} }
async fn write_u32(&mut self, value: u32) -> anyhow::Result<()> { async fn write_u32(&mut self, value: u32) -> anyhow::Result<()> {
self.writer self.writer
.write_u32(value) .write_u32(value)
.await .await
.with_context(|| "Failed to write u32 to pcap writer") .context("Failed to write u32 to pcap writer")
} }
async fn global_header(&mut self) -> anyhow::Result<()> { async fn global_header(&mut self) -> anyhow::Result<()> {
@ -64,14 +64,14 @@ impl Pcap {
async fn packet(&mut self, timestamp: Instant, packet: &[u8]) -> anyhow::Result<()> { async fn packet(&mut self, timestamp: Instant, packet: &[u8]) -> anyhow::Result<()> {
self.packet_header(timestamp, packet.len()) self.packet_header(timestamp, packet.len())
.await .await
.with_context(|| "Failed to write packet header to pcap writer")?; .context("Failed to write packet header to pcap writer")?;
self.write(packet) self.write(packet)
.await .await
.with_context(|| "Failed to write packet to pcap writer")?; .context("Failed to write packet to pcap writer")?;
self.writer self.writer
.flush() .flush()
.await .await
.with_context(|| "Failed to flush pcap writer")?; .context("Failed to flush pcap writer")?;
self.flush().await self.flush().await
} }
} }
@ -81,14 +81,14 @@ pub async fn capture(pcap_file: String, bus: Bus) -> anyhow::Result<()> {
let mut endpoint = bus.new_endpoint(); let mut endpoint = bus.new_endpoint();
let file = File::create(&pcap_file) let file = File::create(&pcap_file)
.await .await
.with_context(|| "Failed to create pcap file")?; .context("Failed to create pcap file")?;
let writer = BufWriter::new(file); let writer = BufWriter::new(file);
let mut writer = Pcap { writer }; let mut writer = Pcap { writer };
writer writer
.global_header() .global_header()
.await .await
.with_context(|| "Failed to write global header to pcap writer")?; .context("Failed to write global header to pcap writer")?;
info!("Capturing WireGuard IP packets to {}", &pcap_file); info!("Capturing WireGuard IP packets to {}", &pcap_file);
loop { loop {
@ -98,14 +98,14 @@ pub async fn capture(pcap_file: String, bus: Bus) -> anyhow::Result<()> {
writer writer
.packet(instant, &ip) .packet(instant, &ip)
.await .await
.with_context(|| "Failed to write inbound IP packet to pcap writer")?; .context("Failed to write inbound IP packet to pcap writer")?;
} }
Event::OutboundInternetPacket(ip) => { Event::OutboundInternetPacket(ip) => {
let instant = Instant::now(); let instant = Instant::now();
writer writer
.packet(instant, &ip) .packet(instant, &ip)
.await .await
.with_context(|| "Failed to write output IP packet to pcap writer")?; .context("Failed to write output IP packet to pcap writer")?;
} }
_ => {} _ => {}
} }

View file

@ -27,14 +27,14 @@ pub async fn tcp_proxy_server(
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let listener = TcpListener::bind(port_forward.source) let listener = TcpListener::bind(port_forward.source)
.await .await
.with_context(|| "Failed to listen on TCP proxy server")?; .context("Failed to listen on TCP proxy server")?;
loop { loop {
let port_pool = port_pool.clone(); let port_pool = port_pool.clone();
let (socket, peer_addr) = listener let (socket, peer_addr) = listener
.accept() .accept()
.await .await
.with_context(|| "Failed to accept connection on TCP proxy server")?; .context("Failed to accept connection on TCP proxy server")?;
// Assign a 'virtual port': this is a unique port number used to route IP packets // Assign a 'virtual port': this is a unique port number used to route IP packets
// received from the WireGuard tunnel. It is the port number that the virtual client will // received from the WireGuard tunnel. It is the port number that the virtual client will
@ -192,7 +192,7 @@ impl TcpPortPool {
let port = inner let port = inner
.queue .queue
.pop_front() .pop_front()
.with_context(|| "TCP virtual port pool is exhausted")?; .context("TCP virtual port pool is exhausted")?;
Ok(VirtualPort::new(port, PortProtocol::Tcp)) Ok(VirtualPort::new(port, PortProtocol::Tcp))
} }

View file

@ -37,7 +37,7 @@ pub async fn udp_proxy_server(
let mut endpoint = bus.new_endpoint(); let mut endpoint = bus.new_endpoint();
let socket = UdpSocket::bind(port_forward.source) let socket = UdpSocket::bind(port_forward.source)
.await .await
.with_context(|| "Failed to bind on UDP proxy address")?; .context("Failed to bind on UDP proxy address")?;
let mut buffer = [0u8; MAX_PACKET]; let mut buffer = [0u8; MAX_PACKET];
loop { loop {
@ -103,7 +103,7 @@ async fn next_udp_datagram(
let (size, peer_addr) = socket let (size, peer_addr) = socket
.recv_from(buffer) .recv_from(buffer)
.await .await
.with_context(|| "Failed to accept incoming UDP datagram")?; .context("Failed to accept incoming UDP datagram")?;
// Assign a 'virtual port': this is a unique port number used to route IP packets // Assign a 'virtual port': this is a unique port number used to route IP packets
// received from the WireGuard tunnel. It is the port number that the virtual client will // received from the WireGuard tunnel. It is the port number that the virtual client will
@ -212,7 +212,7 @@ impl UdpPortPool {
None None
} }
}) })
.with_context(|| "virtual port pool is exhausted")?; .context("Virtual port pool is exhausted")?;
inner.port_by_peer_addr.insert(peer_addr, port); inner.port_by_peer_addr.insert(peer_addr, port);
inner.peer_addr_by_port.insert(port, peer_addr); inner.peer_addr_by_port.insert(port, peer_addr);

View file

@ -56,7 +56,7 @@ impl TcpVirtualInterface {
IpAddress::from(port_forward.destination.ip()), IpAddress::from(port_forward.destination.ip()),
port_forward.destination.port(), port_forward.destination.port(),
)) ))
.with_context(|| "Virtual server socket failed to listen")?; .context("Virtual server socket failed to listen")?;
Ok(socket) Ok(socket)
} }
@ -218,7 +218,7 @@ impl VirtualInterfacePoll for TcpVirtualInterface {
), ),
(IpAddress::from(self.source_peer_ip), virtual_port.num()), (IpAddress::from(self.source_peer_ip), virtual_port.num()),
) )
.with_context(|| "Virtual server socket failed to listen")?; .context("Virtual server socket failed to listen")?;
next_poll = None; next_poll = None;
} }

View file

@ -61,7 +61,7 @@ impl UdpVirtualInterface {
IpAddress::from(port_forward.destination.ip()), IpAddress::from(port_forward.destination.ip()),
port_forward.destination.port(), port_forward.destination.port(),
)) ))
.with_context(|| "UDP virtual server socket failed to bind")?; .context("UDP virtual server socket failed to bind")?;
Ok(socket) Ok(socket)
} }
@ -78,7 +78,7 @@ impl UdpVirtualInterface {
let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); let mut socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer);
socket socket
.bind((IpAddress::from(source_peer_ip), client_port.num())) .bind((IpAddress::from(source_peer_ip), client_port.num()))
.with_context(|| "UDP virtual client failed to bind")?; .context("UDP virtual client failed to bind")?;
Ok(socket) Ok(socket)
} }

View file

@ -41,7 +41,7 @@ impl WireGuardTunnel {
let endpoint = config.endpoint_addr; let endpoint = config.endpoint_addr;
let udp = UdpSocket::bind(config.endpoint_bind_addr) let udp = UdpSocket::bind(config.endpoint_bind_addr)
.await .await
.with_context(|| "Failed to create UDP socket for WireGuard connection")?; .context("Failed to create UDP socket for WireGuard connection")?;
Ok(Self { Ok(Self {
source_peer_ip, source_peer_ip,
@ -65,7 +65,7 @@ impl WireGuardTunnel {
self.udp self.udp
.send_to(packet, self.endpoint) .send_to(packet, self.endpoint)
.await .await
.with_context(|| "Failed to send encrypted IP packet to WireGuard endpoint.")?; .context("Failed to send encrypted IP packet to WireGuard endpoint.")?;
debug!( debug!(
"Sent {} bytes to WireGuard endpoint (encrypted IP packet)", "Sent {} bytes to WireGuard endpoint (encrypted IP packet)",
packet.len() packet.len()
@ -244,7 +244,7 @@ impl WireGuardTunnel {
None, None,
) )
.map_err(|s| anyhow::anyhow!("{}", s)) .map_err(|s| anyhow::anyhow!("{}", s))
.with_context(|| "Failed to initialize boringtun Tunn") .context("Failed to initialize boringtun Tunn")
} }
/// Determine the inner protocol of the incoming IP packet (TCP/UDP). /// Determine the inner protocol of the incoming IP packet (TCP/UDP).