diff --git a/lib/middlewares/src/metering.rs b/lib/middlewares/src/metering.rs index ba9cff4576b..a90faca2388 100644 --- a/lib/middlewares/src/metering.rs +++ b/lib/middlewares/src/metering.rs @@ -13,8 +13,8 @@ use std::fmt; use std::sync::{Arc, Mutex}; use wasmer::wasmparser::{Operator, Type as WpType, TypeOrFuncType as WpTypeOrFuncType}; use wasmer::{ - ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, Instance, LocalFunctionIndex, - MiddlewareError, MiddlewareReaderState, ModuleMiddleware, Mutability, Type, + AsContextMut, ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, Instance, + LocalFunctionIndex, MiddlewareError, MiddlewareReaderState, ModuleMiddleware, Mutability, Type, }; use wasmer_types::{GlobalIndex, ModuleInfo}; @@ -281,12 +281,15 @@ impl u64 + Send + Sync> FunctionMiddleware for FunctionMeter /// matches!(get_remaining_points(instance), MeteringPoints::Remaining(points) if points > 0) /// } /// ``` -pub fn get_remaining_points(instance: &Instance) -> MeteringPoints { +pub fn get_remaining_points( + mut ctx: impl AsContextMut, + instance: &Instance, +) -> MeteringPoints { let exhausted: i32 = instance .exports .get_global("wasmer_metering_points_exhausted") .expect("Can't get `wasmer_metering_points_exhausted` from Instance") - .get() + .get(&mut ctx) .try_into() .expect("`wasmer_metering_points_exhausted` from Instance has wrong type"); @@ -298,7 +301,7 @@ pub fn get_remaining_points(instance: &Instance) -> MeteringPoints { .exports .get_global("wasmer_metering_remaining_points") .expect("Can't get `wasmer_metering_remaining_points` from Instance") - .get() + .get(&mut ctx) .try_into() .expect("`wasmer_metering_remaining_points` from Instance has wrong type"); @@ -331,19 +334,23 @@ pub fn get_remaining_points(instance: &Instance) -> MeteringPoints { /// set_remaining_points(instance, new_limit); /// } /// ``` -pub fn set_remaining_points(instance: &Instance, points: u64) { +pub fn set_remaining_points( + mut ctx: impl AsContextMut, + instance: &Instance, + points: u64, +) { instance .exports .get_global("wasmer_metering_remaining_points") .expect("Can't get `wasmer_metering_remaining_points` from Instance") - .set(points.into()) + .set(&mut ctx, points.into()) .expect("Can't set `wasmer_metering_remaining_points` in Instance"); instance .exports .get_global("wasmer_metering_points_exhausted") .expect("Can't get `wasmer_metering_points_exhausted` from Instance") - .set(0i32.into()) + .set(&mut ctx, 0i32.into()) .expect("Can't set `wasmer_metering_points_exhausted` in Instance"); } @@ -352,7 +359,10 @@ mod tests { use super::*; use std::sync::Arc; - use wasmer::{imports, wat2wasm, CompilerConfig, Cranelift, Module, Store, Universal}; + use wasmer::{ + imports, wat2wasm, AsContextMut, CompilerConfig, Context, Cranelift, Module, Store, + TypedFunction, Universal, + }; fn cost_function(operator: &Operator) -> u64 { match operator { @@ -385,11 +395,12 @@ mod tests { compiler_config.push_middleware(metering); let store = Store::new_with_engine(&Universal::new(compiler_config).engine()); let module = Module::new(&store, bytecode()).unwrap(); + let mut ctx = Context::new(module.store(), ()); // Instantiate - let instance = Instance::new(&module, &imports! {}).unwrap(); + let instance = Instance::new(&mut ctx, &module, &imports! {}).unwrap(); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(10) ); @@ -399,28 +410,31 @@ mod tests { // * `local.get $value` is a `Operator::LocalGet` which costs 1 point; // * `i32.const` is a `Operator::I32Const` which costs 1 point; // * `i32.add` is a `Operator::I32Add` which costs 2 points. - let add_one = instance + let add_one: TypedFunction = instance .exports .get_function("add_one") .unwrap() - .native::() + .native(&ctx) .unwrap(); - add_one.call(1).unwrap(); + add_one.call(&mut ctx.as_context_mut(), 1).unwrap(); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(6) ); // Second call - add_one.call(1).unwrap(); + add_one.call(&mut ctx.as_context_mut(), 1).unwrap(); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(2) ); // Third call fails due to limit - assert!(add_one.call(1).is_err()); - assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted); + assert!(add_one.call(&mut ctx.as_context_mut(), 1).is_err()); + assert_eq!( + get_remaining_points(ctx.as_context_mut(), &instance), + MeteringPoints::Exhausted + ); } #[test] @@ -430,49 +444,53 @@ mod tests { compiler_config.push_middleware(metering); let store = Store::new_with_engine(&Universal::new(compiler_config).engine()); let module = Module::new(&store, bytecode()).unwrap(); + let mut ctx = Context::new(module.store(), ()); // Instantiate - let instance = Instance::new(&module, &imports! {}).unwrap(); + let instance = Instance::new(&mut ctx, &module, &imports! {}).unwrap(); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(10) ); - let add_one = instance + let add_one: TypedFunction = instance .exports .get_function("add_one") .unwrap() - .native::() + .native(&ctx) .unwrap(); // Increase a bit to have enough for 3 calls - set_remaining_points(&instance, 12); + set_remaining_points(ctx.as_context_mut(), &instance, 12); // Ensure we can use the new points now - add_one.call(1).unwrap(); + add_one.call(&mut ctx.as_context_mut(), 1).unwrap(); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(8) ); - add_one.call(1).unwrap(); + add_one.call(&mut ctx.as_context_mut(), 1).unwrap(); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(4) ); - add_one.call(1).unwrap(); + add_one.call(&mut ctx.as_context_mut(), 1).unwrap(); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(0) ); - assert!(add_one.call(1).is_err()); - assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted); + assert!(add_one.call(&mut ctx.as_context_mut(), 1).is_err()); + assert_eq!( + get_remaining_points(ctx.as_context_mut(), &instance), + MeteringPoints::Exhausted + ); // Add some points for another call - set_remaining_points(&instance, 4); + set_remaining_points(ctx.as_context_mut(), &instance, 4); assert_eq!( - get_remaining_points(&instance), + get_remaining_points(ctx.as_context_mut(), &instance), MeteringPoints::Remaining(4) ); }