Skip to content

Commit

Permalink
Avoid calling AsyncRead::read with empty target
Browse files Browse the repository at this point in the history
  • Loading branch information
akonradi-signal committed Jun 28, 2024
1 parent f26fd13 commit 4cce563
Showing 1 changed file with 80 additions and 22 deletions.
102 changes: 80 additions & 22 deletions rust/message-backup/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,9 @@ impl<R: AsyncRead + Unpin> VarintDelimitedReader<R> {
async fn read_next_varint(&mut self) -> Result<Option<usize>, ParseError> {
let Self { buffer, reader } = self;

// First fill up the buffer with zeros so it can be treated as a slice.
// Keep track of how many bytes in the buffer have actually been read
// from the reader.
let mut read_bytes = buffer.len();
buffer.extend(
[0; VARINT_MAX_LENGTH][..buffer.remaining_capacity()]
.iter()
.cloned(),
);

// Read into the invalid portion until it's full or the reader is empty.
loop {
let n = reader.read(&mut buffer[read_bytes..]).await?;
if n == 0 {
break;
}
read_bytes += n;
}

// Chop off any zeroed-but-not-read bytes.
buffer.truncate(read_bytes);
fill_buffer_from_reader(reader, buffer).await?;

if read_bytes == 0 {
if buffer.is_empty() {
return Ok(None);
}

Expand All @@ -107,6 +87,33 @@ impl<R: AsyncRead + Unpin> VarintDelimitedReader<R> {
}
}

/// Read bytes from `reader` into `buffer`.
///
/// Returns when the latter is full or the former is exhausted.
async fn fill_buffer_from_reader<R: AsyncRead + Unpin, const N: usize>(
reader: &mut R,
buffer: &mut ArrayVec<u8, N>,
) -> Result<(), ParseError> {
// First fill up the buffer with zeros so it can be treated as a slice.
// Keep track of how many bytes in the buffer have actually been read
// from the reader.
let mut valid_bytes = buffer.len();
buffer.extend(
[0; VARINT_MAX_LENGTH][..buffer.remaining_capacity()]
.iter()
.cloned(),
);
while valid_bytes < buffer.len() {
let n = reader.read(&mut buffer[valid_bytes..]).await?;
if n == 0 {
break;
}
valid_bytes += n;
}
buffer.truncate(valid_bytes);
Ok(())
}

#[cfg(test)]
mod test {
use assert_matches::assert_matches;
Expand Down Expand Up @@ -222,4 +229,55 @@ mod test {
);
assert_matches!(block_on(reader.read_next()), Ok(None));
}

#[derive(Debug)]
struct ForbidZeroLengthTargetReader<R>(R);

impl<R: AsyncRead + Unpin> AsyncRead for ForbidZeroLengthTargetReader<R> {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> {
if buf.is_empty() {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"zero-length read is forbidden",
)));
}
std::pin::Pin::new(&mut self.as_mut().0).poll_read(cx, buf)
}
}

/// Regression test for a behavior where VarintDelimitedReader was calling
/// `AsyncRead::read` with an empty slice.
#[test]
fn varint_read_with_empty_target_slice() {
const FIRST: MessageAndLen<1, 9> = MessageAndLen::new([9], *b"123456789");
const SECOND: MessageAndLen<1, 7> = MessageAndLen::new([7], *b"abcdefg");

// Assert that our constants are correct before using them as input to
// the actual test.
assert_valid(&FIRST);
assert_valid(&SECOND);

let concatenated_reader = FIRST.into_reader().chain(SECOND.into_reader());
let reader = VarintDelimitedReader::new(ForbidZeroLengthTargetReader(concatenated_reader));
pin_mut!(reader);

// Read two messages.
assert_eq!(
*block_on(reader.read_next())
.expect("can read")
.expect("has frame"),
FIRST.message
);
assert_eq!(
*block_on(reader.read_next())
.expect("can read")
.expect("has frame"),
SECOND.message
);
assert_matches!(block_on(reader.read_next()), Ok(None));
}
}

0 comments on commit 4cce563

Please sign in to comment.