Skip to content

Commit

Permalink
de: add EnumExt trait with deserialize_variant method
Browse files Browse the repository at this point in the history
Introduce de::EnumExt trait with deserialize_variant method which
takes variant number as an argument rather than reading it from the
input.

This trait is derived for all enums which use BorshDeserialize derive.

It is useful when customer wants to perform some validation after the
variant is known but before the variant is fully deserialised.
  • Loading branch information
mina86 committed Jan 25, 2023
1 parent 28bb5fd commit 562af78
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 40 deletions.
69 changes: 29 additions & 40 deletions borsh-derive-internal/src/enum_de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,49 +72,38 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2>
#variant_idx => #name::#variant_ident #variant_header ,
});
}
let variant_idx = quote! {
let variant_idx: u8 = #cratename::BorshDeserialize::deserialize_reader(reader)?;

let init = if let Some(method_ident) = init_method {
quote! {
return_value.#method_ident();
}
} else {
quote! {}
};
if let Some(method_ident) = init_method {
Ok(quote! {
impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause {
fn deserialize_reader<R: borsh::maybestd::io::Read>(reader: &mut R) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
#variant_idx
let mut return_value = match variant_idx {
#variant_arms
_ => {
let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx);

return Err(#cratename::maybestd::io::Error::new(
#cratename::maybestd::io::ErrorKind::InvalidInput,
msg,
));
}
};
return_value.#method_ident();
Ok(return_value)
}
Ok(quote! {
impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause {
fn deserialize_reader<R: borsh::maybestd::io::Read>(reader: &mut R) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
let tag = <u8 as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?;
<Self as #cratename::de::EnumExt>::deserialize_variant(reader, tag)
}
})
} else {
Ok(quote! {
impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause {
fn deserialize_reader<R: borsh::maybestd::io::Read>(reader: &mut R) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
#variant_idx
let return_value = match variant_idx {
#variant_arms
_ => {
let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx);
}

return Err(#cratename::maybestd::io::Error::new(
#cratename::maybestd::io::ErrorKind::InvalidInput,
msg,
));
}
};
Ok(return_value)
}
impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause {
fn deserialize_variant<R: borsh::maybestd::io::Read>(
reader: &mut R,
variant_idx: u8,
) -> ::core::result::Result<Self, #cratename::maybestd::io::Error> {
let mut return_value = match variant_idx {
#variant_arms
_ => return Err(#cratename::maybestd::io::Error::new(
#cratename::maybestd::io::ErrorKind::InvalidInput,
#cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx),
))
};
#init
Ok(return_value)
}
})
}
}
})
}
49 changes: 49 additions & 0 deletions borsh/src/de/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,55 @@ pub trait BorshDeserialize: Sized {
}
}

/// Additional methods offered on enums which uses `[derive(BorshDeserialize)]`.
pub trait EnumExt: BorshDeserialize {
/// Deserialises given variant of an enum from the reader.
///
/// This may be used to perform validation or filtering based on what
/// variant is being deserialised.
///
/// ```
/// use borsh::BorshDeserialize;
/// use borsh::de::EnumExt as _;
///
/// #[derive(Debug, PartialEq, Eq, BorshDeserialize)]
/// enum MyEnum {
/// Zero,
/// One(u8),
/// Many(Vec<u8>)
/// }
///
/// #[derive(Debug, PartialEq, Eq)]
/// struct OneOrZero(MyEnum);
///
/// impl borsh::de::BorshDeserialize for OneOrZero {
/// fn deserialize_reader<R: std::io::Read>(
/// reader: &mut R,
/// ) -> std::io::Result<Self> {
/// use borsh::de::EnumExt;
/// let tag = u8::deserialize_reader(reader)?;
/// if tag == 2 {
/// Err(std::io::Error::new(
/// std::io::ErrorKind::InvalidInput,
/// "MyEnum::Many not allowed here",
/// ))
/// } else {
/// MyEnum::deserialize_variant(reader, tag).map(Self)
/// }
/// }
/// }
///
/// let data = b"\0";
/// assert_eq!(MyEnum::Zero, MyEnum::try_from_slice(&data[..]).unwrap());
/// assert_eq!(MyEnum::Zero, OneOrZero::try_from_slice(&data[..]).unwrap().0);
///
/// let data = b"\x02\0\0\0\0";
/// assert_eq!(MyEnum::Many(Vec::new()), MyEnum::try_from_slice(&data[..]).unwrap());
/// assert!(OneOrZero::try_from_slice(&data[..]).is_err());
/// ```
fn deserialize_variant<R: Read>(reader: &mut R, tag: u8) -> Result<Self>;
}

fn unexpected_eof_to_unexpected_length_of_input(e: Error) -> Error {
if e.kind() == ErrorKind::UnexpectedEof {
Error::new(ErrorKind::InvalidInput, ERROR_UNEXPECTED_LENGTH_OF_INPUT)
Expand Down

0 comments on commit 562af78

Please sign in to comment.