diff --git a/borsh-derive/src/internals/serialize/enums/mod.rs b/borsh-derive/src/internals/serialize/enums/mod.rs index 4d7114407..923f1ee62 100644 --- a/borsh-derive/src/internals/serialize/enums/mod.rs +++ b/borsh-derive/src/internals/serialize/enums/mod.rs @@ -18,6 +18,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut fields_body = TokenStream2::new(); let use_discriminant = item::contains_use_discriminant(input)?; let discriminants = Discriminants::new(&input.variants); + let mut has_unit_variant = false; for (variant_idx, variant) in input.variants.iter().enumerate() { let variant_ident = &variant.ident; @@ -30,13 +31,16 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { &mut generics_output, )?; all_variants_idx_body.extend(variant_output.variant_idx_body); - let (variant_header, variant_body) = (variant_output.header, variant_output.body); - fields_body.extend(quote!( - #enum_ident::#variant_ident #variant_header => { - #variant_body - } - )) + match variant_output.body { + VariantBody::Unit => has_unit_variant = true, + VariantBody::Fields(VariantFields { header, body }) => fields_body.extend(quote!( + #enum_ident::#variant_ident #header => { + #body + } + )), + } } + let fields_body = optimize_fields_body(fields_body, has_unit_variant); generics_output.extend(&mut where_clause, &cratename); Ok(quote! { @@ -47,31 +51,78 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { }; writer.write_all(&variant_idx.to_le_bytes())?; - match self { - #fields_body - } + #fields_body Ok(()) } } }) } -struct VariantOutput { +fn optimize_fields_body(fields_body: TokenStream2, has_unit_variant: bool) -> TokenStream2 { + if fields_body.is_empty() { + // If we no variants with fields, there's nothing to match against. Just + // re-use the empty token stream. + fields_body + } else { + let unit_fields_catchall = if has_unit_variant { + // We had some variants with unit fields, create a catch-all for + // these to be used at the bottom. + quote!( + _ => {} + ) + } else { + TokenStream2::new() + }; + // Create a match that serialises all the fields for each non-unit + // variant and add a catch-all at the bottom if we do have unit + // variants. + quote!( + match self { + #fields_body + #unit_fields_catchall + } + ) + } +} + +#[derive(Default)] +struct VariantFields { header: TokenStream2, body: TokenStream2, - variant_idx_body: TokenStream2, } -impl VariantOutput { - fn new() -> Self { - Self { - body: TokenStream2::new(), - header: TokenStream2::new(), - variant_idx_body: TokenStream2::new(), +impl VariantFields { + fn named_header(self) -> Self { + let header = self.header; + + VariantFields { + // `..` pattern matching works even if all fields were specified + header: quote! { { #header.. }}, + body: self.body, + } + } + fn unnamed_header(self) -> Self { + let header = self.header; + + VariantFields { + header: quote! { ( #header )}, + body: self.body, } } } +enum VariantBody { + // No body variant, unit enum variant. + Unit, + // Variant with body (fields) + Fields(VariantFields), +} + +struct VariantOutput { + body: VariantBody, + variant_idx_body: TokenStream2, +} + fn process_variant( variant: &Variant, enum_ident: &Ident, @@ -80,36 +131,39 @@ fn process_variant( generics: &mut serialize::GenericsOutput, ) -> syn::Result { let variant_ident = &variant.ident; - let mut variant_output = VariantOutput::new(); - match &variant.fields { + let variant_output = match &variant.fields { Fields::Named(fields) => { + let mut variant_fields = VariantFields::default(); for field in &fields.named { let field_id = serialize::FieldId::Enum(field.ident.clone().unwrap()); - process_field(field, field_id, cratename, generics, &mut variant_output)?; + process_field(field, field_id, cratename, generics, &mut variant_fields)?; + } + VariantOutput { + body: VariantBody::Fields(variant_fields.named_header()), + variant_idx_body: quote!( + #enum_ident::#variant_ident {..} => #discriminant_value, + ), } - let header = variant_output.header; - // `..` pattern matching works even if all fields were specified - variant_output.header = quote! { { #header.. }}; - variant_output.variant_idx_body = quote!( - #enum_ident::#variant_ident {..} => #discriminant_value, - ); } Fields::Unnamed(fields) => { + let mut variant_fields = VariantFields::default(); for (field_idx, field) in fields.unnamed.iter().enumerate() { let field_id = serialize::FieldId::new_enum_unnamed(field_idx)?; - process_field(field, field_id, cratename, generics, &mut variant_output)?; + process_field(field, field_id, cratename, generics, &mut variant_fields)?; + } + VariantOutput { + body: VariantBody::Fields(variant_fields.unnamed_header()), + variant_idx_body: quote!( + #enum_ident::#variant_ident(..) => #discriminant_value, + ), } - let header = variant_output.header; - variant_output.header = quote! { ( #header )}; - variant_output.variant_idx_body = quote!( - #enum_ident::#variant_ident(..) => #discriminant_value, - ); } - Fields::Unit => { - variant_output.variant_idx_body = quote!( + Fields::Unit => VariantOutput { + body: VariantBody::Unit, + variant_idx_body: quote!( #enum_ident::#variant_ident => #discriminant_value, - ); - } + ), + }, }; Ok(variant_output) } @@ -119,7 +173,7 @@ fn process_field( field_id: serialize::FieldId, cratename: &Path, generics: &mut serialize::GenericsOutput, - output: &mut VariantOutput, + output: &mut VariantFields, ) -> syn::Result<()> { let parsed = field::Attributes::parse(&field.attrs)?; @@ -425,4 +479,20 @@ mod tests { local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); } + + #[test] + fn mixed_with_unit_variants() { + let item_enum: ItemEnum = syn::parse2(quote! { + enum X { + A(u16), + B, + C {x: i32, y: i32}, + D, + } + }) + .unwrap(); + let actual = process(&item_enum, default_cratename()).unwrap(); + + local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } } diff --git a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap index f2d9971f7..9fc0d1487 100644 --- a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap +++ b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap @@ -16,14 +16,6 @@ impl borsh::ser::BorshSerialize for X { X::F => 5u8, }; writer.write_all(&variant_idx.to_le_bytes())?; - match self { - X::A => {} - X::B => {} - X::C => {} - X::D => {} - X::E => {} - X::F => {} - } Ok(()) } } diff --git a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap index 7191ea5bb..75ae04424 100644 --- a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap +++ b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap @@ -16,14 +16,6 @@ impl borsh::ser::BorshSerialize for X { X::F => 10 + 1, }; writer.write_all(&variant_idx.to_le_bytes())?; - match self { - X::A => {} - X::B => {} - X::C => {} - X::D => {} - X::E => {} - X::F => {} - } Ok(()) } } diff --git a/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap b/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap new file mode 100644 index 000000000..6de338b4b --- /dev/null +++ b/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap @@ -0,0 +1,30 @@ +--- +source: borsh-derive/src/internals/serialize/enums/mod.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::ser::BorshSerialize for X { + fn serialize( + &self, + writer: &mut W, + ) -> ::core::result::Result<(), borsh::io::Error> { + let variant_idx: u8 = match self { + X::A(..) => 0u8, + X::B => 1u8, + X::C { .. } => 2u8, + X::D => 3u8, + }; + writer.write_all(&variant_idx.to_le_bytes())?; + match self { + X::A(id0) => { + borsh::BorshSerialize::serialize(id0, writer)?; + } + X::C { x, y, .. } => { + borsh::BorshSerialize::serialize(x, writer)?; + borsh::BorshSerialize::serialize(y, writer)?; + } + _ => {} + } + Ok(()) + } +} + diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum-2.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum-2.snap new file mode 100644 index 000000000..e70b0c847 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum-2.snap @@ -0,0 +1,7 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 1, +] diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap new file mode 100644 index 000000000..121d21f21 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap @@ -0,0 +1,15 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 2, + 132, + 0, + 0, + 0, + 239, + 255, + 255, + 255, +] diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum-4.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum-4.snap new file mode 100644 index 000000000..130ef3f52 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum-4.snap @@ -0,0 +1,7 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 3, +] diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum.snap new file mode 100644 index 000000000..ac859bf65 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum.snap @@ -0,0 +1,9 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 0, + 13, + 0, +] diff --git a/borsh/tests/test_simple_structs.rs b/borsh/tests/test_simple_structs.rs index 647bf3fe3..5385e2a8c 100644 --- a/borsh/tests/test_simple_structs.rs +++ b/borsh/tests/test_simple_structs.rs @@ -223,3 +223,30 @@ fn test_object_length() { assert_eq!(encoded_a_len, len_helper_result); } + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +enum MixedWithUnitVariants { + A(u16), + B, + C { x: i32, y: i32 }, + D, +} + +#[test] +fn test_mixed_enum() { + let vars = vec![ + MixedWithUnitVariants::A(13), + MixedWithUnitVariants::B, + MixedWithUnitVariants::C { x: 132, y: -17 }, + MixedWithUnitVariants::D, + ]; + for variant in vars { + let encoded = to_vec(&variant).unwrap(); + #[cfg(feature = "std")] + insta::assert_debug_snapshot!(encoded); + + let decoded = from_slice::(&encoded).unwrap(); + + assert_eq!(variant, decoded); + } +}