From 4a7bba0b241984fb8e6377600cd714a189e308c4 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Mon, 27 May 2024 10:19:08 +0300 Subject: [PATCH] Lightweight error value in TryFrom for enums (#1010) Add a new UnknownEnumValue error type to prost, change the TryFrom conversions generated for enums to use that as the associated Error type. The error value carries the original integer, making it also more informative than DecodeError. --- prost-derive/src/lib.rs | 6 +++--- prost/src/error.rs | 17 +++++++++++++++++ prost/src/lib.rs | 2 +- tests/src/lib.rs | 5 +---- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 42f0ccd1d..06bf465c8 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -349,12 +349,12 @@ fn try_enumeration(input: TokenStream) -> Result { } impl #impl_generics ::core::convert::TryFrom:: for #ident #ty_generics #where_clause { - type Error = ::prost::DecodeError; + type Error = ::prost::UnknownEnumValue; - fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::DecodeError> { + fn try_from(value: i32) -> ::core::result::Result<#ident, Self::Error> { match value { #(#try_from,)* - _ => ::core::result::Result::Err(::prost::DecodeError::new("invalid enumeration value")), + _ => ::core::result::Result::Err(::prost::UnknownEnumValue(value)), } } } diff --git a/prost/src/error.rs b/prost/src/error.rs index 6572e502d..78874b038 100644 --- a/prost/src/error.rs +++ b/prost/src/error.rs @@ -131,3 +131,20 @@ impl From for std::io::Error { std::io::Error::new(std::io::ErrorKind::InvalidInput, error) } } + +/// An error indicating that an unknown enumeration value was encountered. +/// +/// The Protobuf spec mandates that enumeration value sets are ‘open’, so this +/// error's value represents an integer value unrecognized by the +/// presently used enum definition. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct UnknownEnumValue(pub i32); + +impl fmt::Display for UnknownEnumValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "unknown enumeration value {}", self.0) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for UnknownEnumValue {} diff --git a/prost/src/lib.rs b/prost/src/lib.rs index a898c8ed7..ab63edb9f 100644 --- a/prost/src/lib.rs +++ b/prost/src/lib.rs @@ -17,7 +17,7 @@ mod types; #[doc(hidden)] pub mod encoding; -pub use crate::error::{DecodeError, EncodeError}; +pub use crate::error::{DecodeError, EncodeError, UnknownEnumValue}; pub use crate::message::Message; pub use crate::name::Name; diff --git a/tests/src/lib.rs b/tests/src/lib.rs index e336a74d9..f0f69c84c 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -632,10 +632,7 @@ mod tests { Ok(PrivacyLevel::PrivacyLevelprivacyLevelFour), PrivacyLevel::try_from(4) ); - assert_eq!( - Err(prost::DecodeError::new("invalid enumeration value")), - PrivacyLevel::try_from(5) - ); + assert_eq!(Err(prost::UnknownEnumValue(5)), PrivacyLevel::try_from(5)); assert_eq!( Ok(ERemoteClientBroadcastMsg::KERemoteClientBroadcastMsgDiscovery),