Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wazevo: adds support for context cancelation #1709

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/engine/wazevo/backend/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3100,7 +3100,7 @@ L9 (SSA Block: blk6):
t.Run(tc.name, func(t *testing.T) {
ssab := ssa.NewBuilder()
offset := wazevoapi.NewModuleContextOffsetData(tc.m)
fc := frontend.NewFrontendCompiler(tc.m, ssab, &offset)
fc := frontend.NewFrontendCompiler(tc.m, ssab, &offset, false)
machine := newMachine()
machine.DisableStackCheck()
be := backend.NewCompiler(context.Background(), machine, ssab)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func (m *machine) insertStackBoundsCheck(requiredStackSize int64, cur *instructi
ldrAddress.asULoad(operandNR(tmpRegVReg), addressMode{
kind: addressModeKindRegUnsignedImm12,
rn: x0VReg, // execution context is always the first argument
imm: wazevoapi.ExecutionContextOffsets.StackGrowCallSequenceAddress.I64(),
imm: wazevoapi.ExecutionContextOffsets.StackGrowCallTrampolineAddress.I64(),
}, 64)
cur = linkInstr(cur, ldrAddress)

Expand Down
34 changes: 31 additions & 3 deletions internal/engine/wazevo/call_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ type (
stackGrowRequiredSize uintptr
// memoryGrowTrampolineAddress holds the address of memory grow trampoline function.
memoryGrowTrampolineAddress *byte
// stackGrowCallSequenceAddress holds the address of stack grow call sequence function.
stackGrowCallSequenceAddress *byte
// stackGrowCallTrampolineAddress holds the address of stack grow trampoline function.
stackGrowCallTrampolineAddress *byte
// checkModuleExitCodeTrampolineAddress holds the address of check-module-exit-code function.
checkModuleExitCodeTrampolineAddress *byte
// savedRegisters is the opaque spaces for save/restore registers.
// We want to align 16 bytes for each register, so we use [64][2]uint64.
_ uint64
savedRegisters [64][2]uint64
// goFunctionCallCalleeModuleContextOpaque is the pointer to the target Go function's moduleContextOpaque.
goFunctionCallCalleeModuleContextOpaque uintptr
Expand Down Expand Up @@ -138,6 +139,19 @@ func (c *callEngine) CallWithStack(ctx context.Context, paramResultStack []uint6

// CallWithStack implements api.Function.
func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint64) (err error) {
p := c.parent
ensureTermination := p.parent.ensureTermination
m := p.module
if ensureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the module and return the error.
m.CloseWithCtxErr(ctx)
return m.FailIfClosed()
default:
}
}

var paramResultPtr *uint64
if len(paramResultStack) > 0 {
paramResultPtr = &paramResultStack[0]
Expand Down Expand Up @@ -165,6 +179,11 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
}
}()

if ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}

entrypoint(c.preambleExecutable, c.executable, c.execCtxPtr, c.parent.opaquePtr, paramResultPtr, c.stackTop)
for {
switch ec := c.execCtx.exitCode; ec & wazevoapi.ExitCodeMask {
Expand Down Expand Up @@ -210,6 +229,15 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
f.Call(ctx, mod, c.execCtx.goFunctionCallStack[:])
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr, c.execCtx.stackPointerBeforeGoCall)
case wazevoapi.ExitCodeCheckModuleExitCode:
// Note: this operation must be done in Go, not native code. The reason is that
// native code cannot be preempted and that means it can block forever if there are not
// enough OS threads (which we don't have control over).
if err := m.FailIfClosed(); err != nil {
panic(err)
}
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr, c.execCtx.stackPointerBeforeGoCall)
case wazevoapi.ExitCodeUnreachable:
panic(wasmruntime.ErrRuntimeUnreachable)
case wazevoapi.ExitCodeMemoryOutOfBounds:
Expand Down
49 changes: 34 additions & 15 deletions internal/engine/wazevo/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type (
sharedFunctions struct {
// memoryGrowExecutable is a compiled executable for memory.grow builtin function.
memoryGrowExecutable []byte
// checkModuleExitCode is a compiled executable for checking module instance exit code. This
// is used when ensureTermination is true.
checkModuleExitCode []byte
// stackGrowExecutable is a compiled executable for growing stack builtin function.
stackGrowExecutable []byte
entryPreambles map[*wasm.FunctionType][]byte
Expand All @@ -54,10 +57,11 @@ type (
compiledModule struct {
executable []byte
// functionOffsets maps a local function index to the offset in the executable.
functionOffsets []int
parent *engine
module *wasm.Module
entryPreambles []*byte // indexed-correlated with the type index.
functionOffsets []int
parent *engine
module *wasm.Module
entryPreambles []*byte // indexed-correlated with the type index.
ensureTermination bool

// The followings are only available for non host modules.

Expand All @@ -78,7 +82,7 @@ func NewEngine(ctx context.Context, _ api.CoreFeatures, _ filecache.Cache) wasm.
machine: machine,
be: be,
}
e.compileBuiltinFunctions()
e.compileSharedFunctions()
return e
}

Expand Down Expand Up @@ -108,6 +112,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
e.rels = e.rels[:0]
cm := &compiledModule{
offsets: wazevoapi.NewModuleContextOffsetData(module), parent: e, module: module,
ensureTermination: ensureTermination,
}

if module.IsHostModule {
Expand All @@ -131,7 +136,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene

// Creates new compiler instances which are reused for each function.
ssaBuilder := ssa.NewBuilder()
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets)
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination)
machine := newMachine()
be := backend.NewCompiler(ctx, machine, ssaBuilder)

Expand All @@ -154,7 +159,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
ctx = wazevoapi.SetCurrentFunctionName(ctx, fmt.Sprintf("[%d/%d] \"%s\"", i, len(module.CodeSection)-1, name))
}

body, rels, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, listeners, ensureTermination)
body, rels, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, listeners)
if err != nil {
return nil, fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err)
}
Expand Down Expand Up @@ -214,7 +219,7 @@ func (e *engine) compileLocalWasmFunction(
fe *frontend.Compiler,
ssaBuilder ssa.Builder,
be backend.Compiler,
_ []experimental.FunctionListener, _ bool,
_ []experimental.FunctionListener,
) (body []byte, rels []backend.RelocationInfo, err error) {
typ := &module.TypeSection[module.FunctionSection[localFunctionIndex]]
codeSeg := &module.CodeSection[localFunctionIndex]
Expand Down Expand Up @@ -468,7 +473,7 @@ func (e *engine) NewModuleEngine(m *wasm.Module, mi *wasm.ModuleInstance) (wasm.
return me, nil
}

func (e *engine) compileBuiltinFunctions() {
func (e *engine) compileSharedFunctions() {
e.sharedFunctions = &sharedFunctions{entryPreambles: make(map[*wasm.FunctionType][]byte)}

e.be.Init()
Expand All @@ -480,6 +485,15 @@ func (e *engine) compileBuiltinFunctions() {
e.sharedFunctions.memoryGrowExecutable = mmapExecutable(src)
}

e.be.Init()
{
src := e.machine.CompileGoFunctionTrampoline(wazevoapi.ExitCodeCheckModuleExitCode, &ssa.Signature{
Params: []ssa.Type{ssa.TypeI32 /* exec context */},
Results: []ssa.Type{ssa.TypeI32},
}, false)
e.sharedFunctions.checkModuleExitCode = mmapExecutable(src)
}

// TODO: table grow, etc.

e.be.Init()
Expand All @@ -491,21 +505,26 @@ func (e *engine) compileBuiltinFunctions() {
e.setFinalizer(e.sharedFunctions, sharedFunctionsFinalizer)
}

func sharedFunctionsFinalizer(bf *sharedFunctions) {
if err := platform.MunmapCodeSegment(bf.memoryGrowExecutable); err != nil {
func sharedFunctionsFinalizer(sf *sharedFunctions) {
if err := platform.MunmapCodeSegment(sf.memoryGrowExecutable); err != nil {
panic(err)
}
if err := platform.MunmapCodeSegment(sf.checkModuleExitCode); err != nil {
panic(err)
}
if err := platform.MunmapCodeSegment(bf.stackGrowExecutable); err != nil {
if err := platform.MunmapCodeSegment(sf.stackGrowExecutable); err != nil {
panic(err)
}
for _, f := range bf.entryPreambles {
for _, f := range sf.entryPreambles {
if err := platform.MunmapCodeSegment(f); err != nil {
panic(err)
}
}

bf.memoryGrowExecutable = nil
bf.stackGrowExecutable = nil
sf.memoryGrowExecutable = nil
sf.checkModuleExitCode = nil
sf.stackGrowExecutable = nil
sf.entryPreambles = nil
}

func compiledModuleFinalizer(cm *compiledModule) {
Expand Down
30 changes: 24 additions & 6 deletions internal/engine/wazevo/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,37 @@ import (
)

func Test_sharedFunctionsFinalizer(t *testing.T) {
bf := &sharedFunctions{}
sf := &sharedFunctions{}

b1, err := platform.MmapCodeSegment(100)
require.NoError(t, err)

b2, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
bf.memoryGrowExecutable = b1
bf.stackGrowExecutable = b2

sharedFunctionsFinalizer(bf)
require.Nil(t, bf.memoryGrowExecutable)
require.Nil(t, bf.stackGrowExecutable)
b3, err := platform.MmapCodeSegment(100)
require.NoError(t, err)

b4, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
b5, err := platform.MmapCodeSegment(100)
require.NoError(t, err)

preabmles := map[*wasm.FunctionType][]byte{
{Params: []wasm.ValueType{}}: b4,
{Params: []wasm.ValueType{wasm.ValueTypeI32}}: b5,
}

sf.memoryGrowExecutable = b1
sf.stackGrowExecutable = b2
sf.checkModuleExitCode = b3
sf.entryPreambles = preabmles

sharedFunctionsFinalizer(sf)
require.Nil(t, sf.memoryGrowExecutable)
require.Nil(t, sf.stackGrowExecutable)
require.Nil(t, sf.checkModuleExitCode)
require.Nil(t, sf.entryPreambles)
}

func Test_compiledModuleFinalizer(t *testing.T) {
Expand Down
18 changes: 14 additions & 4 deletions internal/engine/wazevo/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ type Compiler struct {
m *wasm.Module
offset *wazevoapi.ModuleContextOffsetData
// ssaBuilder is a ssa.Builder used by this frontend.
ssaBuilder ssa.Builder
signatures map[*wasm.FunctionType]*ssa.Signature
memoryGrowSig ssa.Signature
ssaBuilder ssa.Builder
signatures map[*wasm.FunctionType]*ssa.Signature
memoryGrowSig ssa.Signature
checkModuleExitCodeSig ssa.Signature
checkModuleExitCodeArg [1]ssa.Value
ensureTermination bool

// Followings are reset by per function.

Expand All @@ -43,13 +46,14 @@ type Compiler struct {
}

// NewFrontendCompiler returns a frontend Compiler.
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData) *Compiler {
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool) *Compiler {
c := &Compiler{
m: m,
ssaBuilder: ssaBuilder,
br: bytes.NewReader(nil),
wasmLocalToVariable: make(map[wasm.Index]ssa.Variable),
offset: offset,
ensureTermination: ensureTermination,
}

c.signatures = make(map[*wasm.FunctionType]*ssa.Signature, len(m.TypeSection)+1)
Expand All @@ -70,6 +74,12 @@ func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoa
}
c.ssaBuilder.DeclareSignature(&c.memoryGrowSig)

c.checkModuleExitCodeSig = ssa.Signature{
ID: c.memoryGrowSig.ID + 1,
// Only takes execution context.
Params: []ssa.Type{ssa.TypeI64},
}
c.ssaBuilder.DeclareSignature(&c.checkModuleExitCodeSig)
return c
}

Expand Down
35 changes: 33 additions & 2 deletions internal/engine/wazevo/frontend/frontend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ func TestCompiler_LowerToSSA(t *testing.T) {
// what output should look like, you can run:
// `~/wasmtime/target/debug/clif-util wasm --target aarch64-apple-darwin testcase.wat -p -t`
for _, tc := range []struct {
name string
name string
ensureTermination bool
// m is the *wasm.Module to be compiled in this test.
m *wasm.Module
// targetIndex is the index of a local function to be compiled in this test.
Expand Down Expand Up @@ -219,6 +220,36 @@ blk0: (exec_ctx:i64, module_ctx:i64)

blk1: () <-- (blk0,blk1)
Jump blk1
`,
},
{
name: "loop - br / ensure termination", m: testcases.LoopBr.Module,
ensureTermination: true,
exp: `
signatures:
sig2: i64_v

blk0: (exec_ctx:i64, module_ctx:i64)
Jump blk1

blk1: () <-- (blk0,blk1)
v2:i64 = Load exec_ctx, 0x58
CallIndirect v2:sig2, exec_ctx
Jump blk1

blk2: ()
`,
expAfterOpt: `
signatures:
sig2: i64_v

blk0: (exec_ctx:i64, module_ctx:i64)
Jump blk1

blk1: () <-- (blk0,blk1)
v2:i64 = Load exec_ctx, 0x58
CallIndirect v2:sig2, exec_ctx
Jump blk1
`,
},
{
Expand Down Expand Up @@ -1736,7 +1767,7 @@ blk4: () <-- (blk2,blk3)
b := ssa.NewBuilder()

offset := wazevoapi.NewModuleContextOffsetData(tc.m)
fc := NewFrontendCompiler(tc.m, b, &offset)
fc := NewFrontendCompiler(tc.m, b, &offset, tc.ensureTermination)
typeIndex := tc.m.FunctionSection[tc.targetIndex]
code := &tc.m.CodeSection[tc.targetIndex]
fc.Init(tc.targetIndex, &tc.m.TypeSection[typeIndex], code.LocalTypes, code.Body)
Expand Down
13 changes: 13 additions & 0 deletions internal/engine/wazevo/frontend/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,19 @@ func (c *Compiler) lowerCurrentOpcode() {

c.switchTo(originalLen, loopHeader)

if c.ensureTermination {
checkModuleExitCodePtr := builder.AllocateInstruction().
AsLoad(c.execCtxPtrValue,
wazevoapi.ExecutionContextOffsets.CheckModuleExitCodeTrampolineAddress.U32(),
ssa.TypeI64,
).Insert(builder).Return()

c.checkModuleExitCodeArg[0] = c.execCtxPtrValue

builder.AllocateInstruction().
AsCallIndirect(checkModuleExitCodePtr, &c.checkModuleExitCodeSig, c.checkModuleExitCodeArg[:]).
Insert(builder)
}
case wasm.OpcodeIf:
bt := c.readBlockType()

Expand Down
3 changes: 2 additions & 1 deletion internal/engine/wazevo/module_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ func (m *moduleEngine) NewFunction(index wasm.Index) api.Function {
}

ce.execCtx.memoryGrowTrampolineAddress = &m.parent.sharedFunctions.memoryGrowExecutable[0]
ce.execCtx.stackGrowCallSequenceAddress = &m.parent.sharedFunctions.stackGrowExecutable[0]
ce.execCtx.stackGrowCallTrampolineAddress = &m.parent.sharedFunctions.stackGrowExecutable[0]
ce.execCtx.checkModuleExitCodeTrampolineAddress = &m.parent.sharedFunctions.checkModuleExitCode[0]
ce.init()
return ce
}
Expand Down
3 changes: 2 additions & 1 deletion internal/engine/wazevo/wazevo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ func Test_ExecutionContextOffsets(t *testing.T) {
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackPointerBeforeGoCall)), offsets.StackPointerBeforeGoCall)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowRequiredSize)), offsets.StackGrowRequiredSize)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.memoryGrowTrampolineAddress)), offsets.MemoryGrowTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowCallSequenceAddress)), offsets.StackGrowCallSequenceAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowCallTrampolineAddress)), offsets.StackGrowCallTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.checkModuleExitCodeTrampolineAddress)), offsets.CheckModuleExitCodeTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.savedRegisters))%16, wazevoapi.Offset(0),
"SavedRegistersBegin must be aligned to 16 bytes")
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.savedRegisters)), offsets.SavedRegistersBegin)
Expand Down
Loading
Loading