Skip to content

Commit

Permalink
feat: derive BorshSerialize for recursive structures
Browse files Browse the repository at this point in the history
  • Loading branch information
dj8yf0μl committed Jul 5, 2023
1 parent db334b9 commit 9107dbc
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 18 deletions.
22 changes: 22 additions & 0 deletions borsh-derive-internal/src/generics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use syn::{Generics, WherePredicate, Ident};
use quote::quote;


pub fn compute_predicates(
generics: &Generics,
cratename: &Ident,
) -> Vec<WherePredicate> {
// Generate function that returns the name of the type.
let mut where_predicates= vec![];
for type_param in generics.type_params() {

let type_param_name = &type_param.ident;
where_predicates.push(
syn::parse2(quote! {
#type_param_name: #cratename::ser::BorshSerialize
})
.unwrap(),
);
}
where_predicates
}
1 change: 1 addition & 0 deletions borsh-derive-internal/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![recursion_limit = "128"]

mod attribute_helpers;
mod generics;
mod enum_de;
mod enum_discriminant_map;
mod enum_ser;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ expression: pretty_print_syn_str(&actual).unwrap()
impl<K: Key, V> borsh::ser::BorshSerialize for A<K, V>
where
V: Value,
HashMap<K, V>: borsh::ser::BorshSerialize,
String: borsh::ser::BorshSerialize,
K: borsh::ser::BorshSerialize,
V: borsh::ser::BorshSerialize,
{
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
source: borsh-derive-internal/src/struct_ser.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl borsh::ser::BorshSerialize for CRecC {
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
writer: &mut W,
) -> ::core::result::Result<(), borsh::__private::maybestd::io::Error> {
borsh::BorshSerialize::serialize(&self.a, writer)?;
borsh::BorshSerialize::serialize(&self.b, writer)?;
Ok(())
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ expression: pretty_print_syn_str(&actual).unwrap()
---
impl<K, V> borsh::ser::BorshSerialize for A<K, V>
where
HashMap<K, V>: borsh::ser::BorshSerialize,
String: borsh::ser::BorshSerialize,
K: borsh::ser::BorshSerialize,
V: borsh::ser::BorshSerialize,
{
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
source: borsh-derive-internal/src/struct_ser.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl borsh::ser::BorshSerialize for A
where
u64: borsh::ser::BorshSerialize,
String: borsh::ser::BorshSerialize,
{
impl borsh::ser::BorshSerialize for A {
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
writer: &mut W,
Expand Down
28 changes: 19 additions & 9 deletions borsh-derive-internal/src/struct_ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{Fields, Ident, Index, ItemStruct, WhereClause};

use crate::attribute_helpers::contains_skip;
use crate::{attribute_helpers::contains_skip, generics::compute_predicates};

pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream2> {
let name = &input.ident;
Expand All @@ -16,6 +16,10 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStre
},
Clone::clone,
);
let predicates = compute_predicates(&input.generics, &cratename);
predicates
.into_iter()
.for_each(|predicate| where_clause.predicates.push(predicate));
let mut body = TokenStream2::new();
match &input.fields {
Fields::Named(fields) => {
Expand All @@ -28,14 +32,6 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStre
#cratename::BorshSerialize::serialize(&self.#field_name, writer)?;
};
body.extend(delta);

let field_type = &field.ty;
where_clause.predicates.push(
syn::parse2(quote! {
#field_type: #cratename::ser::BorshSerialize
})
.unwrap(),
);
}
}
Fields::Unnamed(fields) => {
Expand Down Expand Up @@ -109,4 +105,18 @@ mod tests {
let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap();
insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}

#[test]
fn recursive_struct() {
let item_struct: ItemStruct = syn::parse2(quote!{
struct CRecC {
a: String,
b: HashMap<String, CRecC>,
}
}).unwrap();

let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap();

insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
---
source: borsh/tests/test_generic_struct.rs
expression: data
---
[
5,
0,
0,
0,
102,
105,
101,
108,
100,
2,
0,
0,
0,
14,
0,
0,
0,
5,
0,
0,
0,
118,
97,
108,
117,
101,
34,
0,
0,
0,
7,
0,
0,
0,
97,
110,
111,
116,
104,
101,
114,
]
33 changes: 33 additions & 0 deletions borsh/tests/test_generic_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#![cfg(feature = "derive")]
use core::marker::PhantomData;

#[cfg(feature = "hashbrown")]
use hashbrown::HashMap;

#[cfg(feature = "std")]
use std::collections::HashMap;

#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
Expand Down Expand Up @@ -31,6 +37,15 @@ enum B<F, G> {
Y(G),
}

/// `T: PartialOrd` bound is required for `BorshSerialize` derive to be successful
#[cfg(hash_collections)]
#[derive(BorshSerialize, BorshDeserialize)]
struct C<T: PartialOrd, U> {
a: String,
b: HashMap<T, U>,
}


#[test]
fn test_generic_struct() {
let a = A::<String, u64, String> {
Expand All @@ -47,3 +62,21 @@ fn test_generic_struct() {
let actual_a = from_slice::<A<String, u64, String>>(&data).unwrap();
assert_eq!(a, actual_a);
}

#[cfg(hash_collections)]
#[test]
fn test_generic_struct_hashmap() {
let mut hashmap = HashMap::new();
hashmap.insert(34, "another".to_string());
hashmap.insert(14, "value".to_string());
let a = C::<u32, String> {
a: "field".to_string(),
b: hashmap,
};
let data = a.try_to_vec().unwrap();
#[cfg(feature = "std")]
insta::assert_debug_snapshot!(data);
let actual_a = from_slice::<C::<u32, String>>(&data).unwrap();
assert_eq!(actual_a.b.get(&14), Some("value".to_string()).as_ref());
assert_eq!(actual_a.b.get(&34), Some("another".to_string()).as_ref());
}
45 changes: 45 additions & 0 deletions borsh/tests/test_recursive_structs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#![cfg(feature = "derive")]
use borsh::BorshSerialize;

#[cfg(feature = "hashbrown")]
use hashbrown::HashMap;

#[cfg(feature = "std")]
use std::collections::HashMap;

#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::{
string::{String, ToString},
boxed::Box,
vec::Vec,
};


/// strangely enough, this worked before current commit
#[cfg(hash_collections)]
#[derive(BorshSerialize)]
struct CRec<U: PartialOrd> {
a: String,
b: HashMap<U,CRec<U>>,
}

#[derive(BorshSerialize)]
struct CRecA {
a: String,
b: Box<CRecA>,
}

#[derive(BorshSerialize)]
struct CRecB {
a: String,
b: Vec<CRecB>,
}

#[cfg(hash_collections)]
#[derive(BorshSerialize)]
struct CRecC {
a: String,
b: HashMap<String, CRecC>,
}

0 comments on commit 9107dbc

Please sign in to comment.