Skip to content

Commit

Permalink
Implement locking of Tunn in WireGuardTunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
aramperes committed Dec 24, 2023
1 parent e23cfc3 commit 1d703fa
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions src/wg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use boringtun::noise::{Tunn, TunnResult};
use log::Level;
use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet};
use tokio::net::UdpSocket;
use tokio::sync::Mutex;

use crate::config::{Config, PortProtocol};
use crate::events::Event;
Expand All @@ -23,7 +24,7 @@ const MAX_PACKET: usize = 65536;
pub struct WireGuardTunnel {
pub(crate) source_peer_ip: IpAddr,
/// `boringtun` peer/tunnel implementation, used for crypto & WG protocol.
peer: Box<Tunn>,
peer: Mutex<Box<Tunn>>,
/// The UDP socket for the public WireGuard endpoint to connect to.
udp: UdpSocket,
/// The address of the public WireGuard endpoint (UDP).
Expand All @@ -36,7 +37,7 @@ impl WireGuardTunnel {
/// Initialize a new WireGuard tunnel.
pub async fn new(config: &Config, bus: Bus) -> anyhow::Result<Self> {
let source_peer_ip = config.source_peer_ip;
let peer = Self::create_tunnel(config)?;
let peer = Mutex::new(Box::new(Self::create_tunnel(config)?));
let endpoint = config.endpoint_addr;
let udp = UdpSocket::bind(config.endpoint_bind_addr)
.await
Expand All @@ -55,7 +56,11 @@ impl WireGuardTunnel {
pub async fn send_ip_packet(&self, packet: &[u8]) -> anyhow::Result<()> {
trace_ip_packet("Sending IP packet", packet);
let mut send_buf = [0u8; MAX_PACKET];
match self.peer.encapsulate(packet, &mut send_buf) {
let encapsulate_result = {
let mut peer = self.peer.lock().await;
peer.encapsulate(packet, &mut send_buf)
};
match encapsulate_result {
TunnResult::WriteToNetwork(packet) => {
self.udp
.send_to(packet, self.endpoint)
Expand Down Expand Up @@ -104,7 +109,7 @@ impl WireGuardTunnel {

loop {
let mut send_buf = [0u8; MAX_PACKET];
let tun_result = self.peer.update_timers(&mut send_buf);
let tun_result = { self.peer.lock().await.update_timers(&mut send_buf) };
self.handle_routine_tun_result(tun_result).await;
}
}
Expand All @@ -131,7 +136,11 @@ impl WireGuardTunnel {
warn!("Wireguard handshake has expired!");

let mut buf = vec![0u8; MAX_PACKET];
let result = self.peer.format_handshake_initiation(&mut buf[..], false);
let result = self
.peer
.lock()
.await
.format_handshake_initiation(&mut buf[..], false);

self.handle_routine_tun_result(result).await
}
Expand Down Expand Up @@ -172,7 +181,11 @@ impl WireGuardTunnel {
};

let data = &recv_buf[..size];
match self.peer.decapsulate(None, data, &mut send_buf) {
let decapsulate_result = {
let mut peer = self.peer.lock().await;
peer.decapsulate(None, data, &mut send_buf)
};
match decapsulate_result {
TunnResult::WriteToNetwork(packet) => {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Expand All @@ -181,9 +194,10 @@ impl WireGuardTunnel {
continue;
}
};
let mut peer = self.peer.lock().await;
loop {
let mut send_buf = [0u8; MAX_PACKET];
match self.peer.decapsulate(None, &[], &mut send_buf) {
match peer.decapsulate(None, &[], &mut send_buf) {
TunnResult::WriteToNetwork(packet) => {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Expand Down Expand Up @@ -217,10 +231,13 @@ impl WireGuardTunnel {
}
}

fn create_tunnel(config: &Config) -> anyhow::Result<Box<Tunn>> {
fn create_tunnel(config: &Config) -> anyhow::Result<Tunn> {
let private = config.private_key.as_ref().clone();
let public = *config.endpoint_public_key.as_ref();

Tunn::new(
config.private_key.clone(),
config.endpoint_public_key.clone(),
private,
public,
config.preshared_key,
config.keepalive_seconds,
0,
Expand Down

0 comments on commit 1d703fa

Please sign in to comment.