diff --git a/src/codec.rs b/src/codec.rs index c8d9e40d..757e9f0f 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -85,6 +85,11 @@ pub trait Input { /// This is called when decoding reference-based type is finished. fn ascend_ref(&mut self) {} + /// Try to allocate a contiguous chunk of memory of `size` bytes. + fn try_alloc(&mut self, size: usize) -> Result<(), Error> { + Ok(()) + } + /// !INTERNAL USE ONLY! /// /// Decodes a `bytes::Bytes`. @@ -1135,6 +1140,9 @@ where return Err("Not enough data to decode vector".into()); } + // Check that we have enough memory left to do this. + input.try_alloc(byte_len)?; + // In both these branches we're going to be creating and resizing a Vec, // but casting it to a &mut [u8] for reading. @@ -1187,11 +1195,16 @@ impl, U: Encode> EncodeLike> for &[T] {} impl Decode for Vec { fn decode(input: &mut I) -> Result { + input.try_alloc(mem::size_of::())?; + >::decode(input) .and_then(move |Compact(len)| decode_vec_with_len(input, len as usize)) } } +// Mark vec as MemLimited since we track the allocated memory: +impl crate::memory_limit::DecodeMemLimit for Vec { } + macro_rules! impl_codec_through_iterator { ($( $type:ident @@ -1466,6 +1479,8 @@ macro_rules! impl_endians { const TYPE_INFO: TypeInfo = TypeInfo::$ty_info; fn decode(input: &mut I) -> Result { + input.try_alloc(mem::size_of::<$t>())?; + let mut buf = [0u8; mem::size_of::<$t>()]; input.read(&mut buf)?; Ok(<$t>::from_le_bytes(buf)) @@ -1497,6 +1512,8 @@ macro_rules! impl_one_byte { const TYPE_INFO: TypeInfo = TypeInfo::$ty_info; fn decode(input: &mut I) -> Result { + input.try_alloc(1)?; + Ok(input.read_byte()? as $t) } } diff --git a/src/depth_limit.rs b/src/depth_limit.rs index 2af17843..dbbb059d 100644 --- a/src/depth_limit.rs +++ b/src/depth_limit.rs @@ -32,8 +32,13 @@ pub trait DecodeLimit: Sized { } struct DepthTrackingInput<'a, I> { + /// The actual input. input: &'a mut I, + + /// Current recursive depth. depth: u32, + + /// Maximum allowed recursive depth. max_depth: u32, } @@ -64,6 +69,10 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> { self.input.ascend_ref(); self.depth -= 1; } + + fn try_alloc(&mut self, size: usize) -> Result<(), Error> { + self.input.try_alloc(size) + } } impl DecodeLimit for T { diff --git a/src/lib.rs b/src/lib.rs index 9f95cf69..0c75cb03 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,7 @@ mod const_encoded_len; mod decode_all; mod decode_finished; mod depth_limit; +mod memory_limit; mod encode_append; mod encode_like; mod error; diff --git a/src/memory_limit.rs b/src/memory_limit.rs new file mode 100644 index 00000000..99cb38c5 --- /dev/null +++ b/src/memory_limit.rs @@ -0,0 +1,182 @@ +// Copyright 2017, 2018 Parity Technologies +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::{Decode, Error, Input}; + +/// The error message returned when the memory limit is reached. +const DECODE_OOM_ERROR: &str = "Out of memory when decoding"; + +/// Extension trait to [`Decode`] for decoding with a maximum memory consumption. +/// +/// Should be used as a marker for types that do the memory tracking in their `decode` implementation with [`Input::try_alloc`]. +pub trait DecodeMemLimit: Decode + Sized { + fn decode_with_mem_limit(limit: MemLimit, input: &mut I) -> Result { + let mut input = MemTrackingInput { inner: input, limit }; + Self::decode(&mut input) + } +} + +// Mark tuples as memory-tracking compatible if all of their elements are, since they dont consume any memory themselves. +#[impl_trait_for_tuples::impl_for_tuples(18)] +impl DecodeMemLimit for Tuple { + for_tuples!( where #( Tuple: DecodeMemLimit )* ); + + fn decode_with_mem_limit(limit: MemLimit, input: &mut I) -> Result { + let mut input = MemTrackingInput { inner: input, limit }; + let r = for_tuples!( ( #( Tuple::decode(&mut input)? ),* ) ); + Ok(r) + } +} + +/// An input that additionally tracks memory usage. +pub struct MemTrackingInput<'a, I> { + /// The actual input. + pub inner: &'a mut I, + + /// The remaining memory limit. + pub limit: MemLimit, +} + +/// A limit on allocated memory. +pub struct MemLimit { + /// The remaining memory limit. + limit: usize, + /// Memory alignment to be applied before allocating memory. + align: Option, +} + +impl MemLimit { + /// Try to allocate a contiguous chunk of memory. + pub fn try_alloc(&mut self, size: usize) -> Result<(), Error> { + let size = self.align.as_ref().map_or(size, |a| a.align(size)); + + if let Some(remaining) = self.limit.checked_sub(size) { + self.limit = remaining; + Ok(()) + } else { + Err(DECODE_OOM_ERROR.into()) + } + } + + /// Maximal possible limit. + pub fn max() -> Self { + Self { limit: usize::MAX, align: None } + } +} + +/// Alignment of some amount of memory. +/// +/// Normally the word `alignment` is used in the context of a pointer - not an amount of memory, but this is still +/// the most fitting name. +pub enum MemAlignment { + /// Round up to the next power of two. + NextPowerOfTwo, +} + +impl MemAlignment { + fn align(&self, size: usize) -> usize { + match self { + MemAlignment::NextPowerOfTwo => { + size.next_power_of_two().max(size) + }, + } + } +} + +impl > From for MemLimit { + fn from(limit: T) -> Self { + Self { limit: limit.into(), align: None } + } +} + +impl<'a, I: Input> Input for MemTrackingInput<'a, I> { + fn remaining_len(&mut self) -> Result, Error> { + self.inner.remaining_len() + } + + fn read(&mut self, into: &mut [u8]) -> Result<(), Error> { + self.inner.read(into) + } + + fn read_byte(&mut self) -> Result { + self.inner.read_byte() + } + + fn descend_ref(&mut self) -> Result<(), Error> { + self.inner.descend_ref() + } + + fn ascend_ref(&mut self) { + self.inner.ascend_ref() + } + + fn try_alloc(&mut self, size: usize) -> Result<(), Error> { + self.limit.try_alloc(size) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::*; + use core::mem; + + #[test] + fn decode_with_mem_limit_oom_detected() { + let bytes = (Compact(1024 as u32), vec![0u32; 1024]).encode(); + let mut input = &bytes[..]; + + // Limit is one too small. + let limit = 4096 + mem::size_of::>() - 1; + let result = >::decode_with_mem_limit((limit as usize).into(), &mut input); + assert_eq!(result, Err("Out of memory when decoding".into())); + + // Now it works: + let limit = limit + 1; + let result = >::decode_with_mem_limit((limit as usize).into(), &mut input); + assert_eq!(result, Ok(vec![0u32; 1024])); + } + + #[test] + fn decode_with_mem_limit_tuple_oom_detected() { + // First entry is 1 KiB, second is 4 KiB. + let data = (vec![0u8; 1024], vec![0u32; 1024]); + let bytes = data.encode(); + + // Limit is one too small. + let limit = 1024 + 4096 + 2 * mem::size_of::>() - 1; + let result = <(Vec, Vec)>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]); + assert_eq!(result, Err("Out of memory when decoding".into())); + + // Now it works: + let limit = limit + 1; + let result = <(Vec, Vec)>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]); + assert_eq!(result, Ok(data)); + } + + #[test] + fn decode_with_mem_limit_nested_oom_detected() { + // Total size is 4 KiB + 3 * vector_size. + let data = vec![vec![vec![0u32; 1024]]]; + let bytes = data.encode(); + + let limit = 4096 + 3 * mem::size_of::>() - 1; + let result = >>>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]); + assert_eq!(result, Err("Out of memory when decoding".into())); + + let limit = limit + 1; + let result = >>>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]); + assert_eq!(result, Ok(data)); + } +}