Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Fix OOM on incorrect input. #2

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions protobuf/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum WireError {
Utf8Error,
InvalidEnumValue(i32),
OverRecursionLimit,
TruncatedMessage,
Other,
}

Expand Down Expand Up @@ -57,6 +58,7 @@ impl Error for ProtobufError {
WireError::IncompleteMap => "incomplete map",
WireError::UnexpectedEof => "unexpected EOF",
WireError::OverRecursionLimit => "over recursion limit",
WireError::TruncatedMessage => "truncated message",
WireError::Other => "other error",
}
}
Expand Down
58 changes: 54 additions & 4 deletions protobuf/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ const OUTPUT_STREAM_BUFFER_SIZE: usize = 8 * 1024;
// Default recursion level limit. 100 is the default value of C++'s implementation.
const DEFAULT_RECURSION_LIMIT: u32 = 100;

// Max allocated vec when reading length-delimited from unknown input stream
const READ_RAW_BYTES_MAX_ALLOC: usize = 10_000_000;


pub mod wire_format {
// TODO: temporary
Expand Down Expand Up @@ -623,14 +626,34 @@ impl<'a> CodedInputStream<'a> {
/// Read raw bytes into the supplied vector. The vector will be resized as needed and
/// overwritten.
pub fn read_raw_bytes_into(&mut self, count: u32, target: &mut Vec<u8>) -> ProtobufResult<()> {
let count = count as usize;

// TODO: also do some limits when reading from unlimited source
if count as u64 > self.source.bytes_until_limit() {
return Err(ProtobufError::WireError(WireError::TruncatedMessage));
}

unsafe {
target.set_len(0);
}
target.reserve(count as usize);
unsafe {
target.set_len(count as usize);

if count >= READ_RAW_BYTES_MAX_ALLOC {
// avoid calling `reserve` on buf with very large buffer: could be a malformed message

let mut take = self.by_ref().take(count as u64);
take.read_to_end(target)?;

if target.len() != count {
return Err(ProtobufError::WireError(WireError::TruncatedMessage));
}
} else {
target.reserve(count);
unsafe {
target.set_len(count);
}

self.source.read_exact(target)?;
}
self.read(target)?;
Ok(())
}

Expand Down Expand Up @@ -1255,6 +1278,7 @@ mod test {
use super::wire_format;
use super::CodedInputStream;
use super::CodedOutputStream;
use super::READ_RAW_BYTES_MAX_ALLOC;

fn test_read_partial<F>(hex: &str, mut callback: F)
where
Expand Down Expand Up @@ -1425,6 +1449,32 @@ mod test {
});
}

#[test]
fn test_input_stream_read_raw_bytes_into_huge() {
let mut v = Vec::new();
for i in 0..READ_RAW_BYTES_MAX_ALLOC + 1000 {
v.push((i % 10) as u8);
}

let mut slice: &[u8] = v.as_slice();

let mut is = CodedInputStream::new(&mut slice);

let mut buf = Vec::new();

is.read_raw_bytes_into(READ_RAW_BYTES_MAX_ALLOC as u32 + 10, &mut buf).expect("read");

assert_eq!(READ_RAW_BYTES_MAX_ALLOC + 10, buf.len());

buf.clear();

is.read_raw_bytes_into(1000 - 10, &mut buf).expect("read");

assert_eq!(1000 - 10, buf.len());

assert!(is.eof().expect("eof"));
}

fn test_write<F>(expected: &str, mut gen: F)
where
F : FnMut(&mut CodedOutputStream) -> ProtobufResult<()>,
Expand Down