From 74b7b78025e4811d6f9fbcddf2a42c24e4f8c79e Mon Sep 17 00:00:00 2001 From: Marcin Anforowicz Date: Sat, 6 Apr 2024 08:57:29 -0700 Subject: [PATCH] Working on adding TCP support to client --- README.md | 4 +- TODO.md | 22 ++- gday/src/base32.rs | 30 +++- gday/src/main.rs | 9 +- gday_encryption/src/helper_buf.rs | 135 +++++++++++----- gday_encryption/src/lib.rs | 206 +++++++++++------------- gday_encryption/src/test.rs | 82 +++++----- gday_hole_punch/src/contact_sharer.rs | 2 +- gday_hole_punch/src/server_connector.rs | 155 ++++++++++++------ gday_server/Cargo.toml | 10 +- gday_server/src/connection_handler.rs | 56 +++++-- gday_server/src/main.rs | 28 ++-- 12 files changed, 449 insertions(+), 290 deletions(-) diff --git a/README.md b/README.md index 7391ba5..3b987e2 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ We stand on the shoulders of giants. ✅ - croc + magic-wormhole ❌ ✅ ✅ @@ -43,7 +43,7 @@ We stand on the shoulders of giants. ✅ - magic-wormhole + croc ❌ ✅ ✅ diff --git a/TODO.md b/TODO.md index a729998..a18bbca 100644 --- a/TODO.md +++ b/TODO.md @@ -1,19 +1,11 @@ # To-Do's -Items to consider implementing. Not all of them are desirable or necessary. -These are just some quick notes that might not make sense. - -- Have the hole puncher actively prefer local sockets over public sockets. -But I don't think this matters much since -most NATs don't support hairpin translation, and if they do, I doubt its much slower than a direct connection. +Quick notes on items to consider implementing. +Not all of them are desirable or necessary. - Deduplicate errors in the error vec the hole puncher can return. This might give a more helpful error message to the user. -- Give the client and server the option to use plain TCP instead of TLS. -This might be difficult because various inner functions require a get address function. -Maybe I can create a trait that allows for such a function call, and implement this trait -for both raw TCP and TLS? That sounds overly complicated, but maybe it's the only option? -Or potentially just pass an address parameter everywhere?? +- Give the client the option to select a server to use. - Restructure the hole puncher to force keeping connection to server open during hole-punching. That might please some NATs that lose state when TCP connection is closed. @@ -29,4 +21,10 @@ the peer's device is acting as some sort of server. - Add some sort of end-to-end integration tests. -- Allow sending a simple text string instead of only files. \ No newline at end of file +- Maybe add some versioning to the protocols? + +## Low-priority ideas + +- Allow sending a simple text string instead of only files. + Though, I don't think this is a common use case, so will only + add if I get requests. \ No newline at end of file diff --git a/gday/src/base32.rs b/gday/src/base32.rs index f3104e8..f600c82 100644 --- a/gday/src/base32.rs +++ b/gday/src/base32.rs @@ -162,16 +162,33 @@ pub enum Error { mod tests { use super::*; + /// Test encoding a message. #[test] - fn test_general() { + fn test_encode() { let peer_code = PeerCode { - room_code: 17535328141421925132, - server_id: 4358432574238545432, - shared_secret: 9175435743820743890, + room_code: 1, + server_id: 288, + shared_secret: 1, + }; + + let message = peer_code.to_str(); + assert_eq!(message, "90.1.1.E"); + } + + /// Test decoding a message with lowercase letters and + /// - o instead of 0 + /// - i or l instead of 1 + #[test] + fn test_decode() { + let message = "9o.i.l.e"; + let received = PeerCode::from_str(message).unwrap(); + + let peer_code = PeerCode { + room_code: 1, + server_id: 288, + shared_secret: 1, }; - let str: String = peer_code.to_str(); - let received = PeerCode::from_str(&str).unwrap(); assert_eq!(peer_code, received); } @@ -184,7 +201,6 @@ mod tests { }; let str: String = peer_code.to_str(); - println!("{str}"); let received = PeerCode::from_str(&str).unwrap(); assert_eq!(peer_code, received); } diff --git a/gday/src/main.rs b/gday/src/main.rs index 345f97c..db329ae 100644 --- a/gday/src/main.rs +++ b/gday/src/main.rs @@ -40,6 +40,10 @@ struct Args { #[arg(long)] room: Option, + /// Use unencrypted TCP instead of TLS. TODO + #[arg(long)] + unencrypted: bool, + /// Use a custom shared secret #[arg(long)] secret: Option, @@ -80,7 +84,10 @@ fn run(args: Args) -> Result<(), Box> { // use custom server if the user provided one, // otherwise pick a random default server let (mut server_connection, server_id) = if let Some(domain_name) = args.server { - (server_connector::connect_to_domain_name(&domain_name)?, 0) + ( + server_connector::connect_to_domain_name(&domain_name, true)?, + 0, + ) } else { server_connector::connect_to_random_server(DEFAULT_SERVERS)? }; diff --git a/gday_encryption/src/helper_buf.rs b/gday_encryption/src/helper_buf.rs index a3061e0..fefa0b4 100644 --- a/gday_encryption/src/helper_buf.rs +++ b/gday_encryption/src/helper_buf.rs @@ -2,8 +2,11 @@ use chacha20poly1305::aead; use std::ops::{Deref, DerefMut}; /// Buffer for storing bytes. +/// - Implemented as a heap-allocated array +/// with a left and right cursor defining +/// the in-use portion. pub struct HelperBuf { - buf: Box<[u8]>, + inner: Box<[u8]>, l_cursor: usize, r_cursor: usize, } @@ -12,14 +15,15 @@ impl HelperBuf { /// Creates a new [`HelperBuf`] with `capacity`. pub fn with_capacity(capacity: usize) -> Self { Self { - buf: vec![0; capacity].into_boxed_slice(), + inner: vec![0; capacity].into_boxed_slice(), l_cursor: 0, r_cursor: 0, } } - /// Removes the first `num_bytes` bytes. - /// Panics if `num_bytes` > `self.len()` + /// Increments the left cursor by `num_bytes` bytes. + /// - Effectively "removes" the first `num_bytes`. + /// - Panics if `num_bytes` > `self.len()`. pub fn consume(&mut self, num_bytes: usize) { self.l_cursor += num_bytes; assert!(self.l_cursor <= self.r_cursor); @@ -32,26 +36,28 @@ impl HelperBuf { } } - /// Use after putting data to `spare_capacity()`. - pub fn increase_len(&mut self, size: usize) { - self.r_cursor += size; - assert!(self.r_cursor <= self.buf.len()); + /// Returns the internal spare capacity after the right cursor. + /// - Put data to the spare capacity, then use [`Self::increase_len()`] + pub fn spare_capacity(&mut self) -> &mut [u8] { + &mut self.inner[self.r_cursor..] } - /// Returns the internal spare capacity after the stored data. - pub fn spare_capacity(&mut self) -> &mut [u8] { - &mut self.buf[self.r_cursor..] + /// Increment the right cursor by `num_bytes`. + /// - Do this after putting data to [`Self::spare_capacity()`]. + pub fn increase_len(&mut self, num_bytes: usize) { + self.r_cursor += num_bytes; + assert!(self.r_cursor <= self.inner.len()); } - /// Moves the stored data to the beginning of the internal buffer. + /// Shifts the stored data to the beginning of the internal buffer. /// Maximizes `spare_capacity_len()` without changing anything else. pub fn left_align(&mut self) { - self.buf.copy_within(self.l_cursor..self.r_cursor, 0); + self.inner.copy_within(self.l_cursor..self.r_cursor, 0); self.r_cursor -= self.l_cursor; self.l_cursor = 0; } - /// Returns a mutable view into the part of this + /// Returns a mutable [`aead::Buffer`] view into the part of this /// buffer starting at index `i`. pub fn split_off_aead_buf(&mut self, i: usize) -> HelperBufPart { let start_i = self.l_cursor + i; @@ -63,16 +69,21 @@ impl HelperBuf { } impl aead::Buffer for HelperBuf { + /// Extends the [`HelperBuf`] with `other`. + /// - Returns an [`aead::Error`] if there's not enough capacity. fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> { let new_r_cursor = self.r_cursor + other.len(); - if new_r_cursor > self.buf.len() { + if new_r_cursor > self.inner.len() { return Err(aead::Error); } - self.buf[self.r_cursor..new_r_cursor].copy_from_slice(other); + self.inner[self.r_cursor..new_r_cursor].copy_from_slice(other); self.r_cursor = new_r_cursor; Ok(()) } + /// Shortens the length of [`HelperBuf`] to `len` + /// by cutting off data at the end. + /// - Panics if `len > self.len()` fn truncate(&mut self, len: usize) { let new_r_cursor = self.l_cursor + len; assert!(new_r_cursor <= self.r_cursor); @@ -86,13 +97,13 @@ impl Deref for HelperBuf { type Target = [u8]; fn deref(&self) -> &Self::Target { - &self.buf[self.l_cursor..self.r_cursor] + &self.inner[self.l_cursor..self.r_cursor] } } impl DerefMut for HelperBuf { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.buf[self.l_cursor..self.r_cursor] + &mut self.inner[self.l_cursor..self.r_cursor] } } @@ -108,25 +119,30 @@ impl AsMut<[u8]> for HelperBuf { } } -/// A mutable view into the back part of a `HelperBuf`. +/// A mutable view into the back part of a [`HelperBuf`]. pub struct HelperBufPart<'a> { - /// The `HelperBuf` this struct references. + /// The [`HelperBuf`] this struct references. parent: &'a mut HelperBuf, /// The index in `parent` where this view begins. start_i: usize, } impl<'a> aead::Buffer for HelperBufPart<'a> { + /// Extends the [`HelperBufPart`] with `other`. + /// - Returns an [`aead::Error`] if there's not enough capacity. fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> { let new_r_cursor = self.parent.r_cursor + other.len(); - if new_r_cursor > self.parent.buf.len() { + if new_r_cursor > self.parent.inner.len() { return Err(aead::Error); } - self.parent.buf[self.parent.r_cursor..new_r_cursor].copy_from_slice(other); + self.parent.inner[self.parent.r_cursor..new_r_cursor].copy_from_slice(other); self.parent.r_cursor = new_r_cursor; Ok(()) } + /// Shortens the length of this [`HelperBufPart`] to `len` + /// by cutting off data at the end. + /// - Panics if `len > self.len()` fn truncate(&mut self, len: usize) { let new_r_cursor = self.start_i + len; assert!(new_r_cursor <= self.parent.r_cursor); @@ -140,13 +156,13 @@ impl<'a> Deref for HelperBufPart<'a> { type Target = [u8]; fn deref(&self) -> &Self::Target { - &self.parent.buf[self.start_i..self.parent.r_cursor] + &self.parent.inner[self.start_i..self.parent.r_cursor] } } impl<'a> DerefMut for HelperBufPart<'a> { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.parent.buf[self.start_i..self.parent.r_cursor] + &mut self.parent.inner[self.start_i..self.parent.r_cursor] } } @@ -165,33 +181,72 @@ impl<'a> AsMut<[u8]> for HelperBufPart<'a> { #[cfg(test)] mod tests { use crate::helper_buf::HelperBuf; - use chacha20poly1305::aead::Buffer; + use chacha20poly1305::aead::{self, Buffer}; #[test] fn test_helper_buf() { let mut buf = HelperBuf::with_capacity(4); - assert_eq!(buf.buf.len(), 4); assert!(buf.is_empty()); - assert_eq!(buf.spare_capacity().len(), 4); + assert!(buf[..].is_empty()); + assert_eq!(buf.spare_capacity(), [0, 0, 0, 0]); + assert_eq!(*buf.inner, [0, 0, 0, 0]); buf.extend_from_slice(&[1, 2, 3]).unwrap(); - assert_eq!(buf[..], [1, 2, 3][..]); - assert_eq!(*buf.buf, [1, 2, 3, 0]); - assert_eq!(buf.spare_capacity().len(), 1); + assert_eq!(*buf, [1, 2, 3]); + assert_eq!(buf.spare_capacity(), [0]); + assert_eq!(*buf.inner, [1, 2, 3, 0]); buf.consume(1); - assert_eq!(buf[..], [2, 3][..]); - assert_eq!(*buf.buf, [1, 2, 3, 0]); - assert_eq!(buf.spare_capacity().len(), 1); + assert_eq!(*buf, [2, 3]); + assert_eq!(buf.spare_capacity(), [0]); + assert_eq!(*buf.inner, [1, 2, 3, 0]); + + buf[0] = 7; + assert_eq!(*buf, [7, 3]); + assert_eq!(buf.spare_capacity(), [0]); + assert_eq!(*buf.inner, [1, 7, 3, 0]); buf.left_align(); - assert_eq!(buf[..], [2, 3][..]); - assert_eq!(*buf.buf, [2, 3, 3, 0]); - assert_eq!(buf.spare_capacity().len(), 2); + assert_eq!(*buf, [7, 3]); + assert_eq!(buf.spare_capacity(), [3, 0]); + assert_eq!(*buf.inner, [7, 3, 3, 0]); - buf.consume(1); - assert_eq!(buf[..], [3][..]); - assert_eq!(*buf.buf, [2, 3, 3, 0]); - assert_eq!(buf.spare_capacity().len(), 2); + buf.spare_capacity()[0] = 5; + buf.increase_len(1); + assert_eq!(*buf, [7, 3, 5]); + assert_eq!(buf.spare_capacity(), [0]); + assert_eq!(*buf.inner, [7, 3, 5, 0]); + + // Trying to extend by slice longer than spare capacity + // results in an error + assert_eq!(buf.extend_from_slice(&[2, 2, 2, 2]), Err(aead::Error)); + + buf.truncate(1); + assert_eq!(*buf, [7]); + assert_eq!(buf.spare_capacity(), [3, 5, 0]); + assert_eq!(*buf.inner, [7, 3, 5, 0]); + } + + #[test] + fn test_helper_buf_part() { + let mut buf = HelperBuf::with_capacity(4); + + buf.extend_from_slice(&[1, 2, 3]).unwrap(); + assert_eq!(*buf, [1, 2, 3]); + let mut part = buf.split_off_aead_buf(1); + assert_eq!(*part, [2, 3]); + + part[0] = 5; + assert_eq!(*part, [5, 3]); + + part.extend_from_slice(&[6]).unwrap(); + assert_eq!(*part, [5, 3, 6]); + + assert_eq!(part.extend_from_slice(&[0]), Err(aead::Error)); + + part.truncate(1); + assert_eq!(*part, [5]); + + assert_eq!(*buf, [1, 5]); } } diff --git a/gday_encryption/src/lib.rs b/gday_encryption/src/lib.rs index d9ae701..ef05baf 100644 --- a/gday_encryption/src/lib.rs +++ b/gday_encryption/src/lib.rs @@ -23,8 +23,11 @@ pub struct EncryptedStream { /// Stream decryptor decryptor: DecryptorBE32, + /// Stream encryptor + encryptor: EncryptorBE32, + /// Encrypted data received from the inner IO stream. - /// Invariant: Always stores only an incomplete chunk. + /// - Invariant: Never stores a complete chunk(s). /// As soon as the full chunk arrives, moves and decrypts it /// into `decrypted`. received: HelperBuf, @@ -32,38 +35,88 @@ pub struct EncryptedStream { /// Data that has been decrypted from `received` decrypted: HelperBuf, - /// Stream encryptor - encryptor: EncryptorBE32, - - /// Data to be sent. `is_flushing` indicates it's encrypted + /// Data to be sent. Encrypted when flushing. + /// - Invariant: the first 2 bytes are always + /// reserved for the length header to_send: HelperBuf, - - /// Indicates whether `to_send` has been encrypted - is_flushing: bool, } impl EncryptedStream { - /// Wraps `inner` in an `EncryptedStream`. - /// Both sides must have the same `key` and `nonce`. - /// The `key` must be a secure random secret. - /// The `nonce` should be random, but doesn't need to be secret. - pub fn new(inner: T, key: &[u8; 32], nonce: &[u8; 7]) -> Self { + /// Wraps `io_stream` in an `EncryptedStream`. + /// - The sender and receiver must have the same `key` and `nonce`. + /// - The `key` must be a secure cryptographically random secret. + /// - The `nonce` shouldn't be reused, but doesn't need to be secret. + pub fn new(io_stream: T, key: &[u8; 32], nonce: &[u8; 7]) -> Self { + let mut to_send = HelperBuf::with_capacity(u16::MAX as usize + 2); + // add 2 bytes for length header to uphold invariant + to_send.extend_from_slice(&[0, 0]).expect("unreachable"); + Self { - inner, + inner: io_stream, decryptor: DecryptorBE32::new(key.into(), nonce.into()), encryptor: EncryptorBE32::new(key.into(), nonce.into()), - to_send: HelperBuf::with_capacity(u16::MAX as usize + 2), received: HelperBuf::with_capacity(u16::MAX as usize + 2), decrypted: HelperBuf::with_capacity(u16::MAX as usize + 2), - is_flushing: true, + to_send, } } } +impl Read for EncryptedStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + // if we're out of decrypted data, read more + if self.decrypted.is_empty() { + self.inner_read()?; + } + + let num_bytes = std::cmp::min(self.decrypted.len(), buf.len()); + buf[0..num_bytes].copy_from_slice(&self.decrypted[0..num_bytes]); + self.decrypted.consume(num_bytes); + Ok(num_bytes) + } +} + +impl BufRead for EncryptedStream { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + // if we're out of plaintext, read more + if self.decrypted.is_empty() { + self.inner_read()?; + } + + Ok(&self.decrypted) + } + + fn consume(&mut self, amt: usize) { + self.decrypted.consume(amt); + } +} + +impl Write for EncryptedStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let bytes_taken = std::cmp::min(buf.len(), self.to_send.spare_capacity().len() - TAG_SIZE); + self.to_send + .extend_from_slice(&buf[0..bytes_taken]) + .expect("unreachable"); + + // if `to_send` is full, flush it + if self.to_send.spare_capacity().len() == TAG_SIZE { + self.flush_write_buf()?; + } + Ok(bytes_taken) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.flush_write_buf()?; + self.inner.flush() + } +} + impl EncryptedStream { - /// Reads at least 1 new chunk into `self.plaintext`. - /// Otherwise returns `Poll::pending` + /// Reads and decrypts at least 1 new chunks into `self.decrypted`. + /// - Must only be called when `self.decrypted` is empty, + /// so that it has space to decrypt into. fn inner_read(&mut self) -> std::io::Result<()> { + debug_assert!(self.decrypted.is_empty()); // ensure at least a 2-byte header will fit in // the spare `received` capacity if self.received.len() + self.received.spare_capacity().len() < 2 { @@ -93,22 +146,17 @@ impl EncryptedStream { self.received.increase_len(bytes_read); } - self.decrypt_all_available()?; - Ok(()) - } + /// If there is a full chunk at the beginning of `data`, + /// returns it. + fn peek_cipher_chunk(data: &[u8]) -> Option<&[u8]> { + let len: [u8; 2] = data.get(0..2)?.try_into().expect("unreachable"); + let len = u16::from_be_bytes(len) as usize; + data.get(2..2 + len) + } - /// Decrypts all the full chunks in `self.ciphertext`, and - /// moves them into `self.plaintext` - fn decrypt_all_available(&mut self) -> std::io::Result<()> { // while there's another full encrypted chunk: while let Some(cipher_chunk) = peek_cipher_chunk(&self.received) { - // exit if there isn't enough room to put the - // decrypted plaintext - if self.decrypted.spare_capacity().len() < cipher_chunk.len() { - return Ok(()); - } - - // decrypt in `self.plaintext` + // decrypt in `self.decrypted` let mut decryption_space = self.decrypted.split_off_aead_buf(self.decrypted.len()); decryption_space @@ -128,39 +176,25 @@ impl EncryptedStream { impl EncryptedStream { fn flush_write_buf(&mut self) -> std::io::Result<()> { - // no need to flush if there's no data. - if self.to_send.len() == 2 { - self.is_flushing = false; - return Ok(()); - } - - // if not flushing, begin flushing - if !self.is_flushing { - // encrypt in place - let mut msg = self.to_send.split_off_aead_buf(2); - - self.encryptor - .encrypt_next_in_place(&[], &mut msg) - .map_err(|_| std::io::Error::new(ErrorKind::InvalidData, "Encryption error"))?; + // encrypt in place + let mut msg = self.to_send.split_off_aead_buf(2); + self.encryptor + .encrypt_next_in_place(&[], &mut msg) + .map_err(|_| std::io::Error::new(ErrorKind::InvalidData, "Encryption error"))?; - let len = u16::try_from(msg.len()) - .expect("unreachable: Length of message buffer should always fit in u16") - .to_be_bytes(); + let len = u16::try_from(msg.len()) + .expect("unreachable: Length of message buffer should always fit in u16") + .to_be_bytes(); - // write length to header - self.to_send[0..2].copy_from_slice(&len); + // write length to header + self.to_send[0..2].copy_from_slice(&len); - self.is_flushing = true; - } - - // write until empty or `Poll::Pending` + // write until empty while !self.to_send.is_empty() { let bytes_written = self.inner.write(&self.to_send)?; self.to_send.consume(bytes_written); } - self.is_flushing = false; - // make space for new header self.to_send .extend_from_slice(&[0, 0]) @@ -168,63 +202,3 @@ impl EncryptedStream { Ok(()) } } - -/// If there is a full chunk at the beginning of `data`, -/// returns it. -fn peek_cipher_chunk(data: &[u8]) -> Option<&[u8]> { - let len: [u8; 2] = data.get(0..2)?.try_into().expect("unreachable"); - let len = u16::from_be_bytes(len) as usize; - data.get(2..2 + len) -} - -impl Read for EncryptedStream { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - // if we're out of decrypted data, read more - if self.decrypted.is_empty() { - self.inner_read()?; - } - - let num_bytes = std::cmp::min(self.decrypted.len(), buf.len()); - buf[0..num_bytes].copy_from_slice(&self.decrypted[0..num_bytes]); - self.decrypted.consume(num_bytes); - Ok(num_bytes) - } -} - -impl BufRead for EncryptedStream { - fn fill_buf(&mut self) -> std::io::Result<&[u8]> { - // if we're out of plaintext, read more - if self.decrypted.is_empty() { - self.inner_read()?; - } - - Ok(&self.decrypted) - } - - fn consume(&mut self, amt: usize) { - self.decrypted.consume(amt); - } -} - -impl Write for EncryptedStream { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - if self.is_flushing { - self.flush_write_buf()?; - } - - let bytes_taken = std::cmp::min(buf.len(), self.to_send.spare_capacity().len() - TAG_SIZE); - self.to_send.extend_from_slice(&buf[..bytes_taken]).expect( - "unreachable: bytes_taken is less than or equal to to_send.spare_capacity().len()", - ); - - if self.to_send.spare_capacity().len() == TAG_SIZE { - self.flush_write_buf()?; - } - Ok(bytes_taken) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.flush_write_buf()?; - self.inner.flush() - } -} diff --git a/gday_encryption/src/test.rs b/gday_encryption/src/test.rs index 2b2805f..81cdc12 100644 --- a/gday_encryption/src/test.rs +++ b/gday_encryption/src/test.rs @@ -3,50 +3,58 @@ use std::{ io::{Read, Write}, }; +const TEST_DATA: &[&[u8]] = &[ + b"abc5423gsgdds43", + b"def432gfd2354", + b"ggdsgdst43646543hi", + b"g", + b"mgresgdfgno", + b"463prs", + b"tufdxb5436w", + b"y4325tzz", + b"132ddsagasfa", + b"vds dagdsfa", + b"ete243yfdga", + b"dbasbalp35", + b";kbfagp98845", + b"bjkdal;f023590qjva", + b"balkdlsaj353osdfa.b", + b"bfaa;489ajdfakl;db", +]; + +/// Test sending and receiving many small messages. #[test] fn test_small_messages() { - let nonce = [5; 7]; - let key = [5; 32]; + let nonce: [u8; 7] = [42; 7]; + let key: [u8; 32] = [123; 32]; + let mut pipe = VecDeque::new(); + let mut stream = crate::EncryptedStream::new(&mut pipe, &key, &nonce); - let pipe = VecDeque::new(); + for &msg in TEST_DATA { + stream.write_all(msg).unwrap(); + stream.flush().unwrap(); + let mut buf = vec![0; msg.len()]; + stream.read_exact(&mut buf).unwrap(); + assert_eq!(buf, msg); + } +} +/// Try to spot edge-cases that occur when sending +/// several large messages. +#[test] +fn test_large_messages() { + let nonce: [u8; 7] = [75; 7]; + let key: [u8; 32] = [22; 32]; + let pipe = VecDeque::new(); let mut stream = crate::EncryptedStream::new(pipe, &key, &nonce); - let test_data = [ - &b"abc5423gsgdds43"[..], - &b"def432gfd2354"[..], - &b"ggdsgdst43646543hi"[..], - &b"g"[..], - &b"mgresgdfgno"[..], - &b"463prs"[..], - &b"tufdxb5436w"[..], - &b"y4325tzz"[..], - &b"a"[..], - &b"b"[..], - &b"132ddsagasfa"[..], - &b"vds dagdsfa"[..], - &b" dfsafsadf fsa "[..], - &b"ete243yfdga"[..], - &b"dbasbalp35"[..], - &b";kbfdbaj;dsjagp98845"[..], - &b"bjkdal;f023590qjva"[..], - &b"balkdlsaj353osdfa.b"[..], - &b"bfaa;489ajdfakl;db"[..], - &b"bsafsda;498fasklj"[..], - &b";adosp0fspag098b"[..], - &b"10e92fsa"[..], - &b"9402389054va"[..], - &b"xcznvm,.zva"[..], - &b"0-90`=`=.;[.["[..], - &b"m.xzc[];][./21"[..], - &b"10-9].k],.;./,aks"[..], - ]; + let msg = vec![123; 70_000]; - for msg in test_data { - stream.write_all(msg).unwrap(); + for _ in 0..5 { + stream.write_all(&msg).unwrap(); stream.flush().unwrap(); - let mut buf = vec![0; msg.len()]; - stream.read_exact(&mut buf).unwrap(); - assert_eq!(buf, msg[..]); + let mut received = vec![0; msg.len()]; + stream.read_exact(&mut received).unwrap(); + assert_eq!(msg, received); } } diff --git a/gday_hole_punch/src/contact_sharer.rs b/gday_hole_punch/src/contact_sharer.rs index cb243c4..039965e 100644 --- a/gday_hole_punch/src/contact_sharer.rs +++ b/gday_hole_punch/src/contact_sharer.rs @@ -77,7 +77,7 @@ impl<'a> ContactSharer<'a> { let mut streams = self.connection.streams(); for stream in &mut streams { - let private_addr = Some(stream.get_ref().local_addr()?); + let private_addr = Some(stream.local_addr()?); let msg = ClientMsg::SendAddr { room_code: self.room_code, is_creator: self.is_creator, diff --git a/gday_hole_punch/src/server_connector.rs b/gday_hole_punch/src/server_connector.rs index ee38c90..1f58320 100644 --- a/gday_hole_punch/src/server_connector.rs +++ b/gday_hole_punch/src/server_connector.rs @@ -1,9 +1,11 @@ //! Functions for connecting to a Gday server. +//! TODO: Tidy up this file use crate::Error; use log::{debug, error}; use rand::seq::SliceRandom; use socket2::SockRef; +use std::io::{Read, Write}; use std::net::SocketAddr::{V4, V6}; use std::{ net::{SocketAddr, TcpStream, ToSocketAddrs}, @@ -18,6 +20,9 @@ pub const DEFAULT_SERVERS: &[ServerInfo] = &[ServerInfo { prefer: true, }]; +/// The port that public Gday servers listen on. +pub const DEFAULT_PORT: u16 = 234; + /// Information about a single Gday server. pub struct ServerInfo { /// The domain name of the server. @@ -36,13 +41,65 @@ pub struct ServerInfo { pub prefer: bool, } -/// A single [`rustls`] TLS TCP stream to a Gday server. -pub type TLSStream = rustls::StreamOwned; +#[allow(clippy::large_enum_variant)] +pub enum ServerStream { + TCP(TcpStream), + TLS(rustls::StreamOwned), +} + +impl Read for ServerStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + Self::TCP(stream) => stream.read(buf), + Self::TLS(stream) => stream.read(buf), + } + } +} + +impl Write for ServerStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + Self::TCP(stream) => stream.write(buf), + Self::TLS(stream) => stream.write(buf), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + match self { + Self::TCP(stream) => stream.flush(), + Self::TLS(stream) => stream.flush(), + } + } +} + +impl ServerStream { + /// Returns the local socket address of this stream + pub fn local_addr(&self) -> std::io::Result { + match self { + Self::TCP(stream) => stream.local_addr(), + Self::TLS(stream) => stream.get_ref().local_addr(), + } + } + + /// Enables SO_REUSEADDR and SO_REUSEPORT + /// So that this socket can be reused for + /// hole-punching. + fn enable_reuse(&self) { + let stream = match self { + Self::TCP(stream) => stream, + Self::TLS(stream) => stream.get_ref(), + }; + + let sock = SockRef::from(stream); + let _ = sock.set_reuse_address(true); + let _ = sock.set_reuse_port(true); + } +} -/// Can hold both a IPv4 and IPv6 [`TLSStream`] to a Gday server. +/// Can hold both a IPv4 and IPv6 [`ServerStream`] to a Gday server. pub struct ServerConnection { - pub v4: Option, - pub v6: Option, + pub v4: Option, + pub v6: Option, } /// Some private helper functions used by [`ContactSharer`] @@ -58,45 +115,38 @@ impl ServerConnection { } if let Some(stream) = &self.v4 { - let addr = stream.get_ref().local_addr()?; + let addr = stream.local_addr()?; if !matches!(addr, V4(_)) { return Err(Error::ExpectedIPv4); }; - Self::configure_stream(stream); + stream.enable_reuse(); } if let Some(stream) = &self.v6 { - let addr = stream.get_ref().local_addr()?; + let addr = stream.local_addr()?; if !matches!(addr, V6(_)) { return Err(Error::ExpectedIPv6); }; - Self::configure_stream(stream); + stream.enable_reuse(); } Ok(()) } - /// Returns a [`Vec`] of all the [`ServerStream`]s in this connection. - /// Order is guaranteed to always be the same. - pub(super) fn streams(&mut self) -> Vec<&mut TLSStream> { + /// Returns a [`Vec`] of all the [`TLSStream`]s in this connection. + /// Will return IPV6 followed by IPV4 + pub(super) fn streams(&mut self) -> Vec<&mut ServerStream> { let mut streams = Vec::new(); - if let Some(messenger) = &mut self.v4 { - streams.push(messenger); + if let Some(stream) = &mut self.v6 { + streams.push(stream); } - if let Some(messenger) = &mut self.v6 { - streams.push(messenger); + + if let Some(stream) = &mut self.v4 { + streams.push(stream); } streams } - - /// Enables `SO_REUSEADDR` and `SO_REUSEPORT` so that the port of - /// this stream can be reused for hole punching. - fn configure_stream(stream: &TLSStream) { - let sock = SockRef::from(stream.get_ref()); - let _ = sock.set_reuse_address(true); - let _ = sock.set_reuse_port(true); - } } /// Sequentially try connecting to the given servers, returning the first successful connection. @@ -125,7 +175,7 @@ pub fn connect_to_server_id( let Some(server) = servers.iter().find(|server| server.id == server_id) else { return Err(Error::ServerIDNotFound(server_id)); }; - connect_to_domain_name(server.domain_name) + connect_to_domain_name(server.domain_name, true) } /// Sequentially try connecting to the given addresses, returning the first successful connection. @@ -144,7 +194,7 @@ pub fn connect_to_random_address( for i in indices { let server = domain_names[i]; - let streams = match connect_to_domain_name(server) { + let streams = match connect_to_domain_name(server, true) { Ok(streams) => streams, Err(err) => { error!("Couldn't connect to \"{}\": {}", server, err); @@ -157,8 +207,9 @@ pub fn connect_to_random_address( } /// Try connecting to this `domain_name` and returning a [`ServerConnection`] -pub fn connect_to_domain_name(domain_name: &str) -> Result { - let address = format!("{domain_name}:8080"); +/// TODO: Add info about `encrypt` +pub fn connect_to_domain_name(domain_name: &str, encrypt: bool) -> Result { + let address = format!("{domain_name}:{DEFAULT_PORT}"); debug!("Connecting to '{address}`"); let addrs: Vec = address.to_socket_addrs()?.collect(); @@ -191,34 +242,46 @@ pub fn connect_to_domain_name(domain_name: &str) -> Result rustls::ClientConfig { +fn get_tls_config() -> Arc { let root_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth() + Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(), + ) } diff --git a/gday_server/Cargo.toml b/gday_server/Cargo.toml index 6e0e5fa..8ad7c03 100644 --- a/gday_server/Cargo.toml +++ b/gday_server/Cargo.toml @@ -3,7 +3,7 @@ name = "gday_server" version = "0.1.0" authors = ["Marcin Anforowicz"] edition = "2021" -description = "Lets 2 peers exchange their private and public addresses." +description = "A server that lets 2 peers exchange their private and public addresses via the gday contact exchange protocol." license = "MIT" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -11,7 +11,13 @@ license = "MIT" [dependencies] clap = { version = "4.5.4", features = ["derive"] } socket2 = "0.5.6" -tokio = { version = "1.37.0", features = ["rt-multi-thread", "macros", "net", "time", "sync"] } +tokio = { version = "1.37.0", features = [ + "rt-multi-thread", + "macros", + "net", + "time", + "sync", +] } tokio-rustls = "0.26.0" gday_contact_exchange_protocol = { path = "../gday_contact_exchange_protocol" } thiserror = "1.0.58" diff --git a/gday_server/src/connection_handler.rs b/gday_server/src/connection_handler.rs index a716436..6ab93b9 100644 --- a/gday_server/src/connection_handler.rs +++ b/gday_server/src/connection_handler.rs @@ -3,38 +3,62 @@ use gday_contact_exchange_protocol::{ deserialize_from_async, serialize_into_async, ClientMsg, ServerMsg, }; use log::{debug, warn}; -use std::fmt::Debug; -use tokio::net::TcpStream; -use tokio_rustls::{server::TlsStream, TlsAcceptor}; +use std::net::SocketAddr; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, +}; +use tokio_rustls::TlsAcceptor; /// Establishes a tls connection with the `tls_acceptor` on this `tcp_stream`. /// Handles all incoming requests. /// Exits with an error message if an issue is encountered. -pub async fn handle_connection(tcp_stream: TcpStream, tls_acceptor: TlsAcceptor, state: State) { +pub async fn handle_connection( + mut tcp_stream: TcpStream, + tls_acceptor: Option, + state: State, +) { // try establishing a TLS connection - let mut tls_stream = match tls_acceptor.accept(tcp_stream).await { - Ok(tls_stream) => tls_stream, + + let origin = match tcp_stream.peer_addr() { + Ok(origin) => origin, Err(err) => { - warn!("Error establishing TLS connection: {err}"); + warn!("Couldn't get client's IP address: {err}"); return; } }; - // try handling the requests - - if let Err(err) = handle_requests(&mut tls_stream, state).await { - debug!("Dropping connection because: {err}"); + if let Some(tls_acceptor) = tls_acceptor { + let mut tls_stream = match tls_acceptor.accept(tcp_stream).await { + Ok(tls_stream) => tls_stream, + Err(err) => { + warn!("Error establishing TLS connection: {err}"); + return; + } + }; + handle_requests(&mut tls_stream, state, origin) + .await + .unwrap_or_else(|err| { + debug!("Dropping connection because: {err}"); + }); + } else { + handle_requests(&mut tcp_stream, state, origin) + .await + .unwrap_or_else(|err| { + debug!("Dropping connection because: {err}"); + }); } } /// Handles requests from this connection. /// Returns an error if any problem is encountered. async fn handle_requests( - tls: &mut TlsStream, + tls: &mut (impl AsyncRead + AsyncWrite + Unpin), mut state: State, + origin: SocketAddr, ) -> Result<(), HandleMessageError> { loop { - let result = handle_message(tls, &mut state).await; + let result = handle_message(tls, &mut state, origin).await; match result { Ok(()) => (), Err(HandleMessageError::State(state::Error::NoSuchRoomCode)) => { @@ -63,12 +87,10 @@ async fn handle_requests( } async fn handle_message( - tls: &mut TlsStream, + tls: &mut (impl AsyncRead + AsyncWrite + Unpin), state: &mut State, + origin: SocketAddr, ) -> Result<(), HandleMessageError> { - // get this connection's ip address - let origin = tls.get_ref().0.peer_addr()?; - // try to deserialize the message let msg: ClientMsg = deserialize_from_async(tls).await?; diff --git a/gday_server/src/main.rs b/gday_server/src/main.rs index c026bcd..3326473 100644 --- a/gday_server/src/main.rs +++ b/gday_server/src/main.rs @@ -30,23 +30,27 @@ use tokio_rustls::{ #[command(author, version, about)] struct Args { /// PEM-encoded private TLS server key - #[arg(short, long)] - key: PathBuf, + #[arg(short, long, required_unless_present("unencrypted"))] + key: Option, /// PEM-encoded signed TLS server certificate - #[arg(short, long)] - certificate: PathBuf, + #[arg(short, long, required_unless_present("unencrypted"))] + certificate: Option, + + /// Use unencrypted TCP instead of TLS + #[arg(short, long, conflicts_with_all(["key", "certificate"]))] + unencrypted: bool, /// The socket address from which to listen - #[arg(short, long, default_value = "[::]:8080")] + #[arg(short, long, default_value = "[::]:234")] address: String, - /// Number of seconds before a new room is deleted. + /// Number of seconds before a new room is deleted #[arg(short, long, default_value = "300")] timeout: u64, /// Max number of requests an IP address can - /// send in a minute before they're rejected. + /// send in a minute before they're rejected #[arg(short, long, default_value = "60")] request_limit: u32, @@ -63,9 +67,15 @@ async fn main() { // set the log level according to the command line argument env_logger::builder().filter_level(args.verbosity).init(); - // get tcp listener and acceptor + // get tcp listener let tcp_listener = get_tcp_listener(args.address).await; - let tls_acceptor = get_tls_acceptor(&args.key, &args.certificate); + + // get the TLS acceptor if applicable + let tls_acceptor = if let (Some(k), Some(c)) = (args.key, args.certificate) { + Some(get_tls_acceptor(&k, &c)) + } else { + None + }; // create the shared global state object let state = State::new(args.request_limit, args.timeout);