Skip to content

Commit

Permalink
feat!: add DiscriminantValue to Definition::Enum::variants tuples (
Browse files Browse the repository at this point in the history
…#232)

* feat!: add `DiscriminantValue` to `Definition::Enum::variants` tuples

* test: add a couple of tests for enum with discriminants set schema

* chore: restrict valid discriminants to be of `u8` range

to be more congruent with `BorshSerialize`, `BorshDeserialize`

* doc: mention `use_discriminant = <bool>` in `BorshSchema` context
  • Loading branch information
dj8yfo authored Sep 26, 2023
1 parent 2e02f68 commit 2a13d3a
Show file tree
Hide file tree
Showing 26 changed files with 563 additions and 83 deletions.
89 changes: 86 additions & 3 deletions borsh-derive/src/internals/schema/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ use quote::{quote, ToTokens};
use std::collections::HashSet;
use syn::{Fields, Generics, Ident, ItemEnum, ItemStruct, Path, Variant, Visibility};

use crate::internals::{attributes::field, generics, schema};
use crate::internals::{
attributes::{field, item},
enum_discriminant::Discriminants,
generics, schema,
};

fn transform_variant_fields(mut input: Fields) -> Fields {
match input {
Expand Down Expand Up @@ -31,14 +35,23 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let mut where_clause = generics::default_where(where_clause);
let mut generics_output = schema::GenericsOutput::new(&generics);
let use_discriminant = item::contains_use_discriminant(input)?;
let discriminants = Discriminants::new(&input.variants);

// Generate functions that return the schema for variants.
let mut discriminant_variables = vec![];
let mut variants_defs = vec![];
let mut inner_defs = TokenStream2::new();
let mut add_recursive_defs = TokenStream2::new();
for variant in &input.variants {
for (variant_idx, variant) in input.variants.iter().enumerate() {
let discriminant_info = DiscriminantInfo {
variant_idx,
discriminants: &discriminants,
use_discriminant,
};
let variant_output = process_variant(
variant,
discriminant_info,
&cratename,
&enum_name,
&generics,
Expand All @@ -47,12 +60,14 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
inner_defs.extend(variant_output.inner_struct);
add_recursive_defs.extend(variant_output.add_definitions_recursively_call);
variants_defs.push(variant_output.variant_entry);
discriminant_variables.push(variant_output.discriminant_variable_assignment);
}

let type_definitions = quote! {
fn add_definitions_recursively(definitions: &mut #cratename::__private::maybestd::collections::BTreeMap<#cratename::schema::Declaration, #cratename::schema::Definition>) {
#inner_defs
#add_recursive_defs
#(#discriminant_variables)*
let definition = #cratename::schema::Definition::Enum {
tag_width: 1,
variants: #cratename::__private::maybestd::vec![#(#variants_defs),*],
Expand All @@ -78,12 +93,38 @@ struct VariantOutput {
inner_struct: TokenStream2,
/// call to `add_definitions_recursively`.
add_definitions_recursively_call: TokenStream2,
/// declaration of `u8` variable, holding the value for discriminant of a variant
discriminant_variable_assignment: TokenStream2,
/// entry with a variant's declaration, element in vector of whole enum's definition
variant_entry: TokenStream2,
}

struct DiscriminantInfo<'a> {
variant_idx: usize,
discriminants: &'a Discriminants,
use_discriminant: bool,
}

fn process_discriminant(
variant_ident: &Ident,
info: DiscriminantInfo<'_>,
) -> syn::Result<(Ident, TokenStream2)> {
let discriminant_value =
info.discriminants
.get(variant_ident, info.use_discriminant, info.variant_idx)?;

let discriminant_variable_name = format!("discriminant_{}", info.variant_idx);
let discriminant_variable = Ident::new(&discriminant_variable_name, Span::call_site());

let discriminant_variable_assignment = quote! {
let #discriminant_variable: u8 = #discriminant_value;
};
Ok((discriminant_variable, discriminant_variable_assignment))
}

fn process_variant(
variant: &Variant,
discriminant_info: DiscriminantInfo,
cratename: &Path,
enum_name: &str,
enum_generics: &Generics,
Expand All @@ -101,13 +142,18 @@ fn process_variant(
let add_definitions_recursively_call = quote! {
<#full_variant_ident #inner_struct_ty_generics as #cratename::BorshSchema>::add_definitions_recursively(definitions);
};

let (discriminant_variable, discriminant_variable_assignment) =
process_discriminant(&variant.ident, discriminant_info)?;

let variant_entry = quote! {
(#variant_name.to_string(), <#full_variant_ident #inner_struct_ty_generics>::declaration())
(#discriminant_variable as i64, #variant_name.to_string(), <#full_variant_ident #inner_struct_ty_generics>::declaration())
};
Ok(VariantOutput {
inner_struct,
add_definitions_recursively_call,
variant_entry,
discriminant_variable_assignment,
})
}

Expand Down Expand Up @@ -187,6 +233,43 @@ mod tests {
local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}

#[test]
fn borsh_discriminant_false() {
let item_enum: ItemEnum = syn::parse2(quote! {
#[borsh(use_discriminant = false)]
enum X {
A,
B = 20,
C,
D,
E = 10,
F,
}
})
.unwrap();
let actual = process(&item_enum, default_cratename()).unwrap();

local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}
#[test]
fn borsh_discriminant_true() {
let item_enum: ItemEnum = syn::parse2(quote! {
#[borsh(use_discriminant = true)]
enum X {
A,
B = 20,
C,
D,
E = 10,
F,
}
})
.unwrap();
let actual = process(&item_enum, default_cratename()).unwrap();

local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
}

#[test]
fn single_field_enum() {
let item_enum: ItemEnum = syn::parse2(quote! {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
---
source: borsh-derive/src/internals/schema/enums/mod.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl borsh::BorshSchema for X {
fn declaration() -> borsh::schema::Declaration {
"X".to_string()
}
fn add_definitions_recursively(
definitions: &mut borsh::__private::maybestd::collections::BTreeMap<
borsh::schema::Declaration,
borsh::schema::Definition,
>,
) {
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XA;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XB;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XC;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XD;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XE;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XF;
<XA as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XB as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XC as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XD as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XE as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XF as borsh::BorshSchema>::add_definitions_recursively(definitions);
let discriminant_0: u8 = 0u8;
let discriminant_1: u8 = 1u8;
let discriminant_2: u8 = 2u8;
let discriminant_3: u8 = 3u8;
let discriminant_4: u8 = 4u8;
let discriminant_5: u8 = 5u8;
let definition = borsh::schema::Definition::Enum {
tag_width: 1,
variants: borsh::__private::maybestd::vec![
(discriminant_0 as i64, "A".to_string(), < XA > ::declaration()),
(discriminant_1 as i64, "B".to_string(), < XB > ::declaration()),
(discriminant_2 as i64, "C".to_string(), < XC > ::declaration()),
(discriminant_3 as i64, "D".to_string(), < XD > ::declaration()),
(discriminant_4 as i64, "E".to_string(), < XE > ::declaration()),
(discriminant_5 as i64, "F".to_string(), < XF > ::declaration())
],
};
borsh::schema::add_definition(Self::declaration(), definition, definitions);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
---
source: borsh-derive/src/internals/schema/enums/mod.rs
expression: pretty_print_syn_str(&actual).unwrap()
---
impl borsh::BorshSchema for X {
fn declaration() -> borsh::schema::Declaration {
"X".to_string()
}
fn add_definitions_recursively(
definitions: &mut borsh::__private::maybestd::collections::BTreeMap<
borsh::schema::Declaration,
borsh::schema::Definition,
>,
) {
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XA;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XB;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XC;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XD;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XE;
#[allow(dead_code)]
#[derive(borsh::BorshSchema)]
#[borsh(crate = "borsh")]
struct XF;
<XA as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XB as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XC as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XD as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XE as borsh::BorshSchema>::add_definitions_recursively(definitions);
<XF as borsh::BorshSchema>::add_definitions_recursively(definitions);
let discriminant_0: u8 = 0;
let discriminant_1: u8 = 20;
let discriminant_2: u8 = 20 + 1;
let discriminant_3: u8 = 20 + 1 + 1;
let discriminant_4: u8 = 10;
let discriminant_5: u8 = 10 + 1;
let definition = borsh::schema::Definition::Enum {
tag_width: 1,
variants: borsh::__private::maybestd::vec![
(discriminant_0 as i64, "A".to_string(), < XA > ::declaration()),
(discriminant_1 as i64, "B".to_string(), < XB > ::declaration()),
(discriminant_2 as i64, "C".to_string(), < XC > ::declaration()),
(discriminant_3 as i64, "D".to_string(), < XD > ::declaration()),
(discriminant_4 as i64, "E".to_string(), < XE > ::declaration()),
(discriminant_5 as i64, "F".to_string(), < XF > ::declaration())
],
};
borsh::schema::add_definition(Self::declaration(), definition, definitions);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,18 @@ impl borsh::BorshSchema for A {
<AEggs as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASalad as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASausage as borsh::BorshSchema>::add_definitions_recursively(definitions);
let discriminant_0: u8 = 0u8;
let discriminant_1: u8 = 1u8;
let discriminant_2: u8 = 2u8;
let discriminant_3: u8 = 3u8;
let definition = borsh::schema::Definition::Enum {
tag_width: 1,
variants: borsh::__private::maybestd::vec![
("Bacon".to_string(), < ABacon > ::declaration()), ("Eggs".to_string(), <
AEggs > ::declaration()), ("Salad".to_string(), < ASalad >
::declaration()), ("Sausage".to_string(), < ASausage > ::declaration())
(discriminant_0 as i64, "Bacon".to_string(), < ABacon > ::declaration()),
(discriminant_1 as i64, "Eggs".to_string(), < AEggs > ::declaration()),
(discriminant_2 as i64, "Salad".to_string(), < ASalad > ::declaration()),
(discriminant_3 as i64, "Sausage".to_string(), < ASausage >
::declaration())
],
};
borsh::schema::add_definition(Self::declaration(), definition, definitions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,18 @@ where
<AEggs as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASalad<C> as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASausage<W> as borsh::BorshSchema>::add_definitions_recursively(definitions);
let discriminant_0: u8 = 0u8;
let discriminant_1: u8 = 1u8;
let discriminant_2: u8 = 2u8;
let discriminant_3: u8 = 3u8;
let definition = borsh::schema::Definition::Enum {
tag_width: 1,
variants: borsh::__private::maybestd::vec![
("Bacon".to_string(), < ABacon > ::declaration()), ("Eggs".to_string(), <
AEggs > ::declaration()), ("Salad".to_string(), < ASalad < C > >
::declaration()), ("Sausage".to_string(), < ASausage < W > >
::declaration())
(discriminant_0 as i64, "Bacon".to_string(), < ABacon > ::declaration()),
(discriminant_1 as i64, "Eggs".to_string(), < AEggs > ::declaration()),
(discriminant_2 as i64, "Salad".to_string(), < ASalad < C > >
::declaration()), (discriminant_3 as i64, "Sausage".to_string(), <
ASausage < W > > ::declaration())
],
};
borsh::schema::add_definition(Self::declaration(), definition, definitions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,18 @@ where
<AEggs as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASalad<C> as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASausage<W, U> as borsh::BorshSchema>::add_definitions_recursively(definitions);
let discriminant_0: u8 = 0u8;
let discriminant_1: u8 = 1u8;
let discriminant_2: u8 = 2u8;
let discriminant_3: u8 = 3u8;
let definition = borsh::schema::Definition::Enum {
tag_width: 1,
variants: borsh::__private::maybestd::vec![
("Bacon".to_string(), < ABacon > ::declaration()), ("Eggs".to_string(), <
AEggs > ::declaration()), ("Salad".to_string(), < ASalad < C > >
::declaration()), ("Sausage".to_string(), < ASausage < W, U > >
::declaration())
(discriminant_0 as i64, "Bacon".to_string(), < ABacon > ::declaration()),
(discriminant_1 as i64, "Eggs".to_string(), < AEggs > ::declaration()),
(discriminant_2 as i64, "Salad".to_string(), < ASalad < C > >
::declaration()), (discriminant_3 as i64, "Sausage".to_string(), <
ASausage < W, U > > ::declaration())
],
};
borsh::schema::add_definition(Self::declaration(), definition, definitions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,18 @@ where
<AEggs as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASalad<C> as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ASausage<W> as borsh::BorshSchema>::add_definitions_recursively(definitions);
let discriminant_0: u8 = 0u8;
let discriminant_1: u8 = 1u8;
let discriminant_2: u8 = 2u8;
let discriminant_3: u8 = 3u8;
let definition = borsh::schema::Definition::Enum {
tag_width: 1,
variants: borsh::__private::maybestd::vec![
("Bacon".to_string(), < ABacon > ::declaration()), ("Eggs".to_string(), <
AEggs > ::declaration()), ("Salad".to_string(), < ASalad < C > >
::declaration()), ("Sausage".to_string(), < ASausage < W > >
::declaration())
(discriminant_0 as i64, "Bacon".to_string(), < ABacon > ::declaration()),
(discriminant_1 as i64, "Eggs".to_string(), < AEggs > ::declaration()),
(discriminant_2 as i64, "Salad".to_string(), < ASalad < C > >
::declaration()), (discriminant_3 as i64, "Sausage".to_string(), <
ASausage < W > > ::declaration())
],
};
borsh::schema::add_definition(Self::declaration(), definition, definitions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ impl borsh::BorshSchema for A {
}
<AB as borsh::BorshSchema>::add_definitions_recursively(definitions);
<ANegative as borsh::BorshSchema>::add_definitions_recursively(definitions);
let discriminant_0: u8 = 0u8;
let discriminant_1: u8 = 1u8;
let definition = borsh::schema::Definition::Enum {
tag_width: 1,
variants: borsh::__private::maybestd::vec![
("B".to_string(), < AB > ::declaration()), ("Negative".to_string(), <
ANegative > ::declaration())
(discriminant_0 as i64, "B".to_string(), < AB > ::declaration()),
(discriminant_1 as i64, "Negative".to_string(), < ANegative >
::declaration())
],
};
borsh::schema::add_definition(Self::declaration(), definition, definitions);
Expand Down
Loading

0 comments on commit 2a13d3a

Please sign in to comment.