Skip to content

Commit

Permalink
Merge pull request #179 from Carter12s/fix-178-subscriber-buffer-size
Browse files Browse the repository at this point in the history
Fix 178 subscriber buffer size
  • Loading branch information
Carter12s authored Jul 30, 2024
2 parents 8e0066c + a0a692c commit 952cba7
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 65 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Bug with message_definitions provided by Publisher in the connection header not being the fully expanded definition.
- Bug with ROS1 native subscribers not being able to receive messages larger than 4096 bytes.

### Changed

Expand Down
2 changes: 1 addition & 1 deletion roslibrust/src/ros1/publisher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl Publication {
);

// Read the connection header:
let connection_header = match tcpros::recieve_header(&mut stream).await {
let connection_header = match tcpros::receive_header(&mut stream).await {
Ok(header) => header,
Err(e) => {
log::error!("Failed to read connection header: {e:?}");
Expand Down
34 changes: 8 additions & 26 deletions roslibrust/src/ros1/service_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use tokio::{
},
};

use super::tcpros;

pub type CallServiceRequest = (Vec<u8>, oneshot::Sender<CallServiceResponse>);
pub type CallServiceResponse = RosLibRustResult<Vec<u8>>;

Expand Down Expand Up @@ -67,13 +69,7 @@ impl<T: RosServiceType> ServiceClient<T> {
self.service_name,
result_payload
);

// Okay the 1.. is funky and needs to be addressed
// This is a little buried in the ROS documentation by the first byte is the "success" byte
// if it is 1 then the rest of the payload is the response
// Otherwise ros silently swaps the payload out for an error message
// We need to parse that error message and display somewhere
let response: T::Response = serde_rosmsg::from_slice(&result_payload[1..])
let response: T::Response = serde_rosmsg::from_slice(&result_payload)
.map_err(|err| RosLibRustError::SerializationError(err.to_string()))?;
return Ok(response);
}
Expand Down Expand Up @@ -201,26 +197,12 @@ impl ServiceClientLink {

if success {
// Parse length of the payload body
let mut body_len_bytes = [0u8; 4];
let _body_bytes_read = stream.read_exact(&mut body_len_bytes).await?;
let body_len = u32::from_le_bytes(body_len_bytes) as usize;

let mut body = vec![0u8; body_len];
stream.read_exact(&mut body).await?;

// Dumb mangling here, our implementation expects the length and success at the front
// got to be a better way than this
let full_body = [success_byte.to_vec(), body_len_bytes.to_vec(), body].concat();

Ok(full_body)
let body = tcpros::receive_body(stream).await?;
Ok(body)
} else {
let mut body_len_bytes = [0u8; 4];
let _body_bytes_read = stream.read_exact(&mut body_len_bytes).await?;
let body_len = u32::from_le_bytes(body_len_bytes) as usize;
let mut body = vec![0u8; body_len];
stream.read_exact(&mut body).await?;
let full_body = [body_len_bytes.to_vec(), body].concat();
let err_msg: String = serde_rosmsg::from_slice(&full_body).map_err(|err| {
// Parse an error message as the body
let error_body = tcpros::receive_body(stream).await?;
let err_msg: String = serde_rosmsg::from_slice(&error_body).map_err(|err| {
log::error!("Failed to parse service call error message: {err}");
std::io::Error::new(
std::io::ErrorKind::InvalidData,
Expand Down
34 changes: 11 additions & 23 deletions roslibrust/src/ros1/service_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use abort_on_drop::ChildTask;
use log::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::AsyncWriteExt;

use crate::ros1::tcpros::{self, ConnectionHeader};

Expand Down Expand Up @@ -170,7 +170,7 @@ impl ServiceServerLink {
// Probably it is better to try to send an error back?
debug!("Received service_request connection from {peer_addr} for {service_name}");

let connection_header = match tcpros::recieve_header(&mut stream).await {
let connection_header = match tcpros::receive_header(&mut stream).await {
Ok(header) => header,
Err(e) => {
warn!("Communication error while handling service request connection for {service_name}, could not parse header: {e:?}");
Expand Down Expand Up @@ -203,27 +203,15 @@ impl ServiceServerLink {
// That means we expect one header exchange, and then multiple body exchanges
// Each loop is one body:
loop {
let mut body_len_bytes = [0u8; 4];
if let Err(e) = stream.read_exact(&mut body_len_bytes).await {
// Note: this was lowered to debug! from warn! because this is intentionally done by tools like `rosservice` to discover service type
debug!("Communication error while handling service request connection for {service_name}, could not get body length: {e:?}");
// TODO returning here simply closes the socket? Should we respond with an error instead?
return;
}
let body_len = u32::from_le_bytes(body_len_bytes) as usize;
trace!("Got body length {body_len} for service {service_name}");

let mut body = vec![0u8; body_len];
if let Err(e) = stream.read_exact(&mut body).await {
warn!("Communication error while handling service request connection for {service_name}, could not get body: {e:?}");
// TODO returning here simply closes the socket? Should we respond with an error instead?
return;
}
trace!("Got body for service {service_name}: {body:#?}");

// Okay this is funky and I should be able to do better here
// serde_rosmsg expects the length at the front
let full_body = [body_len_bytes.to_vec(), body].concat();
let full_body = match tcpros::receive_body(&mut stream).await {
Ok(body) => body,
Err(e) => {
// Note this was degraded to debug! from warn! as every single use client produces this message
debug!("Communication error while handling service request connection for {service_name}, could not read body: {e:?}");
// Returning here closes the socket
return;
}
};

let response = (method)(full_body);

Expand Down
78 changes: 65 additions & 13 deletions roslibrust/src/ros1/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use abort_on_drop::ChildTask;
use roslibrust_codegen::RosMessageType;
use std::{marker::PhantomData, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
io::AsyncWriteExt,
net::TcpStream,
sync::{
broadcast::{self, error::RecvError},
Expand Down Expand Up @@ -117,21 +117,19 @@ impl Subscription {
{
publisher_list.write().await.push(publisher_uri.to_owned());
// Repeatedly read from the stream until its dry
let mut read_buffer = Vec::with_capacity(4 * 1024);
loop {
if let Ok(bytes_read) = stream.read_buf(&mut read_buffer).await {
if bytes_read == 0 {
log::debug!("Got a message with 0 bytes, probably an EOF, closing connection");
break;
match tcpros::receive_body(&mut stream).await {
Ok(body) => {
let send_result = sender.send(body);
if let Err(err) = send_result {
log::error!("Unable to send message data due to dropped channel, closing connection: {err}");
break;
}
}
log::debug!("Read {bytes_read} bytes from the publisher connection");
if let Err(err) = sender.send(Vec::from(&read_buffer[..bytes_read])) {
log::error!("Unable to send message data due to dropped channel, closing connection: {err}");
Err(e) => {
log::debug!("Failed to read body from publisher connection: {e}, closing connection");
break;
}
read_buffer.clear();
} else {
log::warn!("Got an error reading from the publisher connection on topic {topic_name:?}, closing");
}
}
}
Expand All @@ -155,7 +153,7 @@ async fn establish_publisher_connection(
let conn_header_bytes = conn_header.to_bytes(true)?;
stream.write_all(&conn_header_bytes[..]).await?;

if let Ok(responded_header) = tcpros::recieve_header(&mut stream).await {
if let Ok(responded_header) = tcpros::receive_header(&mut stream).await {
if conn_header.md5sum == responded_header.md5sum {
log::debug!(
"Established connection with publisher for {:?}",
Expand Down Expand Up @@ -248,3 +246,57 @@ impl From<serde_rosmsg::Error> for SubscriberError {
Self::DeserializeError(value.to_string())
}
}

#[cfg(test)]
mod test {

use crate::ros1::NodeHandle;

// TODO stop redundantly doing codegen so many times in tests
roslibrust_codegen_macro::find_and_generate_ros_messages!(
"assets/ros1_test_msgs",
"assets/ros1_common_interfaces"
);

#[test_log::test(tokio::test)]
async fn test_large_payload_subscriber() {
let nh = NodeHandle::new("http://localhost:11311", "/test_large_payload_subscriber")
.await
.unwrap();

let publisher = nh
.advertise::<test_msgs::RoundTripArrayRequest>("/large_payload_topic", 1)
.await
.unwrap();

let mut subscriber = nh
.subscribe::<test_msgs::RoundTripArrayRequest>("/large_payload_topic", 1)
.await
.unwrap();

// Give some time for subscriber to connect to publisher
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

for _i in 0..10 {
let bytes = vec![0; 10_000];
publisher
.publish(&test_msgs::RoundTripArrayRequest {
bytes: bytes.clone(),
})
.await
.unwrap();

match subscriber.next().await {
Some(Ok(msg)) => {
assert_eq!(msg.bytes, bytes);
}
Some(Err(e)) => {
panic!("Got error: {e:?}");
}
None => {
panic!("Got None");
}
}
}
}
}
27 changes: 25 additions & 2 deletions roslibrust/src/ros1/tcpros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ pub async fn establish_connection(
stream.write_all(&conn_header_bytes[..]).await?;

// Recieve the header from the server
let responded_header = recieve_header(&mut stream).await;
let responded_header = receive_header(&mut stream).await;
if let Ok(_responded_header) = responded_header {
// TODO we should really examine this md5sum logic...
// according to the ROS documentation, the service isn't required to respond
Expand All @@ -226,7 +226,7 @@ pub async fn establish_connection(
}

// Reads a complete ROS connection header from the given stream
pub async fn recieve_header(stream: &mut TcpStream) -> Result<ConnectionHeader, std::io::Error> {
pub async fn receive_header(stream: &mut TcpStream) -> Result<ConnectionHeader, std::io::Error> {
// Bring trait def into scope
use tokio::io::AsyncReadExt;
// Recieve the header length
Expand All @@ -241,6 +241,29 @@ pub async fn recieve_header(stream: &mut TcpStream) -> Result<ConnectionHeader,
ConnectionHeader::from_bytes(&header_bytes)
}

/// Reads the body of a message from the given stream
/// It first reads the length of the body, then reads the body itself
/// The returned Vec<> includes the length of the body at the front as serde_rosmsg expects
pub async fn receive_body(stream: &mut TcpStream) -> Result<Vec<u8>, std::io::Error> {
// Bring trait def into scope
use tokio::io::AsyncReadExt;

// Read the four bytes of size directly
let mut body_len_bytes = [0u8; 4];
stream.read_exact(&mut body_len_bytes).await?;
let body_len = u32::from_le_bytes(body_len_bytes);

// Allocate buffer space for length and body
let mut body = vec![0u8; body_len as usize + 4];
// Copy the length into the first four bytes
body[..4].copy_from_slice(&body_len.to_le_bytes());
// Read the body into the buffer
stream.read_exact(&mut body[4..]).await?;

// Return body
Ok(body)
}

#[cfg(test)]
mod test {
use super::ConnectionHeader;
Expand Down

0 comments on commit 952cba7

Please sign in to comment.