Skip to content

Commit

Permalink
Add experimental ImportResolver
Browse files Browse the repository at this point in the history
If set in context, the ImportResolver will be used as the first step in resolving imports.

Closes tetratelabs#2294
  • Loading branch information
bep committed Aug 4, 2024
1 parent 5ad3f06 commit e7ef70e
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 26 deletions.
19 changes: 19 additions & 0 deletions experimental/importresolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package experimental

import (
"context"

"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/internal/expctxkeys"
)

// ImportResolver is an experimental func type that, if set,
// will be used as the first step in resolving imports.
// See https://github.com/tetratelabs/wazero/issues/2294
// If the import name is not found, it should return nil.
type ImportResolver func(name string) api.Module

// WithImportResolver returns a new context with the given ImportResolver.
func WithImportResolver(ctx context.Context, resolver ImportResolver) context.Context {
return context.WithValue(ctx, expctxkeys.ImportResolverKey{}, resolver)
}
101 changes: 101 additions & 0 deletions experimental/importresolver_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package experimental_test

import (
"bytes"
"context"
_ "embed"
"fmt"
"log"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
)

var (
// These wasm files were generated by the following:
// cd testdata
// wat2wasm --debug-names inoutdispatcher.wat
// wat2wasm --debug-names inoutdispatcherclient.wat

//go:embed testdata/inoutdispatcher.wasm
inoutdispatcherWasm []byte
//go:embed testdata/inoutdispatcherclient.wasm
inoutdispatcherclientWasm []byte
)

func Example_importResolver() {
ctx := context.Background()

r := wazero.NewRuntime(ctx)
defer r.Close(ctx)

// The client imports the inoutdispatcher module that reads from stdin and writes to stdout.
// This means that we need multiple instances of the inoutdispatcher module to have different stdin/stdout.
// This example demonstrates a way to do that.
type mod struct {
in bytes.Buffer
out bytes.Buffer

client api.Module
}

wasi_snapshot_preview1.MustInstantiate(ctx, r)

const numInstances = 3
mods := make([]*mod, numInstances)
for i := range mods {
mods[i] = &mod{}
m := mods[i]
idm, err := r.CompileModule(ctx, inoutdispatcherWasm)
if err != nil {
log.Panicln(err)
}
idcm, err := r.CompileModule(ctx, inoutdispatcherclientWasm)
if err != nil {
log.Panicln(err)
}

const inoutDispatcherModuleName = "inoutdispatcher"

dispatcherInstance, err := r.InstantiateModule(ctx, idm,
wazero.NewModuleConfig().
WithStdin(&m.in).
WithStdout(&m.out).
WithName("")) // Makes it an anonymous module.
if err != nil {
log.Panicln(err)
}

ctx = experimental.WithImportResolver(ctx, func(name string) api.Module {
if name == inoutDispatcherModuleName {
return dispatcherInstance
}
return nil
})

m.client, err = r.InstantiateModule(ctx, idcm, wazero.NewModuleConfig().WithName(fmt.Sprintf("m%d", i)))
if err != nil {
log.Panicln(err)
}

}

for i, m := range mods {
m.in.WriteString(fmt.Sprintf("Module instance #%d", i))
_, err := m.client.ExportedFunction("dispatch").Call(ctx)
if err != nil {
log.Panicln(err)
}
}

for i, m := range mods {
fmt.Printf("out%d: %s\n", i, m.out.String())
}

// Output:
// out0: Module instance #0
// out1: Module instance #1
// out2: Module instance #2
}
63 changes: 63 additions & 0 deletions experimental/importresolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package experimental_test

import (
"context"
"fmt"
"testing"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/internal/testing/binaryencoding"
"github.com/tetratelabs/wazero/internal/testing/require"
"github.com/tetratelabs/wazero/internal/wasm"
)

func TestImportResolver(t *testing.T) {
ctx := context.Background()

r := wazero.NewRuntime(ctx)
defer r.Close(ctx)

for i := 0; i < 5; i++ {
var callCount int
start := func(ctx context.Context) {
callCount++
}
modImport, err := r.NewHostModuleBuilder(fmt.Sprintf("env%d", i)).
NewFunctionBuilder().WithFunc(start).Export("start").
Compile(ctx)
require.NoError(t, err)
// Anonymous module, it will be resolved by the import resolver.
instanceImport, err := r.InstantiateModule(ctx, modImport, wazero.NewModuleConfig().WithName(""))
require.NoError(t, err)

resolveImport := func(name string) api.Module {
if name == "env" {
return instanceImport
}
return nil
}

// Set the import resolver in the context.
ctx = experimental.WithImportResolver(context.Background(), resolveImport)

one := uint32(1)
binary := binaryencoding.EncodeModule(&wasm.Module{
TypeSection: []wasm.FunctionType{{}},
ImportSection: []wasm.Import{{Module: "env", Name: "start", Type: wasm.ExternTypeFunc, DescFunc: 0}},
FunctionSection: []wasm.Index{0},
CodeSection: []wasm.Code{
{Body: []byte{wasm.OpcodeCall, 0, wasm.OpcodeEnd}}, // Call the imported env.start.
},
StartSection: &one,
})

modMain, err := r.CompileModule(ctx, binary)
require.NoError(t, err)

_, err = r.InstantiateModule(ctx, modMain, wazero.NewModuleConfig())
require.NoError(t, err)
require.Equal(t, 1, callCount)
}
}
Binary file added experimental/testdata/inoutdispatcher.wasm
Binary file not shown.
38 changes: 38 additions & 0 deletions experimental/testdata/inoutdispatcher.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
(module
(import "wasi_snapshot_preview1" "fd_read" (func $fd_read (param i32 i32 i32 i32) (result i32)))
(import "wasi_snapshot_preview1" "fd_write" (func $fd_write (param i32 i32 i32 i32) (result i32)))
(memory 1 1 )
(func (export "dispatch")
;; Buffer of 100 chars to read into.
(i32.store (i32.const 4) (i32.const 12))
(i32.store (i32.const 8) (i32.const 100))

(block $done
(loop $read
;; Read from stdin.
(call $fd_read
(i32.const 0) ;; fd; 0 is stdin.
(i32.const 4) ;; iovs
(i32.const 1) ;; iovs_len
(i32.const 8) ;; nread
)

;; If nread is 0, we're done.
(if (i32.eq (i32.load (i32.const 8)) (i32.const 0))
(then br $done)
)

;; Write to stdout.
(drop (call $fd_write
(i32.const 1) ;; fd; 1 is stdout.
(i32.const 4) ;; iovs
(i32.const 1) ;; iovs_len
(i32.const 0) ;; nwritten
))
(br $read)

)
)
)

)
Binary file added experimental/testdata/inoutdispatcherclient.wasm
Binary file not shown.
7 changes: 7 additions & 0 deletions experimental/testdata/inoutdispatcherclient.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

(module
(import "inoutdispatcher" "dispatch" (func $dispatch))
(func (export "dispatch")
(call $dispatch)
)
)
6 changes: 6 additions & 0 deletions internal/expctxkeys/importresolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package expctxkeys

// ImportResolverKey is a context.Context Value key.
// Its associated value should be an ImportResolver.
// See https://github.com/tetratelabs/wazero/issues/2294
type ImportResolverKey struct{}
20 changes: 15 additions & 5 deletions internal/wasm/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func (s *Store) instantiate(
return nil, err
}

if err = m.resolveImports(module); err != nil {
if err = m.resolveImports(ctx, module); err != nil {
return nil, err
}

Expand Down Expand Up @@ -410,12 +410,22 @@ func (s *Store) instantiate(
return
}

func (m *ModuleInstance) resolveImports(module *Module) (err error) {
func (m *ModuleInstance) resolveImports(ctx context.Context, module *Module) (err error) {
// Check if ctx contains an ImportResolver.
resolveImport, _ := ctx.Value(expctxkeys.ImportResolverKey{}).(experimental.ImportResolver)

for moduleName, imports := range module.ImportPerModule {
var importedModule *ModuleInstance
importedModule, err = m.s.module(moduleName)
if err != nil {
return err
if resolveImport != nil {
if v := resolveImport(moduleName); v != nil {
importedModule = v.(*ModuleInstance)
}
}
if importedModule == nil {
importedModule, err = m.s.module(moduleName)
if err != nil {
return err
}
}

for _, i := range imports {
Expand Down
39 changes: 22 additions & 17 deletions internal/wasm/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,13 +701,13 @@ func Test_resolveImports(t *testing.T) {

t.Run("module not instantiated", func(t *testing.T) {
m := &ModuleInstance{s: newStore()}
err := m.resolveImports(&Module{ImportPerModule: map[string][]*Import{"unknown": {{}}}})
err := m.resolveImports(context.Background(), &Module{ImportPerModule: map[string][]*Import{"unknown": {{}}}})
require.EqualError(t, err, "module[unknown] not instantiated")
})
t.Run("export instance not found", func(t *testing.T) {
m := &ModuleInstance{s: newStore()}
m.s.nameToModule[moduleName] = &ModuleInstance{Exports: map[string]*Export{}, ModuleName: moduleName}
err := m.resolveImports(&Module{ImportPerModule: map[string][]*Import{moduleName: {{Name: "unknown"}}}})
err := m.resolveImports(context.Background(), &Module{ImportPerModule: map[string][]*Import{moduleName: {{Name: "unknown"}}}})
require.EqualError(t, err, "\"unknown\" is not exported in module \"test\"")
})
t.Run("func", func(t *testing.T) {
Expand Down Expand Up @@ -743,7 +743,7 @@ func Test_resolveImports(t *testing.T) {
}

m := &ModuleInstance{Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}, s: s, Source: module}
err := m.resolveImports(module)
err := m.resolveImports(context.Background(), module)
require.NoError(t, err)

me := m.Engine.(*mockModuleEngine)
Expand Down Expand Up @@ -773,7 +773,7 @@ func Test_resolveImports(t *testing.T) {
}

m := &ModuleInstance{Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}, s: s, Source: module}
err := m.resolveImports(module)
err := m.resolveImports(context.Background(), module)
require.EqualError(t, err, "import func[test.target]: signature mismatch: v_f32 != v_v")
})
})
Expand All @@ -787,6 +787,7 @@ func Test_resolveImports(t *testing.T) {
Exports: map[string]*Export{name: {Type: ExternTypeGlobal, Index: 0}}, ModuleName: moduleName,
}
err := m.resolveImports(
context.Background(),
&Module{
ImportPerModule: map[string][]*Import{moduleName: {{Name: name, Type: ExternTypeGlobal, DescGlobal: g.Type}}},
},
Expand All @@ -805,11 +806,13 @@ func Test_resolveImports(t *testing.T) {
ModuleName: moduleName,
}
m := &ModuleInstance{Globals: make([]*GlobalInstance, 1), s: s}
err := m.resolveImports(&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{Mutable: true}},
}},
})
err := m.resolveImports(
context.Background(),
&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{Mutable: true}},
}},
})
require.EqualError(t, err, "import global[test.target]: mutability mismatch: true != false")
})
t.Run("type mismatch", func(t *testing.T) {
Expand All @@ -823,11 +826,13 @@ func Test_resolveImports(t *testing.T) {
ModuleName: moduleName,
}
m := &ModuleInstance{Globals: make([]*GlobalInstance, 1), s: s}
err := m.resolveImports(&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{ValType: ValueTypeF64}},
}},
})
err := m.resolveImports(
context.Background(),
&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{ValType: ValueTypeF64}},
}},
})
require.EqualError(t, err, "import global[test.target]: value type mismatch: f64 != i32")
})
})
Expand All @@ -846,7 +851,7 @@ func Test_resolveImports(t *testing.T) {
Engine: importedME,
}
m := &ModuleInstance{s: s, Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}}
err := m.resolveImports(&Module{
err := m.resolveImports(context.Background(), &Module{
ImportPerModule: map[string][]*Import{
moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: &Memory{Max: max}}},
},
Expand All @@ -866,7 +871,7 @@ func Test_resolveImports(t *testing.T) {
ModuleName: moduleName,
}
m := &ModuleInstance{s: s}
err := m.resolveImports(&Module{
err := m.resolveImports(context.Background(), &Module{
ImportPerModule: map[string][]*Import{
moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}},
},
Expand All @@ -886,7 +891,7 @@ func Test_resolveImports(t *testing.T) {
max := uint32(10)
importMemoryType := &Memory{Max: max}
m := &ModuleInstance{s: s}
err := m.resolveImports(&Module{
err := m.resolveImports(context.Background(), &Module{
ImportPerModule: map[string][]*Import{moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}},
})
require.EqualError(t, err, "import memory[test.target]: maximum size mismatch: 10 < 65536")
Expand Down
Loading

0 comments on commit e7ef70e

Please sign in to comment.