diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 2518d98..08522eb 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -21,7 +21,7 @@ rust-version = "1.63.0" [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "1.0.56", features = ['derive'] } +syn = { version = "1.0.56", features = ['derive', 'parsing'] } [lib] proc_macro = true diff --git a/derive/src/container_attributes.rs b/derive/src/container_attributes.rs new file mode 100644 index 0000000..9a91ac8 --- /dev/null +++ b/derive/src/container_attributes.rs @@ -0,0 +1,72 @@ +use crate::ARBITRARY_ATTRIBUTE_NAME; +use syn::{ + parse::Error, punctuated::Punctuated, DeriveInput, Lit, Meta, MetaNameValue, NestedMeta, Token, + TypeParam, +}; + +pub struct ContainerAttributes { + /// Specify type bounds to be applied to the derived `Arbitrary` implementation instead of the + /// default inferred bounds. + /// + /// ```ignore + /// #[arbitrary(bound = "T: Default, U: Debug")] + /// ``` + /// + /// Multiple attributes will be combined as long as they don't conflict, e.g. + /// + /// ```ignore + /// #[arbitrary(bound = "T: Default")] + /// #[arbitrary(bound = "U: Default")] + /// ``` + pub bounds: Option>>, +} + +impl ContainerAttributes { + pub fn from_derive_input(derive_input: &DeriveInput) -> Result { + let mut bounds = None; + + for attr in &derive_input.attrs { + if !attr.path.is_ident(ARBITRARY_ATTRIBUTE_NAME) { + continue; + } + + let meta_list = match attr.parse_meta()? { + Meta::List(l) => l, + _ => { + return Err(Error::new_spanned( + attr, + format!( + "invalid `{}` attribute. expected list", + ARBITRARY_ATTRIBUTE_NAME + ), + )) + } + }; + + for nested_meta in meta_list.nested.iter() { + match nested_meta { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(bound_str_lit), + .. + })) if path.is_ident("bound") => { + bounds + .get_or_insert_with(Vec::new) + .push(bound_str_lit.parse_with(Punctuated::parse_terminated)?); + } + _ => { + return Err(Error::new_spanned( + attr, + format!( + "invalid `{}` attribute. expected `bound = \"..\"`", + ARBITRARY_ATTRIBUTE_NAME, + ), + )) + } + } + } + } + + Ok(Self { bounds }) + } +} diff --git a/derive/src/field_attributes.rs b/derive/src/field_attributes.rs index ccaba74..2ca0f1c 100644 --- a/derive/src/field_attributes.rs +++ b/derive/src/field_attributes.rs @@ -1,10 +1,8 @@ +use crate::ARBITRARY_ATTRIBUTE_NAME; use proc_macro2::{Group, Span, TokenStream, TokenTree}; use quote::quote; use syn::{spanned::Spanned, *}; -/// Used to filter out necessary field attribute and within error messages. -static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary"; - /// Determines how a value for a field should be constructed. #[cfg_attr(test, derive(Debug))] pub enum FieldConstructor { diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 4ed3817..5e05522 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -4,9 +4,12 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::*; +mod container_attributes; mod field_attributes; +use container_attributes::ContainerAttributes; use field_attributes::{determine_field_constructor, FieldConstructor}; +static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary"; static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary"; #[proc_macro_derive(Arbitrary, attributes(arbitrary))] @@ -18,6 +21,8 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr } fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result { + let container_attrs = ContainerAttributes::from_derive_input(&input)?; + let (lifetime_without_bounds, lifetime_with_bounds) = build_arbitrary_lifetime(input.generics.clone()); @@ -30,8 +35,13 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result { gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?; let size_hint_method = gen_size_hint_method(&input)?; let name = input.ident; - // Add a bound `T: Arbitrary` to every type parameter T. - let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone()); + + // Apply user-supplied bounds or automatic `T: ArbitraryBounds`. + let generics = apply_trait_bounds( + input.generics, + lifetime_without_bounds.clone(), + &container_attrs, + )?; // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90) let mut generics_with_lifetime = generics.clone(); @@ -77,6 +87,51 @@ fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) { (lifetime_without_bounds, lifetime_with_bounds) } +fn apply_trait_bounds( + mut generics: Generics, + lifetime: LifetimeDef, + container_attrs: &ContainerAttributes, +) -> Result { + // If user-supplied bounds exist, apply them to their matching type parameters. + if let Some(config_bounds) = &container_attrs.bounds { + let mut config_bounds_applied = 0; + for param in generics.params.iter_mut() { + if let GenericParam::Type(type_param) = param { + if let Some(replacement) = config_bounds + .iter() + .flatten() + .find(|p| p.ident == type_param.ident) + { + *type_param = replacement.clone(); + config_bounds_applied += 1; + } else { + // If no user-supplied bounds exist for this type, delete the original bounds. + // This mimics serde. + type_param.bounds = Default::default(); + type_param.default = None; + } + } + } + let config_bounds_supplied = config_bounds + .iter() + .map(|bounds| bounds.len()) + .sum::(); + if config_bounds_applied != config_bounds_supplied { + return Err(Error::new( + Span::call_site(), + format!( + "invalid `{}` attribute. too many bounds, only {} out of {} are applicable", + ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied, + ), + )); + } + Ok(generics) + } else { + // Otherwise, inject a `T: Arbitrary` bound for every parameter. + Ok(add_trait_bounds(generics, lifetime)) + } +} + // Add a bound `T: Arbitrary` to every type parameter T. fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics { for param in generics.params.iter_mut() { diff --git a/src/lib.rs b/src/lib.rs index a3fa48b..dfaebd0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1338,5 +1338,56 @@ mod test { /// x: i32, /// } /// ``` +/// +/// Multiple conflicting bounds at the container-level: +/// ```compile_fail +/// #[derive(::arbitrary::Arbitrary)] +/// #[arbitrary(bound = "T: Default")] +/// #[arbitrary(bound = "T: Default")] +/// struct Point { +/// #[arbitrary(default)] +/// x: T, +/// } +/// ``` +/// +/// Multiple conflicting bounds in a single bound attribute: +/// ```compile_fail +/// #[derive(::arbitrary::Arbitrary)] +/// #[arbitrary(bound = "T: Default, T: Default")] +/// struct Point { +/// #[arbitrary(default)] +/// x: T, +/// } +/// ``` +/// +/// Multiple conflicting bounds in multiple bound attributes: +/// ```compile_fail +/// #[derive(::arbitrary::Arbitrary)] +/// #[arbitrary(bound = "T: Default", bound = "T: Default")] +/// struct Point { +/// #[arbitrary(default)] +/// x: T, +/// } +/// ``` +/// +/// Too many bounds supplied: +/// ```compile_fail +/// #[derive(::arbitrary::Arbitrary)] +/// #[arbitrary(bound = "T: Default")] +/// struct Point { +/// x: i32, +/// } +/// ``` +/// +/// Too many bounds supplied across multiple attributes: +/// ```compile_fail +/// #[derive(::arbitrary::Arbitrary)] +/// #[arbitrary(bound = "T: Default")] +/// #[arbitrary(bound = "U: Default")] +/// struct Point { +/// #[arbitrary(default)] +/// x: T, +/// } +/// ``` #[cfg(all(doctest, feature = "derive"))] pub struct CompileFailTests; diff --git a/tests/bound.rs b/tests/bound.rs new file mode 100644 index 0000000..7a772ac --- /dev/null +++ b/tests/bound.rs @@ -0,0 +1,142 @@ +#![cfg(feature = "derive")] + +use arbitrary::{Arbitrary, Unstructured}; + +fn arbitrary_from<'a, T: Arbitrary<'a>>(input: &'a [u8]) -> T { + let mut buf = Unstructured::new(input); + T::arbitrary(&mut buf).expect("can create arbitrary instance OK") +} + +/// This wrapper trait *implies* `Arbitrary`, but the compiler isn't smart enough to work that out +/// so when using this wrapper we *must* opt-out of the auto-generated `T: Arbitrary` bounds. +pub trait WrapperTrait: for<'a> Arbitrary<'a> {} + +impl WrapperTrait for u32 {} + +#[derive(Arbitrary)] +#[arbitrary(bound = "T: WrapperTrait")] +struct GenericSingleBound { + t: T, +} + +#[test] +fn single_bound() { + let v: GenericSingleBound = arbitrary_from(&[0, 0, 0, 0]); + assert_eq!(v.t, 0); +} + +#[derive(Arbitrary)] +#[arbitrary(bound = "T: WrapperTrait, U: WrapperTrait")] +struct GenericMultipleBoundsSingleAttribute { + t: T, + u: U, +} + +#[test] +fn multiple_bounds_single_attribute() { + let v: GenericMultipleBoundsSingleAttribute = + arbitrary_from(&[1, 0, 0, 0, 2, 0, 0, 0]); + assert_eq!(v.t, 1); + assert_eq!(v.u, 2); +} + +#[derive(Arbitrary)] +#[arbitrary(bound = "T: WrapperTrait")] +#[arbitrary(bound = "U: Default")] +struct GenericMultipleArbitraryAttributes { + t: T, + #[arbitrary(default)] + u: U, +} + +#[test] +fn multiple_arbitrary_attributes() { + let v: GenericMultipleArbitraryAttributes = arbitrary_from(&[1, 0, 0, 0]); + assert_eq!(v.t, 1); + assert_eq!(v.u, 0); +} + +#[derive(Arbitrary)] +#[arbitrary(bound = "T: WrapperTrait", bound = "U: Default")] +struct GenericMultipleBoundAttributes { + t: T, + #[arbitrary(default)] + u: U, +} + +#[test] +fn multiple_bound_attributes() { + let v: GenericMultipleBoundAttributes = arbitrary_from(&[1, 0, 0, 0]); + assert_eq!(v.t, 1); + assert_eq!(v.u, 0); +} + +#[derive(Arbitrary)] +#[arbitrary(bound = "T: WrapperTrait", bound = "U: Default")] +#[arbitrary(bound = "V: WrapperTrait, W: Default")] +struct GenericMultipleArbitraryAndBoundAttributes< + T: WrapperTrait, + U: Default, + V: WrapperTrait, + W: Default, +> { + t: T, + #[arbitrary(default)] + u: U, + v: V, + #[arbitrary(default)] + w: W, +} + +#[test] +fn multiple_arbitrary_and_bound_attributes() { + let v: GenericMultipleArbitraryAndBoundAttributes = + arbitrary_from(&[1, 0, 0, 0, 2, 0, 0, 0]); + assert_eq!(v.t, 1); + assert_eq!(v.u, 0); + assert_eq!(v.v, 2); + assert_eq!(v.w, 0); +} + +#[derive(Arbitrary)] +#[arbitrary(bound = "T: Default")] +struct GenericDefault { + #[arbitrary(default)] + x: T, +} + +#[test] +fn default_bound() { + // We can write a generic func without any `Arbitrary` bound. + fn generic_default() -> GenericDefault { + arbitrary_from(&[]) + } + + assert_eq!(generic_default::().x, 0); + assert_eq!(generic_default::().x, String::new()); + assert_eq!(generic_default::>().x, Vec::new()); +} + +#[derive(Arbitrary)] +#[arbitrary()] +struct EmptyArbitraryAttribute { + t: u32, +} + +#[test] +fn empty_arbitrary_attribute() { + let v: EmptyArbitraryAttribute = arbitrary_from(&[1, 0, 0, 0]); + assert_eq!(v.t, 1); +} + +#[derive(Arbitrary)] +#[arbitrary(bound = "")] +struct EmptyBoundAttribute { + t: u32, +} + +#[test] +fn empty_bound_attribute() { + let v: EmptyBoundAttribute = arbitrary_from(&[1, 0, 0, 0]); + assert_eq!(v.t, 1); +}