Skip to content
This repository has been archived by the owner on Nov 15, 2023. It is now read-only.

Updates Benchmark macro parsing to use Generic Argument #13919

Merged
merged 6 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frame/benchmarking/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ pub mod v2 {
// Used in #[benchmark] implementation to ensure that benchmark function arguments
// implement [`ParamRange`].
#[doc(hidden)]
pub use static_assertions::{assert_impl_all, assert_type_eq_all};
pub use static_assertions::{assert_impl_all, assert_type_eq_all, const_assert};

/// Used by the new benchmarking code to specify that a benchmarking variable is linear
/// over some specified range, i.e. `Linear<0, 1_000>` means that the corresponding variable
Expand Down
20 changes: 17 additions & 3 deletions frame/support/procedural/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,12 @@ struct UnrolledParams {
param_ranges: Vec<TokenStream2>,
param_names: Vec<TokenStream2>,
param_types: Vec<TokenStream2>,
param_conditions: Vec<TokenStream2>,
}

impl UnrolledParams {
/// Constructs an [`UnrolledParams`] from a [`Vec<ParamDef>`]
fn from(params: &Vec<ParamDef>) -> UnrolledParams {
fn from(params: &Vec<ParamDef>, home: TokenStream2) -> UnrolledParams {
let param_ranges: Vec<TokenStream2> = params
.iter()
.map(|p| {
Expand All @@ -703,14 +704,22 @@ impl UnrolledParams {
quote!(#name)
})
.collect();
let param_conditions: Vec<TokenStream2> = params
.iter()
.map(|p| {
let start = &p.start;
let end = &p.end;
quote!(#home::const_assert!(#start <= #end);)
})
.collect();
let param_types: Vec<TokenStream2> = params
.iter()
.map(|p| {
let typ = &p.typ;
quote!(#typ)
})
.collect();
UnrolledParams { param_ranges, param_names, param_types }
UnrolledParams { param_ranges, param_names, param_types, param_conditions }
}
}

Expand All @@ -735,10 +744,11 @@ fn expand_benchmark(
let test_ident = Ident::new(format!("test_{}", name.to_string()).as_str(), Span::call_site());

// unroll params (prepare for quoting)
let unrolled = UnrolledParams::from(&benchmark_def.params);
let unrolled = UnrolledParams::from(&benchmark_def.params, home.clone());
let param_names = unrolled.param_names;
let param_ranges = unrolled.param_ranges;
let param_types = unrolled.param_types;
let param_conditions = unrolled.param_conditions;

let type_use_generics = match is_instance {
false => quote!(T),
Expand Down Expand Up @@ -882,6 +892,10 @@ fn expand_benchmark(
#home::assert_impl_all!(#param_types: #home::ParamRange);
)*

#(
#param_conditions
)*

#[allow(non_camel_case_types)]
#(
#fn_attrs
Expand Down
12 changes: 7 additions & 5 deletions frame/support/test/tests/benchmark_ui/bad_param_range.stderr
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
error: The start of a `ParamRange` must be less than or equal to the end
--> tests/benchmark_ui/bad_param_range.rs:10:21
|
10 | fn bench(x: Linear<3, 1>) {
| ^
error[E0080]: evaluation of constant value failed
--> tests/benchmark_ui/bad_param_range.rs:5:1
|
5 | #[benchmarks]
| ^^^^^^^^^^^^^ attempt to compute `0_usize - 1_usize`, which would overflow
gupnik marked this conversation as resolved.
Show resolved Hide resolved
|
= note: this error originates in the macro `frame_benchmarking::v2::const_assert` which comes from the expansion of the attribute macro `benchmarks` (in Nightly builds, run with -Z macro-backtrace for more info)
28 changes: 28 additions & 0 deletions frame/support/test/tests/benchmark_ui/pass/valid_const_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use frame_benchmarking::v2::*;
use frame_support_test::Config;
use frame_support::parameter_types;

#[benchmarks]
mod benches {
use super::*;

const MY_CONST: u32 = 100;

const fn my_fn() -> u32 {
200
}

parameter_types! {
const MyConst: u32 = MY_CONST;
}

#[benchmark(skip_meta, extra)]
fn bench(a: Linear<{MY_CONST * 2}, {my_fn() + MyConst::get()}>) {
let a = 2 + 2;
#[block]
{}
assert_eq!(a, 4);
}
}

fn main() {}