Skip to content

Commit

Permalink
wazevo: adds support for context cancelation (#1709)
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
  • Loading branch information
mathetake authored Sep 14, 2023
1 parent 69c15b1 commit 173fae7
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 38 deletions.
2 changes: 1 addition & 1 deletion internal/engine/wazevo/backend/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3100,7 +3100,7 @@ L9 (SSA Block: blk6):
t.Run(tc.name, func(t *testing.T) {
ssab := ssa.NewBuilder()
offset := wazevoapi.NewModuleContextOffsetData(tc.m)
fc := frontend.NewFrontendCompiler(tc.m, ssab, &offset)
fc := frontend.NewFrontendCompiler(tc.m, ssab, &offset, false)
machine := newMachine()
machine.DisableStackCheck()
be := backend.NewCompiler(context.Background(), machine, ssab)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func (m *machine) insertStackBoundsCheck(requiredStackSize int64, cur *instructi
ldrAddress.asULoad(operandNR(tmpRegVReg), addressMode{
kind: addressModeKindRegUnsignedImm12,
rn: x0VReg, // execution context is always the first argument
imm: wazevoapi.ExecutionContextOffsets.StackGrowCallSequenceAddress.I64(),
imm: wazevoapi.ExecutionContextOffsets.StackGrowCallTrampolineAddress.I64(),
}, 64)
cur = linkInstr(cur, ldrAddress)

Expand Down
34 changes: 31 additions & 3 deletions internal/engine/wazevo/call_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ type (
stackGrowRequiredSize uintptr
// memoryGrowTrampolineAddress holds the address of memory grow trampoline function.
memoryGrowTrampolineAddress *byte
// stackGrowCallSequenceAddress holds the address of stack grow call sequence function.
stackGrowCallSequenceAddress *byte
// stackGrowCallTrampolineAddress holds the address of stack grow trampoline function.
stackGrowCallTrampolineAddress *byte
// checkModuleExitCodeTrampolineAddress holds the address of check-module-exit-code function.
checkModuleExitCodeTrampolineAddress *byte
// savedRegisters is the opaque spaces for save/restore registers.
// We want to align 16 bytes for each register, so we use [64][2]uint64.
_ uint64
savedRegisters [64][2]uint64
// goFunctionCallCalleeModuleContextOpaque is the pointer to the target Go function's moduleContextOpaque.
goFunctionCallCalleeModuleContextOpaque uintptr
Expand Down Expand Up @@ -138,6 +139,19 @@ func (c *callEngine) CallWithStack(ctx context.Context, paramResultStack []uint6

// CallWithStack implements api.Function.
func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint64) (err error) {
p := c.parent
ensureTermination := p.parent.ensureTermination
m := p.module
if ensureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the module and return the error.
m.CloseWithCtxErr(ctx)
return m.FailIfClosed()
default:
}
}

var paramResultPtr *uint64
if len(paramResultStack) > 0 {
paramResultPtr = &paramResultStack[0]
Expand Down Expand Up @@ -165,6 +179,11 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
}
}()

if ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}

entrypoint(c.preambleExecutable, c.executable, c.execCtxPtr, c.parent.opaquePtr, paramResultPtr, c.stackTop)
for {
switch ec := c.execCtx.exitCode; ec & wazevoapi.ExitCodeMask {
Expand Down Expand Up @@ -210,6 +229,15 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
f.Call(ctx, mod, c.execCtx.goFunctionCallStack[:])
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr, c.execCtx.stackPointerBeforeGoCall)
case wazevoapi.ExitCodeCheckModuleExitCode:
// Note: this operation must be done in Go, not native code. The reason is that
// native code cannot be preempted and that means it can block forever if there are not
// enough OS threads (which we don't have control over).
if err := m.FailIfClosed(); err != nil {
panic(err)
}
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr, c.execCtx.stackPointerBeforeGoCall)
case wazevoapi.ExitCodeUnreachable:
panic(wasmruntime.ErrRuntimeUnreachable)
case wazevoapi.ExitCodeMemoryOutOfBounds:
Expand Down
49 changes: 34 additions & 15 deletions internal/engine/wazevo/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type (
sharedFunctions struct {
// memoryGrowExecutable is a compiled executable for memory.grow builtin function.
memoryGrowExecutable []byte
// checkModuleExitCode is a compiled executable for checking module instance exit code. This
// is used when ensureTermination is true.
checkModuleExitCode []byte
// stackGrowExecutable is a compiled executable for growing stack builtin function.
stackGrowExecutable []byte
entryPreambles map[*wasm.FunctionType][]byte
Expand All @@ -54,10 +57,11 @@ type (
compiledModule struct {
executable []byte
// functionOffsets maps a local function index to the offset in the executable.
functionOffsets []int
parent *engine
module *wasm.Module
entryPreambles []*byte // indexed-correlated with the type index.
functionOffsets []int
parent *engine
module *wasm.Module
entryPreambles []*byte // indexed-correlated with the type index.
ensureTermination bool

// The followings are only available for non host modules.

Expand All @@ -78,7 +82,7 @@ func NewEngine(ctx context.Context, _ api.CoreFeatures, _ filecache.Cache) wasm.
machine: machine,
be: be,
}
e.compileBuiltinFunctions()
e.compileSharedFunctions()
return e
}

Expand Down Expand Up @@ -108,6 +112,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
e.rels = e.rels[:0]
cm := &compiledModule{
offsets: wazevoapi.NewModuleContextOffsetData(module), parent: e, module: module,
ensureTermination: ensureTermination,
}

if module.IsHostModule {
Expand All @@ -131,7 +136,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene

// Creates new compiler instances which are reused for each function.
ssaBuilder := ssa.NewBuilder()
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets)
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination)
machine := newMachine()
be := backend.NewCompiler(ctx, machine, ssaBuilder)

Expand All @@ -154,7 +159,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
ctx = wazevoapi.SetCurrentFunctionName(ctx, fmt.Sprintf("[%d/%d] \"%s\"", i, len(module.CodeSection)-1, name))
}

body, rels, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, listeners, ensureTermination)
body, rels, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, listeners)
if err != nil {
return nil, fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err)
}
Expand Down Expand Up @@ -214,7 +219,7 @@ func (e *engine) compileLocalWasmFunction(
fe *frontend.Compiler,
ssaBuilder ssa.Builder,
be backend.Compiler,
_ []experimental.FunctionListener, _ bool,
_ []experimental.FunctionListener,
) (body []byte, rels []backend.RelocationInfo, err error) {
typ := &module.TypeSection[module.FunctionSection[localFunctionIndex]]
codeSeg := &module.CodeSection[localFunctionIndex]
Expand Down Expand Up @@ -468,7 +473,7 @@ func (e *engine) NewModuleEngine(m *wasm.Module, mi *wasm.ModuleInstance) (wasm.
return me, nil
}

func (e *engine) compileBuiltinFunctions() {
func (e *engine) compileSharedFunctions() {
e.sharedFunctions = &sharedFunctions{entryPreambles: make(map[*wasm.FunctionType][]byte)}

e.be.Init()
Expand All @@ -480,6 +485,15 @@ func (e *engine) compileBuiltinFunctions() {
e.sharedFunctions.memoryGrowExecutable = mmapExecutable(src)
}

e.be.Init()
{
src := e.machine.CompileGoFunctionTrampoline(wazevoapi.ExitCodeCheckModuleExitCode, &ssa.Signature{
Params: []ssa.Type{ssa.TypeI32 /* exec context */},
Results: []ssa.Type{ssa.TypeI32},
}, false)
e.sharedFunctions.checkModuleExitCode = mmapExecutable(src)
}

// TODO: table grow, etc.

e.be.Init()
Expand All @@ -491,21 +505,26 @@ func (e *engine) compileBuiltinFunctions() {
e.setFinalizer(e.sharedFunctions, sharedFunctionsFinalizer)
}

func sharedFunctionsFinalizer(bf *sharedFunctions) {
if err := platform.MunmapCodeSegment(bf.memoryGrowExecutable); err != nil {
func sharedFunctionsFinalizer(sf *sharedFunctions) {
if err := platform.MunmapCodeSegment(sf.memoryGrowExecutable); err != nil {
panic(err)
}
if err := platform.MunmapCodeSegment(sf.checkModuleExitCode); err != nil {
panic(err)
}
if err := platform.MunmapCodeSegment(bf.stackGrowExecutable); err != nil {
if err := platform.MunmapCodeSegment(sf.stackGrowExecutable); err != nil {
panic(err)
}
for _, f := range bf.entryPreambles {
for _, f := range sf.entryPreambles {
if err := platform.MunmapCodeSegment(f); err != nil {
panic(err)
}
}

bf.memoryGrowExecutable = nil
bf.stackGrowExecutable = nil
sf.memoryGrowExecutable = nil
sf.checkModuleExitCode = nil
sf.stackGrowExecutable = nil
sf.entryPreambles = nil
}

func compiledModuleFinalizer(cm *compiledModule) {
Expand Down
30 changes: 24 additions & 6 deletions internal/engine/wazevo/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,37 @@ import (
)

func Test_sharedFunctionsFinalizer(t *testing.T) {
bf := &sharedFunctions{}
sf := &sharedFunctions{}

b1, err := platform.MmapCodeSegment(100)
require.NoError(t, err)

b2, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
bf.memoryGrowExecutable = b1
bf.stackGrowExecutable = b2

sharedFunctionsFinalizer(bf)
require.Nil(t, bf.memoryGrowExecutable)
require.Nil(t, bf.stackGrowExecutable)
b3, err := platform.MmapCodeSegment(100)
require.NoError(t, err)

b4, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
b5, err := platform.MmapCodeSegment(100)
require.NoError(t, err)

preabmles := map[*wasm.FunctionType][]byte{
{Params: []wasm.ValueType{}}: b4,
{Params: []wasm.ValueType{wasm.ValueTypeI32}}: b5,
}

sf.memoryGrowExecutable = b1
sf.stackGrowExecutable = b2
sf.checkModuleExitCode = b3
sf.entryPreambles = preabmles

sharedFunctionsFinalizer(sf)
require.Nil(t, sf.memoryGrowExecutable)
require.Nil(t, sf.stackGrowExecutable)
require.Nil(t, sf.checkModuleExitCode)
require.Nil(t, sf.entryPreambles)
}

func Test_compiledModuleFinalizer(t *testing.T) {
Expand Down
18 changes: 14 additions & 4 deletions internal/engine/wazevo/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ type Compiler struct {
m *wasm.Module
offset *wazevoapi.ModuleContextOffsetData
// ssaBuilder is a ssa.Builder used by this frontend.
ssaBuilder ssa.Builder
signatures map[*wasm.FunctionType]*ssa.Signature
memoryGrowSig ssa.Signature
ssaBuilder ssa.Builder
signatures map[*wasm.FunctionType]*ssa.Signature
memoryGrowSig ssa.Signature
checkModuleExitCodeSig ssa.Signature
checkModuleExitCodeArg [1]ssa.Value
ensureTermination bool

// Followings are reset by per function.

Expand All @@ -43,13 +46,14 @@ type Compiler struct {
}

// NewFrontendCompiler returns a frontend Compiler.
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData) *Compiler {
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool) *Compiler {
c := &Compiler{
m: m,
ssaBuilder: ssaBuilder,
br: bytes.NewReader(nil),
wasmLocalToVariable: make(map[wasm.Index]ssa.Variable),
offset: offset,
ensureTermination: ensureTermination,
}

c.signatures = make(map[*wasm.FunctionType]*ssa.Signature, len(m.TypeSection)+1)
Expand All @@ -70,6 +74,12 @@ func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoa
}
c.ssaBuilder.DeclareSignature(&c.memoryGrowSig)

c.checkModuleExitCodeSig = ssa.Signature{
ID: c.memoryGrowSig.ID + 1,
// Only takes execution context.
Params: []ssa.Type{ssa.TypeI64},
}
c.ssaBuilder.DeclareSignature(&c.checkModuleExitCodeSig)
return c
}

Expand Down
35 changes: 33 additions & 2 deletions internal/engine/wazevo/frontend/frontend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ func TestCompiler_LowerToSSA(t *testing.T) {
// what output should look like, you can run:
// `~/wasmtime/target/debug/clif-util wasm --target aarch64-apple-darwin testcase.wat -p -t`
for _, tc := range []struct {
name string
name string
ensureTermination bool
// m is the *wasm.Module to be compiled in this test.
m *wasm.Module
// targetIndex is the index of a local function to be compiled in this test.
Expand Down Expand Up @@ -219,6 +220,36 @@ blk0: (exec_ctx:i64, module_ctx:i64)
blk1: () <-- (blk0,blk1)
Jump blk1
`,
},
{
name: "loop - br / ensure termination", m: testcases.LoopBr.Module,
ensureTermination: true,
exp: `
signatures:
sig2: i64_v
blk0: (exec_ctx:i64, module_ctx:i64)
Jump blk1
blk1: () <-- (blk0,blk1)
v2:i64 = Load exec_ctx, 0x58
CallIndirect v2:sig2, exec_ctx
Jump blk1
blk2: ()
`,
expAfterOpt: `
signatures:
sig2: i64_v
blk0: (exec_ctx:i64, module_ctx:i64)
Jump blk1
blk1: () <-- (blk0,blk1)
v2:i64 = Load exec_ctx, 0x58
CallIndirect v2:sig2, exec_ctx
Jump blk1
`,
},
{
Expand Down Expand Up @@ -1736,7 +1767,7 @@ blk4: () <-- (blk2,blk3)
b := ssa.NewBuilder()

offset := wazevoapi.NewModuleContextOffsetData(tc.m)
fc := NewFrontendCompiler(tc.m, b, &offset)
fc := NewFrontendCompiler(tc.m, b, &offset, tc.ensureTermination)
typeIndex := tc.m.FunctionSection[tc.targetIndex]
code := &tc.m.CodeSection[tc.targetIndex]
fc.Init(tc.targetIndex, &tc.m.TypeSection[typeIndex], code.LocalTypes, code.Body)
Expand Down
13 changes: 13 additions & 0 deletions internal/engine/wazevo/frontend/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,19 @@ func (c *Compiler) lowerCurrentOpcode() {

c.switchTo(originalLen, loopHeader)

if c.ensureTermination {
checkModuleExitCodePtr := builder.AllocateInstruction().
AsLoad(c.execCtxPtrValue,
wazevoapi.ExecutionContextOffsets.CheckModuleExitCodeTrampolineAddress.U32(),
ssa.TypeI64,
).Insert(builder).Return()

c.checkModuleExitCodeArg[0] = c.execCtxPtrValue

builder.AllocateInstruction().
AsCallIndirect(checkModuleExitCodePtr, &c.checkModuleExitCodeSig, c.checkModuleExitCodeArg[:]).
Insert(builder)
}
case wasm.OpcodeIf:
bt := c.readBlockType()

Expand Down
3 changes: 2 additions & 1 deletion internal/engine/wazevo/module_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ func (m *moduleEngine) NewFunction(index wasm.Index) api.Function {
}

ce.execCtx.memoryGrowTrampolineAddress = &m.parent.sharedFunctions.memoryGrowExecutable[0]
ce.execCtx.stackGrowCallSequenceAddress = &m.parent.sharedFunctions.stackGrowExecutable[0]
ce.execCtx.stackGrowCallTrampolineAddress = &m.parent.sharedFunctions.stackGrowExecutable[0]
ce.execCtx.checkModuleExitCodeTrampolineAddress = &m.parent.sharedFunctions.checkModuleExitCode[0]
ce.init()
return ce
}
Expand Down
3 changes: 2 additions & 1 deletion internal/engine/wazevo/wazevo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ func Test_ExecutionContextOffsets(t *testing.T) {
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackPointerBeforeGoCall)), offsets.StackPointerBeforeGoCall)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowRequiredSize)), offsets.StackGrowRequiredSize)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.memoryGrowTrampolineAddress)), offsets.MemoryGrowTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowCallSequenceAddress)), offsets.StackGrowCallSequenceAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowCallTrampolineAddress)), offsets.StackGrowCallTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.checkModuleExitCodeTrampolineAddress)), offsets.CheckModuleExitCodeTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.savedRegisters))%16, wazevoapi.Offset(0),
"SavedRegistersBegin must be aligned to 16 bytes")
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.savedRegisters)), offsets.SavedRegistersBegin)
Expand Down
Loading

0 comments on commit 173fae7

Please sign in to comment.