Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MAX_PREALLOCATION consistently #605

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 90 additions & 120 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::{
DecodeFinished, Error,
};

pub(crate) const MAX_PREALLOCATION: usize = 4 * 1024;
pub(crate) const MAX_PREALLOCATION: usize = 16 * 1024;
const A_BILLION: u32 = 1_000_000_000;

/// Trait that allows reading of data into a slice.
Expand Down Expand Up @@ -834,52 +834,6 @@ pub(crate) fn encode_slice_no_len<T: Encode, W: Output + ?Sized>(slice: &[T], de
}
}

/// Decode the vec (without a prepended len).
///
/// This is equivalent to decode all elements one by one, but it is optimized in some
/// situation.
pub fn decode_vec_with_len<T: Decode, I: Input>(
input: &mut I,
len: usize,
) -> Result<Vec<T>, Error> {
fn decode_unoptimized<I: Input, T: Decode>(
input: &mut I,
items_len: usize,
) -> Result<Vec<T>, Error> {
let input_capacity = input
.remaining_len()?
.unwrap_or(MAX_PREALLOCATION)
.checked_div(mem::size_of::<T>())
.unwrap_or(0);
let mut r = Vec::with_capacity(input_capacity.min(items_len));
input.descend_ref()?;
for _ in 0..items_len {
r.push(T::decode(input)?);
}
input.ascend_ref();
Ok(r)
}

macro_rules! decode {
( $ty:ty, $input:ident, $len:ident ) => {{
if cfg!(target_endian = "little") || mem::size_of::<T>() == 1 {
let vec = read_vec_from_u8s::<_, $ty>($input, $len)?;
Ok(unsafe { mem::transmute::<Vec<$ty>, Vec<T>>(vec) })
} else {
decode_unoptimized($input, $len)
}
}};
}

with_type_info! {
<T as Decode>::TYPE_INFO,
decode(input, len),
{
decode_unoptimized(input, len)
},
}
}

impl_for_non_zero! {
NonZeroI8,
NonZeroI16,
Expand Down Expand Up @@ -1113,71 +1067,110 @@ impl<T: Encode> Encode for [T] {
}
}

/// Create a `Vec<T>` by casting directly from a buffer of read `u8`s
///
/// The encoding of `T` must be equal to its binary representation, and size of `T` must be less or
/// equal to [`MAX_PREALLOCATION`].
pub(crate) fn read_vec_from_u8s<I, T>(input: &mut I, items_len: usize) -> Result<Vec<T>, Error>
fn decode_vec_chunked<T, F>(len: usize, mut decode_chunk: F) -> Result<Vec<T>, Error>
where
I: Input,
T: ToMutByteSlice + Default + Clone,
F: FnMut(&mut Vec<T>, usize) -> Result<(), Error>,
{
debug_assert!(MAX_PREALLOCATION >= mem::size_of::<T>(), "Invalid precondition");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make this into a static assert and check it at compile time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't manage to do this so far. I tried something like

const _: () = {
    assert!(MAX_PREALLOCATION >= mem::size_of::<T>())
}

inside decode_vec_chunked()

But I'm getting an error: can't use generic parameters from outer item.

Any suggestion would be helpful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to define a constant; since Rust 1.79 you can use a const {} block to force const evaluation of an expression.

Copy link
Contributor Author

@serban300 serban300 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this works, thanks ! PTAL on #615

But the CI fails, because the CI image uses rust 1.73.0 . We can try to release a paritytech/ci-unified:bullseye-1.79.0 image. Will check how this can be done.

let chunk_len = MAX_PREALLOCATION / mem::size_of::<T>();

let byte_len = items_len
.checked_mul(mem::size_of::<T>())
.ok_or("Item is too big and cannot be allocated")?;
let mut decoded_vec = vec![];
let mut num_undecoded_items = len;
while num_undecoded_items > 0 {
let chunk_len = chunk_len.min(num_undecoded_items);
decoded_vec.reserve_exact(chunk_len);

let input_len = input.remaining_len()?;
decode_chunk(&mut decoded_vec, chunk_len)?;

// If there is input len and it cannot be pre-allocated then return directly.
if input_len.map(|l| l < byte_len).unwrap_or(false) {
return Err("Not enough data to decode vector".into());
num_undecoded_items = num_undecoded_items.saturating_sub(chunk_len);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This saturating_sub's completely unnecessary here since impossible to have chunk_len > num_undecoded_items due to the min.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in #615

}

// In both these branches we're going to be creating and resizing a Vec<T>,
// but casting it to a &mut [u8] for reading.

// Note: we checked that if input_len is some then it can preallocated.
let r = if input_len.is_some() || byte_len < MAX_PREALLOCATION {
// Here we pre-allocate the whole buffer.
let mut items: Vec<T> = vec![Default::default(); items_len];
let bytes_slice = items.as_mut_byte_slice();
input.read(bytes_slice)?;
Ok(decoded_vec)
}

items
} else {
// An allowed number of preallocated item.
// Note: `MAX_PREALLOCATION` is expected to be more or equal to size of `T`, precondition.
let max_preallocated_items = MAX_PREALLOCATION / mem::size_of::<T>();
/// Create a `Vec<T>` by casting directly from a buffer of read `u8`s
///
/// The encoding of `T` must be equal to its binary representation, and size of `T` must be less
/// or equal to [`MAX_PREALLOCATION`].
fn read_vec_from_u8s<T, I>(input: &mut I, len: usize) -> Result<Vec<T>, Error>
where
T: ToMutByteSlice + Default + Clone,
I: Input,
{
let byte_len = len
.checked_mul(mem::size_of::<T>())
.ok_or("Item is too big and cannot be allocated")?;

// Here we pre-allocate only the maximum pre-allocation
let mut items: Vec<T> = vec![];
// Check if there is enough data in the input buffer.
if let Some(input_len) = input.remaining_len()? {
if input_len < byte_len {
return Err("Not enough data to decode vector".into());
}
}

let mut items_remains = items_len;
decode_vec_chunked(len, |decoded_vec, chunk_len| {
let decoded_vec_len = decoded_vec.len();
let decoded_vec_size = decoded_vec_len * mem::size_of::<T>();
unsafe {
decoded_vec.set_len(decoded_vec_len + chunk_len);
}

while items_remains > 0 {
let items_len_read = max_preallocated_items.min(items_remains);
let bytes_slice = decoded_vec.as_mut_byte_slice();
input.read(&mut bytes_slice[decoded_vec_size..])
})
}

let items_len_filled = items.len();
let items_new_size = items_len_filled + items_len_read;
fn decode_vec_from_items<T, I>(input: &mut I, len: usize) -> Result<Vec<T>, Error>
where
T: Decode,
I: Input,
{
// Check if there is enough data in the input buffer.
if let Some(input_len) = input.remaining_len()? {
if input_len < len {
return Err("Not enough data to decode vector".into());
}
}
Comment on lines +1129 to +1133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct as deserializing T might take any number of bytes (including even zero bytes, e.g. ()).

What we should do here is to have a serialized_size_hint() method (or, more specifically, probably an associated const so that it can be checked statically to fit within MAX_PREALLOCATION) or something like that on T which would return a value that could allow this check. (We already have encoded_fixed_size there, but that returns an exact number of bytes; it could be used here, but technically that's too strict and we can do better here by using the minimum.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just drop this check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, alternatively we can drop it. Although having it here can have one benefit - if we end up not having enough data then this will return an early error instead of wasting time trying to deserialize it. Nice to have, but not strictly necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To implement this, we would need to write quite a lot of code. For example for an enum we would need to know the variant that requires the least amount of bytes. However, it could then still fail at decoding because we try to decode always the enum variant that uses much more bytes etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, well, would it be that much code? I implemented this in my serialization crate and it's mostly fine; with enums you essentially just autogenerate a (min(variant1, variant2, ..), max(variant1, variant2, ..)) in your impl. Of course this is just an optimization (in some cases it would make incomplete deserializations fail early, and in some cases it would allow the compiler to remove per-element size checks), and as you've said it can still fail at decoding depending on what you're decoding.

Anyway, I'm fine with going with your suggestion to just delete the check.

Copy link
Contributor Author

@serban300 serban300 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it for the moment: #615


items.reserve_exact(items_len_read);
unsafe {
items.set_len(items_new_size);
}
input.descend_ref()?;
let vec = decode_vec_chunked(len, |decoded_vec, chunk_len| {
for _ in 0..chunk_len {
decoded_vec.push(T::decode(input)?);
}

let bytes_slice = items.as_mut_byte_slice();
let bytes_len_filled = items_len_filled * mem::size_of::<T>();
input.read(&mut bytes_slice[bytes_len_filled..])?;
Ok(())
})?;
input.ascend_ref();

items_remains = items_remains.saturating_sub(items_len_read);
}
Ok(vec)
}

items
};
/// Decode the vec (without a prepended len).
///
/// This is equivalent to decode all elements one by one, but it is optimized in some
/// situation.
pub fn decode_vec_with_len<T: Decode, I: Input>(
input: &mut I,
len: usize,
) -> Result<Vec<T>, Error> {
macro_rules! decode {
( $ty:ty, $input:ident, $len:ident ) => {{
if cfg!(target_endian = "little") || mem::size_of::<T>() == 1 {
let vec = read_vec_from_u8s::<$ty, _>($input, $len)?;
Ok(unsafe { mem::transmute::<Vec<$ty>, Vec<T>>(vec) })
} else {
decode_vec_from_items::<T, _>($input, $len)
}
}};
}

Ok(r)
with_type_info! {
<T as Decode>::TYPE_INFO,
decode(input, len),
{
decode_vec_from_items::<T, _>(input, len)
},
}
}

impl<T> WrapperTypeEncode for Vec<T> {}
Expand Down Expand Up @@ -1260,32 +1253,9 @@ impl<T: Encode> Encode for VecDeque<T> {
fn encode_to<W: Output + ?Sized>(&self, dest: &mut W) {
compact_encode_len_to(dest, self.len()).expect("Compact encodes length");

macro_rules! encode_to {
( $ty:ty, $self:ident, $dest:ident ) => {{
if cfg!(target_endian = "little") || mem::size_of::<T>() == 1 {
let slices = $self.as_slices();
let typed =
unsafe { core::mem::transmute::<(&[T], &[T]), (&[$ty], &[$ty])>(slices) };

$dest.write(<[$ty] as AsByteSlice<$ty>>::as_byte_slice(typed.0));
$dest.write(<[$ty] as AsByteSlice<$ty>>::as_byte_slice(typed.1));
} else {
for item in $self {
item.encode_to($dest);
}
}
}};
}

with_type_info! {
<T as Encode>::TYPE_INFO,
encode_to(self, dest),
{
for item in self {
item.encode_to(dest);
}
},
}
let slices = self.as_slices();
encode_slice_no_len(slices.0, dest);
encode_slice_no_len(slices.1, dest);
}
}

Expand Down
Loading