Skip to content

Commit

Permalink
Support async fn test.
Browse files Browse the repository at this point in the history
  • Loading branch information
frozenlib committed May 27, 2023
1 parent aa0a62a commit fd49486
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 26 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ structmeta = "0.2.0"
[dev-dependencies]
proptest = "1.1.0"
trybuild = "1.0.80"
tokio = { version = "1.28.1", features = ["rt-multi-thread"] }
108 changes: 85 additions & 23 deletions src/proptest_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::syn_utils::{Arg, Args};
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse2, parse_quote, parse_str, spanned::Spanned, token, Field, FieldMutability, FnArg, Ident,
ItemFn, Pat, Result, Visibility,
parse2, parse_quote, parse_str, spanned::Spanned, token, Block, Field, FieldMutability, FnArg,
Ident, ItemFn, LitStr, Pat, Result, Visibility,
};

pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStream> {
Expand All @@ -20,6 +20,8 @@ pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStr
true
}
});
let (mut attr_args, config_args) = TestFnAttrArgs::from(attr_args.unwrap_or_default())?;

let args_type_str = format!("_{}Args", to_camel_case(&item_fn.sig.ident.to_string()));
let args_type_ident: Ident = parse_str(&args_type_str).unwrap();
let args = item_fn
Expand All @@ -30,6 +32,15 @@ pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStr
.collect::<Result<Vec<_>>>()?;
let args_pats = args.iter().map(|arg| arg.pat());
let block = &item_fn.block;
if item_fn.sig.asyncness.is_none() {
attr_args.r#async = None;
}
let block = if let Some(a) = attr_args.r#async {
item_fn.sig.asyncness = None;
a.apply(block)
} else {
quote!(#block)
};
let block = quote! {
{
let #args_type_ident { #(#args_pats,)* } = input;
Expand All @@ -39,7 +50,7 @@ pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStr
item_fn.sig.inputs = parse_quote! { input: #args_type_ident };
item_fn.block = Box::new(parse2(block)?);
let args_fields = args.iter().map(|arg| &arg.field);
let config = to_proptest_config(attr_args);
let config = to_proptest_config(config_args);
let ts = quote! {
#[derive(test_strategy::Arbitrary, Debug)]
struct #args_type_ident {
Expand All @@ -57,27 +68,26 @@ pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStr
Ok(ts)
}

fn to_proptest_config(args: Option<Args>) -> TokenStream {
if let Some(args) = args {
let mut base_expr = None;
let mut inits = Vec::new();
for arg in args {
match arg {
Arg::Value(value) => base_expr = Some(value),
Arg::NameValue { name, value, .. } => inits.push(quote!(#name : #value)),
}
}
let base_expr = base_expr.unwrap_or_else(|| {
parse_quote!(<proptest::test_runner::Config as std::default::Default>::default())
});
quote! {
#![proptest_config(proptest::test_runner::Config {
#(#inits,)*
.. #base_expr
})]
fn to_proptest_config(args: Args) -> TokenStream {
if args.is_empty() {
return quote!();
}
let mut base_expr = None;
let mut inits = Vec::new();
for arg in args {
match arg {
Arg::Value(value) => base_expr = Some(value),
Arg::NameValue { name, value, .. } => inits.push(quote!(#name : #value)),
}
} else {
quote! {}
}
let base_expr = base_expr.unwrap_or_else(|| {
parse_quote!(<proptest::test_runner::Config as std::default::Default>::default())
});
quote! {
#![proptest_config(proptest::test_runner::Config {
#(#inits,)*
.. #base_expr
})]
}
}
struct TestFnArg {
Expand Down Expand Up @@ -118,6 +128,58 @@ impl TestFnArg {
}
}

#[derive(Debug, Clone, Copy)]
enum Async {
Tokio,
}
impl Async {
fn apply(&self, block: &Block) -> TokenStream {
match self {
Async::Tokio => {
quote! {
let ret: ::core::result::Result<_, proptest::test_runner::TestCaseError> =
tokio::runtime::Runtime::new()
.unwrap()
.block_on(async move {
#block
Ok(())
});
ret?;
}
}
}
}
}
impl syn::parse::Parse for Async {
fn parse(input: syn::parse::ParseStream) -> Result<Self> {
let s: LitStr = input.parse()?;
match s.value().as_str() {
"tokio" => Ok(Async::Tokio),
_ => bail!(s.span(), "expected `tokio`."),
}
}
}

struct TestFnAttrArgs {
r#async: Option<Async>,
}
impl TestFnAttrArgs {
fn from(args: Args) -> Result<(Self, Args)> {
let mut config_args = Args::new();
let mut this = TestFnAttrArgs { r#async: None };
for arg in args {
if let Arg::NameValue { name, value, .. } = &arg {
if name == "async" {
this.r#async = Some(parse2(value.to_token_stream())?);
continue;
}
}
config_args.0.push(arg);
}
Ok((this, config_args))
}
}

fn to_camel_case(s: &str) -> String {
let mut upper = true;
let mut r = String::new();
Expand Down
4 changes: 2 additions & 2 deletions src/syn_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ impl<T> Deref for Parenthesized<T> {
}

#[derive(Parse)]
pub struct Args(#[parse(terminated)] Punctuated<Arg, Comma>);
pub struct Args(#[parse(terminated)] pub Punctuated<Arg, Comma>);

impl Args {
fn new() -> Self {
pub fn new() -> Self {
Self(Punctuated::new())
}
pub fn expect_single_value(&self, span: Span) -> Result<&Expr> {
Expand Down
17 changes: 16 additions & 1 deletion tests/proptest_fn.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use proptest::prelude::ProptestConfig;
use proptest::{prelude::ProptestConfig, prop_assert};
use test_strategy::proptest;

#[proptest]
Expand Down Expand Up @@ -57,3 +57,18 @@ fn config_field() {
fn config_expr_and_field() {
std::thread::sleep(std::time::Duration::from_millis(30));
}

#[proptest(async = "tokio")]
async fn tokio_test() {
tokio::task::spawn(async {}).await.unwrap()
}

#[proptest(async = "tokio")]
async fn tokio_test_no_copy_arg(#[strategy("a+")] s: String) {
prop_assert!(s.contains("a"));
}

#[proptest(async = "tokio")]
async fn tokio_test_prop_assert() {
prop_assert!(true);
}

0 comments on commit fd49486

Please sign in to comment.