Skip to content

Commit

Permalink
Rechannel: don't ack fragment packet on the first received fragment
Browse files Browse the repository at this point in the history
Only ack the packet when all fragments are received.

The connection was considering the packet received on the first fragment received, this would cause sync errors when one fragment is missed, the other end would think the packet was received and would never resend the messages that were in that packet.

Added test to considers check this case, also updated the usage test to fragment packets and add rng to have better packet loss simulation.
  • Loading branch information
lucaspoffo committed Jul 11, 2022
1 parent 2339e56 commit 207091a
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 21 deletions.
1 change: 1 addition & 0 deletions rechannel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ bytes = { version = "1.1", features = ["serde"] }

[dev-dependencies]
env_logger = "0.9.0"
rand = "0.8.5"
12 changes: 6 additions & 6 deletions rechannel/src/reassembly_fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl SequenceBuffer<ReassemblyFragment> {
fragment_data: FragmentData,
max_packet_size: u64,
config: &FragmentConfig,
) -> Result<Vec<ChannelPacketData>, FragmentError> {
) -> Result<Option<Vec<ChannelPacketData>>, FragmentError> {
let FragmentData {
fragment_id,
num_fragments,
Expand Down Expand Up @@ -187,10 +187,10 @@ impl SequenceBuffer<ReassemblyFragment> {
let messages: Vec<ChannelPacketData> = bincode::options().deserialize(&reassembly_fragment.buffer)?;

log::trace!("Completed the reassembly of packet {}.", reassembly_fragment.sequence);
return Ok(messages);
return Ok(Some(messages));
}

Ok(vec![])
Ok(None)
}
}

Expand Down Expand Up @@ -255,18 +255,18 @@ mod tests {

let result = fragments_reassembly.handle_fragment(sequence, fragments[0].clone(), 250_000, &config);
match result {
Ok(payloads) => assert!(payloads.is_empty()),
Ok(payloads) => assert!(payloads.is_none()),
_ => unreachable!(),
}

let result = fragments_reassembly.handle_fragment(sequence, fragments[1].clone(), 250_000, &config);
match result {
Ok(payloads) => assert!(payloads.is_empty()),
Ok(payloads) => assert!(payloads.is_none()),
_ => unreachable!(),
}

let result = fragments_reassembly.handle_fragment(sequence, fragments[2].clone(), 250_000, &config);
let result = result.unwrap();
let result = result.unwrap().unwrap();

assert_eq!(messages.len(), result.len());

Expand Down
41 changes: 33 additions & 8 deletions rechannel/src/remote_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,22 @@ impl RemoteConnection {
ack_data,
fragment_data,
} => {
if self.received_buffer.get_mut(sequence).is_none() {
self.received_buffer.insert(sequence, ());
}

self.update_acket_packets(ack_data.ack, ack_data.ack_bits);

self.reassembly_buffer.handle_fragment(
let packet = self.reassembly_buffer.handle_fragment(
sequence,
fragment_data,
self.config.max_packet_size,
&self.config.fragment_config,
)?
)?;
match packet {
None => return Ok(()),
Some(packet) => {
// Only consider the packet received when the fragment is completed
self.received_buffer.insert(sequence, ());
packet
}
}
}
Packet::Heartbeat { ack_data } => {
self.update_acket_packets(ack_data.ack, ack_data.ack_bits);
Expand Down Expand Up @@ -317,15 +321,16 @@ impl RemoteConnection {
Ok(vec![])
}

fn update_acket_packets(&mut self, ack: u16, ack_bits: u32) {
let mut ack_bits = ack_bits;
fn update_acket_packets(&mut self, ack: u16, mut ack_bits: u32) {
for i in 0..32 {
if ack_bits & 1 != 0 {
let ack_sequence = ack.wrapping_sub(i);
if let Some(ref mut sent_packet) = self.sent_buffer.get_mut(ack_sequence) {
if !sent_packet.ack {
self.acks.push(ack_sequence);
sent_packet.ack = true;

// Update RTT
let rtt = (self.current_time - sent_packet.time).as_secs_f32() * 1000.;

if self.rtt == 0.0 || self.rtt < f32::EPSILON {
Expand Down Expand Up @@ -416,4 +421,24 @@ mod tests {
connection.update_packet_loss();
assert_eq!(connection.packet_loss(), 0.5);
}

#[test]
fn confirm_only_completed_fragmented_packet() {
let config = ConnectionConfig::default();
let mut connection = RemoteConnection::new(Duration::ZERO, config);
let message = vec![7u8; 2500];
connection.send_message(0, message.clone().into());

let packets = connection.get_packets_to_send().unwrap();
assert!(packets.len() > 1);
for packet in packets.iter() {
assert!(!connection.received_buffer.exists(0));
connection.process_packet(packet).unwrap();
}
// After all fragments are received it should be considered received
assert!(connection.received_buffer.exists(0));

let received_message = connection.receive_message(0).unwrap();
assert_eq!(message, received_message);
}
}
18 changes: 11 additions & 7 deletions rechannel/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use rechannel::{

use bincode::{self, Options};
use serde::{Deserialize, Serialize};
use rand::prelude::*;

use std::time::Duration;

Expand Down Expand Up @@ -148,7 +149,7 @@ struct TestUsage {

impl Default for TestUsage {
fn default() -> Self {
Self { value: vec![255; 400] }
Self { value: vec![255; 2500] }
}
}

Expand All @@ -158,6 +159,7 @@ use std::collections::HashMap;
fn test_usage() {
// TODO: we can't distinguish the log between the clients
init_log();
let mut rng = rand::thread_rng();
let mut server = RechannelServer::new(Duration::ZERO, ConnectionConfig::default());

let mut clients_status: HashMap<usize, ClientStatus> = HashMap::new();
Expand All @@ -173,9 +175,7 @@ fn test_usage() {
server.add_connection(&i);
}

let mut count: u64 = 0;
loop {
count += 1;
for (connection_id, status) in clients_status.iter_mut() {
status.connection.update().unwrap();
if status.connection.receive_message(0).is_some() {
Expand All @@ -195,12 +195,16 @@ fn test_usage() {
let client_packets = status.connection.get_packets_to_send().unwrap();
let server_packets = server.get_packets_to_send(connection_id).unwrap();

// 66% packet loss emulation
if count % 3 == 0 {
for packet in client_packets.iter() {
for packet in client_packets.iter() {
// 10% packet loss emulation
if rng.gen::<f64>() < 0.9 {
server.process_packet_from(packet, connection_id).unwrap();
}
for packet in server_packets.iter() {
}

for packet in server_packets.iter() {
// 10% packet loss emulation
if rng.gen::<f64>() < 0.9 {
status.connection.process_packet(packet).unwrap();
}
}
Expand Down

0 comments on commit 207091a

Please sign in to comment.