Skip to content

Commit

Permalink
Removes requirement to pass a HostFunctionCallContext (#260)
Browse files Browse the repository at this point in the history
This allows users to decouple from wazero code when authoring host
functions. Notably, this allows them to opt out of using a context, or
only using a Go context instead of HostFunctionCallContext.

This backfills docs on how to write host functions (in simple terms).

Finally, this does not optimize engines to avoid propagating context or
looking up memory if it would never be used. That could be done later.

Signed-off-by: Adrian Cole <adrian@tetrate.io>
  • Loading branch information
codefromthecrypt authored Feb 18, 2022
1 parent fbe153b commit 3d25f48
Show file tree
Hide file tree
Showing 16 changed files with 437 additions and 144 deletions.
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ func main() {
// Decode the binary as WebAssembly module.
mod, _ := wazero.DecodeModuleBinary(source)

// Initialize the execution environment called "store" with Interpreter-based engine.
store := wazero.NewStore()

// Instantiate the module, which returns its exported functions
functions, _ := store.Instantiate(mod)
// Instantiate the module with a Wasm Interpreter, to return its exported functions
functions, _ := wazero.NewStore().Instantiate(mod)

// Get the factorial function
fac, _ := functions.GetFunctionI64Return("fac")
Expand Down
3 changes: 1 addition & 2 deletions examples/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/stretchr/testify/require"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/wasm"
)

// Test_Simple implements a basic function in go: hello. This is imported as the Wasm name "$hello" and run on start.
Expand All @@ -20,7 +19,7 @@ func Test_Simple(t *testing.T) {
require.NoError(t, err)

stdout := new(bytes.Buffer)
goFunc := func(wasm.HostFunctionCallContext) {
goFunc := func() {
_, _ = fmt.Fprintln(stdout, "hello!")
}

Expand Down
8 changes: 4 additions & 4 deletions internal/wasi/wasi.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ const (
FunctionPathUnlinkFile = "path_unlink_file"
FunctionPollOneoff = "poll_oneoff"

// ProcExit terminates the execution of the module with an exit code.
// FunctionProcExit terminates the execution of the module with an exit code.
// See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit
FunctionProcExit = "proc_exit"

Expand Down Expand Up @@ -339,7 +339,7 @@ type SnapshotPreview1 interface {
//
// Note: ImportProcExit shows this signature in the WebAssembly 1.0 (MVP) Text Format.
// See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit
ProcExit(ctx wasm.HostFunctionCallContext, rval uint32)
ProcExit(rval uint32)

// TODO: ProcRaise
// TODO: SchedYield
Expand Down Expand Up @@ -505,8 +505,8 @@ func (a *wasiAPI) ClockTimeGet(ctx wasm.HostFunctionCallContext, id uint32, prec
return wasi.ErrnoSuccess
}

// ProcExit implements API.ProcExit
func (a *wasiAPI) ProcExit(ctx wasm.HostFunctionCallContext, exitCode uint32) {
// ProcExit implements SnapshotPreview1.ProcExit
func (a *wasiAPI) ProcExit(exitCode uint32) {
// Panic in a host function is caught by the engines, and the value of the panic is returned as the error of the CallFunction.
// See the document of API.ProcExit.
panic(wasi.ExitCode(exitCode))
Expand Down
129 changes: 109 additions & 20 deletions internal/wasm/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,131 @@ import (
publicwasm "github.com/tetratelabs/wazero/wasm"
)

// FunctionKind identifies the type of function that can be called.
type FunctionKind byte

const (
// FunctionKindWasm is not a host function: it is implemented in Wasm.
FunctionKindWasm FunctionKind = iota
// FunctionKindHostNoContext is a function implemented in Go, with a signature matching FunctionType.
FunctionKindHostNoContext
// FunctionKindHostGoContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is
// a context.Context.
FunctionKindHostGoContext
// FunctionKindHostFunctionCallContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is
// a HostFunctionCallContext.
FunctionKindHostFunctionCallContext
)

type HostFunction struct {
name string
name string
// functionKind is never FunctionKindWasm
functionKind FunctionKind
functionType *FunctionType
goFunc *reflect.Value
}

func NewHostFunction(funcName string, goFunc interface{}) (hf *HostFunction, err error) {
hf = &HostFunction{name: funcName}
fn := reflect.ValueOf(goFunc)
hf.goFunc = &fn
hf.functionKind, hf.functionType, err = GetFunctionType(hf.name, hf.goFunc)
return
}

// Below are reflection code to get the interface type used to parse functions and set values.

var hostFunctionCallContextType = reflect.TypeOf((*publicwasm.HostFunctionCallContext)(nil)).Elem()
var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem()
var errorType = reflect.TypeOf((*error)(nil)).Elem()

// GetHostFunctionCallContextValue returns a reflect.Value for a context param[0], or nil if there isn't one.
func GetHostFunctionCallContextValue(fk FunctionKind, ctx *HostFunctionCallContext) *reflect.Value {
switch fk {
case FunctionKindHostNoContext: // no special param zero
case FunctionKindHostGoContext:
val := reflect.New(goContextType).Elem()
val.Set(reflect.ValueOf(ctx.Context()))
return &val
case FunctionKindHostFunctionCallContext:
val := reflect.New(hostFunctionCallContextType).Elem()
val.Set(reflect.ValueOf(ctx))
return &val
}
return nil
}

// GetFunctionType returns the function type corresponding to the function signature or errs if invalid.
func GetFunctionType(name string, fn *reflect.Value) (*FunctionType, error) {
func GetFunctionType(name string, fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err error) {
if fn.Kind() != reflect.Func {
return nil, fmt.Errorf("%s value is not a reflect.Func: %s", name, fn.String())
err = fmt.Errorf("%s is a %s, but should be a Func", name, fn.Kind().String())
return
}
p := fn.Type()
if p.NumIn() == 0 { // TODO: actually check the type
return nil, fmt.Errorf("%s must accept wasm.HostFunctionCallContext as the first param", name)
}

paramTypes := make([]ValueType, p.NumIn()-1)
for i := range paramTypes {
kind := p.In(i + 1).Kind()
if t, ok := getTypeOf(kind); !ok {
return nil, fmt.Errorf("%s param[%d] is unsupported: %s", name, i, kind.String())
} else {
paramTypes[i] = t
pOffset := 0
pCount := p.NumIn()
fk = FunctionKindHostNoContext
if pCount > 0 && p.In(0).Kind() == reflect.Interface {
p0 := p.In(0)
if p0.Implements(hostFunctionCallContextType) {
fk = FunctionKindHostFunctionCallContext
pOffset = 1
pCount--
} else if p0.Implements(goContextType) {
fk = FunctionKindHostGoContext
pOffset = 1
pCount--
}
}
rCount := p.NumOut()
switch rCount {
case 0, 1: // ok
default:
err = fmt.Errorf("%s has more than one result", name)
return
}

ft = &FunctionType{Params: make([]ValueType, pCount), Results: make([]ValueType, rCount)}

resultTypes := make([]ValueType, p.NumOut())
for i := range resultTypes {
kind := p.Out(i).Kind()
if t, ok := getTypeOf(kind); !ok {
return nil, fmt.Errorf("%s result[%d] is unsupported: %s", name, i, kind.String())
for i := 0; i < len(ft.Params); i++ {
pI := p.In(i + pOffset)
if t, ok := getTypeOf(pI.Kind()); ok {
ft.Params[i] = t
continue
}

// Now, we will definitely err, decide which message is best
var arg0Type reflect.Type
if hc := pI.Implements(hostFunctionCallContextType); hc {
arg0Type = hostFunctionCallContextType
} else if gc := pI.Implements(goContextType); gc {
arg0Type = goContextType
}

if arg0Type != nil {
err = fmt.Errorf("%s param[%d] is a %s, which may be defined only once as param[0]", name, i+pOffset, arg0Type)
} else {
resultTypes[i] = t
err = fmt.Errorf("%s param[%d] is unsupported: %s", name, i+pOffset, pI.Kind())
}
return
}

if rCount == 0 {
return
}
result := p.Out(0)
if t, ok := getTypeOf(result.Kind()); ok {
ft.Results[0] = t
return
}

if e := result.Implements(errorType); e {
err = fmt.Errorf("%s result[0] is an error, which is unsupported", name)
} else {
err = fmt.Errorf("%s result[0] is unsupported: %s", name, result.Kind())
}
return &FunctionType{Params: paramTypes, Results: resultTypes}, nil
return
}

func getTypeOf(kind reflect.Kind) (ValueType, bool) {
Expand Down
123 changes: 123 additions & 0 deletions internal/wasm/host_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package internalwasm

import (
"context"
"math"
"reflect"
"testing"

"github.com/stretchr/testify/require"

publicwasm "github.com/tetratelabs/wazero/wasm"
)

func TestMemoryInstance_HasLen(t *testing.T) {
Expand Down Expand Up @@ -448,3 +452,122 @@ func TestMemoryInstance_WriteFloat64Le(t *testing.T) {
})
}
}

func TestGetFunctionType(t *testing.T) {
i32, i64, f32, f64 := ValueTypeI32, ValueTypeI64, ValueTypeF32, ValueTypeF64

tests := []struct {
name string
inputFunc interface{}
expectedKind FunctionKind
expectedType *FunctionType
}{
{
name: "nullary",
inputFunc: func() {},
expectedKind: FunctionKindHostNoContext,
expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}},
},
{
name: "wasm.HostFunctionCallContext void return",
inputFunc: func(publicwasm.HostFunctionCallContext) {},
expectedKind: FunctionKindHostFunctionCallContext,
expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}},
},
{
name: "context.Context void return",
inputFunc: func(context.Context) {},
expectedKind: FunctionKindHostGoContext,
expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}},
},
{
name: "all supported params and i32 result",
inputFunc: func(uint32, uint64, float32, float64) uint32 { return 0 },
expectedKind: FunctionKindHostNoContext,
expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}},
},
{
name: "all supported params and i32 result - wasm.HostFunctionCallContext",
inputFunc: func(publicwasm.HostFunctionCallContext, uint32, uint64, float32, float64) uint32 { return 0 },
expectedKind: FunctionKindHostFunctionCallContext,
expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}},
},
{
name: "all supported params and i32 result - context.Context",
inputFunc: func(context.Context, uint32, uint64, float32, float64) uint32 { return 0 },
expectedKind: FunctionKindHostGoContext,
expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}},
},
}

for _, tt := range tests {
tc := tt

t.Run(tc.name, func(t *testing.T) {
rVal := reflect.ValueOf(tc.inputFunc)
fk, ft, err := GetFunctionType("fn", &rVal)
require.NoError(t, err)
require.Equal(t, tc.expectedKind, fk)
require.Equal(t, tc.expectedType, ft)
})
}
}

func TestGetFunctionTypeErrors(t *testing.T) {
tests := []struct {
name string
input interface{}
expectedErr string
}{
{
name: "not a func",
input: struct{}{},
expectedErr: "fn is a struct, but should be a Func",
},
{
name: "unsupported param",
input: func(uint32, string) {},
expectedErr: "fn param[1] is unsupported: string",
},
{
name: "unsupported result",
input: func() string { return "" },
expectedErr: "fn result[0] is unsupported: string",
},
{
name: "error result",
input: func() error { return nil },
expectedErr: "fn result[0] is an error, which is unsupported",
},
{
name: "multiple results",
input: func() (uint64, uint32) { return 0, 0 },
expectedErr: "fn has more than one result",
},
{
name: "multiple context types",
input: func(publicwasm.HostFunctionCallContext, context.Context) error { return nil },
expectedErr: "fn param[1] is a context.Context, which may be defined only once as param[0]",
},
{
name: "multiple context.Context",
input: func(context.Context, uint64, context.Context) error { return nil },
expectedErr: "fn param[2] is a context.Context, which may be defined only once as param[0]",
},
{
name: "multiple wasm.HostFunctionCallContext",
input: func(publicwasm.HostFunctionCallContext, uint64, publicwasm.HostFunctionCallContext) error { return nil },
expectedErr: "fn param[2] is a wasm.HostFunctionCallContext, which may be defined only once as param[0]",
},
}

for _, tt := range tests {
tc := tt

t.Run(tc.name, func(t *testing.T) {
rVal := reflect.ValueOf(tc.input)
_, _, err := GetFunctionType("fn", &rVal)
require.EqualError(t, err, tc.expectedErr)
})
}
}
Loading

0 comments on commit 3d25f48

Please sign in to comment.