diff --git a/borsh-derive-internal/src/enum_de.rs b/borsh-derive-internal/src/enum_de.rs index 37f2d34cb..1ecb351cc 100644 --- a/borsh-derive-internal/src/enum_de.rs +++ b/borsh-derive-internal/src/enum_de.rs @@ -72,49 +72,38 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result #variant_idx => #name::#variant_ident #variant_header , }); } - let variant_idx = quote! { - let variant_idx: u8 = #cratename::BorshDeserialize::deserialize_reader(reader)?; + + let init = if let Some(method_ident) = init_method { + quote! { + return_value.#method_ident(); + } + } else { + quote! {} }; - if let Some(method_ident) = init_method { - Ok(quote! { - impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { - fn deserialize_reader(reader: &mut R) -> ::core::result::Result { - #variant_idx - let mut return_value = match variant_idx { - #variant_arms - _ => { - let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx); - return Err(#cratename::maybestd::io::Error::new( - #cratename::maybestd::io::ErrorKind::InvalidInput, - msg, - )); - } - }; - return_value.#method_ident(); - Ok(return_value) - } + Ok(quote! { + impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { + fn deserialize_reader(reader: &mut R) -> ::core::result::Result { + let tag = ::deserialize_reader(reader)?; + ::deserialize_variant(reader, tag) } - }) - } else { - Ok(quote! { - impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { - fn deserialize_reader(reader: &mut R) -> ::core::result::Result { - #variant_idx - let return_value = match variant_idx { - #variant_arms - _ => { - let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx); + } - return Err(#cratename::maybestd::io::Error::new( - #cratename::maybestd::io::ErrorKind::InvalidInput, - msg, - )); - } - }; - Ok(return_value) - } + impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause { + fn deserialize_variant( + reader: &mut R, + variant_idx: u8, + ) -> ::core::result::Result { + let mut return_value = match variant_idx { + #variant_arms + _ => return Err(#cratename::maybestd::io::Error::new( + #cratename::maybestd::io::ErrorKind::InvalidInput, + #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx), + )) + }; + #init + Ok(return_value) } - }) - } + } + }) } diff --git a/borsh/src/de/mod.rs b/borsh/src/de/mod.rs index 0c6299944..973af3f79 100644 --- a/borsh/src/de/mod.rs +++ b/borsh/src/de/mod.rs @@ -73,6 +73,55 @@ pub trait BorshDeserialize: Sized { } } +/// Additional methods offered on enums which uses `[derive(BorshDeserialize)]`. +pub trait EnumExt: BorshDeserialize { + /// Deserialises given variant of an enum from the reader. + /// + /// This may be used to perform validation or filtering based on what + /// variant is being deserialised. + /// + /// ``` + /// use borsh::BorshDeserialize; + /// use borsh::de::EnumExt as _; + /// + /// #[derive(Debug, PartialEq, Eq, BorshDeserialize)] + /// enum MyEnum { + /// Zero, + /// One(u8), + /// Many(Vec) + /// } + /// + /// #[derive(Debug, PartialEq, Eq)] + /// struct OneOrZero(MyEnum); + /// + /// impl borsh::de::BorshDeserialize for OneOrZero { + /// fn deserialize_reader( + /// reader: &mut R, + /// ) -> borsh::maybestd::io::Result { + /// use borsh::de::EnumExt; + /// let tag = u8::deserialize_reader(reader)?; + /// if tag == 2 { + /// Err(borsh::maybestd::io::Error::new( + /// borsh::maybestd::io::ErrorKind::InvalidInput, + /// "MyEnum::Many not allowed here", + /// )) + /// } else { + /// MyEnum::deserialize_variant(reader, tag).map(Self) + /// } + /// } + /// } + /// + /// let data = b"\0"; + /// assert_eq!(MyEnum::Zero, MyEnum::try_from_slice(&data[..]).unwrap()); + /// assert_eq!(MyEnum::Zero, OneOrZero::try_from_slice(&data[..]).unwrap().0); + /// + /// let data = b"\x02\0\0\0\0"; + /// assert_eq!(MyEnum::Many(Vec::new()), MyEnum::try_from_slice(&data[..]).unwrap()); + /// assert!(OneOrZero::try_from_slice(&data[..]).is_err()); + /// ``` + fn deserialize_variant(reader: &mut R, tag: u8) -> Result; +} + fn unexpected_eof_to_unexpected_length_of_input(e: Error) -> Error { if e.kind() == ErrorKind::UnexpectedEof { Error::new(ErrorKind::InvalidInput, ERROR_UNEXPECTED_LENGTH_OF_INPUT)