From 562af7853d42035edd5486180fee565827cce09f Mon Sep 17 00:00:00 2001 From: Michal Nazarewicz Date: Wed, 25 Jan 2023 23:56:46 +0100 Subject: [PATCH] de: add EnumExt trait with deserialize_variant method Introduce de::EnumExt trait with deserialize_variant method which takes variant number as an argument rather than reading it from the input. This trait is derived for all enums which use BorshDeserialize derive. It is useful when customer wants to perform some validation after the variant is known but before the variant is fully deserialised. --- borsh-derive-internal/src/enum_de.rs | 69 ++++++++++++---------------- borsh/src/de/mod.rs | 49 ++++++++++++++++++++ 2 files changed, 78 insertions(+), 40 deletions(-) diff --git a/borsh-derive-internal/src/enum_de.rs b/borsh-derive-internal/src/enum_de.rs index 37f2d34cb..1ecb351cc 100644 --- a/borsh-derive-internal/src/enum_de.rs +++ b/borsh-derive-internal/src/enum_de.rs @@ -72,49 +72,38 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result #variant_idx => #name::#variant_ident #variant_header , }); } - let variant_idx = quote! { - let variant_idx: u8 = #cratename::BorshDeserialize::deserialize_reader(reader)?; + + let init = if let Some(method_ident) = init_method { + quote! { + return_value.#method_ident(); + } + } else { + quote! {} }; - if let Some(method_ident) = init_method { - Ok(quote! { - impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { - fn deserialize_reader(reader: &mut R) -> ::core::result::Result { - #variant_idx - let mut return_value = match variant_idx { - #variant_arms - _ => { - let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx); - return Err(#cratename::maybestd::io::Error::new( - #cratename::maybestd::io::ErrorKind::InvalidInput, - msg, - )); - } - }; - return_value.#method_ident(); - Ok(return_value) - } + Ok(quote! { + impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { + fn deserialize_reader(reader: &mut R) -> ::core::result::Result { + let tag = ::deserialize_reader(reader)?; + ::deserialize_variant(reader, tag) } - }) - } else { - Ok(quote! { - impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { - fn deserialize_reader(reader: &mut R) -> ::core::result::Result { - #variant_idx - let return_value = match variant_idx { - #variant_arms - _ => { - let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx); + } - return Err(#cratename::maybestd::io::Error::new( - #cratename::maybestd::io::ErrorKind::InvalidInput, - msg, - )); - } - }; - Ok(return_value) - } + impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause { + fn deserialize_variant( + reader: &mut R, + variant_idx: u8, + ) -> ::core::result::Result { + let mut return_value = match variant_idx { + #variant_arms + _ => return Err(#cratename::maybestd::io::Error::new( + #cratename::maybestd::io::ErrorKind::InvalidInput, + #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx), + )) + }; + #init + Ok(return_value) } - }) - } + } + }) } diff --git a/borsh/src/de/mod.rs b/borsh/src/de/mod.rs index 0c6299944..c968a83af 100644 --- a/borsh/src/de/mod.rs +++ b/borsh/src/de/mod.rs @@ -73,6 +73,55 @@ pub trait BorshDeserialize: Sized { } } +/// Additional methods offered on enums which uses `[derive(BorshDeserialize)]`. +pub trait EnumExt: BorshDeserialize { + /// Deserialises given variant of an enum from the reader. + /// + /// This may be used to perform validation or filtering based on what + /// variant is being deserialised. + /// + /// ``` + /// use borsh::BorshDeserialize; + /// use borsh::de::EnumExt as _; + /// + /// #[derive(Debug, PartialEq, Eq, BorshDeserialize)] + /// enum MyEnum { + /// Zero, + /// One(u8), + /// Many(Vec) + /// } + /// + /// #[derive(Debug, PartialEq, Eq)] + /// struct OneOrZero(MyEnum); + /// + /// impl borsh::de::BorshDeserialize for OneOrZero { + /// fn deserialize_reader( + /// reader: &mut R, + /// ) -> std::io::Result { + /// use borsh::de::EnumExt; + /// let tag = u8::deserialize_reader(reader)?; + /// if tag == 2 { + /// Err(std::io::Error::new( + /// std::io::ErrorKind::InvalidInput, + /// "MyEnum::Many not allowed here", + /// )) + /// } else { + /// MyEnum::deserialize_variant(reader, tag).map(Self) + /// } + /// } + /// } + /// + /// let data = b"\0"; + /// assert_eq!(MyEnum::Zero, MyEnum::try_from_slice(&data[..]).unwrap()); + /// assert_eq!(MyEnum::Zero, OneOrZero::try_from_slice(&data[..]).unwrap().0); + /// + /// let data = b"\x02\0\0\0\0"; + /// assert_eq!(MyEnum::Many(Vec::new()), MyEnum::try_from_slice(&data[..]).unwrap()); + /// assert!(OneOrZero::try_from_slice(&data[..]).is_err()); + /// ``` + fn deserialize_variant(reader: &mut R, tag: u8) -> Result; +} + fn unexpected_eof_to_unexpected_length_of_input(e: Error) -> Error { if e.kind() == ErrorKind::UnexpectedEof { Error::new(ErrorKind::InvalidInput, ERROR_UNEXPECTED_LENGTH_OF_INPUT)