Skip to content

Commit

Permalink
Fix Bytes decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
serban300 committed Jul 24, 2024
1 parent 0313d3b commit 5b1b387
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 19 deletions.
58 changes: 40 additions & 18 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ pub trait Input {

/// !INTERNAL USE ONLY!
///
/// Decodes a `bytes::Bytes`.
/// Used when decoding a `bytes::Bytes` from a `BytesCursor` input.
#[cfg(feature = "bytes")]
#[doc(hidden)]
fn scale_internal_decode_bytes(&mut self) -> Result<bytes::Bytes, Error>
fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor>
where
Self: Sized,
{
Vec::<u8>::decode(self).map(bytes::Bytes::from)
None
}
}

Expand Down Expand Up @@ -414,12 +414,32 @@ mod feature_wrapper_bytes {
impl EncodeLike<Bytes> for Vec<u8> {}
}

/// `Input` implementation optimized for decoding `bytes::Bytes`.
#[cfg(feature = "bytes")]
struct BytesCursor {
pub struct BytesCursor {
bytes: bytes::Bytes,
position: usize,
}

#[cfg(feature = "bytes")]
impl BytesCursor {
/// Create a new instance of `BytesCursor`.
pub fn new(bytes: bytes::Bytes) -> Self {
Self { bytes, position: 0 }
}

fn decode_bytes_with_len(&mut self, length: usize) -> Result<bytes::Bytes, Error> {
bytes::Buf::advance(&mut self.bytes, self.position);
self.position = 0;

if length > self.bytes.len() {
return Err("Not enough data to fill buffer".into());
}

Ok(self.bytes.split_to(length))
}
}

#[cfg(feature = "bytes")]
impl Input for BytesCursor {
fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
Expand All @@ -436,18 +456,11 @@ impl Input for BytesCursor {
Ok(())
}

fn scale_internal_decode_bytes(&mut self) -> Result<bytes::Bytes, Error> {
let length = <Compact<u32>>::decode(self)?.0 as usize;

bytes::Buf::advance(&mut self.bytes, self.position);
self.position = 0;

if length > self.bytes.len() {
return Err("Not enough data to fill buffer".into());
}

self.on_before_alloc_mem(length)?;
Ok(self.bytes.split_to(length))
fn __private_bytes_cursor(&mut self) -> Option<&mut BytesCursor>
where
Self: Sized,
{
Some(self)
}
}

Expand All @@ -473,14 +486,23 @@ where
// However, if `T` doesn't contain any `Bytes` then this extra allocation is
// technically unnecessary, and we can avoid it by tracking the position ourselves
// and treating the underlying `Bytes` as a fancy `&[u8]`.
let mut input = BytesCursor { bytes, position: 0 };
let mut input = BytesCursor::new(bytes);
T::decode(&mut input)
}

#[cfg(feature = "bytes")]
impl Decode for bytes::Bytes {
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
input.scale_internal_decode_bytes()
let len = <Compact<u32>>::decode(input)?.0 as usize;
if input.__private_bytes_cursor().is_some() {
input.on_before_alloc_mem(len)?;
}

if let Some(bytes_cursor) = input.__private_bytes_cursor() {
bytes_cursor.decode_bytes_with_len(len)
} else {
decode_vec_with_len::<u8, _>(input, len).map(bytes::Bytes::from)
}
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/depth_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> {
self.input.ascend_ref();
self.depth -= 1;
}

fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> {
self.input.on_before_alloc_mem(size)
}

#[cfg(feature = "bytes")]
fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor>
where
Self: Sized,
{
self.input.__private_bytes_cursor()
}
}

impl<T: Decode> DecodeLimit for T {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ pub use max_encoded_len::MaxEncodedLen;
pub use parity_scale_codec_derive::MaxEncodedLen;

#[cfg(feature = "bytes")]
pub use self::codec::decode_from_bytes;
pub use self::codec::{decode_from_bytes, BytesCursor};
8 changes: 8 additions & 0 deletions src/mem_tracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,12 @@ impl<'a, I: Input> Input for MemTrackingInput<'a, I> {

Ok(())
}

#[cfg(feature = "bytes")]
fn __private_bytes_cursor(&mut self) -> Option<&mut crate::BytesCursor>
where
Self: Sized,
{
self.input.__private_bytes_cursor()
}
}
14 changes: 14 additions & 0 deletions tests/mem_tracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ fn decode_complex_objects_works() {
assert!(decode_object(Box::new(Rc::new(vec![String::from("test")])), usize::MAX, 60).is_ok());
}

#[cfg(feature = "bytes")]
#[test]
fn decode_bytes_from_bytes_works() {
use parity_scale_codec::Decode;

let obj = ([0u8; 100], Box::new(0u8), bytes::Bytes::from(vec![0u8; 50]));
let encoded_bytes = obj.encode();
let mut bytes_cursor = parity_scale_codec::BytesCursor::new(bytes::Bytes::from(encoded_bytes));
let mut input = MemTrackingInput::new(&mut bytes_cursor, usize::MAX);
let decoded_obj = <([u8; 100], Box<u8>, bytes::Bytes)>::decode(&mut input).unwrap();
assert_eq!(&decoded_obj, &obj);
assert_eq!(input.used_mem(), 51);
}

#[test]
fn decode_complex_derived_struct_works() {
#[derive(DeriveEncode, DeriveDecode, PartialEq, Debug)]
Expand Down

0 comments on commit 5b1b387

Please sign in to comment.