Skip to content

Commit

Permalink
feat: derive BorshSerialize for recursive structures
Browse files Browse the repository at this point in the history
  • Loading branch information
dj8yf0μl committed Jul 5, 2023
1 parent 387c665 commit 57d8f4f
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 18 deletions.
16 changes: 16 additions & 0 deletions borsh-derive-internal/src/generics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use quote::quote;
use syn::{Generics, Ident, WherePredicate};

pub fn compute_predicates(generics: &Generics, cratename: &Ident) -> Vec<WherePredicate> {
let mut where_predicates = vec![];
for type_param in generics.type_params() {
let type_param_name = &type_param.ident;
where_predicates.push(
syn::parse2(quote! {
#type_param_name: #cratename::ser::BorshSerialize
})
.unwrap(),
);
}
where_predicates
}
1 change: 1 addition & 0 deletions borsh-derive-internal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod attribute_helpers;
mod enum_de;
mod enum_discriminant_map;
mod enum_ser;
mod generics;
mod struct_de;
mod struct_ser;
mod union_de;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ expression: pretty_print_syn_str(&actual).unwrap()
impl<K: Key, V> borsh::ser::BorshSerialize for A<K, V>
where
V: Value,
HashMap<K, V>: borsh::ser::BorshSerialize,
String: borsh::ser::BorshSerialize,
K: borsh::ser::BorshSerialize,
V: borsh::ser::BorshSerialize,
{
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
source: borsh-derive-internal/src/struct_ser.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl borsh::ser::BorshSerialize for CRecC {
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
writer: &mut W,
) -> ::core::result::Result<(), borsh::__private::maybestd::io::Error> {
borsh::BorshSerialize::serialize(&self.a, writer)?;
borsh::BorshSerialize::serialize(&self.b, writer)?;
Ok(())
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
source: borsh-derive-internal/src/struct_ser.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl<T> borsh::ser::BorshSerialize for TupleA<T>
where
T: borsh::ser::BorshSerialize,
{
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
writer: &mut W,
) -> ::core::result::Result<(), borsh::__private::maybestd::io::Error> {
borsh::BorshSerialize::serialize(&self.0, writer)?;
borsh::BorshSerialize::serialize(&self.1, writer)?;
Ok(())
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ expression: pretty_print_syn_str(&actual).unwrap()
---
impl<K, V> borsh::ser::BorshSerialize for A<K, V>
where
HashMap<K, V>: borsh::ser::BorshSerialize,
String: borsh::ser::BorshSerialize,
K: borsh::ser::BorshSerialize,
V: borsh::ser::BorshSerialize,
{
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
source: borsh-derive-internal/src/struct_ser.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl borsh::ser::BorshSerialize for A
where
u64: borsh::ser::BorshSerialize,
String: borsh::ser::BorshSerialize,
{
impl borsh::ser::BorshSerialize for A {
fn serialize<W: borsh::__private::maybestd::io::Write>(
&self,
writer: &mut W,
Expand Down
38 changes: 29 additions & 9 deletions borsh-derive-internal/src/struct_ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{Fields, Ident, Index, ItemStruct, WhereClause};

use crate::attribute_helpers::contains_skip;
use crate::{attribute_helpers::contains_skip, generics::compute_predicates};

pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream2> {
let name = &input.ident;
Expand All @@ -16,6 +16,10 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStre
},
Clone::clone,
);
let predicates = compute_predicates(&input.generics, &cratename);
predicates
.into_iter()
.for_each(|predicate| where_clause.predicates.push(predicate));
let mut body = TokenStream2::new();
match &input.fields {
Fields::Named(fields) => {
Expand All @@ -28,14 +32,6 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStre
#cratename::BorshSerialize::serialize(&self.#field_name, writer)?;
};
body.extend(delta);

let field_type = &field.ty;
where_clause.predicates.push(
syn::parse2(quote! {
#field_type: #cratename::ser::BorshSerialize
})
.unwrap(),
);
}
}
Fields::Unnamed(fields) => {
Expand Down Expand Up @@ -97,6 +93,16 @@ mod tests {
insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}

#[test]
fn simple_generic_tuple_struct() {
let item_struct: ItemStruct = syn::parse2(quote!{
struct TupleA<T>(T, u32);
}).unwrap();

let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap();
insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}

#[test]
fn bound_generics() {
let item_struct: ItemStruct = syn::parse2(quote!{
Expand All @@ -109,4 +115,18 @@ mod tests {
let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap();
insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}

#[test]
fn recursive_struct() {
let item_struct: ItemStruct = syn::parse2(quote!{
struct CRecC {
a: String,
b: HashMap<String, CRecC>,
}
}).unwrap();

let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap();

insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
---
source: borsh/tests/test_generic_struct.rs
expression: data
---
[
5,
0,
0,
0,
102,
105,
101,
108,
100,
2,
0,
0,
0,
14,
0,
0,
0,
5,
0,
0,
0,
118,
97,
108,
117,
101,
34,
0,
0,
0,
7,
0,
0,
0,
97,
110,
111,
116,
104,
101,
114,
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
source: borsh/tests/test_recursive_structs.rs
expression: data
---
[
5,
0,
0,
0,
116,
104,
114,
101,
101,
2,
0,
0,
0,
3,
0,
0,
0,
111,
110,
101,
0,
0,
0,
0,
3,
0,
0,
0,
116,
119,
111,
0,
0,
0,
0,
]
41 changes: 41 additions & 0 deletions borsh/tests/test_generic_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#![cfg(feature = "derive")]
use core::marker::PhantomData;

#[cfg(feature = "hashbrown")]
use hashbrown::HashMap;

#[cfg(feature = "std")]
use std::collections::HashMap;

#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
Expand Down Expand Up @@ -31,6 +37,23 @@ enum B<F, G> {
Y(G),
}

#[derive(BorshSerialize, Debug)]
struct TupleA<T>(T, u32);

#[derive(BorshSerialize, Debug)]
struct NamedA<T> {
a: T,
b: u32,
}

/// `T: PartialOrd` bound is required for `BorshSerialize` derive to be successful
#[cfg(hash_collections)]
#[derive(BorshSerialize, BorshDeserialize)]
struct C<T: PartialOrd, U> {
a: String,
b: HashMap<T, U>,
}

#[test]
fn test_generic_struct() {
let a = A::<String, u64, String> {
Expand All @@ -47,3 +70,21 @@ fn test_generic_struct() {
let actual_a = from_slice::<A<String, u64, String>>(&data).unwrap();
assert_eq!(a, actual_a);
}

#[cfg(hash_collections)]
#[test]
fn test_generic_struct_hashmap() {
let mut hashmap = HashMap::new();
hashmap.insert(34, "another".to_string());
hashmap.insert(14, "value".to_string());
let a = C::<u32, String> {
a: "field".to_string(),
b: hashmap,
};
let data = a.try_to_vec().unwrap();
#[cfg(feature = "std")]
insta::assert_debug_snapshot!(data);
let actual_a = from_slice::<C<u32, String>>(&data).unwrap();
assert_eq!(actual_a.b.get(&14), Some("value".to_string()).as_ref());
assert_eq!(actual_a.b.get(&34), Some("another".to_string()).as_ref());
}
61 changes: 61 additions & 0 deletions borsh/tests/test_recursive_structs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#![cfg(feature = "derive")]
use borsh::BorshSerialize;

#[cfg(feature = "hashbrown")]
use hashbrown::HashMap;

#[cfg(feature = "std")]
use std::collections::HashMap;

#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};

/// strangely enough, this worked before current commit
#[cfg(hash_collections)]
#[derive(BorshSerialize)]
struct CRec<U: PartialOrd> {
a: String,
b: HashMap<U, CRec<U>>,
}

#[derive(BorshSerialize)]
struct CRecA {
a: String,
b: Box<CRecA>,
}

#[derive(BorshSerialize, PartialEq, Eq)]
struct CRecB {
a: String,
b: Vec<CRecB>,
}

#[cfg(hash_collections)]
#[derive(BorshSerialize)]
struct CRecC {
a: String,
b: HashMap<String, CRecC>,
}

#[test]
fn test_recursive_struct() {
let one = CRecB {
a: "one".to_string(),
b: vec![],
};
let two = CRecB {
a: "two".to_string(),
b: vec![],
};

let three = CRecB {
a: "three".to_string(),
b: vec![one, two],
};
let _data = three.try_to_vec().unwrap();
#[cfg(feature = "std")]
insta::assert_debug_snapshot!(_data);
// let actual_three = from_slice::<CRecB>(&data).unwrap();
}

0 comments on commit 57d8f4f

Please sign in to comment.