diff --git a/rechannel/src/channel/block.rs b/rechannel/src/channel/block.rs index 1212bee2..46798a4c 100644 --- a/rechannel/src/channel/block.rs +++ b/rechannel/src/channel/block.rs @@ -10,9 +10,9 @@ use crate::{ sequence_buffer::SequenceBuffer, timer::Timer, }; -use log::{error, info}; +use log::{debug, error, info}; -use super::{Channel, ChannelNetworkInfo}; +use super::{ReceiveChannel, SendChannel}; #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct SliceMessage { @@ -52,43 +52,52 @@ pub struct BlockChannelConfig { } #[derive(Debug)] -struct ChunkSender { - sending: bool, +enum Sending { + Yes { + num_slices: usize, + current_slice_id: usize, + num_acked_slices: usize, + acked: Vec, + data: Bytes, + resend_timers: Vec, + }, + No, +} + +#[derive(Debug)] +pub struct SendBlockChannel { + channel_id: u8, chunk_id: u16, + sending: Sending, slice_size: usize, - num_slices: usize, - current_slice_id: usize, - num_acked_slices: usize, - acked: Vec, - chunk_data: Bytes, - resend_timers: Vec, - packets_sent: SequenceBuffer, resend_time: Duration, - max_message_size: u64, packet_budget: u64, + max_message_size: u64, + message_send_queue_size: usize, + packets_sent: SequenceBuffer, + messages_to_send: VecDeque, + error: Option, } #[derive(Debug)] -struct ChunkReceiver { - receiving: bool, - chunk_id: u16, - slice_size: usize, - num_slices: usize, - num_received_slices: usize, - received: Vec, - max_message_size: u64, - chunk_data: Payload, +enum Receiving { + Yes { + chunk_id: u16, + num_slices: usize, + num_received_slices: usize, + received: Vec, + chunk_data: Payload, + }, + No, } #[derive(Debug)] -pub(crate) struct BlockChannel { +pub struct ReceiveBlockChannel { channel_id: u8, - sender: ChunkSender, - receiver: ChunkReceiver, + receiving: Receiving, messages_received: VecDeque, - messages_to_send: VecDeque, - message_send_queue_size: usize, - info: ChannelNetworkInfo, + slice_size: usize, + max_message_size: u64, error: Option, } @@ -116,268 +125,96 @@ impl PacketSent { } } -impl ChunkSender { - fn new(slice_size: usize, sent_packet_buffer_size: usize, resend_time: Duration, packet_budget: u64, max_message_size: u64) -> Self { +impl SendBlockChannel { + pub fn new(config: BlockChannelConfig) -> Self { + assert!((config.slice_size as u64) <= config.packet_budget); + Self { - sending: false, chunk_id: 0, - slice_size, - num_slices: 0, - current_slice_id: 0, - num_acked_slices: 0, - acked: Vec::new(), - chunk_data: Bytes::new(), - resend_timers: Vec::with_capacity(sent_packet_buffer_size), - packets_sent: SequenceBuffer::with_capacity(sent_packet_buffer_size), - resend_time, - max_message_size, - packet_budget, + max_message_size: config.max_message_size, + slice_size: config.slice_size, + packet_budget: config.packet_budget, + resend_time: config.resend_time, + channel_id: config.channel_id, + message_send_queue_size: config.message_send_queue_size, + sending: Sending::No, + packets_sent: SequenceBuffer::with_capacity(config.sent_packet_buffer_size), + messages_to_send: VecDeque::with_capacity(config.message_send_queue_size), + error: None, } } - fn send_message(&mut self, data: Bytes) { - assert!(!self.sending); - - self.sending = true; - self.num_acked_slices = 0; - self.num_slices = (data.len() + self.slice_size - 1) / self.slice_size; - - self.acked = vec![false; self.num_slices]; - let mut resend_timer = Timer::new(self.resend_time); - resend_timer.finish(); - self.resend_timers.clear(); - self.resend_timers.resize(self.num_slices, resend_timer); - self.chunk_data = data; - } - fn generate_slice_packets(&mut self, mut available_bytes: u64) -> Result, bincode::Error> { let mut slice_messages: Vec = vec![]; - if !self.sending { - return Ok(slice_messages); - } - - available_bytes = available_bytes.min(self.packet_budget); - - for i in 0..self.num_slices { - let slice_id = (self.current_slice_id + i) % self.num_slices; - - if self.acked[slice_id] { - continue; - } - let resend_timer = &mut self.resend_timers[slice_id]; - if !resend_timer.is_finished() { - continue; - } - - let start = slice_id * self.slice_size; - let end = if slice_id == self.num_slices - 1 { self.chunk_data.len() } else { (slice_id + 1) * self.slice_size }; - - let data = self.chunk_data[start..end].to_vec(); - - let message = SliceMessage { - chunk_id: self.chunk_id, - slice_id: slice_id as u32, - num_slices: self.num_slices as u32, + match &mut self.sending { + Sending::No => Ok(slice_messages), + Sending::Yes { + num_slices, + current_slice_id, + acked, + resend_timers, data, - }; + .. + } => { + available_bytes = available_bytes.min(self.packet_budget); - let message_size = bincode::options().serialized_size(&message)?; - let message_size = message_size as u64; + for i in 0..*num_slices { + let slice_id = (*current_slice_id + i) % *num_slices; - if available_bytes < message_size { - break; - } + if acked[slice_id] { + continue; + } + let resend_timer = &mut resend_timers[slice_id]; + if !resend_timer.is_finished() { + continue; + } - available_bytes -= message_size; - resend_timer.reset(); + let start = slice_id * self.slice_size; + let end = if slice_id == *num_slices - 1 { data.len() } else { (slice_id + 1) * self.slice_size }; - info!( - "Generated SliceMessage {} from chunk_id {}. ({}/{})", - message.slice_id, self.chunk_id, message.slice_id, self.num_slices - ); + let data = data[start..end].to_vec(); - slice_messages.push(message); - } - self.current_slice_id = (self.current_slice_id + slice_messages.len()) % self.num_slices; + let message = SliceMessage { + chunk_id: self.chunk_id, + slice_id: slice_id as u32, + num_slices: *num_slices as u32, + data, + }; - Ok(slice_messages) - } + let message_size = bincode::options().serialized_size(&message)?; + let message_size = message_size as u64; - fn process_ack(&mut self, ack: u16) { - if let Some(sent_packet) = self.packets_sent.get_mut(ack) { - if sent_packet.acked || sent_packet.chunk_id != self.chunk_id { - return; - } - sent_packet.acked = true; + if available_bytes < message_size { + break; + } + + available_bytes -= message_size; + resend_timer.reset(); - for &slice_id in sent_packet.slice_ids.iter() { - if !self.acked[slice_id as usize] { - self.acked[slice_id as usize] = true; - self.num_acked_slices += 1; info!( - "Acked SliceMessage {} from chunk_id {}. ({}/{})", - slice_id, self.chunk_id, self.num_acked_slices, self.num_slices + "Generated SliceMessage {} from chunk_id {}. ({}/{})", + message.slice_id, self.chunk_id, message.slice_id, num_slices ); - } - } - - if self.num_acked_slices == self.num_slices { - self.sending = false; - info!("Finished sending block message {}.", self.chunk_id); - self.chunk_id += 1; - } - } - } -} -impl ChunkReceiver { - fn new(slice_size: usize, max_message_size: u64) -> Self { - Self { - receiving: false, - chunk_id: 0, - slice_size, - max_message_size, - num_slices: 0, - num_received_slices: 0, - received: Vec::new(), - chunk_data: Vec::new(), - } - } - - fn process_slice_message(&mut self, message: &SliceMessage) -> Result, ChannelError> { - if !self.receiving { - if message.num_slices == 0 { - error!("Cannot initialize block message with zero slices."); - return Err(ChannelError::InvalidSliceMessage); - } - - let total_size = message.num_slices as u64 * self.slice_size as u64; - if total_size > self.max_message_size { - error!( - "Cannot initialize block message above the channel limit size, got {}, expected less than {}", - total_size, self.max_message_size - ); - return Err(ChannelError::ReceivedMessageAboveMaxSize); - } - - self.receiving = true; - self.num_slices = message.num_slices as usize; - self.chunk_id = message.chunk_id; - self.num_received_slices = 0; - self.received = vec![false; self.num_slices]; - info!( - "Receiving Block message with id {} with {} slices.", - message.chunk_id, message.num_slices - ); - self.chunk_data = vec![0; self.num_slices * self.slice_size]; - } - - if message.chunk_id != self.chunk_id { - info!( - "Invalid chunk id for SliceMessage, expected {}, got {}.", - self.chunk_id, message.chunk_id - ); - // Not an error since this could be an old chunk id. - return Ok(None); - } - - if message.num_slices != self.num_slices as u32 { - error!( - "Invalid number of slices for SliceMessage, got {}, expected {}.", - message.num_slices, self.num_slices, - ); - return Err(ChannelError::InvalidSliceMessage); - } - - let slice_id = message.slice_id as usize; - let is_last_slice = slice_id == self.num_slices - 1; - if is_last_slice { - if message.data.len() > self.slice_size { - error!( - "Invalid last slice_size for SliceMessage, got {}, expected less than {}.", - message.data.len(), - self.slice_size, - ); - return Err(ChannelError::InvalidSliceMessage); - } - } else if message.data.len() != self.slice_size { - error!( - "Invalid slice_size for SliceMessage, expected {}, got {}.", - self.slice_size, - message.data.len() - ); - return Err(ChannelError::InvalidSliceMessage); - } - - if !self.received[slice_id] { - self.received[slice_id] = true; - self.num_received_slices += 1; + slice_messages.push(message); + } + *current_slice_id = (*current_slice_id + slice_messages.len()) % *num_slices; - if is_last_slice { - let len = (self.num_slices - 1) * self.slice_size + message.data.len(); - self.chunk_data.resize(len, 0); + Ok(slice_messages) } - - let start = slice_id * self.slice_size; - let end = if slice_id == self.num_slices - 1 { - (self.num_slices - 1) * self.slice_size + message.data.len() - } else { - (slice_id + 1) * self.slice_size - }; - - self.chunk_data[start..end].copy_from_slice(&message.data); - info!( - "Received slice {} from chunk {}. ({}/{})", - slice_id, self.chunk_id, self.num_received_slices, self.num_slices - ); - } - - if self.num_received_slices == self.num_slices { - info!("Received all slices for chunk {}.", self.chunk_id); - let block = mem::take(&mut self.chunk_data); - self.receiving = false; - return Ok(Some(block)); - } - - Ok(None) - } -} - -impl BlockChannel { - pub fn new(config: BlockChannelConfig) -> Self { - assert!((config.slice_size as u64) <= config.packet_budget); - - let sender = ChunkSender::new( - config.slice_size, - config.sent_packet_buffer_size, - config.resend_time, - config.packet_budget, - config.max_message_size, - ); - let receiver = ChunkReceiver::new(config.slice_size, config.max_message_size); - - Self { - channel_id: config.channel_id, - sender, - receiver, - messages_received: VecDeque::new(), - messages_to_send: VecDeque::with_capacity(config.message_send_queue_size), - message_send_queue_size: config.message_send_queue_size, - info: ChannelNetworkInfo::default(), - error: None, } } } -impl Channel for BlockChannel { +impl SendChannel for SendBlockChannel { fn get_messages_to_send(&mut self, available_bytes: u64, sequence: u16) -> Option { - if !self.sender.sending { + if let Sending::No = self.sending { if let Some(message) = self.messages_to_send.pop_front() { self.send_message(message); } } - let slice_messages: Vec = match self.sender.generate_slice_packets(available_bytes) { + let slice_messages: Vec = match self.generate_slice_packets(available_bytes) { Ok(messages) => messages, Err(e) => { log::error!("Failed serialize message in block channel {}: {}", self.channel_id, e); @@ -396,8 +233,6 @@ impl Channel for BlockChannel { let slice_id = message.slice_id; match bincode::options().serialize(message) { Ok(message) => { - self.info.messages_sent += 1; - self.info.bytes_sent += message.len() as u64; slice_ids.push(slice_id); messages.push(message); } @@ -409,8 +244,8 @@ impl Channel for BlockChannel { } } - let packet_sent = PacketSent::new(self.sender.chunk_id, slice_ids); - self.sender.packets_sent.insert(sequence, packet_sent); + let packet_sent = PacketSent::new(self.chunk_id, slice_ids); + self.packets_sent.insert(sequence, packet_sent); Some(ChannelPacketData { channel_id: self.channel_id, @@ -418,38 +253,48 @@ impl Channel for BlockChannel { }) } - fn advance_time(&mut self, duration: Duration) { - for timer in self.sender.resend_timers.iter_mut() { - timer.advance(duration); - } - } - - fn process_messages(&mut self, messages: Vec) { - if self.error.is_some() { - return; - } - - for message in messages.iter() { - match bincode::options().deserialize::(message) { - Ok(slice_message) => match self.receiver.process_slice_message(&slice_message) { - Ok(Some(message)) => self.messages_received.push_back(message), - Ok(None) => {} - Err(e) => { - self.error = Some(e); + fn process_ack(&mut self, ack: u16) { + match &mut self.sending { + Sending::No => {} + Sending::Yes { + num_acked_slices, + num_slices, + acked, + .. + } => { + if let Some(sent_packet) = self.packets_sent.get_mut(ack) { + if sent_packet.acked || sent_packet.chunk_id != self.chunk_id { return; } - }, - Err(e) => { - error!("Failed to deserialize slice message in channel {}: {}", self.channel_id, e); - self.error = Some(ChannelError::FailedToSerialize); - return; + sent_packet.acked = true; + + for &slice_id in sent_packet.slice_ids.iter() { + if !acked[slice_id as usize] { + acked[slice_id as usize] = true; + *num_acked_slices += 1; + info!( + "Acked SliceMessage {} from chunk_id {}. ({}/{})", + slice_id, self.chunk_id, num_acked_slices, num_slices + ); + } + } + + if num_acked_slices == num_slices { + self.sending = Sending::No; + info!("Finished sending block message {}.", self.chunk_id); + self.chunk_id += 1; + } } } } } - fn process_ack(&mut self, ack: u16) { - self.sender.process_ack(ack); + fn advance_time(&mut self, duration: Duration) { + if let Sending::Yes { resend_timers, .. } = &mut self.sending { + for timer in resend_timers.iter_mut() { + timer.advance(duration); + } + } } fn send_message(&mut self, payload: Bytes) { @@ -457,17 +302,17 @@ impl Channel for BlockChannel { return; } - if payload.len() as u64 > self.sender.max_message_size { + if payload.len() as u64 > self.max_message_size { log::error!( "Tried to send block message with size above the limit, got {} bytes, expected less than {}", payload.len(), - self.sender.max_message_size + self.max_message_size ); self.error = Some(ChannelError::SentMessageAboveMaxSize); return; } - if self.sender.sending { + if matches!(self.sending, Sending::Yes { .. }) { if self.messages_to_send.len() >= self.message_send_queue_size { log::error!( "Tried to send block message but the message queue is full, the limit is {} messages.", @@ -480,19 +325,185 @@ impl Channel for BlockChannel { return; } - self.sender.send_message(payload); - } + let num_slices = (payload.len() + self.slice_size - 1) / self.slice_size; + let mut resend_timer = Timer::new(self.resend_time); + resend_timer.finish(); + let mut resend_timers = Vec::with_capacity(num_slices); + resend_timers.resize(num_slices, resend_timer); - fn receive_message(&mut self) -> Option { - self.messages_received.pop_front() + self.sending = Sending::Yes { + current_slice_id: 0, + num_acked_slices: 0, + acked: vec![false; num_slices], + num_slices, + resend_timers, + data: payload, + }; } fn can_send_message(&self) -> bool { self.messages_to_send.len() < self.message_send_queue_size } - fn channel_network_info(&self) -> ChannelNetworkInfo { - self.info + fn error(&self) -> Option { + self.error + } +} + +impl ReceiveBlockChannel { + pub fn new(config: BlockChannelConfig) -> Self { + assert!((config.slice_size as u64) <= config.packet_budget); + + Self { + slice_size: config.slice_size, + max_message_size: config.max_message_size, + channel_id: config.channel_id, + receiving: Receiving::No, + messages_received: VecDeque::new(), + error: None, + } + } + + fn process_slice_message(&mut self, message: &SliceMessage) -> Result, ChannelError> { + if matches!(self.receiving, Receiving::No) { + if message.num_slices == 0 { + error!("Cannot initialize block message with zero slices."); + return Err(ChannelError::InvalidSliceMessage); + } + + let total_size = message.num_slices as u64 * self.slice_size as u64; + if total_size > self.max_message_size { + error!( + "Cannot initialize block message above the channel limit size, got {}, expected less than {}", + total_size, self.max_message_size + ); + return Err(ChannelError::ReceivedMessageAboveMaxSize); + } + + let num_slices = message.num_slices as usize; + + self.receiving = Receiving::Yes { + num_slices, + chunk_id: message.chunk_id, + num_received_slices: 0, + received: vec![false; num_slices], + chunk_data: vec![0; num_slices * self.slice_size], + }; + info!( + "Receiving Block message with id {} with {} slices.", + message.chunk_id, message.num_slices + ); + } + + match &mut self.receiving { + Receiving::No => unreachable!(), + Receiving::Yes { + chunk_id, + num_slices, + chunk_data, + received, + num_received_slices, + } => { + if message.chunk_id != *chunk_id { + debug!( + "Invalid chunk id for SliceMessage, expected {}, got {}.", + chunk_id, message.chunk_id + ); + // Not an error since this could be an old chunk id. + return Ok(None); + } + + if message.num_slices != *num_slices as u32 { + error!( + "Invalid number of slices for SliceMessage, got {}, expected {}.", + message.num_slices, num_slices, + ); + return Err(ChannelError::InvalidSliceMessage); + } + + let slice_id = message.slice_id as usize; + let is_last_slice = slice_id == *num_slices - 1; + if is_last_slice { + if message.data.len() > self.slice_size { + error!( + "Invalid last slice_size for SliceMessage, got {}, expected less than {}.", + message.data.len(), + self.slice_size, + ); + return Err(ChannelError::InvalidSliceMessage); + } + } else if message.data.len() != self.slice_size { + error!( + "Invalid slice_size for SliceMessage, expected {}, got {}.", + self.slice_size, + message.data.len() + ); + return Err(ChannelError::InvalidSliceMessage); + } + + if !received[slice_id] { + received[slice_id] = true; + *num_received_slices += 1; + + if is_last_slice { + let len = (*num_slices - 1) * self.slice_size + message.data.len(); + chunk_data.resize(len, 0); + } + + let start = slice_id * self.slice_size; + let end = if slice_id == *num_slices - 1 { + (*num_slices - 1) * self.slice_size + message.data.len() + } else { + (slice_id + 1) * self.slice_size + }; + + chunk_data[start..end].copy_from_slice(&message.data); + info!( + "Received slice {} from chunk {}. ({}/{})", + slice_id, chunk_id, num_received_slices, num_slices + ); + } + + if *num_received_slices == *num_slices { + info!("Received all slices for chunk {}.", chunk_id); + let block = mem::take(chunk_data); + self.receiving = Receiving::No; + return Ok(Some(block)); + } + + Ok(None) + } + } + } +} + +impl ReceiveChannel for ReceiveBlockChannel { + fn process_messages(&mut self, messages: Vec) { + if self.error.is_some() { + return; + } + + for message in messages.iter() { + match bincode::options().deserialize::(message) { + Ok(slice_message) => match self.process_slice_message(&slice_message) { + Ok(Some(message)) => self.messages_received.push_back(message), + Ok(None) => {} + Err(e) => { + self.error = Some(e); + return; + } + }, + Err(e) => { + error!("Failed to deserialize slice message in channel {}: {}", self.channel_id, e); + self.error = Some(ChannelError::FailedToSerialize); + return; + } + } + } + } + + fn receive_message(&mut self) -> Option { + self.messages_received.pop_front() } fn error(&self) -> Option { @@ -507,120 +518,126 @@ mod tests { #[test] fn split_chunk() { const SLICE_SIZE: usize = 10; - let mut sender = ChunkSender::new(SLICE_SIZE, 100, Duration::from_millis(100), 30, 5000); + let config = BlockChannelConfig { + slice_size: SLICE_SIZE, + packet_budget: 30, + ..Default::default() + }; + let mut send_channel = SendBlockChannel::new(config.clone()); + let mut receive_channel = ReceiveBlockChannel::new(config); let message = Bytes::from(vec![255u8; 30]); - sender.send_message(message.clone()); + send_channel.send_message(message.clone()); - let mut receiver = ChunkReceiver::new(SLICE_SIZE, 5000); - - let slice_messages = sender.generate_slice_packets(u64::MAX).unwrap(); + let slice_messages = send_channel.generate_slice_packets(u64::MAX).unwrap(); assert_eq!(slice_messages.len(), 2); - sender.process_ack(0); - sender.process_ack(1); + send_channel.process_ack(0); + send_channel.process_ack(1); for slice_message in slice_messages.into_iter() { - receiver.process_slice_message(&slice_message).unwrap(); + receive_channel.process_slice_message(&slice_message).unwrap(); } - let last_message = sender.generate_slice_packets(u64::MAX).unwrap(); - let result = receiver.process_slice_message(&last_message[0]); + let last_message = send_channel.generate_slice_packets(u64::MAX).unwrap(); + let result = receive_channel.process_slice_message(&last_message[0]); assert_eq!(message, result.unwrap().unwrap()); } #[test] fn block_chunk() { let config = BlockChannelConfig::default(); - let mut sender_channel = BlockChannel::new(config.clone()); - let mut receiver_channel = BlockChannel::new(config); + let mut send_channel = SendBlockChannel::new(config.clone()); + let mut receive_channel = ReceiveBlockChannel::new(config); let payload = Bytes::from(vec![7u8; 102400]); - sender_channel.send_message(payload.clone()); + send_channel.send_message(payload.clone()); let mut sequence = 0; loop { - let channel_data = sender_channel.get_messages_to_send(1600, sequence); + let channel_data = send_channel.get_messages_to_send(1600, sequence); match channel_data { None => break, Some(data) => { - receiver_channel.process_messages(data.messages); - sender_channel.process_ack(sequence); + receive_channel.process_messages(data.messages); + send_channel.process_ack(sequence); sequence += 1; } } } - let received_payload = receiver_channel.receive_message().unwrap(); - assert_eq!(payload.len(), received_payload.len()); + let received_payload = receive_channel.receive_message().unwrap(); + assert_eq!(payload, received_payload); } #[test] fn block_channel_queue() { - let mut channel = BlockChannel::new(BlockChannelConfig { + let config = BlockChannelConfig { resend_time: Duration::ZERO, ..Default::default() - }); + }; + let mut send_channel = SendBlockChannel::new(config.clone()); + let mut receive_channel = ReceiveBlockChannel::new(config); + let first_message = Bytes::from(vec![3; 2000]); let second_message = Bytes::from(vec![5; 2000]); - channel.send_message(first_message.clone()); - channel.send_message(second_message.clone()); + send_channel.send_message(first_message.clone()); + send_channel.send_message(second_message.clone()); // First message - let block_channel_data = channel.get_messages_to_send(u64::MAX, 0).unwrap(); + let block_channel_data = send_channel.get_messages_to_send(u64::MAX, 0).unwrap(); assert!(!block_channel_data.messages.is_empty()); - channel.process_messages(block_channel_data.messages); - let received_first_message = channel.receive_message().unwrap(); + receive_channel.process_messages(block_channel_data.messages); + let received_first_message = receive_channel.receive_message().unwrap(); assert_eq!(first_message, received_first_message); - channel.process_ack(0); + send_channel.process_ack(0); // Second message - let block_channel_data = channel.get_messages_to_send(u64::MAX, 1).unwrap(); + let block_channel_data = send_channel.get_messages_to_send(u64::MAX, 1).unwrap(); assert!(!block_channel_data.messages.is_empty()); - channel.process_messages(block_channel_data.messages); - let received_second_message = channel.receive_message().unwrap(); + receive_channel.process_messages(block_channel_data.messages); + let received_second_message = receive_channel.receive_message().unwrap(); assert_eq!(second_message, received_second_message); - channel.process_ack(1); + send_channel.process_ack(1); // Check there is no message to send - assert!(!channel.sender.sending); - let block_channel_data = channel.get_messages_to_send(u64::MAX, 2); - assert!(block_channel_data.is_none()); + assert!(matches!(send_channel.sending, Sending::No)); } #[test] fn acking_packet_with_old_chunk_id() { - let mut channel = BlockChannel::new(BlockChannelConfig { + let config = BlockChannelConfig { resend_time: Duration::ZERO, ..Default::default() - }); + }; + let mut send_channel = SendBlockChannel::new(config); let first_message = Bytes::from(vec![5; 400 * 3]); let second_message = Bytes::from(vec![3; 400]); - channel.send_message(first_message); - channel.send_message(second_message); + send_channel.send_message(first_message); + send_channel.send_message(second_message); - let _ = channel.get_messages_to_send(u64::MAX, 0).unwrap(); - let _ = channel.get_messages_to_send(u64::MAX, 1).unwrap(); + let _ = send_channel.get_messages_to_send(u64::MAX, 0).unwrap(); + let _ = send_channel.get_messages_to_send(u64::MAX, 1).unwrap(); - channel.process_ack(0); - let _ = channel.get_messages_to_send(u64::MAX, 2).unwrap(); + send_channel.process_ack(0); + let _ = send_channel.get_messages_to_send(u64::MAX, 2).unwrap(); - channel.process_ack(1); - assert!(channel.sender.sending); + send_channel.process_ack(1); + assert!(matches!(send_channel.sending, Sending::Yes { .. })); - channel.process_ack(2); - assert!(!channel.sender.sending); + send_channel.process_ack(2); + assert!(matches!(send_channel.sending, Sending::No)); } #[test] fn initialize_block_with_zero_slices() { - let mut channel = BlockChannel::new(Default::default()); + let mut receive_channel = ReceiveBlockChannel::new(Default::default()); let slice_message = SliceMessage { chunk_id: 0, slice_id: 0, num_slices: 0, data: vec![], }; - assert!(channel.receiver.process_slice_message(&slice_message).is_err()); - assert!(!channel.receiver.receiving); + assert!(receive_channel.process_slice_message(&slice_message).is_err()); + assert!(matches!(receive_channel.receiving, Receiving::No)); } } diff --git a/rechannel/src/channel/mod.rs b/rechannel/src/channel/mod.rs index 594a79a9..5bf78d47 100644 --- a/rechannel/src/channel/mod.rs +++ b/rechannel/src/channel/mod.rs @@ -11,7 +11,11 @@ pub use unreliable::UnreliableChannelConfig; use bytes::Bytes; use crate::{ - channel::{block::BlockChannel, reliable::ReliableChannel, unreliable::UnreliableChannel}, + channel::{ + block::{ReceiveBlockChannel, SendBlockChannel}, + reliable::{ReceiveReliableChannel, SendReliableChannel}, + unreliable::{ReceiveUnreliableChannel, SendUnreliableChannel}, + }, error::ChannelError, packet::{ChannelPacketData, Payload}, }; @@ -24,31 +28,28 @@ pub enum ChannelConfig { Block(BlockChannelConfig), } -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] -pub struct ChannelNetworkInfo { - pub messages_sent: u64, - pub messages_received: u64, - pub bytes_sent: u64, - pub bytes_received: u64, -} - -impl ChannelNetworkInfo { - pub fn reset(&mut self) { - self.messages_received = 0; - self.messages_sent = 0; - self.bytes_received = 0; - self.bytes_sent = 0; - } -} - impl ChannelConfig { - pub(crate) fn new_channel(&self) -> Box { + pub(crate) fn new_channels( + &self, + ) -> ( + Box, + Box, + ) { use ChannelConfig::*; match self { - Reliable(config) => Box::new(ReliableChannel::new(config.clone())), - Unreliable(config) => Box::new(UnreliableChannel::new(config.clone())), - Block(config) => Box::new(BlockChannel::new(config.clone())), + Unreliable(config) => ( + Box::new(SendUnreliableChannel::new(config.clone())), + Box::new(ReceiveUnreliableChannel::new(config.clone())), + ), + Reliable(config) => ( + Box::new(SendReliableChannel::new(config.clone())), + Box::new(ReceiveReliableChannel::new(config.clone())), + ), + Block(config) => ( + Box::new(SendBlockChannel::new(config.clone())), + Box::new(ReceiveBlockChannel::new(config.clone())), + ), } } @@ -61,15 +62,20 @@ impl ChannelConfig { } } -pub(crate) trait Channel: std::fmt::Debug { +pub(crate) trait SendChannel: std::fmt::Debug { fn get_messages_to_send(&mut self, available_bytes: u64, sequence: u16) -> Option; - fn advance_time(&mut self, duration: Duration); - fn process_messages(&mut self, messages: Vec); fn process_ack(&mut self, ack: u16); + // TODO: maybe the timer should check with current_time to see if is completed instead of advancing it by the delta every tick + // So we would pass the current_time in the get_messages_to_send + fn advance_time(&mut self, duration: Duration); fn send_message(&mut self, payload: Bytes); - fn receive_message(&mut self) -> Option; fn can_send_message(&self) -> bool; - fn channel_network_info(&self) -> ChannelNetworkInfo; + fn error(&self) -> Option; +} + +pub(crate) trait ReceiveChannel: std::fmt::Debug { + fn process_messages(&mut self, messages: Vec); + fn receive_message(&mut self) -> Option; fn error(&self) -> Option; } diff --git a/rechannel/src/channel/reliable.rs b/rechannel/src/channel/reliable.rs index dcacec14..3c69a40f 100644 --- a/rechannel/src/channel/reliable.rs +++ b/rechannel/src/channel/reliable.rs @@ -1,8 +1,7 @@ use crate::{ - channel::Channel, error::ChannelError, packet::{ChannelPacketData, Payload}, - sequence_buffer::SequenceBuffer, + sequence_buffer::{sequence_greater_than, sequence_less_than, SequenceBuffer}, timer::Timer, }; @@ -12,7 +11,7 @@ use serde::{Deserialize, Serialize}; use std::time::Duration; -use super::ChannelNetworkInfo; +use super::{ReceiveChannel, SendChannel}; #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub(crate) struct ReliableMessage { @@ -56,17 +55,26 @@ pub struct ReliableChannelConfig { } #[derive(Debug)] -pub(crate) struct ReliableChannel { - config: ReliableChannelConfig, +pub(crate) struct SendReliableChannel { + channel_id: u8, + packet_budget: u64, + max_message_size: u64, + message_resend_time: Duration, packets_sent: SequenceBuffer, messages_send: SequenceBuffer, - messages_received: SequenceBuffer, send_message_id: u16, - received_message_id: u16, num_messages_sent: u64, - num_messages_received: u64, oldest_unacked_message_id: u16, - info: ChannelNetworkInfo, + error: Option, +} + +#[derive(Debug)] +pub(crate) struct ReceiveReliableChannel { + channel_id: u8, + max_message_size: u64, + messages_received: SequenceBuffer, + received_message_id: u16, + num_messages_received: u64, error: Option, } @@ -107,59 +115,42 @@ impl Default for ReliableChannelConfig { } } -impl ReliableChannel { +impl SendReliableChannel { + pub fn has_messages_to_send(&self) -> bool { + self.oldest_unacked_message_id != self.send_message_id + } +} + +impl SendReliableChannel { pub fn new(config: ReliableChannelConfig) -> Self { + assert!(config.max_message_size <= config.packet_budget); + Self { + channel_id: config.channel_id, + packet_budget: config.packet_budget, + max_message_size: config.max_message_size, + send_message_id: 0, + oldest_unacked_message_id: 0, packets_sent: SequenceBuffer::with_capacity(config.sent_packet_buffer_size), messages_send: SequenceBuffer::with_capacity(config.message_send_queue_size), - messages_received: SequenceBuffer::with_capacity(config.message_receive_queue_size), - send_message_id: 0, - received_message_id: 0, - num_messages_received: 0, + message_resend_time: config.message_resend_time, num_messages_sent: 0, - oldest_unacked_message_id: 0, - config, - info: ChannelNetworkInfo::default(), error: None, } } - - pub fn has_messages_to_send(&self) -> bool { - self.oldest_unacked_message_id != self.send_message_id - } - - fn update_oldest_message_ack(&mut self) { - let stop_id = self.messages_send.sequence(); - - while self.oldest_unacked_message_id != stop_id && !self.messages_send.exists(self.oldest_unacked_message_id) { - self.oldest_unacked_message_id = self.oldest_unacked_message_id.wrapping_add(1); - } - } } -impl Channel for ReliableChannel { - fn advance_time(&mut self, duration: Duration) { - self.info.reset(); - let message_limit = self.config.message_send_queue_size; - for i in 0..message_limit { - let message_id = self.oldest_unacked_message_id.wrapping_add(i as u16); - if let Some(message) = self.messages_send.get_mut(message_id) { - message.resend_timer.advance(duration); - } - } - } - +impl SendChannel for SendReliableChannel { fn get_messages_to_send(&mut self, mut available_bytes: u64, sequence: u16) -> Option { if !self.has_messages_to_send() || self.error.is_some() { return None; } - available_bytes = available_bytes.min(self.config.packet_budget); + available_bytes = available_bytes.min(self.packet_budget); let mut messages: Vec = vec![]; let mut message_ids: Vec = vec![]; - let message_limit = self.config.message_send_queue_size; - for i in 0..message_limit { + for i in 0..self.messages_send.size() { let message_id = self.oldest_unacked_message_id.wrapping_add(i as u16); let message_send = self.messages_send.get_mut(message_id); if let Some(message_send) = message_send { @@ -170,7 +161,7 @@ impl Channel for ReliableChannel { let serialized_size = match bincode::options().serialized_size(&message_send.reliable_message) { Ok(size) => size as u64, Err(e) => { - log::error!("Failed to get message size in channel {}: {}", self.config.channel_id, e); + log::error!("Failed to get message size in channel {}: {}", self.channel_id, e); self.error = Some(ChannelError::FailedToSerialize); return None; } @@ -180,12 +171,10 @@ impl Channel for ReliableChannel { available_bytes -= serialized_size; message_send.resend_timer.reset(); message_ids.push(message_id); - self.info.messages_sent += 1; - self.info.bytes_sent += serialized_size; let message = match bincode::options().serialize(&message_send.reliable_message) { Ok(message) => message, Err(e) => { - log::error!("Failed to serialize message in channel {}: {}", self.config.channel_id, e); + log::error!("Failed to serialize message in channel {}: {}", self.channel_id, e); self.error = Some(ChannelError::FailedToSerialize); return None; } @@ -203,44 +192,11 @@ impl Channel for ReliableChannel { self.packets_sent.insert(sequence, packet_sent); Some(ChannelPacketData { - channel_id: self.config.channel_id, + channel_id: self.channel_id, messages, }) } - fn process_messages(&mut self, messages: Vec) { - if self.error.is_some() { - return; - } - - for message in messages.iter() { - self.info.bytes_received += message.len() as u64; - self.info.messages_received += 1; - match bincode::options().deserialize::(message) { - Ok(message) => { - if message.payload.len() as u64 > self.config.max_message_size { - log::error!( - "Received reliable message with size above the limit, got {} bytes, expected less than {}", - message.payload.len(), - self.config.max_message_size - ); - self.error = Some(ChannelError::ReceivedMessageAboveMaxSize); - return; - } - - if !self.messages_received.exists(message.id) { - self.messages_received.insert(message.id, message); - } - } - Err(e) => { - log::error!("Failed to deserialize reliable message {}: {}", self.config.channel_id, e); - self.error = Some(ChannelError::FailedToSerialize); - return; - } - } - } - } - fn process_ack(&mut self, ack: u16) { if let Some(sent_packet) = self.packets_sent.get_mut(ack) { if sent_packet.acked { @@ -253,7 +209,22 @@ impl Channel for ReliableChannel { self.messages_send.remove(message_id); } } - self.update_oldest_message_ack(); + + // Update oldest message ack + let stop_id = self.messages_send.sequence(); + + while self.oldest_unacked_message_id != stop_id && !self.messages_send.exists(self.oldest_unacked_message_id) { + self.oldest_unacked_message_id = self.oldest_unacked_message_id.wrapping_add(1); + } + } + } + + fn advance_time(&mut self, duration: Duration) { + for i in 0..self.messages_send.size() { + let message_id = self.oldest_unacked_message_id.wrapping_add(i as u16); + if let Some(message) = self.messages_send.get_mut(message_id) { + message.resend_timer.advance(duration); + } } } @@ -268,11 +239,11 @@ impl Channel for ReliableChannel { return; } - if payload.len() as u64 > self.config.packet_budget { + if payload.len() as u64 > self.max_message_size { log::error!( "Tried to send reliable message with size above the limit, got {} bytes, expected less than {}", payload.len(), - self.config.max_message_size + self.max_message_size ); self.error = Some(ChannelError::SentMessageAboveMaxSize); return; @@ -281,12 +252,79 @@ impl Channel for ReliableChannel { self.send_message_id = self.send_message_id.wrapping_add(1); let reliable_message = ReliableMessage::new(message_id, payload); - let entry = ReliableMessageSent::new(reliable_message, self.config.message_resend_time); + let entry = ReliableMessageSent::new(reliable_message, self.message_resend_time); self.messages_send.insert(message_id, entry); self.num_messages_sent += 1; } + fn can_send_message(&self) -> bool { + self.messages_send.available(self.send_message_id) + } + + fn error(&self) -> Option { + self.error + } +} + +impl ReceiveReliableChannel { + pub fn new(config: ReliableChannelConfig) -> Self { + assert!(config.max_message_size <= config.packet_budget); + + Self { + channel_id: config.channel_id, + max_message_size: config.max_message_size, + received_message_id: 0, + num_messages_received: 0, + messages_received: SequenceBuffer::with_capacity(config.message_receive_queue_size), + error: None, + } + } +} + +impl ReceiveChannel for ReceiveReliableChannel { + fn process_messages(&mut self, messages: Vec) { + if self.error.is_some() { + return; + } + + for message in messages.iter() { + match bincode::options().deserialize::(message) { + Ok(message) => { + if message.payload.len() as u64 > self.max_message_size { + log::error!( + "Received reliable message with size above the limit, got {} bytes, expected less than {}", + message.payload.len(), + self.max_message_size + ); + self.error = Some(ChannelError::ReceivedMessageAboveMaxSize); + return; + } + + if sequence_less_than(message.id, self.received_message_id) { + // Discard old message + continue; + } + + let max_message_id = self.received_message_id + self.messages_received.size() as u16 - 1; + if sequence_greater_than(message.id, max_message_id) { + // Out of messages to add + self.error = Some(ChannelError::ReliableChannelOutOfSync); + } + + if !self.messages_received.exists(message.id) { + self.messages_received.insert(message.id, message); + } + } + Err(e) => { + log::error!("Failed to deserialize reliable message {}: {}", self.channel_id, e); + self.error = Some(ChannelError::FailedToSerialize); + return; + } + } + } + } + fn receive_message(&mut self) -> Option { if self.error.is_some() { return None; @@ -304,14 +342,6 @@ impl Channel for ReliableChannel { self.messages_received.remove(received_message_id).map(|m| m.payload.to_vec()) } - fn can_send_message(&self) -> bool { - self.messages_send.available(self.send_message_id) - } - - fn channel_network_info(&self) -> ChannelNetworkInfo { - self.info - } - fn error(&self) -> Option { self.error } @@ -344,75 +374,46 @@ mod tests { } #[test] - fn send_message() { + fn send_receive_message() { let config = ReliableChannelConfig::default(); - let mut channel: ReliableChannel = ReliableChannel::new(config); + let mut send_channel = SendReliableChannel::new(config.clone()); + let mut receive_channel = ReceiveReliableChannel::new(config); let sequence = 0; - assert!(!channel.has_messages_to_send()); - assert_eq!(channel.num_messages_sent, 0); + assert!(!send_channel.has_messages_to_send()); + assert_eq!(send_channel.num_messages_sent, 0); - channel.send_message(TestMessages::Second(0).serialize()); - assert_eq!(channel.num_messages_sent, 1); - assert!(channel.receive_message().is_none()); + let message = TestMessages::Second(0).serialize(); - let channel_data = channel.get_messages_to_send(u64::MAX, sequence).unwrap(); + send_channel.send_message(message.clone()); + assert_eq!(send_channel.num_messages_sent, 1); + assert!(receive_channel.receive_message().is_none()); + let channel_data = send_channel.get_messages_to_send(u64::MAX, sequence).unwrap(); assert_eq!(channel_data.messages.len(), 1); - assert_eq!( - channel_data, - ChannelPacketData { - channel_id: 0, - messages: vec![bincode::options() - .serialize(&ReliableMessage::new(0, TestMessages::Second(0).serialize())) - .unwrap()] - } - ); - assert!(channel.has_messages_to_send()); - - channel.process_ack(sequence); - assert!(!channel.has_messages_to_send()); - } - - #[test] - fn receive_message() { - let config = ReliableChannelConfig::default(); - let mut channel: ReliableChannel = ReliableChannel::new(config); - - let messages = vec![ - bincode::options() - .serialize(&ReliableMessage::new(0, TestMessages::First.serialize())) - .unwrap(), - bincode::options() - .serialize(&ReliableMessage::new(1, TestMessages::Second(0).serialize())) - .unwrap(), - ]; - - channel.process_messages(messages); - let message = channel.receive_message().unwrap(); - assert_eq!(message, TestMessages::First.serialize()); + receive_channel.process_messages(channel_data.messages); + let received_message = receive_channel.receive_message().unwrap(); + assert_eq!(received_message, message); - let message = channel.receive_message().unwrap(); - assert_eq!(message, TestMessages::Second(0).serialize()); - - assert_eq!(channel.num_messages_received, 2); + assert!(send_channel.has_messages_to_send()); + send_channel.process_ack(sequence); + assert!(!send_channel.has_messages_to_send()); } #[test] fn over_budget() { - let first_message = TestMessages::Third(0); - let second_message = TestMessages::Third(1); + let first_message = TestMessages::Third(0).serialize(); + let second_message = TestMessages::Third(1).serialize(); - let message = ReliableMessage::new(0, first_message.serialize()); + let message = ReliableMessage::new(0, first_message.clone()); + let message_size = bincode::options().serialized_size(&message).unwrap() as u64; let config = ReliableChannelConfig::default(); - let mut channel: ReliableChannel = ReliableChannel::new(config); + let mut channel = SendReliableChannel::new(config); - channel.send_message(first_message.serialize()); - channel.send_message(second_message.serialize()); - - let message_size = bincode::options().serialized_size(&message).unwrap() as u64; + channel.send_message(first_message); + channel.send_message(second_message); let channel_data = channel.get_messages_to_send(message_size, 0).unwrap(); assert_eq!(channel_data.messages.len(), 1); @@ -425,11 +426,12 @@ mod tests { #[test] fn resend_message() { + let message_resend_time = Duration::from_millis(100); let config = ReliableChannelConfig { - message_resend_time: Duration::from_millis(100), + message_resend_time, ..Default::default() }; - let mut channel: ReliableChannel = ReliableChannel::new(config); + let mut channel = SendReliableChannel::new(config); channel.send_message(TestMessages::First.serialize()); @@ -437,7 +439,7 @@ mod tests { assert_eq!(channel_data.messages.len(), 1); assert!(channel.get_messages_to_send(u64::MAX, 1).is_none()); - channel.advance_time(Duration::from_millis(100)); + channel.advance_time(message_resend_time); let channel_data = channel.get_messages_to_send(u64::MAX, 2).unwrap(); assert_eq!(channel_data.messages.len(), 1); @@ -445,14 +447,28 @@ mod tests { #[test] fn out_of_sync() { - let config = ReliableChannelConfig { - message_send_queue_size: 1, + let send_config = ReliableChannelConfig { + message_send_queue_size: 2, ..Default::default() }; - let mut channel: ReliableChannel = ReliableChannel::new(config); + let receive_config = ReliableChannelConfig { + message_receive_queue_size: 1, + ..Default::default() + }; + let mut send_channel = SendReliableChannel::new(send_config); + let mut receive_channel = ReceiveReliableChannel::new(receive_config); + let message = TestMessages::Second(0).serialize(); + + send_channel.send_message(message.clone()); + let first_channel_data = send_channel.get_messages_to_send(u64::MAX, 0).unwrap(); + send_channel.send_message(message.clone()); + let second_channel_data = send_channel.get_messages_to_send(u64::MAX, 0).unwrap(); + + send_channel.send_message(message.clone()); + assert!(matches!(send_channel.error(), Some(ChannelError::ReliableChannelOutOfSync))); - channel.send_message(TestMessages::Second(0).serialize()); - channel.send_message(TestMessages::Second(0).serialize()); - assert!(matches!(channel.error(), Some(ChannelError::ReliableChannelOutOfSync))); + receive_channel.process_messages(first_channel_data.messages); + receive_channel.process_messages(second_channel_data.messages); + assert!(matches!(receive_channel.error(), Some(ChannelError::ReliableChannelOutOfSync))); } } diff --git a/rechannel/src/channel/unreliable.rs b/rechannel/src/channel/unreliable.rs index 0126d890..978570f6 100644 --- a/rechannel/src/channel/unreliable.rs +++ b/rechannel/src/channel/unreliable.rs @@ -1,5 +1,4 @@ use crate::{ - channel::Channel, error::ChannelError, packet::{ChannelPacketData, Payload}, }; @@ -8,7 +7,7 @@ use std::{collections::VecDeque, time::Duration}; use bytes::Bytes; -use super::ChannelNetworkInfo; +use super::{ReceiveChannel, SendChannel}; /// Configuration for a unreliable and unordered channel. /// Messages sent in this channel will behave like a udp packet, @@ -29,11 +28,21 @@ pub struct UnreliableChannelConfig { } #[derive(Debug)] -pub(crate) struct UnreliableChannel { - config: UnreliableChannelConfig, +pub(crate) struct SendUnreliableChannel { + channel_id: u8, + packet_budget: u64, + max_message_size: u64, + message_send_queue_size: usize, messages_to_send: VecDeque, + error: Option, +} + +#[derive(Debug)] +pub(crate) struct ReceiveUnreliableChannel { + channel_id: u8, + max_message_size: u64, + message_receive_queue_size: usize, messages_received: VecDeque, - info: ChannelNetworkInfo, error: Option, } @@ -49,29 +58,29 @@ impl Default for UnreliableChannelConfig { } } -impl UnreliableChannel { +impl SendUnreliableChannel { pub fn new(config: UnreliableChannelConfig) -> Self { - assert!(config.max_message_size < config.packet_budget); + assert!(config.max_message_size <= config.packet_budget); Self { + channel_id: config.channel_id, + packet_budget: config.packet_budget, + max_message_size: config.max_message_size, + message_send_queue_size: config.message_send_queue_size, messages_to_send: VecDeque::with_capacity(config.message_send_queue_size), - messages_received: VecDeque::with_capacity(config.message_receive_queue_size), - config, - info: ChannelNetworkInfo::default(), error: None, } } } -impl Channel for UnreliableChannel { +impl SendChannel for SendUnreliableChannel { fn get_messages_to_send(&mut self, mut available_bytes: u64, _sequence: u16) -> Option { if self.error.is_some() { return None; } let mut messages = vec![]; - - available_bytes = available_bytes.min(self.config.packet_budget); + available_bytes = available_bytes.min(self.packet_budget); while let Some(message) = self.messages_to_send.pop_front() { let message_size = message.len() as u64; @@ -80,8 +89,6 @@ impl Channel for UnreliableChannel { } available_bytes -= message_size; - self.info.messages_sent += 1; - self.info.bytes_sent += message_size; messages.push(message.to_vec()); } @@ -90,80 +97,94 @@ impl Channel for UnreliableChannel { } Some(ChannelPacketData { - channel_id: self.config.channel_id, + channel_id: self.channel_id, messages, }) } - fn advance_time(&mut self, _duration: Duration) { - self.info.reset(); - } - - fn process_messages(&mut self, mut messages: Vec) { - if self.error.is_some() { - return; - } - - while let Some(message) = messages.pop() { - if message.len() as u64 > self.config.max_message_size { - log::error!( - "Received unreliable message with size above the limit, got {} bytes, expected less than {}", - message.len(), - self.config.max_message_size - ); - self.error = Some(ChannelError::ReceivedMessageAboveMaxSize); - return; - } - if self.messages_received.len() == self.config.message_receive_queue_size { - log::warn!( - "Received message was dropped in unreliable channel {}, reached maximum number of messages {}", - self.config.channel_id, - self.config.message_receive_queue_size - ); - return; - } - self.info.messages_received += 1; - self.info.bytes_received += message.len() as u64; - self.messages_received.push_back(message); - } - } - fn process_ack(&mut self, _ack: u16) {} + fn advance_time(&mut self, _duration: Duration) {} + fn send_message(&mut self, payload: Bytes) { if self.error.is_some() { return; } - if payload.len() as u64 > self.config.max_message_size { + if payload.len() as u64 > self.max_message_size { log::error!( "Tried to send unreliable message with size above the limit, got {} bytes, expected less than {}", payload.len(), - self.config.max_message_size + self.max_message_size ); self.error = Some(ChannelError::SentMessageAboveMaxSize); return; } - if self.messages_to_send.len() >= self.config.message_send_queue_size { + if self.messages_to_send.len() >= self.message_send_queue_size { self.error = Some(ChannelError::SendQueueFull); - log::warn!("Unreliable channel {} has reached the maximum queue size", self.config.channel_id); + log::warn!("Unreliable channel {} has reached the maximum queue size", self.channel_id); return; } self.messages_to_send.push_back(payload); } - fn receive_message(&mut self) -> Option { - self.messages_received.pop_front() + fn can_send_message(&self) -> bool { + self.messages_to_send.len() < self.message_send_queue_size } - fn can_send_message(&self) -> bool { - self.messages_to_send.len() < self.config.message_send_queue_size + fn error(&self) -> Option { + self.error } +} + +impl ReceiveUnreliableChannel { + pub fn new(config: UnreliableChannelConfig) -> Self { + assert!(config.max_message_size <= config.packet_budget); + + Self { + channel_id: config.channel_id, + max_message_size: config.max_message_size, + message_receive_queue_size: config.message_receive_queue_size, + messages_received: VecDeque::with_capacity(config.message_receive_queue_size), + error: None, + } + } +} + +impl ReceiveChannel for ReceiveUnreliableChannel { + fn process_messages(&mut self, mut messages: Vec) { + if self.error.is_some() { + return; + } + + while let Some(message) = messages.pop() { + if message.len() as u64 > self.max_message_size { + log::error!( + "Received unreliable message with size above the limit, got {} bytes, expected less than {}", + message.len(), + self.max_message_size + ); + self.error = Some(ChannelError::ReceivedMessageAboveMaxSize); + return; + } + + if self.messages_received.len() == self.message_receive_queue_size { + log::warn!( + "Received message was dropped in unreliable channel {}, reached maximum number of messages {}", + self.channel_id, + self.message_receive_queue_size + ); + return; + } - fn channel_network_info(&self) -> ChannelNetworkInfo { - self.info + self.messages_received.push_back(message); + } + } + + fn receive_message(&mut self) -> Option { + self.messages_received.pop_front() } fn error(&self) -> Option { diff --git a/rechannel/src/error.rs b/rechannel/src/error.rs index 5ca5d3ce..9632d6b0 100644 --- a/rechannel/src/error.rs +++ b/rechannel/src/error.rs @@ -13,8 +13,10 @@ pub enum DisconnectionReason { DisconnectedByClient, /// Channel with given Id was not found InvalidChannelId(u8), - /// Error occurred in a channel - ChannelError { channel_id: u8, error: ChannelError }, + /// Error occurred in a send channel + SendChannelError { channel_id: u8, error: ChannelError }, + /// Error occurred in a receive channel + ReceiveChannelError { channel_id: u8, error: ChannelError }, } /// Possibles errors that can occur in a channel. @@ -25,6 +27,7 @@ pub enum ChannelError { /// The channel send queue has reach it's maximum SendQueueFull, /// Error occurred during (de)serialization + // TODO: rename to SerializationFailure FailedToSerialize, /// Tried to send a message that is above the channel max message size. SentMessageAboveMaxSize, @@ -57,7 +60,8 @@ impl fmt::Display for DisconnectionReason { DisconnectedByServer => write!(fmt, "connection terminated by server"), DisconnectedByClient => write!(fmt, "connection terminated by client"), InvalidChannelId(id) => write!(fmt, "received message with invalid channel {}", id), - ChannelError { channel_id, error } => write!(fmt, "channel {} with error: {}", channel_id, error), + SendChannelError { channel_id, error } => write!(fmt, "send channel {} with error: {}", channel_id, error), + ReceiveChannelError { channel_id, error } => write!(fmt, "receive channel {} with error: {}", channel_id, error), } } } diff --git a/rechannel/src/remote_connection.rs b/rechannel/src/remote_connection.rs index d8d377cb..0b9cd3c7 100644 --- a/rechannel/src/remote_connection.rs +++ b/rechannel/src/remote_connection.rs @@ -1,4 +1,4 @@ -use crate::channel::{Channel, ChannelConfig, ChannelNetworkInfo}; +use crate::channel::{ChannelConfig, ReceiveChannel, SendChannel}; use crate::error::{DisconnectionReason, RechannelError}; use crate::packet::{Packet, Payload}; @@ -34,14 +34,16 @@ pub struct ConnectionConfig { pub packet_loss_smoothing_factor: f32, pub heartbeat_time: Duration, pub fragment_config: FragmentConfig, - pub channels_config: Vec, + pub send_channels_config: Vec, + pub receive_channels_config: Vec, } #[derive(Debug)] pub struct RemoteConnection { state: ConnectionState, sequence: u16, - channels: HashMap>, + send_channels: HashMap>, + receive_channels: HashMap>, heartbeat_timer: Timer, config: ConnectionConfig, reassembly_buffer: SequenceBuffer, @@ -61,6 +63,12 @@ impl SentPacket { impl Default for ConnectionConfig { fn default() -> Self { + let channels = vec![ + ChannelConfig::Reliable(Default::default()), + ChannelConfig::Unreliable(Default::default()), + ChannelConfig::Block(Default::default()), + ]; + Self { max_packet_size: 16 * 1024, sent_packets_buffer_size: 256, @@ -69,11 +77,8 @@ impl Default for ConnectionConfig { packet_loss_smoothing_factor: 0.1, heartbeat_time: Duration::from_millis(100), fragment_config: FragmentConfig::default(), - channels_config: vec![ - ChannelConfig::Reliable(Default::default()), - ChannelConfig::Unreliable(Default::default()), - ChannelConfig::Block(Default::default()), - ], + send_channels_config: channels.clone(), + receive_channels_config: channels, } } } @@ -85,17 +90,26 @@ impl RemoteConnection { let sent_buffer = SequenceBuffer::with_capacity(config.sent_packets_buffer_size); let received_buffer = SequenceBuffer::with_capacity(config.received_packets_buffer_size); - let mut channels = HashMap::new(); - for channel_config in config.channels_config.iter() { - let channel = channel_config.new_channel(); + let mut send_channels = HashMap::new(); + for channel_config in config.send_channels_config.iter() { + let (send_channel, _) = channel_config.new_channels(); + let channel_id = channel_config.channel_id(); + let old_channel = send_channels.insert(channel_id, send_channel); + assert!(old_channel.is_none(), "already exists send channel with id {}", channel_id); + } + + let mut receive_channels = HashMap::new(); + for channel_config in config.send_channels_config.iter() { + let (_, receive_channel) = channel_config.new_channels(); let channel_id = channel_config.channel_id(); - let old_channel = channels.insert(channel_id, channel); - assert!(old_channel.is_none(), "already exists channel with id {}", channel_id); + let old_channel = receive_channels.insert(channel_id, receive_channel); + assert!(old_channel.is_none(), "already exists receive channel with id {}", channel_id); } Self { state: ConnectionState::Connected, - channels, + send_channels, + receive_channels, heartbeat_timer, sequence: 0, reassembly_buffer, @@ -117,13 +131,6 @@ impl RemoteConnection { self.packet_loss } - pub fn channels_network_info(&self) -> Vec<(u8, ChannelNetworkInfo)> { - self.channels - .iter() - .map(|(channel_id, channel)| (*channel_id, channel.channel_network_info())) - .collect() - } - pub fn is_connected(&self) -> bool { matches!(self.state, ConnectionState::Connected) } @@ -147,24 +154,24 @@ impl RemoteConnection { } pub fn can_send_message(&self, channel_id: u8) -> bool { - let channel = self.channels.get(&channel_id).expect("invalid channel id"); + let channel = self.send_channels.get(&channel_id).expect("invalid channel id"); channel.can_send_message() } pub fn send_message(&mut self, channel_id: u8, message: Bytes) { - let channel = self.channels.get_mut(&channel_id).expect("invalid channel id"); + let channel = self.send_channels.get_mut(&channel_id).expect("invalid channel id"); channel.send_message(message); } pub fn receive_message(&mut self, channel_id: u8) -> Option { - let channel = self.channels.get_mut(&channel_id).expect("invalid channel id"); + let channel = self.receive_channels.get_mut(&channel_id).expect("invalid channel id"); channel.receive_message() } pub fn advance_time(&mut self, duration: Duration) { self.current_time += duration; self.heartbeat_timer.advance(duration); - for channel in self.channels.values_mut() { + for channel in self.send_channels.values_mut() { channel.advance_time(duration); } } @@ -174,16 +181,24 @@ impl RemoteConnection { return Err(RechannelError::ClientDisconnected(reason)); } - for (&channel_id, channel) in self.channels.iter() { - if let Some(error) = channel.error() { - let reason = DisconnectionReason::ChannelError { channel_id, error }; + for (&channel_id, send_channel) in self.send_channels.iter() { + if let Some(error) = send_channel.error() { + let reason = DisconnectionReason::SendChannelError { channel_id, error }; + self.state = ConnectionState::Disconnected { reason }; + return Err(RechannelError::ClientDisconnected(reason)); + } + } + + for (&channel_id, receive_channel) in self.receive_channels.iter() { + if let Some(error) = receive_channel.error() { + let reason = DisconnectionReason::ReceiveChannelError { channel_id, error }; self.state = ConnectionState::Disconnected { reason }; return Err(RechannelError::ClientDisconnected(reason)); } } for ack in self.acks.drain(..) { - for channel in self.channels.values_mut() { + for channel in self.send_channels.values_mut() { channel.process_ack(ack); } } @@ -239,7 +254,7 @@ impl RemoteConnection { }; for channel_packet_data in channels_packet_data.into_iter() { - let channel = match self.channels.get_mut(&channel_packet_data.channel_id) { + let receive_channel = match self.receive_channels.get_mut(&channel_packet_data.channel_id) { Some(c) => c, None => { let reason = DisconnectionReason::InvalidChannelId(channel_packet_data.channel_id); @@ -248,7 +263,7 @@ impl RemoteConnection { } }; - channel.process_messages(channel_packet_data.messages); + receive_channel.process_messages(channel_packet_data.messages); } Ok(()) @@ -264,8 +279,8 @@ impl RemoteConnection { const HEADER_SIZE: u64 = 20; let mut available_bytes = self.config.max_packet_size - HEADER_SIZE; let mut channels_packet_data = vec![]; - for channel in self.channels.values_mut() { - if let Some(channel_packet_data) = channel.get_messages_to_send(available_bytes, sequence) { + for send_channel in self.send_channels.values_mut() { + if let Some(channel_packet_data) = send_channel.get_messages_to_send(available_bytes, sequence) { available_bytes -= bincode::options().serialized_size(&channel_packet_data)?; channels_packet_data.push(channel_packet_data) } diff --git a/rechannel/src/sequence_buffer.rs b/rechannel/src/sequence_buffer.rs index 36066d8a..252222b0 100644 --- a/rechannel/src/sequence_buffer.rs +++ b/rechannel/src/sequence_buffer.rs @@ -9,6 +9,8 @@ pub(crate) struct SequenceBuffer { impl SequenceBuffer { pub fn with_capacity(size: usize) -> Self { + assert!(size > 0, "tried to initialize SequenceBuffer with 0 size"); + Self { sequence: 0, entry_sequences: vec![None; size].into_boxed_slice(), @@ -16,6 +18,10 @@ impl SequenceBuffer { } } + pub fn size(&self) -> usize { + self.entries.len() + } + pub fn get_mut(&mut self, sequence: u16) -> Option<&mut T> { if self.exists(sequence) { let index = self.index(sequence); @@ -130,12 +136,12 @@ impl SequenceBuffer { // Since sequences can wrap we need to check when this when checking greater // Ocurring the cutover in the middle of u16 #[inline] -fn sequence_greater_than(s1: u16, s2: u16) -> bool { +pub fn sequence_greater_than(s1: u16, s2: u16) -> bool { ((s1 > s2) && (s1 - s2 <= 32768)) || ((s1 < s2) && (s2 - s1 > 32768)) } #[inline] -fn sequence_less_than(s1: u16, s2: u16) -> bool { +pub fn sequence_less_than(s1: u16, s2: u16) -> bool { sequence_greater_than(s2, s1) } diff --git a/rechannel/src/server.rs b/rechannel/src/server.rs index e2713e79..4f53dfa1 100644 --- a/rechannel/src/server.rs +++ b/rechannel/src/server.rs @@ -1,4 +1,3 @@ -use crate::channel::ChannelNetworkInfo; use crate::error::{DisconnectionReason, RechannelError}; use crate::packet::Payload; use crate::remote_connection::{ConnectionConfig, RemoteConnection}; @@ -60,13 +59,6 @@ impl RechannelServer { } } - pub fn channels_network_info(&self, connection_id: C) -> Vec<(u8, ChannelNetworkInfo)> { - match self.connections.get(&connection_id) { - Some(connection) => connection.channels_network_info(), - None => Vec::with_capacity(0), - } - } - /// Similar to disconnect but does not emit an event pub fn remove_connection(&mut self, connection_id: &C) { self.connections.remove(connection_id);