Skip to content

Commit

Permalink
Improved hole-punching process
Browse files Browse the repository at this point in the history
  • Loading branch information
manforowicz committed Apr 14, 2024
1 parent 67c252a commit 890c107
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 131 deletions.
2 changes: 2 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Not all of them are desirable or necessary.

- Make sure reader returns an EOF error if interrupted?

- Have hole punch error say what connection path each error occured from.

- Think: What other functionality can I pull out into gday_file_offer_protocol.

## Low-priority ideas
Expand Down
1 change: 1 addition & 0 deletions gday/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ license = "MIT"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow = "1.0.82"
clap = { version = "4.5.4", features = ["derive"] }
env_logger = "0.11.3"
gday_encryption = { version = "0.1.0", path = "../gday_encryption" }
Expand Down
4 changes: 2 additions & 2 deletions gday/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ fn run(args: Args) -> Result<(), Box<dyn std::error::Error>> {
let port = if let Some(port) = args.port {
port
} else if args.unencrypted {
gday_hole_punch::server_connector::DEFAULT_TCP_PORT
gday_hole_punch::DEFAULT_TCP_PORT
} else {
gday_hole_punch::server_connector::DEFAULT_TLS_PORT
gday_hole_punch::DEFAULT_TLS_PORT
};
(
server_connector::connect_to_domain_name(&domain_name, port, !args.unencrypted)?,
Expand Down
8 changes: 8 additions & 0 deletions gday_contact_exchange_protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ use std::{
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

/// The port that contact exchange servers
/// using unencrypted TCP should listen on
pub const DEFAULT_TCP_PORT: u16 = 2310;

/// The port that contact exchange servers
/// using encrypted TLS should listen on
pub const DEFAULT_TLS_PORT: u16 = 2311;

/// A message from client to server.
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
pub enum ClientMsg {
Expand Down
18 changes: 14 additions & 4 deletions gday_contact_exchange_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod tests;

use clap::Parser;
use connection_handler::handle_connection;
use gday_contact_exchange_protocol::{DEFAULT_TCP_PORT, DEFAULT_TLS_PORT};
use log::{debug, error, info, warn};
use socket2::{SockRef, TcpKeepalive};
use state::State;
Expand Down Expand Up @@ -41,9 +42,10 @@ struct Args {
#[arg(short, long, conflicts_with_all(["key", "certificate"]))]
unencrypted: bool,

/// The socket address from which to listen
#[arg(short, long, default_value = "[::]:234")]
address: String,
/// Custom socket address on which to listen.
/// Default: [::]:2311 for TLS, [::]:2310 when unencrypted
#[arg(short, long)]
address: Option<String>,

/// Number of seconds before a new room is deleted
#[arg(short, long, default_value = "300")]
Expand All @@ -67,8 +69,16 @@ async fn main() {
// set the log level according to the command line argument
env_logger::builder().filter_level(args.verbosity).init();

let addr = if let Some(addr) = args.address {
addr
} else if args.unencrypted {
format!("[::]:{DEFAULT_TCP_PORT}")
} else {
format!("[::]:{DEFAULT_TLS_PORT}")
};

// get tcp listener
let tcp_listener = get_tcp_listener(args.address).await;
let tcp_listener = get_tcp_listener(addr).await;

// get the TLS acceptor if applicable
let tls_acceptor = if let (Some(k), Some(c)) = (args.key, args.certificate) {
Expand Down
58 changes: 58 additions & 0 deletions gday_file_offer_protocol/src/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#![cfg(test)]

use std::path::PathBuf;

use crate::{FileMeta, FileOfferMsg, FileResponseMsg};

/// Test serializing and deserializing messages.
#[test]
fn sending_messages() {
let mut bytes = std::collections::VecDeque::new();

for msg in get_offer_msg_examples() {
crate::to_writer(msg, &mut bytes).unwrap();
}

for msg in get_offer_msg_examples() {
let deserialized_msg: FileOfferMsg = crate::from_reader(&mut bytes).unwrap();
assert_eq!(msg, deserialized_msg);
}

for msg in get_response_msg_examples() {
crate::to_writer(msg, &mut bytes).unwrap();
}

for msg in get_response_msg_examples() {
let deserialized_msg: FileResponseMsg = crate::from_reader(&mut bytes).unwrap();
assert_eq!(msg, deserialized_msg);
}
}

fn get_offer_msg_examples() -> Vec<FileOfferMsg> {
vec![
FileOfferMsg {
files: vec![
FileMeta {
short_path: PathBuf::from("example/path"),
len: 43,
},
FileMeta {
short_path: PathBuf::from("/foo/hello"),
len: 50,
},
],
},
FileOfferMsg { files: Vec::new() },
]
}

fn get_response_msg_examples() -> Vec<FileResponseMsg> {
vec![
FileResponseMsg {
accepted: vec![None, Some(0), Some(100)],
},
FileResponseMsg {
accepted: vec![None, None, None],
},
]
}
171 changes: 61 additions & 110 deletions gday_hole_punch/src/hole_puncher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use tokio::{

type PeerConnection = (std::net::TcpStream, [u8; 32]);

const RETRY_INTERVAL: Duration = Duration::from_millis(100);
const RETRY_INTERVAL: Duration = Duration::from_millis(200);

// TODO: Update all comments here!

// TODO: ADD BETTER ERROR REPORTING.
// add a timeout.
Expand All @@ -35,54 +37,60 @@ pub fn try_connect_to_peer(
peer_contact: FullContact,
shared_secret: &[u8],
timeout: std::time::Duration,
) -> Result<PeerConnection, HolePunchErrors> {
// time at which to give up hole punching
let end_time = tokio::time::Instant::now() + timeout;

// shorten the variable name for conciseness
let p = shared_secret;

) -> Result<PeerConnection, Error> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.expect("Tokio async runtime error.");

// hole punch asynchronously
runtime.block_on(async {
let mut futs = tokio::task::JoinSet::new();
if let Some(local) = local_contact.v4 {
futs.spawn(try_accept(local, p.to_vec(), end_time));

if let Some(peer) = peer_contact.private.v4 {
futs.spawn(try_connect(local, peer, p.to_vec(), end_time));
}

if let Some(peer) = peer_contact.public.v4 {
futs.spawn(try_connect(local, peer, p.to_vec(), end_time));
}
match runtime.block_on(tokio::time::timeout(
timeout,
hole_punch(local_contact, peer_contact, shared_secret),
)) {
Ok(result) => result,
Err(..) => Err(Error::HolePunchTimeout),
}
}

/// TODO: Comment
async fn hole_punch(
local_contact: Contact,
peer_contact: FullContact,
shared_secret: &[u8],
) -> Result<PeerConnection, Error> {
// shorten the variable name for conciseness
let p = shared_secret;

let mut futs = tokio::task::JoinSet::new();
if let Some(local) = local_contact.v4 {
futs.spawn(try_accept(local, p.to_vec()));

if let Some(peer) = peer_contact.private.v4 {
futs.spawn(try_connect(local, peer, p.to_vec()));
}

if let Some(local) = local_contact.v6 {
futs.spawn(try_accept(local, p.to_vec(), end_time));
if let Some(peer) = peer_contact.public.v4 {
futs.spawn(try_connect(local, peer, p.to_vec()));
}
}

if let Some(peer) = peer_contact.private.v6 {
futs.spawn(try_connect(local, peer, p.to_vec(), end_time));
}
if let Some(peer) = peer_contact.public.v6 {
futs.spawn(try_connect(local, peer, p.to_vec(), end_time));
}
if let Some(local) = local_contact.v6 {
futs.spawn(try_accept(local, p.to_vec()));

if let Some(peer) = peer_contact.private.v6 {
futs.spawn(try_connect(local, peer, p.to_vec()));
}
let mut errors = Vec::new();
trace!("Starting hole-punching");
while let Some(outcome) = futs.join_next().await {
match outcome.expect("Async error") {
Ok(connection) => return Ok(connection),
Err(err) => errors.push(err),
}
if let Some(peer) = peer_contact.public.v6 {
futs.spawn(try_connect(local, peer, p.to_vec()));
}
Err(HolePunchErrors { errors })
})
}
match futs.join_next().await {
Some(Ok(result)) => result,
Some(Err(..)) => panic!("Tokio join error."),
None => Err(Error::ContactEmpty),
}
}

/// Tries to TCP connect to `peer` from `local`.
Expand All @@ -91,102 +99,51 @@ async fn try_connect<T: Into<SocketAddr>>(
local: T,
peer: T,
shared_secret: Vec<u8>,
end_time: tokio::time::Instant,
) -> Result<PeerConnection, Error> {
let local = local.into();
let peer = peer.into();
let mut last_error = Error::HolePunchTimeout;
let mut interval = tokio::time::interval(RETRY_INTERVAL);

trace!("Trying to connect from {local} to {peer}.");

while tokio::time::Instant::now() < end_time {
// try connecting
match tokio::time::timeout_at(end_time, try_connect_once(local, peer, &shared_secret)).await
{
// return successfully connection
Ok(Ok(connection)) => return Ok(connection),
// update `last_error`
Ok(Err(err)) => {
debug!("Error when trying to connect from {local} to {peer}: {err}");
last_error = err;
}
// passed `end_time`
Err(..) => break,
let stream = loop {
let local_socket = get_local_socket(local)?;
if let Ok(stream) = local_socket.connect(peer).await {
break stream;
}
interval.tick().await;
};

// wait some time to avoid flooding the network
match tokio::time::timeout_at(end_time, interval.tick()).await {
// done waiting
Ok(..) => (),
// passed `end_time`
Err(..) => break,
};
}
Err(last_error)
debug!("Connected to {peer} from {local}. Will try to authenticate.");
verify_peer(&shared_secret, stream).await
}

/// Tries to accept a peer TCP connection on `local`.
/// Returns the most recent error if not successful by `end_time`.
async fn try_accept(
local: impl Into<SocketAddr>,
shared_secret: Vec<u8>,
end_time: tokio::time::Instant,
) -> Result<PeerConnection, Error> {
let local = local.into();
trace!("Trying accept peer's connection on {local}.");
let local_socket = get_local_socket(local)?;
let listener = local_socket.listen(1024)?;
let mut last_error = Error::HolePunchTimeout;
let mut interval = tokio::time::interval(RETRY_INTERVAL);

while tokio::time::Instant::now() < end_time {
// try accepting
match tokio::time::timeout_at(end_time, try_accept_once(&listener, &shared_secret)).await {
// return successful connection
Ok(Ok(connection)) => return Ok(connection),
// update `last_error`
Ok(Err(err)) => {
debug!("Error when trying to accept peer's connection on {local}: {err}");
last_error = err;
}
// passed `end_time`
Err(..) => break,
let (stream, addr) = loop {
if let Ok(ok) = listener.accept().await {
break ok;
}

// wait some time to avoid flooding the network
match tokio::time::timeout_at(end_time, interval.tick()).await {
// done waiting
Ok(..) => (),
// passed `end_time`
Err(..) => break,
};
}
Err(last_error)
}

async fn try_connect_once(
local: SocketAddr,
peer: SocketAddr,
shared_secret: &[u8],
) -> Result<PeerConnection, Error> {
let local_socket = get_local_socket(local)?;
let stream = local_socket.connect(peer).await?;
debug!("Connected to {peer} from {local}. Will try to authenticate.");
verify_peer(shared_secret, stream).await
}
interval.tick().await;
};

async fn try_accept_once(
listener: &tokio::net::TcpListener,
shared_secret: &[u8],
) -> Result<PeerConnection, Error> {
let (stream, addr) = listener.accept().await?;
debug!(
"Connected from {} to {}. Will try to authenticate.",
addr,
stream.local_addr()?
);
verify_peer(shared_secret, stream).await

verify_peer(&shared_secret, stream).await
}

/// Uses [SPAKE 2](https://docs.rs/spake2/latest/spake2/)
Expand Down Expand Up @@ -272,16 +229,10 @@ fn get_local_socket(local_addr: SocketAddr) -> std::io::Result<TcpSocket> {

let keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(5))
.with_interval(Duration::from_secs(1))
.with_interval(Duration::from_secs(2))
.with_retries(5);
let _ = sock.set_tcp_keepalive(&keepalive);

socket.bind(local_addr)?;
Ok(socket)
}

#[derive(thiserror::Error, Debug)]
#[error("Couldn't establish connection to peer. {:#?}", errors)]
pub struct HolePunchErrors {
errors: Vec<crate::Error>,
}
Loading

0 comments on commit 890c107

Please sign in to comment.