From 9107dbc43fb586e02e403349ac97ee832579272d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dj8yf0=CE=BCl?= Date: Wed, 5 Jul 2023 21:45:44 +0300 Subject: [PATCH] feat: derive `BorshSerialize` for recursive structures --- borsh-derive-internal/src/generics.rs | 22 +++++++++ borsh-derive-internal/src/lib.rs | 1 + ...al__struct_ser__tests__bound_generics.snap | 4 +- ...__struct_ser__tests__recursive_struct.snap | 15 ++++++ ...l__struct_ser__tests__simple_generics.snap | 4 +- ...nal__struct_ser__tests__simple_struct.snap | 6 +-- borsh-derive-internal/src/struct_ser.rs | 28 +++++++---- ...eneric_struct__generic_struct_hashmap.snap | 47 +++++++++++++++++++ borsh/tests/test_generic_struct.rs | 33 +++++++++++++ borsh/tests/test_recursive_structs.rs | 45 ++++++++++++++++++ 10 files changed, 187 insertions(+), 18 deletions(-) create mode 100644 borsh-derive-internal/src/generics.rs create mode 100644 borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap create mode 100644 borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap create mode 100644 borsh/tests/test_recursive_structs.rs diff --git a/borsh-derive-internal/src/generics.rs b/borsh-derive-internal/src/generics.rs new file mode 100644 index 000000000..f32fbd555 --- /dev/null +++ b/borsh-derive-internal/src/generics.rs @@ -0,0 +1,22 @@ +use syn::{Generics, WherePredicate, Ident}; +use quote::quote; + + +pub fn compute_predicates( + generics: &Generics, + cratename: &Ident, +) -> Vec { + // Generate function that returns the name of the type. + let mut where_predicates= vec![]; + for type_param in generics.type_params() { + + let type_param_name = &type_param.ident; + where_predicates.push( + syn::parse2(quote! { + #type_param_name: #cratename::ser::BorshSerialize + }) + .unwrap(), + ); + } + where_predicates +} diff --git a/borsh-derive-internal/src/lib.rs b/borsh-derive-internal/src/lib.rs index 469d73b2f..71b9dd043 100644 --- a/borsh-derive-internal/src/lib.rs +++ b/borsh-derive-internal/src/lib.rs @@ -1,6 +1,7 @@ #![recursion_limit = "128"] mod attribute_helpers; +mod generics; mod enum_de; mod enum_discriminant_map; mod enum_ser; diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap index e3f607227..907044899 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__bound_generics.snap @@ -5,8 +5,8 @@ expression: pretty_print_syn_str(&actual).unwrap() impl borsh::ser::BorshSerialize for A where V: Value, - HashMap: borsh::ser::BorshSerialize, - String: borsh::ser::BorshSerialize, + K: borsh::ser::BorshSerialize, + V: borsh::ser::BorshSerialize, { fn serialize( &self, diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap new file mode 100644 index 000000000..1a334016c --- /dev/null +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__recursive_struct.snap @@ -0,0 +1,15 @@ +--- +source: borsh-derive-internal/src/struct_ser.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::ser::BorshSerialize for CRecC { + fn serialize( + &self, + writer: &mut W, + ) -> ::core::result::Result<(), borsh::__private::maybestd::io::Error> { + borsh::BorshSerialize::serialize(&self.a, writer)?; + borsh::BorshSerialize::serialize(&self.b, writer)?; + Ok(()) + } +} + diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap index db820e829..4b2c16e4d 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_generics.snap @@ -4,8 +4,8 @@ expression: pretty_print_syn_str(&actual).unwrap() --- impl borsh::ser::BorshSerialize for A where - HashMap: borsh::ser::BorshSerialize, - String: borsh::ser::BorshSerialize, + K: borsh::ser::BorshSerialize, + V: borsh::ser::BorshSerialize, { fn serialize( &self, diff --git a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap index 492711a20..eb7b34f9c 100644 --- a/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap +++ b/borsh-derive-internal/src/snapshots/borsh_derive_internal__struct_ser__tests__simple_struct.snap @@ -2,11 +2,7 @@ source: borsh-derive-internal/src/struct_ser.rs expression: pretty_print_syn_str(&actual).unwrap() --- -impl borsh::ser::BorshSerialize for A -where - u64: borsh::ser::BorshSerialize, - String: borsh::ser::BorshSerialize, -{ +impl borsh::ser::BorshSerialize for A { fn serialize( &self, writer: &mut W, diff --git a/borsh-derive-internal/src/struct_ser.rs b/borsh-derive-internal/src/struct_ser.rs index 21ca5752e..889a65f95 100644 --- a/borsh-derive-internal/src/struct_ser.rs +++ b/borsh-derive-internal/src/struct_ser.rs @@ -4,7 +4,7 @@ use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::quote; use syn::{Fields, Ident, Index, ItemStruct, WhereClause}; -use crate::attribute_helpers::contains_skip; +use crate::{attribute_helpers::contains_skip, generics::compute_predicates}; pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { let name = &input.ident; @@ -16,6 +16,10 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { @@ -28,14 +32,6 @@ pub fn struct_ser(input: &ItemStruct, cratename: Ident) -> syn::Result { @@ -109,4 +105,18 @@ mod tests { let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); } + + #[test] + fn recursive_struct() { + let item_struct: ItemStruct = syn::parse2(quote!{ + struct CRecC { + a: String, + b: HashMap, + } + }).unwrap(); + + let actual = struct_ser(&item_struct, Ident::new("borsh", Span::call_site())).unwrap(); + + insta::assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } } diff --git a/borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap b/borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap new file mode 100644 index 000000000..277db2640 --- /dev/null +++ b/borsh/tests/snapshots/test_generic_struct__generic_struct_hashmap.snap @@ -0,0 +1,47 @@ +--- +source: borsh/tests/test_generic_struct.rs +expression: data +--- +[ + 5, + 0, + 0, + 0, + 102, + 105, + 101, + 108, + 100, + 2, + 0, + 0, + 0, + 14, + 0, + 0, + 0, + 5, + 0, + 0, + 0, + 118, + 97, + 108, + 117, + 101, + 34, + 0, + 0, + 0, + 7, + 0, + 0, + 0, + 97, + 110, + 111, + 116, + 104, + 101, + 114, +] diff --git a/borsh/tests/test_generic_struct.rs b/borsh/tests/test_generic_struct.rs index b509972f2..87264a434 100644 --- a/borsh/tests/test_generic_struct.rs +++ b/borsh/tests/test_generic_struct.rs @@ -2,6 +2,12 @@ #![cfg(feature = "derive")] use core::marker::PhantomData; +#[cfg(feature = "hashbrown")] +use hashbrown::HashMap; + +#[cfg(feature = "std")] +use std::collections::HashMap; + #[cfg(not(feature = "std"))] extern crate alloc; #[cfg(not(feature = "std"))] @@ -31,6 +37,15 @@ enum B { Y(G), } +/// `T: PartialOrd` bound is required for `BorshSerialize` derive to be successful +#[cfg(hash_collections)] +#[derive(BorshSerialize, BorshDeserialize)] +struct C { + a: String, + b: HashMap, +} + + #[test] fn test_generic_struct() { let a = A:: { @@ -47,3 +62,21 @@ fn test_generic_struct() { let actual_a = from_slice::>(&data).unwrap(); assert_eq!(a, actual_a); } + +#[cfg(hash_collections)] +#[test] +fn test_generic_struct_hashmap() { + let mut hashmap = HashMap::new(); + hashmap.insert(34, "another".to_string()); + hashmap.insert(14, "value".to_string()); + let a = C:: { + a: "field".to_string(), + b: hashmap, + }; + let data = a.try_to_vec().unwrap(); + #[cfg(feature = "std")] + insta::assert_debug_snapshot!(data); + let actual_a = from_slice::>(&data).unwrap(); + assert_eq!(actual_a.b.get(&14), Some("value".to_string()).as_ref()); + assert_eq!(actual_a.b.get(&34), Some("another".to_string()).as_ref()); +} diff --git a/borsh/tests/test_recursive_structs.rs b/borsh/tests/test_recursive_structs.rs new file mode 100644 index 000000000..0a19a1cee --- /dev/null +++ b/borsh/tests/test_recursive_structs.rs @@ -0,0 +1,45 @@ +#![cfg(feature = "derive")] +use borsh::BorshSerialize; + +#[cfg(feature = "hashbrown")] +use hashbrown::HashMap; + +#[cfg(feature = "std")] +use std::collections::HashMap; + +#[cfg(not(feature = "std"))] +extern crate alloc; +#[cfg(not(feature = "std"))] +use alloc::{ + string::{String, ToString}, + boxed::Box, + vec::Vec, +}; + + +/// strangely enough, this worked before current commit +#[cfg(hash_collections)] +#[derive(BorshSerialize)] +struct CRec { + a: String, + b: HashMap>, +} + +#[derive(BorshSerialize)] +struct CRecA { + a: String, + b: Box, +} + +#[derive(BorshSerialize)] +struct CRecB { + a: String, + b: Vec, +} + +#[cfg(hash_collections)] +#[derive(BorshSerialize)] +struct CRecC { + a: String, + b: HashMap, +}