From 555a893026ead0344fa4314965c435802b69d4bb Mon Sep 17 00:00:00 2001 From: frozenlib Date: Mon, 9 Jan 2023 10:54:16 +0900 Subject: [PATCH] Fixed a problem with function filters that prevented the use of `args`. --- src/arbitrary.rs | 20 +++++-------- tests/arbitrary.rs | 75 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/src/arbitrary.rs b/src/arbitrary.rs index 0c8a132..865b414 100644 --- a/src/arbitrary.rs +++ b/src/arbitrary.rs @@ -8,11 +8,11 @@ use quote::{quote, quote_spanned, ToTokens}; use std::collections::BTreeMap; use std::{collections::HashMap, fmt::Write, mem::take}; use structmeta::*; -use syn::Pat; use syn::{ parse2, parse_quote, parse_str, spanned::Spanned, Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, Ident, Index, Lit, Member, Path, Result, Type, }; +use syn::{parse_quote_spanned, Pat}; pub fn derive_arbitrary(input: DeriveInput) -> Result { let args: ArbitraryArgsForType = parse_from_attrs(&input.attrs, "arbitrary")?; @@ -42,7 +42,7 @@ pub fn derive_arbitrary(input: DeriveInput) -> Result { #[allow(clippy::redundant_closure_call)] fn arbitrary_with(args: ::Parameters) -> Self::Strategy { #[allow(dead_code)] - fn _to_fn_ptr(f: fn(&T) -> bool) -> fn(&T) -> bool { + fn _to_fn(f: impl Fn(&T) -> bool) -> impl Fn(&T) -> bool { f } #[allow(dead_code)] @@ -279,11 +279,10 @@ impl Filter { } fn make_let_func(&self, var: &Ident, target: Expr, arg_ty: &Type) -> TokenStream { - let whence = &self.whence; + let span = self.fun.span(); let fun = &self.fun; - quote_spanned! {fun.span()=> - let #var = proptest::strategy::Strategy::prop_filter(#var, #whence, |_this| (_to_fn_ptr::<#arg_ty>(#fun))(#target)); - } + let fun = parse_quote_spanned!(span=> (_to_fn::<#arg_ty>(#fun))(#target)); + Self::make_let_as(&self.whence, &fun, var, quote!()) } fn make_let_expr(&self, var: &Ident, target: Expr, ident: &Ident, by_ref: bool) -> TokenStream { let lets = if by_ref { @@ -291,11 +290,9 @@ impl Filter { } else { quote! { let #ident = std::clone::Clone::clone(#target); } }; - self.make_let_as(var, lets) + Self::make_let_as(&self.whence, &self.fun, var, lets) } - fn make_let_as(&self, var: &Ident, lets: TokenStream) -> TokenStream { - let whence = &self.whence; - let fun = &self.fun; + fn make_let_as(whence: &Expr, fun: &Expr, var: &Ident, lets: TokenStream) -> TokenStream { quote_spanned! {fun.span()=> let #var = { #[allow(unused_variables)] @@ -359,8 +356,7 @@ impl FieldsFilter { } else { parse_quote!(#fun) }; - let whence = self.filter.whence.clone(); - Filter { fun, whence }.make_let_as(var, quote!(#(#lets)*)) + Filter::make_let_as(&self.filter.whence, &fun, var, quote!(#(#lets)*)) } } diff --git a/tests/arbitrary.rs b/tests/arbitrary.rs index 2da57f3..c8de7d9 100644 --- a/tests/arbitrary.rs +++ b/tests/arbitrary.rs @@ -809,6 +809,7 @@ fn args_with_struct_filter_sharp_val_x2() { x: i32, } } + #[test] fn args_with_struct_filter_sharp_self() { #[derive(Default)] @@ -838,6 +839,43 @@ fn args_with_struct_filter_sharp_self_x2() { } } +#[test] +fn args_with_struct_filter_fn() { + #[derive(Default)] + struct TestArgs { + m: i32, + } + #[derive(Arbitrary, Debug, PartialEq)] + #[arbitrary(args = TestArgs)] + #[filter(is_valid_fn(args.m))] + struct TestStruct { + x: i32, + } + + fn is_valid_fn(_: i32) -> impl Fn(&TestStruct) -> bool { + |_| true + } +} + +#[test] +fn args_with_struct_filter_fn_x2() { + #[derive(Default)] + struct TestArgs { + m: i32, + } + #[derive(Arbitrary, Debug, PartialEq)] + #[arbitrary(args = TestArgs)] + #[filter(is_valid_fn(args.m))] + #[filter(is_valid_fn(args.m + 1))] + struct TestStruct { + x: i32, + } + + fn is_valid_fn(_: i32) -> impl Fn(&TestStruct) -> bool { + |_| true + } +} + #[test] fn args_with_enum_filter_sharp_val() { #[derive(Default)] @@ -951,6 +989,43 @@ fn args_with_field_filter_sharp_val_x2() { } } +#[test] +fn args_with_field_filter_fn() { + #[derive(Default)] + struct TestArgs { + m: i32, + } + #[derive(Arbitrary, Debug, PartialEq)] + #[arbitrary(args = TestArgs)] + struct TestStruct { + #[filter(is_larger_than(args.m))] + x: i32, + } + + fn is_larger_than(t: i32) -> impl Fn(&i32) -> bool { + move |x: &i32| *x > t + } +} + +#[test] +fn args_with_field_filter_fn_x2() { + #[derive(Default)] + struct TestArgs { + m: i32, + } + #[derive(Arbitrary, Debug, PartialEq)] + #[arbitrary(args = TestArgs)] + struct TestStruct { + #[filter(is_larger_than(args.m))] + #[filter(is_larger_than(args.m + 1))] + x: i32, + } + + fn is_larger_than(t: i32) -> impl Fn(&i32) -> bool { + move |x: &i32| *x > t + } +} + #[test] fn auto_bound_tuple_struct() { #[derive(Arbitrary, Debug, PartialEq)]