From 57d8f4f4f28566c2ca2b94224e3d185c8dbfacc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dj8yf0=CE=BCl?= Date: Wed, 5 Jul 2023 21:45:44 +0300 Subject: [PATCH] feat: derive `BorshSerialize` for recursive structures --- borsh-derive-internal/src/generics.rs | 16 +++++ borsh-derive-internal/src/lib.rs | 1 + ...al__struct_ser__tests__bound_generics.snap | 4 +- ...__struct_ser__tests__recursive_struct.snap | 15 +++++ ...r__tests__simple_generic_tuple_struct.snap | 18 ++++++ ...l__struct_ser__tests__simple_generics.snap | 4 +- ...nal__struct_ser__tests__simple_struct.snap | 6 +- borsh-derive-internal/src/struct_ser.rs | 38 +++++++++--- ...eneric_struct__generic_struct_hashmap.snap | 47 ++++++++++++++ ...t_recursive_structs__recursive_struct.snap | 41 +++++++++++++ borsh/tests/test_generic_struct.rs | 41 +++++++++++++ borsh/tests/test_recursive_structs.rs | 61 +++++++++++++++++++ 12 files changed, 274 insertions(+), 18 deletions(-) create mode 100644 borsh-derive-internal/src/generics.rs create mode 100644 borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap create mode 100644 borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generic_tuple_struct.snap create mode 100644 borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap create mode 100644 borsh/tests/snapshots/test_recursive_structs__recursive_struct.snap create mode 100644 borsh/tests/test_recursive_structs.rs diff --git a/borsh-derive-internal/src/generics.rs b/borsh-derive-internal/src/generics.rs new file mode 100644 index 000000000..f181b391e --- /dev/null +++ b/borsh-derive-internal/src/generics.rs @@ -0,0 +1,16 @@ +use quote::quote; +use syn::{Generics, Ident, WherePredicate}; + +pub fn compute_predicates(generics: &Generics, cratename: &Ident) -> Vec { + let mut where_predicates = vec![]; + for type_param in generics.type_params() { + let type_param_name = &type_param.ident; + where_predicates.push( + syn::parse2(quote! { + #type_param_name: #cratename::ser::BorshSerialize + }) + .unwrap(), + ); + } + where_predicates +} diff --git a/borsh-derive-internal/src/lib.rs b/borsh-derive-internal/src/lib.rs index 469d73b2f..9d4a1328a 100644 --- a/borsh-derive-internal/src/lib.rs +++ b/borsh-derive-internal/src/lib.rs @@ -4,6 +4,7 @@ mod attribute_helpers; mod enum_de; mod enum_discriminant_map; mod enum_ser; +mod generics; mod struct_de; mod struct_ser; mod union_de; diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap index e3f607227..907044899 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap @@ -5,8 +5,8 @@ expression: pretty_print_syn_str(&actual).unwrap() impl borsh::ser::BorshSerialize for A where V: Value, - HashMap: borsh::ser::BorshSerialize, - String: borsh::ser::BorshSerialize, + K: borsh::ser::BorshSerialize, + V: borsh::ser::BorshSerialize, { fn serialize( &self, diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap new file mode 100644 index 000000000..1a334016c --- /dev/null +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap @@ -0,0 +1,15 @@ +--- +source: borsh-derive-internal/src/struct_ser.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::ser::BorshSerialize for CRecC { + fn serialize( + &self, + writer: &mut W, + ) -> ::core::result::Result<(), borsh::__private::maybestd::io::Error> { + borsh::BorshSerialize::serialize(&self.a, writer)?; + borsh::BorshSerialize::serialize(&self.b, writer)?; + Ok(()) + } +} + diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generic_tuple_struct.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generic_tuple_struct.snap new file mode 100644 index 000000000..fe538045d --- /dev/null +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generic_tuple_struct.snap @@ -0,0 +1,18 @@ +--- +source: borsh-derive-internal/src/struct_ser.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::ser::BorshSerialize for TupleA +where + T: borsh::ser::BorshSerialize, +{ + fn serialize( + &self, + writer: &mut W, + ) -> ::core::result::Result<(), borsh::__private::maybestd::io::Error> { + borsh::BorshSerialize::serialize(&self.0, writer)?; + borsh::BorshSerialize::serialize(&self.1, writer)?; + Ok(()) + } +} + diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap index db820e829..4b2c16e4d 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap @@ -4,8 +4,8 @@ expression: pretty_print_syn_str(&actual).unwrap() --- impl borsh::ser::BorshSerialize for A where - HashMap: borsh::ser::BorshSerialize, - String: borsh::ser::BorshSerialize, + K: borsh::ser::BorshSerialize, + V: borsh::ser::BorshSerialize, { fn serialize( &self, diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap index 492711a20..eb7b34f9c 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap @@ -2,11 +2,7 @@ source: borsh-derive-internal/src/struct_ser.rs expression: pretty_print_syn_str(&actual).unwrap() --- -impl borsh::ser::BorshSerialize for A -where - u64: borsh::ser::BorshSerialize, - String: borsh::ser::BorshSerialize, -{ +impl borsh::ser::BorshSerialize for A { fn serialize( &self, writer: &mut W, diff --git a/borsh-derive-internal/src/struct_ser.rs b/borsh-derive-internal/src/struct_ser.rs index 21ca5752e..3d69859ba 100644 --- a/borsh-derive-internal/src/struct_ser.rs +++ b/borsh-derive-internal/src/struct_ser.rs @@ -4,7 +4,7 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::quote; use syn::{Fields, Ident, Index, ItemStruct, WhereClause}; -use crate::attribute_helpers::contains_skip; +use crate::{attribute_helpers::contains_skip, generics::compute_predicates}; pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { let name = &input.ident; @@ -16,6 +16,10 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { @@ -28,14 +32,6 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { @@ -97,6 +93,16 @@ mod tests { insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); } + #[test] + fn simple_generic_tuple_struct() { + let item_struct: ItemStruct = syn::parse2(quote!{ + struct TupleA(T, u32); + }).unwrap(); + + let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } + #[test] fn bound_generics() { let item_struct: ItemStruct = syn::parse2(quote!{ @@ -109,4 +115,18 @@ mod tests { let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); } + + #[test] + fn recursive_struct() { + let item_struct: ItemStruct = syn::parse2(quote!{ + struct CRecC { + a: String, + b: HashMap, + } + }).unwrap(); + + let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + + insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } } diff --git a/borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap b/borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap new file mode 100644 index 000000000..277db2640 --- /dev/null +++ b/borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap @@ -0,0 +1,47 @@ +--- +source: borsh/tests/test_generic_struct.rs +expression: data +--- +[ + 5, + 0, + 0, + 0, + 102, + 105, + 101, + 108, + 100, + 2, + 0, + 0, + 0, + 14, + 0, + 0, + 0, + 5, + 0, + 0, + 0, + 118, + 97, + 108, + 117, + 101, + 34, + 0, + 0, + 0, + 7, + 0, + 0, + 0, + 97, + 110, + 111, + 116, + 104, + 101, + 114, +] diff --git a/borsh/tests/snapshots/test_recursive_structs__recursive_struct.snap b/borsh/tests/snapshots/test_recursive_structs__recursive_struct.snap new file mode 100644 index 000000000..030c7ce64 --- /dev/null +++ b/borsh/tests/snapshots/test_recursive_structs__recursive_struct.snap @@ -0,0 +1,41 @@ +--- +source: borsh/tests/test_recursive_structs.rs +expression: data +--- +[ + 5, + 0, + 0, + 0, + 116, + 104, + 114, + 101, + 101, + 2, + 0, + 0, + 0, + 3, + 0, + 0, + 0, + 111, + 110, + 101, + 0, + 0, + 0, + 0, + 3, + 0, + 0, + 0, + 116, + 119, + 111, + 0, + 0, + 0, + 0, +] diff --git a/borsh/tests/test_generic_struct.rs b/borsh/tests/test_generic_struct.rs index b509972f2..f9133c530 100644 --- a/borsh/tests/test_generic_struct.rs +++ b/borsh/tests/test_generic_struct.rs @@ -2,6 +2,12 @@ #![cfg(feature = "derive")] use core::marker::PhantomData; +#[cfg(feature = "hashbrown")] +use hashbrown::HashMap; + +#[cfg(feature = "std")] +use std::collections::HashMap; + #[cfg(not(feature = "std"))] extern crate alloc; #[cfg(not(feature = "std"))] @@ -31,6 +37,23 @@ enum B { Y(G), } +#[derive(BorshSerialize, Debug)] +struct TupleA(T, u32); + +#[derive(BorshSerialize, Debug)] +struct NamedA { + a: T, + b: u32, +} + +/// `T: PartialOrd` bound is required for `BorshSerialize` derive to be successful +#[cfg(hash_collections)] +#[derive(BorshSerialize, BorshDeserialize)] +struct C { + a: String, + b: HashMap, +} + #[test] fn test_generic_struct() { let a = A:: { @@ -47,3 +70,21 @@ fn test_generic_struct() { let actual_a = from_slice::>(&data).unwrap(); assert_eq!(a, actual_a); } + +#[cfg(hash_collections)] +#[test] +fn test_generic_struct_hashmap() { + let mut hashmap = HashMap::new(); + hashmap.insert(34, "another".to_string()); + hashmap.insert(14, "value".to_string()); + let a = C:: { + a: "field".to_string(), + b: hashmap, + }; + let data = a.try_to_vec().unwrap(); + #[cfg(feature = "std")] + insta::assert_debug_snapshot!(data); + let actual_a = from_slice::>(&data).unwrap(); + assert_eq!(actual_a.b.get(&14), Some("value".to_string()).as_ref()); + assert_eq!(actual_a.b.get(&34), Some("another".to_string()).as_ref()); +} diff --git a/borsh/tests/test_recursive_structs.rs b/borsh/tests/test_recursive_structs.rs new file mode 100644 index 000000000..ae0511ee7 --- /dev/null +++ b/borsh/tests/test_recursive_structs.rs @@ -0,0 +1,61 @@ +#![cfg(feature = "derive")] +use borsh::BorshSerialize; + +#[cfg(feature = "hashbrown")] +use hashbrown::HashMap; + +#[cfg(feature = "std")] +use std::collections::HashMap; + +#[cfg(not(feature = "std"))] +extern crate alloc; +#[cfg(not(feature = "std"))] +use alloc::{boxed::Box, string::String, vec::Vec}; + +/// strangely enough, this worked before current commit +#[cfg(hash_collections)] +#[derive(BorshSerialize)] +struct CRec { + a: String, + b: HashMap>, +} + +#[derive(BorshSerialize)] +struct CRecA { + a: String, + b: Box, +} + +#[derive(BorshSerialize, PartialEq, Eq)] +struct CRecB { + a: String, + b: Vec, +} + +#[cfg(hash_collections)] +#[derive(BorshSerialize)] +struct CRecC { + a: String, + b: HashMap, +} + +#[test] +fn test_recursive_struct() { + let one = CRecB { + a: "one".to_string(), + b: vec![], + }; + let two = CRecB { + a: "two".to_string(), + b: vec![], + }; + + let three = CRecB { + a: "three".to_string(), + b: vec![one, two], + }; + let _data = three.try_to_vec().unwrap(); + #[cfg(feature = "std")] + insta::assert_debug_snapshot!(_data); + // let actual_three = from_slice::(&data).unwrap(); +}