Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support for BTreeMap, BTreeSet, BinaryHeap, LinkedList types #6

Merged
merged 3 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 55 additions & 34 deletions borsh/src/de/mod.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
use core::{
convert::TryInto,
hash::Hash,
mem::{forget, size_of}
mem::{forget, size_of},
};

use crate::maybestd::{
io::{Error, ErrorKind, Result},
borrow::{
Cow,
ToOwned,
Borrow
},
collections::{BTreeMap, HashMap, HashSet},
borrow::{Borrow, Cow, ToOwned},
boxed::Box,
collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque},
format,
io::{Error, ErrorKind, Result},
string::{String, ToString},
vec::Vec,
boxed::Box
};


mod hint;

const ERROR_NOT_ALL_BYTES_READ: &str = "Not all bytes read";
Expand All @@ -35,10 +30,7 @@ pub trait BorshDeserialize: Sized {
let mut v_mut = v;
let result = Self::deserialize(&mut v_mut)?;
if !v_mut.is_empty() {
return Err(Error::new(
ErrorKind::InvalidData,
ERROR_NOT_ALL_BYTES_READ,
));
return Err(Error::new(ErrorKind::InvalidData, ERROR_NOT_ALL_BYTES_READ));
}
Ok(result)
}
Expand Down Expand Up @@ -153,16 +145,11 @@ impl BorshDeserialize for bool {
} else {
let msg = format!("Invalid bool representation: {}", b);

Err(Error::new(
ErrorKind::InvalidInput,
msg,
))
Err(Error::new(ErrorKind::InvalidInput, msg))
}
}
}



impl<T> BorshDeserialize for Option<T>
where
T: BorshDeserialize,
Expand All @@ -187,10 +174,7 @@ where
flag
);

Err(Error::new(
ErrorKind::InvalidInput,
msg,
))
Err(Error::new(ErrorKind::InvalidInput, msg))
}
}
}
Expand Down Expand Up @@ -220,22 +204,18 @@ where
flag
);

Err(Error::new(
ErrorKind::InvalidInput,
msg,
))
Err(Error::new(ErrorKind::InvalidInput, msg))
}
}
}

impl BorshDeserialize for String {
#[inline]
fn deserialize(buf: &mut &[u8]) -> Result<Self> {
String::from_utf8(Vec::<u8>::deserialize(buf)?)
.map_err(|err| {
let msg = err.to_string();
Error::new(ErrorKind::InvalidData, msg)
})
String::from_utf8(Vec::<u8>::deserialize(buf)?).map_err(|err| {
let msg = err.to_string();
Error::new(ErrorKind::InvalidData, msg)
})
}
}

Expand Down Expand Up @@ -310,6 +290,38 @@ where
}
}

impl<T> BorshDeserialize for VecDeque<T>
where
T: BorshDeserialize,
{
#[inline]
fn deserialize(buf: &mut &[u8]) -> Result<Self> {
let vec = <Vec<T>>::deserialize(buf)?;
Ok(vec.into())
}
}

impl<T> BorshDeserialize for LinkedList<T>
where
T: BorshDeserialize,
{
#[inline]
fn deserialize(buf: &mut &[u8]) -> Result<Self> {
let vec = <Vec<T>>::deserialize(buf)?;
Ok(vec.into_iter().collect::<LinkedList<T>>())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks beautiful, however do we concern about avoid extra copy and convert cost, if we really want highest performance?

Copy link
Collaborator Author

@frol frol Jan 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is how HashSet was already implemented. I think we can have another round of writing a helper to deserialize into an iterator and thus optimize the performance and avoid duplicate code

}
}

impl<T> BorshDeserialize for BinaryHeap<T>
where
T: BorshDeserialize + Ord,
{
#[inline]
fn deserialize(buf: &mut &[u8]) -> Result<Self> {
let vec = <Vec<T>>::deserialize(buf)?;
Ok(vec.into_iter().collect::<BinaryHeap<T>>())
}
}

impl<T> BorshDeserialize for HashSet<T>
where
Expand All @@ -322,7 +334,6 @@ where
}
}


impl<K, V> BorshDeserialize for HashMap<K, V>
where
K: BorshDeserialize + Eq + Hash,
Expand All @@ -342,6 +353,16 @@ where
}
}

impl<T> BorshDeserialize for BTreeSet<T>
where
T: BorshDeserialize + Ord,
{
#[inline]
fn deserialize(buf: &mut &[u8]) -> Result<Self> {
let vec = <Vec<T>>::deserialize(buf)?;
Ok(vec.into_iter().collect::<BTreeSet<T>>())
}
}

impl<K, V> BorshDeserialize for BTreeMap<K, V>
where
Expand Down
132 changes: 112 additions & 20 deletions borsh/src/ser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use core::convert::TryFrom;
use core::mem::size_of;

use crate::maybestd::{
borrow::{Cow, ToOwned},
boxed::Box,
collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque},
io::{ErrorKind, Result, Write},
collections::{HashMap, HashSet},
borrow::{ToOwned, Cow},
string::String,
boxed::Box,
vec::Vec
vec::Vec,
};

const DEFAULT_SERIALIZER_CAPACITY: usize = 1024;
Expand Down Expand Up @@ -145,6 +145,24 @@ impl BorshSerialize for String {
}
}

/// Helper method that is used to serialize a slice of data (without the length marker).
#[inline]
fn serialize_slice<T: BorshSerialize, W: Write>(data: &[T], writer: &mut W) -> Result<()> {
if T::is_u8() && size_of::<T>() == size_of::<u8>() {
// The code below uses unsafe memory representation from `&[T]` to `&[u8]`.
// The size of the memory should match because `size_of::<T>() == size_of::<u8>()`.
//
// `T::is_u8()` is a workaround for not being able to implement `Vec<u8>` separately.
let buf = unsafe { core::slice::from_raw_parts(data.as_ptr() as *const u8, data.len()) };
writer.write_all(buf)?;
} else {
for item in data {
item.serialize(writer)?;
}
}
Ok(())
}

impl<T> BorshSerialize for [T]
where
T: BorshSerialize,
Expand All @@ -154,19 +172,7 @@ where
writer.write_all(
&(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(),
)?;
if T::is_u8() && size_of::<T>() == size_of::<u8>() {
// The code below uses unsafe memory representation from `&[T]` to `&[u8]`.
// The size of the memory should match because `size_of::<T>() == size_of::<u8>()`.
//
// `T::is_u8()` is a workaround for not being able to implement `Vec<u8>` separately.
let buf = unsafe { core::slice::from_raw_parts(self.as_ptr() as *const u8, self.len()) };
writer.write_all(buf)?;
} else {
for item in self {
item.serialize(writer)?;
}
}
Ok(())
serialize_slice(self, writer)
}
}

Expand Down Expand Up @@ -197,6 +203,56 @@ where
}
}

impl<T> BorshSerialize for VecDeque<T>
where
T: BorshSerialize,
{
#[inline]
fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
writer.write_all(
&(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(),
)?;
let slices = self.as_slices();
serialize_slice(slices.0, writer)?;
serialize_slice(slices.1, writer)
}
}

impl<T> BorshSerialize for LinkedList<T>
where
T: BorshSerialize,
{
#[inline]
fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
writer.write_all(
&(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(),
)?;
for item in self {
item.serialize(writer)?;
}
Ok(())
}
}

impl<T> BorshSerialize for BinaryHeap<T>
where
T: BorshSerialize,
{
#[inline]
fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
// It could have been just `self.as_slice().serialize(writer)`, but there is no
// `as_slice()` method:
// https://internals.rust-lang.org/t/should-i-add-as-slice-method-to-binaryheap/13816
writer.write_all(
&(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidInput)?).to_le_bytes(),
)?;
for item in self {
item.serialize(writer)?;
}
Ok(())
}
}

impl<K, V> BorshSerialize for HashMap<K, V>
where
K: BorshSerialize + PartialOrd,
Expand All @@ -217,10 +273,9 @@ where
}
}


impl<T> BorshSerialize for HashSet<T>
where
T: BorshSerialize + PartialOrd,
where
T: BorshSerialize + PartialOrd,
{
#[inline]
fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
Expand All @@ -236,6 +291,43 @@ impl<T> BorshSerialize for HashSet<T>
}
}

impl<K, V> BorshSerialize for BTreeMap<K, V>
where
K: BorshSerialize + PartialOrd,
V: BorshSerialize,
{
#[inline]
fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
let mut vec = self.iter().collect::<Vec<_>>();
vec.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap());
frol marked this conversation as resolved.
Show resolved Hide resolved
u32::try_from(vec.len())
.map_err(|_| ErrorKind::InvalidInput)?
.serialize(writer)?;
for (key, value) in vec {
key.serialize(writer)?;
value.serialize(writer)?;
}
Ok(())
}
}

impl<T> BorshSerialize for BTreeSet<T>
where
T: BorshSerialize + PartialOrd,
{
#[inline]
fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
let mut vec = self.iter().collect::<Vec<_>>();
vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
u32::try_from(vec.len())
.map_err(|_| ErrorKind::InvalidInput)?
.serialize(writer)?;
for item in vec {
item.serialize(writer)?;
}
Ok(())
}
}

#[cfg(feature = "std")]
impl BorshSerialize for std::net::SocketAddr {
Expand Down
30 changes: 30 additions & 0 deletions borsh/tests/test_binary_heaps.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use borsh::maybestd::collections::BinaryHeap;
use borsh::{BorshDeserialize, BorshSerialize};

macro_rules! test_binary_heap {
($v: expr, $t: ty) => {
let buf = $v.try_to_vec().unwrap();
let actual_v: BinaryHeap<$t> =
BorshDeserialize::try_from_slice(&buf).expect("failed to deserialize");
assert_eq!(actual_v.into_vec(), $v.into_vec());
};
}

macro_rules! test_binary_heaps {
($test_name: ident, $el: expr, $t: ty) => {
#[test]
fn $test_name() {
test_binary_heap!(BinaryHeap::<$t>::new(), $t);
test_binary_heap!(vec![$el].into_iter().collect::<BinaryHeap<_>>(), $t);
test_binary_heap!(vec![$el; 10].into_iter().collect::<BinaryHeap<_>>(), $t);
test_binary_heap!(vec![$el; 100].into_iter().collect::<BinaryHeap<_>>(), $t);
test_binary_heap!(vec![$el; 1000].into_iter().collect::<BinaryHeap<_>>(), $t);
test_binary_heap!(vec![$el; 10000].into_iter().collect::<BinaryHeap<_>>(), $t);
}
};
}

test_binary_heaps!(test_binary_heap_u8, 100u8, u8);
test_binary_heaps!(test_binary_heap_i8, 100i8, i8);
test_binary_heaps!(test_binary_heap_u32, 1000000000u32, u32);
test_binary_heaps!(test_binary_heap_string, "a".to_string(), String);
Loading