Skip to content

Commit

Permalink
fix: remove unnecessary trait bounds requirements for array (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
austinabell authored Oct 10, 2022
1 parent 4f436f9 commit aec5a4e
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 14 deletions.
114 changes: 100 additions & 14 deletions borsh/src/de/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::{
convert::{TryFrom, TryInto},
hash::{BuildHasher, Hash},
Expand Down Expand Up @@ -53,10 +54,9 @@ pub trait BorshDeserialize: Sized {

#[inline]
#[doc(hidden)]
fn copy_from_bytes(buf: &mut &[u8], out: &mut [Self]) -> Result<bool> {
fn array_from_bytes<const N: usize>(buf: &mut &[u8]) -> Result<Option<[Self; N]>> {
let _ = buf;
let _ = out;
Ok(false)
Ok(None)
}
}

Expand Down Expand Up @@ -91,17 +91,17 @@ impl BorshDeserialize for u8 {

#[inline]
#[doc(hidden)]
fn copy_from_bytes(buf: &mut &[u8], out: &mut [Self]) -> Result<bool> {
if buf.len() < out.len() {
fn array_from_bytes<const N: usize>(buf: &mut &[u8]) -> Result<Option<[Self; N]>> {
if buf.len() < N {
return Err(Error::new(
ErrorKind::InvalidInput,
ERROR_UNEXPECTED_LENGTH_OF_INPUT,
));
}
let (front, rest) = buf.split_at(out.len());
out.copy_from_slice(front);
let (front, rest) = buf.split_at(N);
*buf = rest;
Ok(true)
let front: [u8; N] = front.try_into().unwrap();
Ok(Some(front))
}
}

Expand Down Expand Up @@ -528,17 +528,103 @@ where

impl<T, const N: usize> BorshDeserialize for [T; N]
where
T: BorshDeserialize + Default + Copy,
T: BorshDeserialize,
{
#[inline]
fn deserialize(buf: &mut &[u8]) -> Result<Self> {
let mut result = [T::default(); N];
if N > 0 && !T::copy_from_bytes(buf, &mut result)? {
for i in result.iter_mut() {
*i = T::deserialize(buf)?;
struct ArrayDropGuard<T, const N: usize> {
buffer: [MaybeUninit<T>; N],
init_count: usize,
}
impl<T, const N: usize> Drop for ArrayDropGuard<T, N> {
fn drop(&mut self) {
let init_range = &mut self.buffer[..self.init_count];
// SAFETY: Elements up to self.init_count have been initialized. Assumes this value
// is only incremented in `fill_buffer`, which writes the element before
// increasing the init_count.
unsafe {
core::ptr::drop_in_place(init_range as *mut _ as *mut [T]);
};
}
}
Ok(result)
impl<T, const N: usize> ArrayDropGuard<T, N> {
unsafe fn transmute_to_array(mut self) -> [T; N] {
debug_assert_eq!(self.init_count, N);
// Set init_count to 0 so that the values do not get dropped twice.
self.init_count = 0;
// SAFETY: This cast is required because `mem::transmute` does not work with
// const generics https://github.com/rust-lang/rust/issues/61956. This
// array is guaranteed to be initialized by this point.
core::ptr::read(&self.buffer as *const _ as *const [T; N])
}
fn fill_buffer(&mut self, mut f: impl FnMut() -> Result<T>) -> Result<()> {
// TODO: replace with `core::array::try_from_fn` when stabilized to avoid manually
// dropping uninitialized values through the guard drop.
for elem in self.buffer.iter_mut() {
elem.write(f()?);
self.init_count += 1;
}
Ok(())
}
}

if let Some(arr) = T::array_from_bytes(buf)? {
Ok(arr)
} else {
let mut result = ArrayDropGuard {
buffer: unsafe { MaybeUninit::uninit().assume_init() },
init_count: 0,
};

result.fill_buffer(|| T::deserialize(buf))?;

// SAFETY: The elements up to `i` have been initialized in `fill_buffer`.
Ok(unsafe { result.transmute_to_array() })
}
}
}

#[test]
fn array_deserialization_doesnt_leak() {
use core::sync::atomic::{AtomicUsize, Ordering};

static DESERIALIZE_COUNT: AtomicUsize = AtomicUsize::new(0);
static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);

struct MyType(u8);
impl BorshDeserialize for MyType {
fn deserialize(buf: &mut &[u8]) -> Result<Self> {
let val = u8::deserialize(buf)?;
let v = DESERIALIZE_COUNT.fetch_add(1, Ordering::SeqCst);
if v >= 7 {
panic!("panic in deserialize");
}
Ok(MyType(val))
}
}
impl Drop for MyType {
fn drop(&mut self) {
DROP_COUNT.fetch_add(1, Ordering::SeqCst);
}
}

assert!(<[MyType; 5] as BorshDeserialize>::deserialize(&mut &[0u8; 3][..]).is_err());
assert_eq!(DESERIALIZE_COUNT.load(Ordering::SeqCst), 3);
assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);

assert!(<[MyType; 2] as BorshDeserialize>::deserialize(&mut &[0u8; 2][..]).is_ok());
assert_eq!(DESERIALIZE_COUNT.load(Ordering::SeqCst), 5);
assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 5);

#[cfg(feature = "std")]
{
// Test that during a panic in deserialize, the values are still dropped.
let result = std::panic::catch_unwind(|| {
<[MyType; 3] as BorshDeserialize>::deserialize(&mut &[0u8; 3][..]).unwrap();
});
assert!(result.is_err());
assert_eq!(DESERIALIZE_COUNT.load(Ordering::SeqCst), 8);
assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 7); // 5 because 6 panicked and was not init
}
}

Expand Down
20 changes: 20 additions & 0 deletions borsh/tests/test_arrays.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,23 @@ test_arrays!(
);
test_arrays!(test_array_f32, 1000000000.0f32, f32);
test_arrays!(test_array_array_u8, [100u8; 32], [u8; 32]);
test_arrays!(test_array_zst, (), ());

#[derive(BorshDeserialize, BorshSerialize, PartialEq, Debug)]
struct CustomStruct(u8);

#[test]
fn test_custom_struct_array() {
let arr = [CustomStruct(0), CustomStruct(1), CustomStruct(2)];
let serialized = arr.try_to_vec().unwrap();
let deserialized: [CustomStruct; 3] = BorshDeserialize::try_from_slice(&serialized).unwrap();
assert_eq!(arr, deserialized);
}

#[test]
fn test_string_array() {
let arr = ["0".to_string(), "1".to_string(), "2".to_string()];
let serialized = arr.try_to_vec().unwrap();
let deserialized: [String; 3] = BorshDeserialize::try_from_slice(&serialized).unwrap();
assert_eq!(arr, deserialized);
}

0 comments on commit aec5a4e

Please sign in to comment.