From f4ee84922b3c102e265eb553d4fd365cf3c4e13d Mon Sep 17 00:00:00 2001 From: Frank Emrich Date: Fri, 9 Jun 2023 14:24:01 +0100 Subject: [PATCH] Use array calling convention trampoline to enter functions used as continuations (#35) * extend payload slots in VMContext to 16 bytes * enter continuation functions via array_call trampoline * Test passing and returning values to/from continuaton functions --- crates/environ/src/vmoffsets.rs | 7 +- crates/runtime/src/continuation.rs | 39 ++++--- tests/all/pooling_allocator.rs | 12 +- .../typed-continuations/cont_args.wast | 104 ++++++++++++++++++ 4 files changed, 139 insertions(+), 23 deletions(-) create mode 100644 tests/misc_testsuite/typed-continuations/cont_args.wast diff --git a/crates/environ/src/vmoffsets.rs b/crates/environ/src/vmoffsets.rs index 3b84b426bfb1..b92e58a09c37 100644 --- a/crates/environ/src/vmoffsets.rs +++ b/crates/environ/src/vmoffsets.rs @@ -40,6 +40,10 @@ fn cast_to_u32(sz: usize) -> u32 { u32::try_from(sz).expect("overflow in cast from usize to u32") } +/// Maximum number of arguments and return values a continuation can have. +/// Also maximum number of arguments and return values any tag can have. +pub const MAXIMUM_CONTINUATION_PAYLOAD_COUNT: u32 = 6; + /// Align an offset used in this module to a specific byte-width by rounding up #[inline] fn align(offset: u32, width: u32) -> u32 { @@ -478,8 +482,9 @@ impl From> for VMOffsets

{ size(typed_continuations_store) = ret.ptr.size(), align(16), + // `size_of_vmglobal_definition` corresponds to maximum size of a value size(typed_continuations_payloads) - = cmul(6, ret.ptr.size()), + = cmul(MAXIMUM_CONTINUATION_PAYLOAD_COUNT, ret.ptr.size_of_vmglobal_definition()), align(16), // TODO(dhil): This could probably be done more // efficiently by packing the pointer into the above 16 byte // alignment diff --git a/crates/runtime/src/continuation.rs b/crates/runtime/src/continuation.rs index b18e72e003d9..efc4b06d4950 100644 --- a/crates/runtime/src/continuation.rs +++ b/crates/runtime/src/continuation.rs @@ -1,9 +1,9 @@ //! Continuations TODO use crate::instance::TopOfStackPointer; -use crate::vmcontext::{VMContext, VMFuncRef, VMOpaqueContext, VMWasmCallFunction}; +use crate::vmcontext::{VMArrayCallFunction, VMFuncRef, VMOpaqueContext, ValRaw}; use crate::{Instance, TrapReason}; -use std::ptr::NonNull; +use wasmtime_environ::MAXIMUM_CONTINUATION_PAYLOAD_COUNT; use wasmtime_fibre::{Fiber, FiberStack, Suspend}; /// TODO @@ -11,36 +11,44 @@ use wasmtime_fibre::{Fiber, FiberStack, Suspend}; pub fn cont_new(instance: &mut Instance, func: *mut u8) -> *mut u8 { let func = func as *mut VMFuncRef; let callee_ctx = unsafe { (*func).vmctx }; - let caller_ctx = instance.vmctx(); + let caller_ctx = VMOpaqueContext::from_vmcontext(instance.vmctx()); let f = unsafe { // TODO(dhil): Not sure whether we should use // VMWasmCallFunction or VMNativeCallFunction here. std::mem::transmute::< - NonNull, - unsafe extern "C" fn(*mut VMOpaqueContext, *mut VMContext, ()) -> u32, - >((*func).wasm_call.unwrap()) + VMArrayCallFunction, + unsafe extern "C" fn(*mut VMOpaqueContext, *mut VMOpaqueContext, *mut ValRaw, usize), + >((*func).array_call) }; + let payload_ptr = unsafe { instance.get_typed_continuations_payloads_mut() as *mut ValRaw }; let fiber = Box::new( Fiber::new( FiberStack::new(4096).unwrap(), - move |_first_val: (), _suspend: &Suspend<(), u32, u32>| { + move |_first_val: (), _suspend: &Suspend<(), u32, ()>| { // TODO(frank-emrich): Need to load arguments (if present) from // payload storage and pass to f. // Consider getting the array_call version from func // to achieve this instead. - unsafe { f(callee_ctx, caller_ctx, ()) } + unsafe { + f( + callee_ctx, + caller_ctx, + payload_ptr, + MAXIMUM_CONTINUATION_PAYLOAD_COUNT as usize, + ) + } }, ) .unwrap(), ); - let ptr: *mut Fiber<'static, (), u32, u32> = Box::into_raw(fiber); + let ptr: *mut Fiber<'static, (), u32, ()> = Box::into_raw(fiber); ptr as *mut u8 } /// TODO #[inline(always)] pub fn resume(instance: &mut Instance, cont: *mut u8) -> Result { - let cont = cont as *mut Fiber<'static, (), u32, u32>; + let cont = cont as *mut Fiber<'static, (), u32, ()>; let cont_stack = unsafe { &cont.as_ref().unwrap().stack() }; let tsp = TopOfStackPointer::as_raw(instance.tsp()); unsafe { cont_stack.write_parent(tsp) }; @@ -51,14 +59,13 @@ pub fn resume(instance: &mut Instance, cont: *mut u8) -> Result .get_mut()) = 0 }; match unsafe { cont.as_mut().unwrap().resume(()) } { - Ok(result) => { + Ok(()) => { let drop_box: Box> = unsafe { Box::from_raw(cont) }; drop(drop_box); // I think this would be covered by the close brace below anyway // Store the result. - let payloads_addr = unsafe { instance.get_typed_continuations_payloads_mut() }; - unsafe { - std::ptr::write(payloads_addr, result); - } + + // The result of the continuation was written to the first entry of the payload + // store by virtue of using the array calling trampoline to execute it Ok(0) // zero value = return normally. //Ok(9999) @@ -70,7 +77,7 @@ pub fn resume(instance: &mut Instance, cont: *mut u8) -> Result debug_assert_eq!(tag & signal_mask, 0); unsafe { let cont_store_ptr = instance.get_typed_continuations_store_mut() - as *mut *mut Fiber<'static, (), u32, u32>; + as *mut *mut Fiber<'static, (), u32, ()>; cont_store_ptr.write(cont) }; Ok(tag | signal_mask) diff --git a/tests/all/pooling_allocator.rs b/tests/all/pooling_allocator.rs index 883635244b43..422e463171e5 100644 --- a/tests/all/pooling_allocator.rs +++ b/tests/all/pooling_allocator.rs @@ -678,11 +678,11 @@ fn instance_too_large() -> Result<()> { let engine = Engine::new(&config)?; let expected = "\ -instance allocation for this module requires 320 bytes which exceeds the \ +instance allocation for this module requires 368 bytes which exceeds the \ configured maximum of 16 bytes; breakdown of allocation requirement: - * 55.00% - 176 bytes - instance state management - * 15.00% - 48 bytes - typed continuations payloads + * 47.83% - 176 bytes - instance state management + * 26.09% - 96 bytes - typed continuations payloads "; match Module::new(&engine, "(module)") { Ok(_) => panic!("should have failed to compile"), @@ -696,11 +696,11 @@ configured maximum of 16 bytes; breakdown of allocation requirement: lots_of_globals.push_str(")"); let expected = "\ -instance allocation for this module requires 1920 bytes which exceeds the \ +instance allocation for this module requires 1968 bytes which exceeds the \ configured maximum of 16 bytes; breakdown of allocation requirement: - * 9.17% - 176 bytes - instance state management - * 83.33% - 1600 bytes - defined globals + * 8.94% - 176 bytes - instance state management + * 81.30% - 1600 bytes - defined globals "; match Module::new(&engine, &lots_of_globals) { Ok(_) => panic!("should have failed to compile"), diff --git a/tests/misc_testsuite/typed-continuations/cont_args.wast b/tests/misc_testsuite/typed-continuations/cont_args.wast new file mode 100644 index 000000000000..a78036801058 --- /dev/null +++ b/tests/misc_testsuite/typed-continuations/cont_args.wast @@ -0,0 +1,104 @@ +;; This file tests passing arguments to functions used has continuations and +;; returning values from such continuations on ordinary (i.e., non-suspend) exit + +(module + + (type $unit_to_unit (func)) + (type $unit_to_int (func (result i32))) + (type $int_to_unit (func (param i32))) + (type $int_to_int (func (param i32) (result i32))) + + + (type $f1_t (func (param i32) (result i32))) + (type $f1_ct (cont $f1_t)) + + (type $f2_t (func (param i32) (result i32))) + (type $f2_ct (cont $f2_t)) + + (type $f3_t (func (param i32) (result i32))) + (type $f3_ct (cont $f3_t)) + + (type $res_unit_to_unit (cont $unit_to_unit)) + (type $res_int_to_unit (cont $int_to_unit)) + (type $res_int_to_int (cont $int_to_int)) + (type $res_unit_to_int (cont $unit_to_int)) + + (tag $e1_unit_to_unit) + (tag $e2_int_to_unit (param i32)) + (tag $e3_int_to_int (param i32) (result i32)) + + (global $i (mut i32) (i32.const 0)) + + + ;; Used for testing the passing of arguments to continuation function and returning values out of them + (func $f1 (export "f1") (param $x i32) (result i32) + (global.set $i (i32.add (global.get $i) (local.get $x))) + (suspend $e1_unit_to_unit) + (i32.add (i32.const 2) (local.get $x))) + + ;; Used for testing case where no suspend happens at all + (func $f2 (export "f2") (param $x i32) (result i32) + (global.set $i (i32.add (global.get $i) (local.get $x))) + (i32.add (i32.const 2) (local.get $x))) + + ;; Same as $f1, but additionally passes payloads to and from handler + (func $f3 (export "f3") (param $x i32) (result i32) + (i32.add (local.get $x) (i32.const 1)) + (suspend $e3_int_to_int) + ;; return x + value returned received back from $e3 + (i32.add (local.get $x))) + + + (func $test_case_1 (export "test_case_1") (result i32) + ;; remove this eventually + (global.set $i (i32.const 0)) + (block $on_e1 (result (ref $res_unit_to_int)) + (resume $f1_ct (tag $e1_unit_to_unit $on_e1) (i32.const 100) (cont.new $f1_ct (ref.func $f1))) + ;; unreachable: we never intend to invoke the resumption when handling + ;; $e1 invoked from $f2 + (unreachable)) + ;; after on_e1, stack: [resumption] + (drop) ;; drop resumption + (global.get $i)) + + (func $test_case_2 (export "test_case_2") (result i32) + ;; remove this eventually + (global.set $i (i32.const 0)) + ;;(local $finish_f3 (ref $res_unit_to_unit)) + (block $on_e1 (result (ref $res_unit_to_int)) + (resume $f1_ct (tag $e1_unit_to_unit $on_e1) (i32.const 49) (cont.new $f1_ct (ref.func $f1))) + (unreachable)) + ;; after on_e1, stack: [resumption] + ;;(local.set $finish_f2) + (resume $res_unit_to_int) + ;; the resume above resumes execution of f2, which finishes without further suspends + (i32.add (global.get $i))) + +(func $test_case_3 (export "test_case_3") (result i32) + ;; remove this eventually + (global.set $i (i32.const 0)) + (resume $f2_ct (i32.const 49) (cont.new $f2_ct (ref.func $f2))) + (i32.add (global.get $i))) + + +(func $test_case_4 (export "test_case_4") (result i32) + (local $k (ref $res_int_to_int)) + + (block $on_e3 (result i32 (ref $res_int_to_int)) + (resume $f3_ct (tag $e3_int_to_int $on_e3) (i32.const 49) (cont.new $f3_ct (ref.func $f3))) + (unreachable)) + ;; after on_e3, expected stack: [50 resumption] + (local.set $k) + + ;; add 1 to value 50 received from f6 via tag e3, thus passing 51 back to it + (i32.add (i32.const 1)) + (resume $res_int_to_int (local.get $k)) + ;; expecting to get 49 (original argument to function) + 51 (passed above) back + ) + +) + +(assert_return (invoke "test_case_1") (i32.const 100)) +(assert_return (invoke "test_case_2") (i32.const 100)) +(assert_return (invoke "test_case_3") (i32.const 100)) +(assert_return (invoke "test_case_4") (i32.const 100))