Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode with mem limit #602

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ pub trait Input {
/// This is called when decoding reference-based type is finished.
fn ascend_ref(&mut self) {}

/// Try to allocate a contiguous chunk of memory of `size` bytes.
fn try_alloc(&mut self, size: usize) -> Result<(), Error> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It looks like we won't really allocate memory here, but only keep track of it. So I would rename this to something like note_alloc(). Or even something more generic like on_before_decode()

Ok(())
}

/// !INTERNAL USE ONLY!
///
/// Decodes a `bytes::Bytes`.
Expand Down Expand Up @@ -1135,6 +1140,9 @@ where
return Err("Not enough data to decode vector".into());
}

// Check that we have enough memory left to do this.
input.try_alloc(byte_len)?;

// In both these branches we're going to be creating and resizing a Vec<T>,
// but casting it to a &mut [u8] for reading.

Expand Down Expand Up @@ -1187,11 +1195,16 @@ impl<T: EncodeLike<U>, U: Encode> EncodeLike<Vec<U>> for &[T] {}

impl<T: Decode> Decode for Vec<T> {
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
input.try_alloc(mem::size_of::<Self>())?;

<Compact<u32>>::decode(input)
.and_then(move |Compact(len)| decode_vec_with_len(input, len as usize))
}
}

// Mark vec as MemLimited since we track the allocated memory:
impl<T: Decode> crate::memory_limit::DecodeMemLimit for Vec<T> { }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to do impl<T: crate::memory_limit::DecodeMemLimit> crate::memory_limit::DecodeMemLimit for Vec<T> { } . Otherwise I can call decode_with_mem_limit for example on a Vec<Box<_>> which shouldn't support this yet.


macro_rules! impl_codec_through_iterator {
($(
$type:ident
Expand Down Expand Up @@ -1466,6 +1479,8 @@ macro_rules! impl_endians {
const TYPE_INFO: TypeInfo = TypeInfo::$ty_info;

fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
input.try_alloc(mem::size_of::<$t>())?;

let mut buf = [0u8; mem::size_of::<$t>()];
input.read(&mut buf)?;
Ok(<$t>::from_le_bytes(buf))
Expand Down Expand Up @@ -1497,6 +1512,8 @@ macro_rules! impl_one_byte {
const TYPE_INFO: TypeInfo = TypeInfo::$ty_info;

fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
input.try_alloc(1)?;

Ok(input.read_byte()? as $t)
}
}
Expand Down
9 changes: 9 additions & 0 deletions src/depth_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ pub trait DecodeLimit: Sized {
}

struct DepthTrackingInput<'a, I> {
/// The actual input.
input: &'a mut I,

/// Current recursive depth.
depth: u32,

/// Maximum allowed recursive depth.
max_depth: u32,
}

Expand Down Expand Up @@ -64,6 +69,10 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> {
self.input.ascend_ref();
self.depth -= 1;
}

fn try_alloc(&mut self, size: usize) -> Result<(), Error> {
self.input.try_alloc(size)
}
}

impl<T: Decode> DecodeLimit for T {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ mod const_encoded_len;
mod decode_all;
mod decode_finished;
mod depth_limit;
mod memory_limit;
mod encode_append;
mod encode_like;
mod error;
Expand Down
182 changes: 182 additions & 0 deletions src/memory_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright 2017, 2018 Parity Technologies
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::{Decode, Error, Input};

/// The error message returned when the memory limit is reached.
const DECODE_OOM_ERROR: &str = "Out of memory when decoding";

/// Extension trait to [`Decode`] for decoding with a maximum memory consumption.
///
/// Should be used as a marker for types that do the memory tracking in their `decode` implementation with [`Input::try_alloc`].
pub trait DecodeMemLimit: Decode + Sized {
fn decode_with_mem_limit<I: Input>(limit: MemLimit, input: &mut I) -> Result<Self, Error> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach seems a bit limiting since we can only do decode_with_mem_limit or decode_with_depth_limit. Looks like we can't combine them. It would be nice to be able to do something more generic. For example something like:

decode_with_check(input: &mut I, check: FnMut(DecodeContext) -> Result<(), Error>)

where DecodeContext would be something like

DecodeContext {
    depth: u32,
    used_memory: usize,
    ...
}

This way we could check both for depth and memory limit at the same time. Maybe for other parameters as well. Also would be nice to be able to keep track of the number of objects of a certain type that we encounter recursively while decoding since this is also something that we needed in Polkadot.

Copy link
Member Author

@ggwpez ggwpez May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we can't combine them.

We can. Which is why i chose this approach. These functions are just wrappers, but there is nobody stopping you from doing:

let mut input = MemTrackingInput { inner: DepthTrackingInput { inner: .. }};
T::decode(&mut input);

But it wont have any meaning without requiring T: DecodeMemLimit.

This way we could check both for depth and memory limit at the same time. Maybe for other parameters as well. Also would be nice to be able to keep track of the number of objects of a certain type that we encounter recursively while decoding since this is also something that we needed in Polkadot.

Yea, i had some decode_with_context<C>(..) implemented before, but it makes this not such an opt-in.
For example when some type only wants to track recursion and another only size. If we otherwise enforce this, it would be a big breaking change.

The breaking change would occur in the case where we make the Parameter type of a FRAME extrinsic require this.

let mut input = MemTrackingInput { inner: input, limit };
Self::decode(&mut input)
}
}

// Mark tuples as memory-tracking compatible if all of their elements are, since they dont consume any memory themselves.
#[impl_trait_for_tuples::impl_for_tuples(18)]
impl DecodeMemLimit for Tuple {
for_tuples!( where #( Tuple: DecodeMemLimit )* );

fn decode_with_mem_limit<I: Input>(limit: MemLimit, input: &mut I) -> Result<Self, Error> {
let mut input = MemTrackingInput { inner: input, limit };
let r = for_tuples!( ( #( Tuple::decode(&mut input)? ),* ) );
Ok(r)
}
}

/// An input that additionally tracks memory usage.
pub struct MemTrackingInput<'a, I> {
/// The actual input.
pub inner: &'a mut I,

/// The remaining memory limit.
pub limit: MemLimit,
}

/// A limit on allocated memory.
pub struct MemLimit {
/// The remaining memory limit.
limit: usize,
/// Memory alignment to be applied before allocating memory.
align: Option<MemAlignment>,
}

impl MemLimit {
/// Try to allocate a contiguous chunk of memory.
pub fn try_alloc(&mut self, size: usize) -> Result<(), Error> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather just provide the type and do std::mem::size_of() here which should account for the alignment automatically.

let size = self.align.as_ref().map_or(size, |a| a.align(size));

if let Some(remaining) = self.limit.checked_sub(size) {
self.limit = remaining;
Ok(())
} else {
Err(DECODE_OOM_ERROR.into())
}
}

/// Maximal possible limit.
pub fn max() -> Self {
Self { limit: usize::MAX, align: None }
}
}

/// Alignment of some amount of memory.
///
/// Normally the word `alignment` is used in the context of a pointer - not an amount of memory, but this is still
/// the most fitting name.
pub enum MemAlignment {
/// Round up to the next power of two.
NextPowerOfTwo,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Currently when using this with a Vector, we align each element individually.
But a Vec allocates all elements in a chunk through the WASM bump allocator. So what we have to do is round(sum(e...)) over all elements instead of sum(round(e)..).

Otherwise it would round each element size up to the next power of two, but the vector may still allocate more since it aligns with the sum of the sizes of all elements.

}

impl MemAlignment {
fn align(&self, size: usize) -> usize {
match self {
MemAlignment::NextPowerOfTwo => {
size.next_power_of_two().max(size)
},
}
}
}

impl <T: Into<usize>> From<T> for MemLimit {
fn from(limit: T) -> Self {
Self { limit: limit.into(), align: None }
}
}

impl<'a, I: Input> Input for MemTrackingInput<'a, I> {
fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
self.inner.remaining_len()
}

fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
self.inner.read(into)
}

fn read_byte(&mut self) -> Result<u8, Error> {
self.inner.read_byte()
}

fn descend_ref(&mut self) -> Result<(), Error> {
self.inner.descend_ref()
}

fn ascend_ref(&mut self) {
self.inner.ascend_ref()
}

fn try_alloc(&mut self, size: usize) -> Result<(), Error> {
self.limit.try_alloc(size)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::*;
use core::mem;

#[test]
fn decode_with_mem_limit_oom_detected() {
let bytes = (Compact(1024 as u32), vec![0u32; 1024]).encode();
let mut input = &bytes[..];

// Limit is one too small.
let limit = 4096 + mem::size_of::<Vec<u32>>() - 1;
let result = <Vec<u32>>::decode_with_mem_limit((limit as usize).into(), &mut input);
assert_eq!(result, Err("Out of memory when decoding".into()));

// Now it works:
let limit = limit + 1;
let result = <Vec<u32>>::decode_with_mem_limit((limit as usize).into(), &mut input);
assert_eq!(result, Ok(vec![0u32; 1024]));
}

#[test]
fn decode_with_mem_limit_tuple_oom_detected() {
// First entry is 1 KiB, second is 4 KiB.
let data = (vec![0u8; 1024], vec![0u32; 1024]);
let bytes = data.encode();

// Limit is one too small.
let limit = 1024 + 4096 + 2 * mem::size_of::<Vec<u32>>() - 1;
let result = <(Vec<u8>, Vec<u32>)>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]);
assert_eq!(result, Err("Out of memory when decoding".into()));

// Now it works:
let limit = limit + 1;
let result = <(Vec<u8>, Vec<u32>)>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]);
assert_eq!(result, Ok(data));
}

#[test]
fn decode_with_mem_limit_nested_oom_detected() {
// Total size is 4 KiB + 3 * vector_size.
let data = vec![vec![vec![0u32; 1024]]];
let bytes = data.encode();

let limit = 4096 + 3 * mem::size_of::<Vec<u32>>() - 1;
let result = <Vec<Vec<Vec<u32>>>>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]);
assert_eq!(result, Err("Out of memory when decoding".into()));

let limit = limit + 1;
let result = <Vec<Vec<Vec<u32>>>>::decode_with_mem_limit((limit as usize).into(), &mut &bytes[..]);
assert_eq!(result, Ok(data));
}
}
Loading