diff --git a/README.md b/README.md index 49f06b7d76..0730899574 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ func main() { // Instantiate the module and return its exported functions module, _ := wazero.NewRuntime().InstantiateModuleFromCode(ctx, source) - defer module.Close() + defer module.Close(ctx) // Discover 7! is 5040 fmt.Println(module.ExportedFunction("fac").Call(ctx, 7)) @@ -63,7 +63,7 @@ env, err := r.NewModuleBuilder("env"). if err != nil { log.Fatal(err) } -defer env.Close() +defer env.Close(ctx) ``` While not a standards body like W3C, there is another dominant community in the @@ -77,11 +77,11 @@ For example, here's how you can allow WebAssembly modules to read "/work/home/a.txt" as "/a.txt" or "./a.txt": ```go wm, err := wasi.InstantiateSnapshotPreview1(ctx, r) -defer wm.Close() +defer wm.Close(ctx) config := wazero.ModuleConfig().WithFS(os.DirFS("/work/home")) module, err := r.InstantiateModule(ctx, binary, config) -defer module.Close() +defer module.Close(ctx) ... ``` @@ -302,7 +302,7 @@ top-level project. That said, Takeshi's original motivation is as relevant today as when he started the project, and worthwhile reading: If you want to provide Wasm host environments in your Go programs, currently -there's no other choice than using CGO andleveraging the state-of-the-art +there's no other choice than using CGO leveraging the state-of-the-art runtimes written in C++/Rust (e.g. V8, Wasmtime, Wasmer, WAVM, etc.), and there's no pure Go Wasm runtime out there. (There's only one exception named [wagon](https://github.com/go-interpreter/wagon), but it was archived with the @@ -313,7 +313,7 @@ plugin systems in your Go project and want these plugin systems to be safe/fast/flexible, and enable users to write plugins in their favorite languages. That's where Wasm comes into play. You write your own Wasm host environments and embed Wasm runtime in your projects, and now users are able to -write plugins in their own favorite lanugages (AssembyScript, C, C++, Rust, +write plugins in their own favorite languages (AssemblyScript, C, C++, Rust, Zig, etc.). As a specific example, you maybe write proxy severs in Go and want to allow users to extend the proxy via [Proxy-Wasm ABI](https://github.com/proxy-wasm/spec). Maybe you are writing server-side rendering applications via Wasm, or diff --git a/api/wasm.go b/api/wasm.go index 1983889e83..34b987d34d 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -70,8 +70,8 @@ type Module interface { Name() string // Close is a convenience that invokes CloseWithExitCode with zero. - Close() error - // ^^ not io.Closer as the rationale (static analysis of leaks) is invalid when there are multiple close methods. + // Note: When the context is nil, it defaults to context.Background. + Close(context.Context) error // CloseWithExitCode releases resources allocated for this Module. Use a non-zero exitCode parameter to indicate a // failure to ExportedFunction callers. @@ -82,7 +82,8 @@ type Module interface { // // Calling this inside a host function is safe, and may cause ExportedFunction callers to receive a sys.ExitError // with the exitCode. - CloseWithExitCode(exitCode uint32) error + // Note: When the context is nil, it defaults to context.Background. + CloseWithExitCode(ctx context.Context, exitCode uint32) error // Memory returns a memory defined in this module or nil if there are none wasn't. Memory() Memory @@ -121,7 +122,7 @@ type Function interface { // encoded according to ResultTypes. An error is returned for any failure looking up or invoking the function // including signature mismatch. // - // Note: when `ctx` is nil, it defaults to context.Background. + // Note: When the context is nil, it defaults to context.Background. // Note: If Module.Close or Module.CloseWithExitCode were invoked during this call, the error returned may be a // sys.ExitError. Interpreting this is specific to the module. For example, some "main" functions always call a // function that exits. @@ -153,7 +154,9 @@ type Global interface { // Get returns the last known value of this global. // See Type for how to encode this value from a Go type. - Get() uint64 + // + // Note: When the context is nil, it defaults to context.Background. + Get(context.Context) uint64 } // MutableGlobal is a Global whose value can be updated at runtime (variable). @@ -162,11 +165,14 @@ type MutableGlobal interface { // Set updates the value of this global. // See Global.Type for how to decode this value to a Go type. - Set(v uint64) + // + // Note: When the context is nil, it defaults to context.Background. + Set(ctx context.Context, v uint64) } // Memory allows restricted access to a module's memory. Notably, this does not allow growing. // +// Note: All functions accept a context.Context, which when nil, default to context.Background. // Note: This is an interface for decoupling, not third-party implementations. All implementations are in wazero. // Note: This includes all value types available in WebAssembly 1.0 (20191205) and all are encoded little-endian. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#storage%E2%91%A0 @@ -177,59 +183,67 @@ type Memory interface { // memory has min 0 and max 2 pages, this returns zero. // // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#-hrefsyntax-instr-memorymathsfmemorysize%E2%91%A0 - Size() uint32 + Size(context.Context) uint32 // IndexByte returns the index of the first instance of c in the underlying buffer at the offset or returns false if // not found or out of range. - IndexByte(offset uint32, c byte) (uint32, bool) + IndexByte(ctx context.Context, offset uint32, c byte) (uint32, bool) // ReadByte reads a single byte from the underlying buffer at the offset or returns false if out of range. - ReadByte(offset uint32) (byte, bool) + ReadByte(ctx context.Context, offset uint32) (byte, bool) + + // ReadUint16Le reads a uint16 in little-endian encoding from the underlying buffer at the offset in or returns + // false if out of range. + ReadUint16Le(ctx context.Context, offset uint32) (uint16, bool) // ReadUint32Le reads a uint32 in little-endian encoding from the underlying buffer at the offset in or returns // false if out of range. - ReadUint32Le(offset uint32) (uint32, bool) + ReadUint32Le(ctx context.Context, offset uint32) (uint32, bool) // ReadFloat32Le reads a float32 from 32 IEEE 754 little-endian encoded bits in the underlying buffer at the offset // or returns false if out of range. // See math.Float32bits - ReadFloat32Le(offset uint32) (float32, bool) + ReadFloat32Le(ctx context.Context, offset uint32) (float32, bool) // ReadUint64Le reads a uint64 in little-endian encoding from the underlying buffer at the offset or returns false // if out of range. - ReadUint64Le(offset uint32) (uint64, bool) + ReadUint64Le(ctx context.Context, offset uint32) (uint64, bool) // ReadFloat64Le reads a float64 from 64 IEEE 754 little-endian encoded bits in the underlying buffer at the offset // or returns false if out of range. // See math.Float64bits - ReadFloat64Le(offset uint32) (float64, bool) + ReadFloat64Le(ctx context.Context, offset uint32) (float64, bool) // Read reads byteCount bytes from the underlying buffer at the offset or returns false if out of range. - Read(offset, byteCount uint32) ([]byte, bool) + Read(ctx context.Context, offset, byteCount uint32) ([]byte, bool) // WriteByte writes a single byte to the underlying buffer at the offset in or returns false if out of range. - WriteByte(offset uint32, v byte) bool + WriteByte(ctx context.Context, offset uint32, v byte) bool + + // WriteUint16Le writes the value in little-endian encoding to the underlying buffer at the offset in or returns + // false if out of range. + WriteUint16Le(ctx context.Context, offset uint32, v uint16) bool // WriteUint32Le writes the value in little-endian encoding to the underlying buffer at the offset in or returns // false if out of range. - WriteUint32Le(offset, v uint32) bool + WriteUint32Le(ctx context.Context, offset, v uint32) bool // WriteFloat32Le writes the value in 32 IEEE 754 little-endian encoded bits to the underlying buffer at the offset // or returns false if out of range. // See math.Float32bits - WriteFloat32Le(offset uint32, v float32) bool + WriteFloat32Le(ctx context.Context, offset uint32, v float32) bool // WriteUint64Le writes the value in little-endian encoding to the underlying buffer at the offset in or returns // false if out of range. - WriteUint64Le(offset uint32, v uint64) bool + WriteUint64Le(ctx context.Context, offset uint32, v uint64) bool // WriteFloat64Le writes the value in 64 IEEE 754 little-endian encoded bits to the underlying buffer at the offset // or returns false if out of range. // See math.Float64bits - WriteFloat64Le(offset uint32, v float64) bool + WriteFloat64Le(ctx context.Context, offset uint32, v float64) bool // Write writes the slice to the underlying buffer at the offset or returns false if out of range. - Write(offset uint32, v []byte) bool + Write(ctx context.Context, offset uint32, v []byte) bool } // EncodeI32 encodes the input as a ValueTypeI32. diff --git a/builder.go b/builder.go index 37441f36a5..cb44181fc6 100644 --- a/builder.go +++ b/builder.go @@ -25,10 +25,10 @@ import ( // env, _ := r.NewModuleBuilder("env").ExportFunction("get_random_string", getRandomString).Build(ctx) // // env1, _ := r.InstantiateModuleWithConfig(ctx, env, NewModuleConfig().WithName("env.1")) -// defer env1.Close() +// defer env1.Close(ctx) // // env2, _ := r.InstantiateModuleWithConfig(ctx, env, NewModuleConfig().WithName("env.2")) -// defer env2.Close() +// defer env2.Close(ctx) // // Note: Builder methods do not return errors, to allow chaining. Any validation errors are deferred until Build. // Note: Insertion order is not retained. Anything defined by this builder is sorted lexicographically on Build. @@ -53,17 +53,17 @@ type ModuleBuilder interface { // // Ex. This uses a Go Context: // - // addInts := func(m context.Context, x uint32, uint32) uint32 { + // addInts := func(ctx context.Context, x uint32, uint32) uint32 { // // add a little extra if we put some in the context! - // return x + y + m.Value(extraKey).(uint32) + // return x + y + ctx.Value(extraKey).(uint32) // } // // Ex. This uses an api.Module to reads the parameters from memory. This is important because there are only numeric // types in Wasm. The only way to share other data is via writing memory and sharing offsets. // - // addInts := func(m api.Module, offset uint32) uint32 { - // x, _ := m.Memory().ReadUint32Le(offset) - // y, _ := m.Memory().ReadUint32Le(offset + 4) // 32 bits == 4 bytes! + // addInts := func(ctx context.Context, m api.Module, offset uint32) uint32 { + // x, _ := m.Memory().ReadUint32Le(ctx, offset) + // y, _ := m.Memory().ReadUint32Le(ctx, offset + 4) // 32 bits == 4 bytes! // return x + y // } // @@ -151,12 +151,12 @@ type ModuleBuilder interface { ExportGlobalF64(name string, v float64) ModuleBuilder // Build returns a module to instantiate, or returns an error if any of the configuration is invalid. - Build(ctx context.Context) (*CompiledCode, error) + Build(context.Context) (*CompiledCode, error) // Instantiate is a convenience that calls Build, then Runtime.InstantiateModule // // Note: Fields in the builder are copied during instantiation: Later changes do not affect the instantiated result. - Instantiate(ctx context.Context) (api.Module, error) + Instantiate(context.Context) (api.Module, error) } // moduleBuilder implements ModuleBuilder @@ -274,8 +274,8 @@ func (b *moduleBuilder) Instantiate(ctx context.Context) (api.Module, error) { if err = b.r.store.Engine.CompileModule(ctx, module.module); err != nil { return nil, err } - // *wasm.ModuleInstance cannot be tracked, so we release the cache inside of this function. - defer module.Close() + // *wasm.ModuleInstance cannot be tracked, so we release the cache inside this function. + defer module.Close(ctx) return b.r.InstantiateModuleWithConfig(ctx, module, NewModuleConfig().WithName(b.moduleName)) } } diff --git a/config.go b/config.go index ebf6e44294..dc187024fd 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package wazero import ( + "context" "errors" "fmt" "io" @@ -151,13 +152,12 @@ type CompiledCode struct { compiledEngine wasm.Engine } -// compile-time check to ensure CompiledCode implements io.Closer (consistent with api.Module) -var _ io.Closer = &CompiledCode{} - // Close releases all the allocated resources for this CompiledCode. // // Note: It is safe to call Close while having outstanding calls from Modules instantiated from this *CompiledCode. -func (c *CompiledCode) Close() error { +func (c *CompiledCode) Close(_ context.Context) error { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + c.compiledEngine.DeleteCompiledModule(c.module) // It is possible the underlying may need to return an error later, but in any case this matches api.Module.Close. return nil diff --git a/config_test.go b/config_test.go index 887f856832..1963e32e2a 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package wazero import ( + "context" "io" "math" "testing" @@ -777,3 +778,27 @@ func requireSysContext(t *testing.T, max uint32, args, environ []string, stdin i require.NoError(t, err) return sys } + +func TestCompiledCode_Close(t *testing.T) { + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + e := &mockEngine{name: "1", cachedModules: map[*wasm.Module]struct{}{}} + + var cs []*CompiledCode + for i := 0; i < 10; i++ { + m := &wasm.Module{} + err := e.CompileModule(ctx, m) + require.NoError(t, err) + cs = append(cs, &CompiledCode{module: m, compiledEngine: e}) + } + + // Before Close. + require.Equal(t, 10, len(e.cachedModules)) + + for _, c := range cs { + require.NoError(t, c.Close(ctx)) + } + + // After Close. + require.Zero(t, len(e.cachedModules)) + } +} diff --git a/example_test.go b/example_test.go index 4d6431c1d5..4c35949bc1 100644 --- a/example_test.go +++ b/example_test.go @@ -29,7 +29,7 @@ func Example() { if err != nil { log.Fatal(err) } - defer mod.Close() + defer mod.Close(ctx) // Get a function that can be reused until its module is closed: add := mod.ExportedFunction("add") diff --git a/examples/allocation/rust/greet.go b/examples/allocation/rust/greet.go index a41e7594a0..0e396734fd 100644 --- a/examples/allocation/rust/greet.go +++ b/examples/allocation/rust/greet.go @@ -34,7 +34,7 @@ func main() { if err != nil { log.Fatal(err) } - defer env.Close() + defer env.Close(ctx) // Instantiate a WebAssembly module that imports the "log" function defined // in "env" and exports "memory" and functions we'll use in this example. @@ -42,7 +42,7 @@ func main() { if err != nil { log.Fatal(err) } - defer mod.Close() + defer mod.Close(ctx) // Get references to WebAssembly functions we'll use in this example. greet := mod.ExportedFunction("greet") @@ -67,9 +67,9 @@ func main() { defer deallocate.Call(ctx, namePtr, nameSize) // The pointer is a linear memory offset, which is where we write the name. - if !mod.Memory().Write(uint32(namePtr), []byte(name)) { + if !mod.Memory().Write(ctx, uint32(namePtr), []byte(name)) { log.Fatalf("Memory.Write(%d, %d) out of range of memory size %d", - namePtr, nameSize, mod.Memory().Size()) + namePtr, nameSize, mod.Memory().Size(ctx)) } // Now, we can call "greet", which reads the string we wrote to memory! @@ -91,16 +91,16 @@ func main() { defer deallocate.Call(ctx, uint64(greetingPtr), uint64(greetingSize)) // The pointer is a linear memory offset, which is where we write the name. - if bytes, ok := mod.Memory().Read(greetingPtr, greetingSize); !ok { + if bytes, ok := mod.Memory().Read(ctx, greetingPtr, greetingSize); !ok { log.Fatalf("Memory.Read(%d, %d) out of range of memory size %d", - greetingPtr, greetingSize, mod.Memory().Size()) + greetingPtr, greetingSize, mod.Memory().Size(ctx)) } else { fmt.Println("go >>", string(bytes)) } } -func logString(m api.Module, offset, byteCount uint32) { - buf, ok := m.Memory().Read(offset, byteCount) +func logString(ctx context.Context, m api.Module, offset, byteCount uint32) { + buf, ok := m.Memory().Read(ctx, offset, byteCount) if !ok { log.Fatalf("Memory.Read(%d, %d) out of range", offset, byteCount) } diff --git a/examples/allocation/tinygo/greet.go b/examples/allocation/tinygo/greet.go index 4c77180d7d..59d701293d 100644 --- a/examples/allocation/tinygo/greet.go +++ b/examples/allocation/tinygo/greet.go @@ -35,7 +35,7 @@ func main() { if err != nil { log.Fatal(err) } - defer env.Close() + defer env.Close(ctx) // Note: testdata/greet.go doesn't use WASI, but TinyGo needs it to // implement functions such as panic. @@ -43,7 +43,7 @@ func main() { if err != nil { log.Fatal(err) } - defer wm.Close() + defer wm.Close(ctx) // Instantiate a WebAssembly module that imports the "log" function defined // in "env" and exports "memory" and functions we'll use in this example. @@ -51,7 +51,7 @@ func main() { if err != nil { log.Fatal(err) } - defer mod.Close() + defer mod.Close(ctx) // Get references to WebAssembly functions we'll use in this example. greet := mod.ExportedFunction("greet") @@ -77,9 +77,9 @@ func main() { defer free.Call(ctx, namePtr) // The pointer is a linear memory offset, which is where we write the name. - if !mod.Memory().Write(uint32(namePtr), []byte(name)) { + if !mod.Memory().Write(ctx, uint32(namePtr), []byte(name)) { log.Fatalf("Memory.Write(%d, %d) out of range of memory size %d", - namePtr, nameSize, mod.Memory().Size()) + namePtr, nameSize, mod.Memory().Size(ctx)) } // Now, we can call "greet", which reads the string we wrote to memory! @@ -98,16 +98,16 @@ func main() { greetingPtr := uint32(ptrSize[0] >> 32) greetingSize := uint32(ptrSize[0]) // The pointer is a linear memory offset, which is where we write the name. - if bytes, ok := mod.Memory().Read(greetingPtr, greetingSize); !ok { + if bytes, ok := mod.Memory().Read(ctx, greetingPtr, greetingSize); !ok { log.Fatalf("Memory.Read(%d, %d) out of range of memory size %d", - greetingPtr, greetingSize, mod.Memory().Size()) + greetingPtr, greetingSize, mod.Memory().Size(ctx)) } else { fmt.Println("go >>", string(bytes)) } } -func logString(m api.Module, offset, byteCount uint32) { - buf, ok := m.Memory().Read(offset, byteCount) +func logString(ctx context.Context, m api.Module, offset, byteCount uint32) { + buf, ok := m.Memory().Read(ctx, offset, byteCount) if !ok { log.Fatalf("Memory.Read(%d, %d) out of range", offset, byteCount) } diff --git a/examples/basic/add.go b/examples/basic/add.go index 44d5a14f2a..fa476ae8ec 100644 --- a/examples/basic/add.go +++ b/examples/basic/add.go @@ -32,7 +32,7 @@ func main() { if err != nil { log.Fatal(err) } - defer wasm.Close() + defer wasm.Close(ctx) // Add a module to the runtime named "host/math" which exports one function "add", implemented in Go. host, err := r.NewModuleBuilder("host/math"). @@ -42,7 +42,7 @@ func main() { if err != nil { log.Fatal(err) } - defer host.Close() + defer host.Close(ctx) // Read two args to add. x, y := readTwoArgs() diff --git a/examples/import-go/age-calculator.go b/examples/import-go/age-calculator.go index 41d4a32c87..838411b718 100644 --- a/examples/import-go/age-calculator.go +++ b/examples/import-go/age-calculator.go @@ -44,7 +44,7 @@ func main() { if err != nil { log.Fatal(err) } - defer env.Close() + defer env.Close(ctx) // Instantiate a WebAssembly module named "age-calculator" that imports // functions defined in "env". @@ -87,7 +87,7 @@ func main() { if err != nil { log.Fatal(err) } - defer ageCalculator.Close() + defer ageCalculator.Close(ctx) // Read the birthYear from the arguments to main birthYear, err := strconv.ParseUint(os.Args[1], 10, 64) diff --git a/examples/multiple-results/multiple-results.go b/examples/multiple-results/multiple-results.go index aa8b375aa6..30302c3356 100644 --- a/examples/multiple-results/multiple-results.go +++ b/examples/multiple-results/multiple-results.go @@ -36,14 +36,14 @@ func main() { if err != nil { log.Fatal(err) } - defer wasm.Close() + defer wasm.Close(ctx) // Add a module that uses offset parameters for multiple results, with functions defined in Go. host, err := resultOffsetHostFunctions(ctx, runtime) if err != nil { log.Fatal(err) } - defer host.Close() + defer host.Close(ctx) // wazero enables only W3C recommended features by default. Opt-in to other features like so: runtimeWithMultiValue := wazero.NewRuntimeWithConfig( @@ -56,14 +56,14 @@ func main() { if err != nil { log.Fatal(err) } - defer wasmWithMultiValue.Close() + defer wasmWithMultiValue.Close(ctx) // Add a module that uses multiple results values, with functions defined in Go. hostWithMultiValue, err := multiValueHostFunctions(ctx, runtimeWithMultiValue) if err != nil { log.Fatal(err) } - defer hostWithMultiValue.Close() + defer hostWithMultiValue.Close(ctx) // Call the same function in all modules and print the results to the console. for _, mod := range []api.Module{wasm, host, wasmWithMultiValue, hostWithMultiValue} { @@ -121,8 +121,8 @@ func resultOffsetHostFunctions(ctx context.Context, r wazero.Runtime) (api.Modul // To use result parameters, we need scratch memory. Allocate the least possible: 1 page (64KB). ExportMemoryWithMax("mem", 1, 1). // Define a function that returns a result, while a second result is written to memory. - ExportFunction("get_age", func(m api.Module, resultOffsetAge uint32) (errno uint32) { - if m.Memory().WriteUint64Le(resultOffsetAge, 37) { + ExportFunction("get_age", func(ctx context.Context, m api.Module, resultOffsetAge uint32) (errno uint32) { + if m.Memory().WriteUint64Le(ctx, resultOffsetAge, 37) { return 0 } return 1 // overflow @@ -132,7 +132,7 @@ func resultOffsetHostFunctions(ctx context.Context, r wazero.Runtime) (api.Modul ExportFunction("call_get_age", func(ctx context.Context, m api.Module) (age uint64) { resultOffsetAge := uint32(8) // arbitrary memory offset (in bytes) _, _ = m.ExportedFunction("get_age").Call(ctx, uint64(resultOffsetAge)) - age, _ = m.Memory().ReadUint64Le(resultOffsetAge) + age, _ = m.Memory().ReadUint64Le(ctx, resultOffsetAge) return }).Instantiate(ctx) } diff --git a/examples/replace-import/replace-import.go b/examples/replace-import/replace-import.go index 70a6b14017..3a4e135321 100644 --- a/examples/replace-import/replace-import.go +++ b/examples/replace-import/replace-import.go @@ -22,13 +22,13 @@ func main() { // Instantiate a Go-defined module named "assemblyscript" that exports a // function to close the module that calls "abort". host, err := r.NewModuleBuilder("assemblyscript"). - ExportFunction("abort", func(m api.Module, messageOffset, fileNameOffset, line, col uint32) { - _ = m.CloseWithExitCode(255) + ExportFunction("abort", func(ctx context.Context, m api.Module, messageOffset, fileNameOffset, line, col uint32) { + _ = m.CloseWithExitCode(ctx, 255) }).Instantiate(ctx) if err != nil { log.Fatal(err) } - defer host.Close() + defer host.Close(ctx) // Compile WebAssembly code that needs the function "env.abort". code, err := r.CompileModule(ctx, []byte(`(module $needs-import @@ -39,7 +39,7 @@ func main() { if err != nil { log.Fatal(err) } - defer code.Close() + defer code.Close(ctx) // Instantiate the WebAssembly module, replacing the import "env.abort" // with "assemblyscript.abort". @@ -48,7 +48,7 @@ func main() { if err != nil { log.Fatal(err) } - defer mod.Close() + defer mod.Close(ctx) // Since the above worked, the exported function closes the module. _, err = mod.ExportedFunction("abort").Call(ctx, 0, 0, 0, 0) diff --git a/examples/wasi/cat.go b/examples/wasi/cat.go index 1da8a25a44..0105db1827 100644 --- a/examples/wasi/cat.go +++ b/examples/wasi/cat.go @@ -45,7 +45,7 @@ func main() { if err != nil { log.Fatal(err) } - defer wm.Close() + defer wm.Close(ctx) // InstantiateModuleFromCodeWithConfig runs the "_start" function which is what TinyGo compiles "main" to. // * Set the program name (arg[0]) to "wasi" and add args to write "test.txt" to stdout twice. @@ -54,5 +54,5 @@ func main() { if err != nil { log.Fatal(err) } - defer cat.Close() + defer cat.Close(ctx) } diff --git a/internal/integration_test/bench/bench_test.go b/internal/integration_test/bench/bench_test.go index a2b921ff87..218e03ae72 100644 --- a/internal/integration_test/bench/bench_test.go +++ b/internal/integration_test/bench/bench_test.go @@ -23,13 +23,13 @@ var caseWasm []byte func BenchmarkInvocation(b *testing.B) { b.Run("interpreter", func(b *testing.B) { m := instantiateHostFunctionModuleWithEngine(b, wazero.NewRuntimeConfigInterpreter()) - defer m.Close() + defer m.Close(testCtx) runAllInvocationBenches(b, m) }) if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { b.Run("jit", func(b *testing.B) { m := instantiateHostFunctionModuleWithEngine(b, wazero.NewRuntimeConfigJIT()) - defer m.Close() + defer m.Close(testCtx) runAllInvocationBenches(b, m) }) } @@ -54,14 +54,14 @@ func runInitializationBench(b *testing.B, r wazero.Runtime) { if err != nil { b.Fatal(err) } - defer compiled.Close() + defer compiled.Close(testCtx) b.ResetTimer() for i := 0; i < b.N; i++ { mod, err := r.InstantiateModule(testCtx, compiled) if err != nil { b.Fatal(err) } - mod.Close() + mod.Close(testCtx) } } @@ -172,11 +172,11 @@ func createRuntime(b *testing.B, engine *wazero.RuntimeConfig) wazero.Runtime { } offset := uint32(results[0]) - m.Memory().WriteUint32Le(retBufPtr, offset) - m.Memory().WriteUint32Le(retBufSize, 10) + m.Memory().WriteUint32Le(ctx, retBufPtr, offset) + m.Memory().WriteUint32Le(ctx, retBufSize, 10) b := make([]byte, 10) _, _ = rand.Read(b) - m.Memory().Write(offset, b) + m.Memory().Write(ctx, offset, b) } r := wazero.NewRuntimeWithConfig(engine) diff --git a/internal/integration_test/bench/memory_bench_test.go b/internal/integration_test/bench/memory_bench_test.go index 8691ab5d52..58678fd0d9 100644 --- a/internal/integration_test/bench/memory_bench_test.go +++ b/internal/integration_test/bench/memory_bench_test.go @@ -8,13 +8,13 @@ import ( func BenchmarkMemory(b *testing.B) { var mem = &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize), Min: 1} - if !mem.WriteByte(10, 16) { + if !mem.WriteByte(testCtx, 10, 16) { b.Fail() } b.Run("ReadByte", func(b *testing.B) { for i := 0; i < b.N; i++ { - if v, ok := mem.ReadByte(10); !ok || v != 16 { + if v, ok := mem.ReadByte(testCtx, 10); !ok || v != 16 { b.Fail() } } @@ -22,7 +22,7 @@ func BenchmarkMemory(b *testing.B) { b.Run("ReadUint32Le", func(b *testing.B) { for i := 0; i < b.N; i++ { - if v, ok := mem.ReadUint32Le(10); !ok || v != 16 { + if v, ok := mem.ReadUint32Le(testCtx, 10); !ok || v != 16 { b.Fail() } } @@ -30,7 +30,7 @@ func BenchmarkMemory(b *testing.B) { b.Run("WriteByte", func(b *testing.B) { for i := 0; i < b.N; i++ { - if !mem.WriteByte(10, 16) { + if !mem.WriteByte(testCtx, 10, 16) { b.Fail() } } @@ -38,7 +38,7 @@ func BenchmarkMemory(b *testing.B) { b.Run("WriteUint32Le", func(b *testing.B) { for i := 0; i < b.N; i++ { - if !mem.WriteUint32Le(10, 16) { + if !mem.WriteUint32Le(testCtx, 10, 16) { b.Fail() } } diff --git a/internal/integration_test/engine/adhoc_test.go b/internal/integration_test/engine/adhoc_test.go index 7a40bcd130..23920e05f7 100644 --- a/internal/integration_test/engine/adhoc_test.go +++ b/internal/integration_test/engine/adhoc_test.go @@ -64,7 +64,7 @@ var ( func testHugeStack(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, hugestackWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) fn := module.ExportedFunction("main") require.NotNil(t, fn) @@ -83,7 +83,7 @@ func testUnreachable(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, unreachableWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) _, err = module.ExportedFunction("main").Call(testCtx) exp := `panic in host function (recovered by wazero) @@ -106,7 +106,7 @@ func testRecursiveEntry(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, recursiveWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) _, err = module.ExportedFunction("main").Call(testCtx, 1) require.NoError(t, err) @@ -121,8 +121,8 @@ func TestImportedAndExportedFunc(t *testing.T) { // Notably, this uses memory, which ensures api.Module is valid in both interpreter and JIT engines. func testImportedAndExportedFunc(t *testing.T, r wazero.Runtime) { var memory *wasm.MemoryInstance - storeInt := func(m api.Module, offset uint32, val uint64) uint32 { - if !m.Memory().WriteUint64Le(offset, val) { + storeInt := func(ctx context.Context, m api.Module, offset uint32, val uint64) uint32 { + if !m.Memory().WriteUint64Le(ctx, offset, val) { return 1 } // sneak a reference to the memory, so we can check it later @@ -132,7 +132,7 @@ func testImportedAndExportedFunc(t *testing.T, r wazero.Runtime) { host, err := r.NewModuleBuilder("").ExportFunction("store_int", storeInt).Instantiate(testCtx) require.NoError(t, err) - defer host.Close() + defer host.Close(testCtx) module, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $test (import "" "store_int" @@ -143,7 +143,7 @@ func testImportedAndExportedFunc(t *testing.T, r wazero.Runtime) { (export "store_int" (func $store_int)) )`)) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) // Call store_int and ensure it didn't return an error code. fn := module.ExportedFunction("store_int") @@ -177,7 +177,7 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) - defer imported.Close() + defer imported.Close(testCtx) for test := range fns { t.Run(test, func(t *testing.T) { @@ -188,7 +188,7 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { (export "call->%[3]s" (func $call_%[3]s)) )`, importingName, importedName, test))) require.NoError(t, err) - defer importing.Close() + defer importing.Close(testCtx) results, err := importing.ExportedFunction("call->"+test).Call(testCtx, math.MaxUint32-1) require.NoError(t, err) @@ -219,7 +219,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) - defer imported.Close() + defer imported.Close(testCtx) for _, test := range []struct { name string @@ -254,7 +254,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { (export "call->%[3]s" (func $call_%[3]s)) )`, importingName, importedName, test.name))) require.NoError(t, err) - defer importing.Close() + defer importing.Close(testCtx) results, err := importing.ExportedFunction("call->"+test.name).Call(testCtx, test.input) require.NoError(t, err) @@ -322,18 +322,18 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { var importingCode, importedCode *wazero.CompiledCode var imported, importing api.Module var err error - closeAndReturn := func(x uint32) uint32 { + closeAndReturn := func(ctx context.Context, x uint32) uint32 { if tc.closeImporting != 0 { - require.NoError(t, importing.CloseWithExitCode(tc.closeImporting)) + require.NoError(t, importing.CloseWithExitCode(ctx, tc.closeImporting)) } if tc.closeImported != 0 { - require.NoError(t, imported.CloseWithExitCode(tc.closeImported)) + require.NoError(t, imported.CloseWithExitCode(ctx, tc.closeImported)) } if tc.closeImportedCode { - importedCode.Close() + importedCode.Close(testCtx) } if tc.closeImportingCode { - importingCode.Close() + importingCode.Close(testCtx) } return x } @@ -345,7 +345,7 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { imported, err = r.InstantiateModule(testCtx, importedCode) require.NoError(t, err) - defer imported.Close() + defer imported.Close(testCtx) // Import that module. source := callReturnImportSource(imported.Name(), t.Name()+"-importing") @@ -354,7 +354,7 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { importing, err = r.InstantiateModule(testCtx, importingCode) require.NoError(t, err) - defer importing.Close() + defer importing.Close(testCtx) var expectedErr error if tc.closeImported != 0 && tc.closeImporting != 0 { @@ -388,7 +388,7 @@ func testMemOps(t *testing.T, r wazero.Runtime) { (export "memory" (memory 0)) )`)) require.NoError(t, err) - defer memory.Close() + defer memory.Close(testCtx) // Check the export worked require.Equal(t, memory.Memory(), memory.ExportedMemory("memory")) @@ -397,7 +397,7 @@ func testMemOps(t *testing.T, r wazero.Runtime) { results, err := memory.ExportedFunction("size").Call(testCtx) require.NoError(t, err) require.Zero(t, results[0]) - require.Zero(t, memory.ExportedMemory("memory").Size()) + require.Zero(t, memory.ExportedMemory("memory").Size(testCtx)) // Try to grow the memory by one page results, err = memory.ExportedFunction("grow").Call(testCtx, 1) @@ -407,8 +407,8 @@ func testMemOps(t *testing.T, r wazero.Runtime) { // Check the size command works! results, err = memory.ExportedFunction("size").Call(testCtx) require.NoError(t, err) - require.Equal(t, uint64(1), results[0]) // 1 page - require.Equal(t, uint32(65536), memory.Memory().Size()) // 64KB + require.Equal(t, uint64(1), results[0]) // 1 page + require.Equal(t, uint32(65536), memory.Memory().Size(testCtx)) // 64KB } func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { @@ -422,16 +422,16 @@ func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { (export "store" (func $store)) )`)) require.NoError(t, err) - defer compiled.Close() + defer compiled.Close(testCtx) // Instantiate multiple modules with the same source (*CompiledCode). for i := 0; i < 100; i++ { module, err := r.InstantiateModuleWithConfig(testCtx, compiled, wazero.NewModuleConfig().WithName(strconv.Itoa(i))) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) // Ensure that compilation cache doesn't cause race on memory instance. - before, ok := module.Memory().ReadUint64Le(1) + before, ok := module.Memory().ReadUint64Le(testCtx, 1) require.True(t, ok) // Value must be zero as the memory must not be affected by the previously instantiated modules. require.Zero(t, before) @@ -443,7 +443,7 @@ func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { require.NoError(t, err) // After the call, the value must be set properly. - after, ok := module.Memory().ReadUint64Le(1) + after, ok := module.Memory().ReadUint64Le(testCtx, 1) require.True(t, ok) require.Equal(t, uint64(1000), after) } diff --git a/internal/integration_test/engine/hammer_test.go b/internal/integration_test/engine/hammer_test.go index c798084db3..ba2d969c09 100644 --- a/internal/integration_test/engine/hammer_test.go +++ b/internal/integration_test/engine/hammer_test.go @@ -31,7 +31,7 @@ func TestEngineInterpreter_hammer(t *testing.T) { func closeImportingModuleWhileInUse(t *testing.T, r wazero.Runtime) { closeModuleWhileInUse(t, r, func(imported, importing api.Module) (api.Module, api.Module) { // Close the importing module, despite calls being in-flight. - require.NoError(t, importing.Close()) + require.NoError(t, importing.Close(testCtx)) // Prove a module can be redefined even with in-flight calls. source := callReturnImportSource(imported.Name(), importing.Name()) @@ -44,8 +44,8 @@ func closeImportingModuleWhileInUse(t *testing.T, r wazero.Runtime) { func closeImportedModuleWhileInUse(t *testing.T, r wazero.Runtime) { closeModuleWhileInUse(t, r, func(imported, importing api.Module) (api.Module, api.Module) { // Close the importing and imported module, despite calls being in-flight. - require.NoError(t, importing.Close()) - require.NoError(t, imported.Close()) + require.NoError(t, importing.Close(testCtx)) + require.NoError(t, imported.Close(testCtx)) // Redefine the imported module, with a function that no longer blocks. imported, err := r.NewModuleBuilder(imported.Name()).ExportFunction("return_input", func(x uint32) uint32 { @@ -80,13 +80,13 @@ func closeModuleWhileInUse(t *testing.T, r wazero.Runtime, closeFn func(imported imported, err := r.NewModuleBuilder(t.Name()+"-imported"). ExportFunction("return_input", blockAndReturn).Instantiate(testCtx) require.NoError(t, err) - defer imported.Close() + defer imported.Close(testCtx) // Import that module. source := callReturnImportSource(imported.Name(), t.Name()+"-importing") importing, err := r.InstantiateModuleFromCode(testCtx, source) require.NoError(t, err) - defer importing.Close() + defer importing.Close(testCtx) // As this is a blocking function call, only run 1 per goroutine. i := importing // pin the module used inside goroutines @@ -99,8 +99,8 @@ func closeModuleWhileInUse(t *testing.T, r wazero.Runtime, closeFn func(imported calls.Add(-P) }) // As references may have changed, ensure we close both. - defer imported.Close() - defer importing.Close() + defer imported.Close(testCtx) + defer importing.Close(testCtx) if t.Failed() { return // At least one test failed, so return now. } diff --git a/internal/integration_test/post1_0/multi-value/spec_test.go b/internal/integration_test/post1_0/multi-value/spec_test.go index 2b6377532a..795400c205 100644 --- a/internal/integration_test/post1_0/multi-value/spec_test.go +++ b/internal/integration_test/post1_0/multi-value/spec_test.go @@ -39,7 +39,7 @@ func testMultiValue(t *testing.T, newRuntimeConfig func() *wazero.RuntimeConfig) r := wazero.NewRuntimeWithConfig(newRuntimeConfig().WithFeatureMultiValue(true)) module, err := r.InstantiateModuleFromCode(testCtx, multiValueWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) swap := module.ExportedFunction("swap") results, err := swap.Call(testCtx, 100, 200) @@ -92,7 +92,7 @@ var brWasm []byte func testBr(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, brWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) testFunctions(t, module, []funcTest{ {name: "type-i32-i32"}, {name: "type-i64-i64"}, {name: "type-f32-f32"}, {name: "type-f64-f64"}, @@ -115,7 +115,7 @@ var callWasm []byte func testCall(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, callWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) testFunctions(t, module, []funcTest{ {name: "type-i32-i64", expected: []uint64{0x132, 0x164}}, @@ -136,7 +136,7 @@ var callIndirectWasm []byte func testCallIndirect(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, callIndirectWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) testFunctions(t, module, []funcTest{ {name: "type-f64-i32", expected: []uint64{api.EncodeF64(0xf64), 32}}, @@ -158,7 +158,7 @@ var facWasm []byte func testFac(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, facWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) fac := module.ExportedFunction("fac-ssa") results, err := fac.Call(testCtx, 25) @@ -173,7 +173,7 @@ var funcWasm []byte func testFunc(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, funcWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) testFunctions(t, module, []funcTest{ {name: "value-i32-f64", expected: []uint64{77, api.EncodeF64(7)}}, @@ -221,7 +221,7 @@ var ifWasm []byte func testIf(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, ifWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) testFunctions(t, module, []funcTest{ {name: "multi", params: []uint64{0}, expected: []uint64{9, api.EncodeI32(-1)}}, @@ -273,7 +273,7 @@ var loopWasm []byte func testLoop(t *testing.T, r wazero.Runtime) { module, err := r.InstantiateModuleFromCode(testCtx, loopWasm) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) testFunctions(t, module, []funcTest{ {name: "as-binary-operands", expected: []uint64{12}}, diff --git a/internal/integration_test/spectest/spec_test.go b/internal/integration_test/spectest/spec_test.go index 8c9ada2027..ddf300081c 100644 --- a/internal/integration_test/spectest/spec_test.go +++ b/internal/integration_test/spectest/spec_test.go @@ -389,7 +389,7 @@ func runTest(t *testing.T, newEngine func(wasm.Features) wasm.Engine) { expType = wasm.ValueTypeF64 } require.Equal(t, expType, global.Type(), msg) - require.Equal(t, exps[0], global.Get(), msg) + require.Equal(t, exps[0], global.Get(testCtx), msg) default: t.Fatalf("unsupported action type type: %v", c) } diff --git a/internal/integration_test/vs/bench_test.go b/internal/integration_test/vs/bench_test.go index bf15e3c0d0..72181d7a63 100644 --- a/internal/integration_test/vs/bench_test.go +++ b/internal/integration_test/vs/bench_test.go @@ -76,7 +76,7 @@ func benchmarkCompile(b *testing.B, rtCfg *runtimeConfig) { if err := rt.Compile(testCtx, rtCfg); err != nil { b.Fatal(err) } - if err := rt.Close(); err != nil { + if err := rt.Close(testCtx); err != nil { b.Fatal(err) } } @@ -92,7 +92,7 @@ func benchmarkInstantiate(b *testing.B, rtCfg *runtimeConfig) { if err := rt.Compile(testCtx, rtCfg); err != nil { b.Fatal(err) } - defer rt.Close() + defer rt.Close(testCtx) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -100,7 +100,7 @@ func benchmarkInstantiate(b *testing.B, rtCfg *runtimeConfig) { if err != nil { b.Fatal(err) } - err = mod.Close() + err = mod.Close(testCtx) if err != nil { b.Fatal(err) } @@ -127,12 +127,12 @@ func benchmarkFn(rt runtime, rtCfg *runtimeConfig, call func(module) (uint64, er if err := rt.Compile(testCtx, rtCfg); err != nil { b.Fatal(err) } - defer rt.Close() + defer rt.Close(testCtx) mod, err := rt.Instantiate(testCtx, rtCfg) if err != nil { b.Fatal(err) } - defer mod.Close() + defer mod.Close(testCtx) b.ResetTimer() for i := 0; i < b.N; i++ { if _, err := call(mod); err != nil { @@ -153,7 +153,7 @@ func testCallFn(rt runtime, rtCfg *runtimeConfig, testCall func(*testing.T, modu return func(t *testing.T) { err := rt.Compile(testCtx, rtCfg) require.NoError(t, err) - defer rt.Close() + defer rt.Close(testCtx) // Ensure the module can be re-instantiated times, even if not all runtimes allow renaming. for i := 0; i < 10; i++ { @@ -165,7 +165,7 @@ func testCallFn(rt runtime, rtCfg *runtimeConfig, testCall func(*testing.T, modu testCall(t, m) } - require.NoError(t, m.Close()) + require.NoError(t, m.Close(testCtx)) } } } diff --git a/internal/integration_test/vs/codec_test.go b/internal/integration_test/vs/codec_test.go index c0fe7cb8b0..2377465515 100644 --- a/internal/integration_test/vs/codec_test.go +++ b/internal/integration_test/vs/codec_test.go @@ -118,12 +118,12 @@ func TestExampleUpToDate(t *testing.T) { // Add WASI to satisfy import tests wm, err := wasi.InstantiateSnapshotPreview1(testCtx, r) require.NoError(t, err) - defer wm.Close() + defer wm.Close(testCtx) // Decode and instantiate the module module, err := r.InstantiateModuleFromCode(testCtx, exampleBinary) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) // Call the swap function as a smoke test results, err := module.ExportedFunction("swap").Call(testCtx, 1, 2) diff --git a/internal/integration_test/vs/runtime.go b/internal/integration_test/vs/runtime.go index f51203ffd9..98472836ba 100644 --- a/internal/integration_test/vs/runtime.go +++ b/internal/integration_test/vs/runtime.go @@ -3,7 +3,6 @@ package vs import ( "context" "fmt" - "io" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" @@ -16,14 +15,14 @@ type runtimeConfig struct { } type runtime interface { - Compile(ctx context.Context, cfg *runtimeConfig) error - Instantiate(ctx context.Context, cfg *runtimeConfig) (module, error) - io.Closer + Compile(context.Context, *runtimeConfig) error + Instantiate(context.Context, *runtimeConfig) (module, error) + Close(context.Context) error } type module interface { CallI64_I64(ctx context.Context, funcName string, param uint64) (uint64, error) - io.Closer + Close(context.Context) error } func newWazeroInterpreterRuntime() runtime { @@ -72,9 +71,9 @@ func (r *wazeroRuntime) Instantiate(ctx context.Context, cfg *runtimeConfig) (mo return } -func (r *wazeroRuntime) Close() (err error) { +func (r *wazeroRuntime) Close(ctx context.Context) (err error) { if compiled := r.compiled; compiled != nil { - err = compiled.Close() + err = compiled.Close(ctx) } r.compiled = nil return @@ -89,9 +88,9 @@ func (m *wazeroModule) CallI64_I64(ctx context.Context, funcName string, param u return 0, nil } -func (m *wazeroModule) Close() (err error) { +func (m *wazeroModule) Close(ctx context.Context) (err error) { if mod := m.mod; mod != nil { - err = mod.Close() + err = mod.Close(ctx) } m.mod = nil return diff --git a/internal/integration_test/vs/wasm3_test.go b/internal/integration_test/vs/wasm3_test.go index 7f82452a04..f1744febe5 100644 --- a/internal/integration_test/vs/wasm3_test.go +++ b/internal/integration_test/vs/wasm3_test.go @@ -59,7 +59,7 @@ func (r *wasm3Runtime) Instantiate(_ context.Context, cfg *runtimeConfig) (mod m return } -func (r *wasm3Runtime) Close() error { +func (r *wasm3Runtime) Close(_ context.Context) error { if r := r.runtime; r != nil { r.Destroy() } @@ -77,7 +77,7 @@ func (m *wasm3Module) CallI64_I64(_ context.Context, funcName string, param uint } } -func (m *wasm3Module) Close() error { +func (m *wasm3Module) Close(_ context.Context) error { // module can't be destroyed m.module = nil m.funcs = nil diff --git a/internal/integration_test/vs/wasmedge_test.go b/internal/integration_test/vs/wasmedge_test.go index 9e604d7567..1970fdb2af 100644 --- a/internal/integration_test/vs/wasmedge_test.go +++ b/internal/integration_test/vs/wasmedge_test.go @@ -59,7 +59,7 @@ func (r *wasmedgeRuntime) Instantiate(_ context.Context, cfg *runtimeConfig) (mo return } -func (r *wasmedgeRuntime) Close() error { +func (r *wasmedgeRuntime) Close(_ context.Context) error { if conf := r.conf; conf != nil { conf.Release() } @@ -75,7 +75,7 @@ func (m *wasmedgeModule) CallI64_I64(_ context.Context, funcName string, param u } } -func (m *wasmedgeModule) Close() error { +func (m *wasmedgeModule) Close(_ context.Context) error { if vm := m.vm; vm != nil { vm.Release() } diff --git a/internal/integration_test/vs/wasmer_test.go b/internal/integration_test/vs/wasmer_test.go index dd5f6d10b7..cea284d089 100644 --- a/internal/integration_test/vs/wasmer_test.go +++ b/internal/integration_test/vs/wasmer_test.go @@ -61,7 +61,7 @@ func (r *wasmerRuntime) Instantiate(_ context.Context, cfg *runtimeConfig) (mod return } -func (r *wasmerRuntime) Close() error { +func (r *wasmerRuntime) Close(_ context.Context) error { r.engine = nil return nil } @@ -75,7 +75,7 @@ func (m *wasmerModule) CallI64_I64(_ context.Context, funcName string, param uin } } -func (m *wasmerModule) Close() error { +func (m *wasmerModule) Close(_ context.Context) error { if instance := m.instance; instance != nil { instance.Close() } diff --git a/internal/integration_test/vs/wasmtime_test.go b/internal/integration_test/vs/wasmtime_test.go index 2387f51650..b0529d185c 100644 --- a/internal/integration_test/vs/wasmtime_test.go +++ b/internal/integration_test/vs/wasmtime_test.go @@ -65,7 +65,7 @@ func (r *wasmtimeRuntime) Instantiate(_ context.Context, cfg *runtimeConfig) (mo return } -func (r *wasmtimeRuntime) Close() error { +func (r *wasmtimeRuntime) Close(_ context.Context) error { r.engine = nil return nil // wasmtime only closes via finalizer } @@ -79,7 +79,7 @@ func (m *wasmtimeModule) CallI64_I64(_ context.Context, funcName string, param u } } -func (m *wasmtimeModule) Close() error { +func (m *wasmtimeModule) Close(_ context.Context) error { m.store = nil m.instance = nil m.funcs = nil diff --git a/internal/modgen/modgen_test.go b/internal/modgen/modgen_test.go index 17c4b2a411..cb58dabf3e 100644 --- a/internal/modgen/modgen_test.go +++ b/internal/modgen/modgen_test.go @@ -14,6 +14,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasm/binary" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + const ( i32 = wasm.ValueTypeI32 i64 = wasm.ValueTypeI64 @@ -46,9 +49,10 @@ func TestModGen(t *testing.T) { // Encode the generated module (*wasm.Module) as binary. bin := binary.EncodeModule(m) // Pass the generated binary into our compilers. - code, err := runtime.CompileModule(context.Background(), bin) + code, err := runtime.CompileModule(testCtx, bin) + require.NoError(t, err) + err = code.Close(testCtx) require.NoError(t, err) - code.Close() }) } } diff --git a/internal/wasm/call_context.go b/internal/wasm/call_context.go index 034130f909..130d6375d1 100644 --- a/internal/wasm/call_context.go +++ b/internal/wasm/call_context.go @@ -74,12 +74,14 @@ func (m *CallContext) String() string { } // Close implements the same method as documented on api.Module. -func (m *CallContext) Close() (err error) { - return m.CloseWithExitCode(0) +func (m *CallContext) Close(ctx context.Context) (err error) { + return m.CloseWithExitCode(ctx, 0) } // CloseWithExitCode implements the same method as documented on api.Module. -func (m *CallContext) CloseWithExitCode(exitCode uint32) (err error) { +func (m *CallContext) CloseWithExitCode(_ context.Context, exitCode uint32) (err error) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + closed := uint64(1) + uint64(exitCode)<<32 // Store exitCode as high-order bits. if !atomic.CompareAndSwapUint64(m.closed, 0, closed) { return nil @@ -91,12 +93,12 @@ func (m *CallContext) CloseWithExitCode(exitCode uint32) (err error) { return } -// Memory implements api.Module Memory +// Memory implements the same method as documented on api.Module. func (m *CallContext) Memory() api.Memory { return m.module.Memory } -// ExportedMemory implements api.Module ExportedMemory +// ExportedMemory implements the same method as documented on api.Module. func (m *CallContext) ExportedMemory(name string) api.Memory { exp, err := m.module.getExport(name, ExternTypeMemory) if err != nil { @@ -105,7 +107,7 @@ func (m *CallContext) ExportedMemory(name string) api.Memory { return exp.Memory } -// ExportedFunction implements api.Module ExportedFunction +// ExportedFunction implements the same method as documented on api.Module. func (m *CallContext) ExportedFunction(name string) api.Function { exp, err := m.module.getExport(name, ExternTypeFunc) if err != nil { @@ -124,17 +126,17 @@ type importedFn struct { importedFn *FunctionInstance } -// ParamTypes implements the same method as documented on api.Function +// ParamTypes implements the same method as documented on api.Function. func (f *importedFn) ParamTypes() []api.ValueType { return f.importedFn.ParamTypes() } -// ResultTypes implements the same method as documented on api.Function +// ResultTypes implements the same method as documented on api.Function. func (f *importedFn) ResultTypes() []api.ValueType { return f.importedFn.ResultTypes() } -// Call implements the same method as documented on api.Function +// Call implements the same method as documented on api.Function. func (f *importedFn) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { if ctx == nil { ctx = context.Background() @@ -143,17 +145,17 @@ func (f *importedFn) Call(ctx context.Context, params ...uint64) (ret []uint64, return f.importedFn.Module.Engine.Call(ctx, mod, f.importedFn, params...) } -// ParamTypes implements the same method as documented on api.Function +// ParamTypes implements the same method as documented on api.Function. func (f *FunctionInstance) ParamTypes() []api.ValueType { return f.Type.Params } -// ResultTypes implements the same method as documented on api.Function +// ResultTypes implements the same method as documented on api.Function. func (f *FunctionInstance) ResultTypes() []api.ValueType { return f.Type.Results } -// Call implements the same method as documented on api.Function +// Call implements the same method as documented on api.Function. func (f *FunctionInstance) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { if ctx == nil { ctx = context.Background() @@ -162,7 +164,7 @@ func (f *FunctionInstance) Call(ctx context.Context, params ...uint64) (ret []ui return mod.Engine.Call(ctx, mod.CallCtx, f, params...) } -// ExportedGlobal implements api.Module ExportedGlobal +// ExportedGlobal implements the same method as documented on api.Module. func (m *CallContext) ExportedGlobal(name string) api.Global { exp, err := m.module.getExport(name, ExternTypeGlobal) if err != nil { diff --git a/internal/wasm/call_context_test.go b/internal/wasm/call_context_test.go index 3cc9e8caf7..f17c0f2835 100644 --- a/internal/wasm/call_context_test.go +++ b/internal/wasm/call_context_test.go @@ -2,6 +2,7 @@ package wasm import ( "context" + "fmt" "path" "testing" @@ -80,7 +81,7 @@ func TestCallContext_String(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Ensure paths that can create the host module can see the name. m, err := s.Instantiate(context.Background(), &Module{}, tc.moduleName, nil) - defer m.Close() //nolint + defer m.Close(testCtx) //nolint require.NoError(t, err) require.Equal(t, tc.expected, m.String()) @@ -92,24 +93,52 @@ func TestCallContext_String(t *testing.T) { func TestCallContext_Close(t *testing.T) { s := newStore() - t.Run("calls store.CloseWithExitCode(module.name)", func(t *testing.T) { - moduleName := t.Name() - m, err := s.Instantiate(context.Background(), &Module{}, moduleName, nil) - require.NoError(t, err) + tests := []struct { + name string + closer func(context.Context, *CallContext) error + expectedClosed uint64 + }{ + { + name: "Close()", + closer: func(ctx context.Context, callContext *CallContext) error { + return callContext.Close(ctx) + }, + expectedClosed: uint64(1), + }, + { + name: "CloseWithExitCode(255)", + closer: func(ctx context.Context, callContext *CallContext) error { + return callContext.CloseWithExitCode(ctx, 255) + }, + expectedClosed: uint64(255)<<32 + 1, + }, + } - // We use side effects to determine if Close in fact called store.CloseWithExitCode (without repeating store_test.go). - // One side effect of store.CloseWithExitCode is that the moduleName can no longer be looked up. Verify our base case. - require.Equal(t, s.Module(moduleName), m) + for _, tt := range tests { + tc := tt + t.Run(fmt.Sprintf("%s calls store.CloseWithExitCode(module.name))", tc.name), func(t *testing.T) { + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + moduleName := t.Name() + m, err := s.Instantiate(ctx, &Module{}, moduleName, nil) + require.NoError(t, err) - // Closing should not err. - require.NoError(t, m.Close()) + // We use side effects to see if Close called store.CloseWithExitCode (without repeating store_test.go). + // One side effect of store.CloseWithExitCode is that the moduleName can no longer be looked up. + require.Equal(t, s.Module(moduleName), m) - // Verify our intended side-effect - require.Nil(t, s.Module(moduleName)) + // Closing should not err. + require.NoError(t, tc.closer(ctx, m)) - // Verify no error closing again. - require.NoError(t, m.Close()) - }) + require.Equal(t, tc.expectedClosed, *m.closed) + + // Verify our intended side-effect + require.Nil(t, s.Module(moduleName)) + + // Verify no error closing again. + require.NoError(t, tc.closer(ctx, m)) + } + }) + } t.Run("calls SysContext.Close()", func(t *testing.T) { tempDir := t.TempDir() @@ -139,12 +168,12 @@ func TestCallContext_Close(t *testing.T) { require.True(t, len(sys.openedFiles) > 0, "sys.openedFiles was empty") // Closing should not err. - require.NoError(t, m.Close()) + require.NoError(t, m.Close(testCtx)) // Verify our intended side-effect require.Equal(t, 0, len(sys.openedFiles), "expected no opened files") // Verify no error closing again. - require.NoError(t, m.Close()) + require.NoError(t, m.Close(testCtx)) }) } diff --git a/internal/wasm/global.go b/internal/wasm/global.go index b088207194..d55ee015fe 100644 --- a/internal/wasm/global.go +++ b/internal/wasm/global.go @@ -1,6 +1,7 @@ package wasm import ( + "context" "fmt" "github.com/tetratelabs/wazero/api" @@ -10,21 +11,25 @@ type mutableGlobal struct { g *GlobalInstance } -// compile-time check to ensure mutableGlobal is a api.Global +// compile-time check to ensure mutableGlobal is a api.Global. var _ api.Global = &mutableGlobal{} -// Type implements api.Global Type +// Type implements the same method as documented on api.Global. func (g *mutableGlobal) Type() api.ValueType { return g.g.Type.ValType } -// Get implements api.Global Get -func (g *mutableGlobal) Get() uint64 { +// Get implements the same method as documented on api.Global. +func (g *mutableGlobal) Get(_ context.Context) uint64 { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + return g.g.Val } -// Set implements api.MutableGlobal Set -func (g *mutableGlobal) Set(v uint64) { +// Set implements the same method as documented on api.MutableGlobal. +func (g *mutableGlobal) Set(_ context.Context, v uint64) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + g.g.Val = v } @@ -32,11 +37,11 @@ func (g *mutableGlobal) Set(v uint64) { func (g *mutableGlobal) String() string { switch g.Type() { case ValueTypeI32, ValueTypeI64: - return fmt.Sprintf("global(%d)", g.Get()) + return fmt.Sprintf("global(%d)", g.Get(context.Background())) case ValueTypeF32: - return fmt.Sprintf("global(%f)", api.DecodeF32(g.Get())) + return fmt.Sprintf("global(%f)", api.DecodeF32(g.Get(context.Background()))) case ValueTypeF64: - return fmt.Sprintf("global(%f)", api.DecodeF64(g.Get())) + return fmt.Sprintf("global(%f)", api.DecodeF64(g.Get(context.Background()))) default: panic(fmt.Errorf("BUG: unknown value type %X", g.Type())) } @@ -47,13 +52,15 @@ type globalI32 uint64 // compile-time check to ensure globalI32 is a api.Global var _ api.Global = globalI32(0) -// Type implements api.Global Type +// Type implements the same method as documented on api.Global. func (g globalI32) Type() api.ValueType { return ValueTypeI32 } -// Get implements api.Global Get -func (g globalI32) Get() uint64 { +// Get implements the same method as documented on api.Global. +func (g globalI32) Get(_ context.Context) uint64 { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + return uint64(g) } @@ -67,13 +74,15 @@ type globalI64 uint64 // compile-time check to ensure globalI64 is a api.Global var _ api.Global = globalI64(0) -// Type implements api.Global Type +// Type implements the same method as documented on api.Global. func (g globalI64) Type() api.ValueType { return ValueTypeI64 } -// Get implements api.Global Get -func (g globalI64) Get() uint64 { +// Get implements the same method as documented on api.Global. +func (g globalI64) Get(_ context.Context) uint64 { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + return uint64(g) } @@ -87,19 +96,21 @@ type globalF32 uint64 // compile-time check to ensure globalF32 is a api.Global var _ api.Global = globalF32(0) -// Type implements api.Global Type +// Type implements the same method as documented on api.Global. func (g globalF32) Type() api.ValueType { return ValueTypeF32 } -// Get implements api.Global Get -func (g globalF32) Get() uint64 { +// Get implements the same method as documented on api.Global. +func (g globalF32) Get(_ context.Context) uint64 { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + return uint64(g) } // String implements fmt.Stringer func (g globalF32) String() string { - return fmt.Sprintf("global(%f)", api.DecodeF32(g.Get())) + return fmt.Sprintf("global(%f)", api.DecodeF32(g.Get(context.Background()))) } type globalF64 uint64 @@ -107,17 +118,19 @@ type globalF64 uint64 // compile-time check to ensure globalF64 is a api.Global var _ api.Global = globalF64(0) -// Type implements api.Global Type +// Type implements the same method as documented on api.Global. func (g globalF64) Type() api.ValueType { return ValueTypeF64 } -// Get implements api.Global Get -func (g globalF64) Get() uint64 { +// Get implements the same method as documented on api.Global. +func (g globalF64) Get(_ context.Context) uint64 { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + return uint64(g) } // String implements fmt.Stringer func (g globalF64) String() string { - return fmt.Sprintf("global(%f)", api.DecodeF64(g.Get())) + return fmt.Sprintf("global(%f)", api.DecodeF64(g.Get(context.Background()))) } diff --git a/internal/wasm/global_test.go b/internal/wasm/global_test.go index 9a2167890a..2a249f2034 100644 --- a/internal/wasm/global_test.go +++ b/internal/wasm/global_test.go @@ -126,15 +126,20 @@ func TestGlobalTypes(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expectedType, tc.global.Type()) - require.Equal(t, tc.expectedVal, tc.global.Get()) - require.Equal(t, tc.expectedString, tc.global.String()) + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + require.Equal(t, tc.expectedType, tc.global.Type()) + require.Equal(t, tc.expectedVal, tc.global.Get(ctx)) + require.Equal(t, tc.expectedString, tc.global.String()) - mutable, ok := tc.global.(api.MutableGlobal) - require.Equal(t, tc.expectedMutable, ok) - if ok { - mutable.Set(2) - require.Equal(t, uint64(2), tc.global.Get()) + mutable, ok := tc.global.(api.MutableGlobal) + require.Equal(t, tc.expectedMutable, ok) + if ok { + mutable.Set(ctx, 2) + require.Equal(t, uint64(2), tc.global.Get(ctx)) + + mutable.Set(ctx, tc.expectedVal) // Set it back! + require.Equal(t, tc.expectedVal, tc.global.Get(ctx)) + } } }) } diff --git a/internal/wasm/interpreter/interpreter.go b/internal/wasm/interpreter/interpreter.go index 912dfd7692..45c9259df1 100644 --- a/internal/wasm/interpreter/interpreter.go +++ b/internal/wasm/interpreter/interpreter.go @@ -2,7 +2,6 @@ package interpreter import ( "context" - "encoding/binary" "fmt" "math" "math/bits" @@ -164,8 +163,15 @@ func (c *code) instantiate(f *wasm.FunctionInstance) *function { } } -// Non-interface union of all the wazeroir operations. +// interpreterOp is the compilation (engine.lowerIR) result of a wazeroir.Operation. +// +// Not all operations result in an interpreterOp, e.g. wazeroir.OperationI32ReinterpretFromF32, and some operations are +// more complex than others, e.g. wazeroir.OperationBrTable. +// +// Note: This is a form of union type as it can store fields needed for any operation. Hence, most fields are opaque and +// only relevant when in context of its kind. type interpreterOp struct { + // kind determines how to interpret the other fields in this struct. kind wazeroir.OperationKind b1, b2 byte b3 bool @@ -181,7 +187,7 @@ func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { funcs := make([]*code, 0, len(module.FunctionSection)) if module.IsHostModule() { - // If this is the host module, there's nothing to do as the runtime reprsentation of + // If this is the host module, there's nothing to do as the runtime representation of // host function in interpreter is its Go function itself as opposed to Wasm functions, // which need to be compiled down to wazeroir. for _, hf := range module.HostFunctionSection { @@ -679,131 +685,131 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallCont } case wazeroir.OperationKindGlobalGet: { - g := globals[op.us[0]] + g := globals[op.us[0]] // TODO: Not yet traceable as it doesn't use the types in global.go ce.pushValue(g.Val) frame.pc++ } case wazeroir.OperationKindGlobalSet: { - g := globals[op.us[0]] + g := globals[op.us[0]] // TODO: Not yet traceable as it doesn't use the types in global.go g.Val = ce.popValue() frame.pc++ } case wazeroir.OperationKindLoad: { - base := op.us[1] + ce.popValue() + offset := ce.popMemoryOffset(op) switch wazeroir.UnsignedType(op.b1) { case wazeroir.UnsignedTypeI32, wazeroir.UnsignedTypeF32: - if uint64(len(memoryInst.Buffer)) < base+4 { + if val, ok := memoryInst.ReadUint32Le(ctx, offset); !ok { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) + } else { + ce.pushValue(uint64(val)) } - ce.pushValue(uint64(binary.LittleEndian.Uint32(memoryInst.Buffer[base:]))) case wazeroir.UnsignedTypeI64, wazeroir.UnsignedTypeF64: - if uint64(len(memoryInst.Buffer)) < base+8 { + if val, ok := memoryInst.ReadUint64Le(ctx, offset); !ok { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) + } else { + ce.pushValue(val) } - ce.pushValue(binary.LittleEndian.Uint64(memoryInst.Buffer[base:])) } frame.pc++ } case wazeroir.OperationKindLoad8: { - base := op.us[1] + ce.popValue() - if uint64(len(memoryInst.Buffer)) < base+1 { + val, ok := memoryInst.ReadByte(ctx, ce.popMemoryOffset(op)) + if !ok { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } + switch wazeroir.SignedInt(op.b1) { case wazeroir.SignedInt32, wazeroir.SignedInt64: - ce.pushValue(uint64(int8(memoryInst.Buffer[base]))) + ce.pushValue(uint64(int8(val))) case wazeroir.SignedUint32, wazeroir.SignedUint64: - ce.pushValue(uint64(uint8(memoryInst.Buffer[base]))) + ce.pushValue(uint64(val)) } frame.pc++ } case wazeroir.OperationKindLoad16: { - base := op.us[1] + ce.popValue() - if uint64(len(memoryInst.Buffer)) < base+2 { + val, ok := memoryInst.ReadUint16Le(ctx, ce.popMemoryOffset(op)) + if !ok { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } + switch wazeroir.SignedInt(op.b1) { case wazeroir.SignedInt32, wazeroir.SignedInt64: - ce.pushValue(uint64(int16(binary.LittleEndian.Uint16(memoryInst.Buffer[base:])))) + ce.pushValue(uint64(int16(val))) case wazeroir.SignedUint32, wazeroir.SignedUint64: - ce.pushValue(uint64(binary.LittleEndian.Uint16(memoryInst.Buffer[base:]))) + ce.pushValue(uint64(val)) } frame.pc++ } case wazeroir.OperationKindLoad32: { - base := op.us[1] + ce.popValue() - if uint64(len(memoryInst.Buffer)) < base+4 { + val, ok := memoryInst.ReadUint32Le(ctx, ce.popMemoryOffset(op)) + if !ok { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } - if op.b1 == 1 { - ce.pushValue(uint64(int32(binary.LittleEndian.Uint32(memoryInst.Buffer[base:])))) + + if op.b1 == 1 { // Signed + ce.pushValue(uint64(int32(val))) } else { - ce.pushValue(uint64(binary.LittleEndian.Uint32(memoryInst.Buffer[base:]))) + ce.pushValue(uint64(val)) } frame.pc++ } case wazeroir.OperationKindStore: { val := ce.popValue() - base := op.us[1] + ce.popValue() + offset := ce.popMemoryOffset(op) switch wazeroir.UnsignedType(op.b1) { case wazeroir.UnsignedTypeI32, wazeroir.UnsignedTypeF32: - if uint64(len(memoryInst.Buffer)) < base+4 { + if !memoryInst.WriteUint32Le(ctx, offset, uint32(val)) { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } - binary.LittleEndian.PutUint32(memoryInst.Buffer[base:], uint32(val)) case wazeroir.UnsignedTypeI64, wazeroir.UnsignedTypeF64: - if uint64(len(memoryInst.Buffer)) < base+8 { + if !memoryInst.WriteUint64Le(ctx, offset, val) { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } - binary.LittleEndian.PutUint64(memoryInst.Buffer[base:], val) } frame.pc++ } case wazeroir.OperationKindStore8: { val := byte(ce.popValue()) - base := op.us[1] + ce.popValue() - if uint64(len(memoryInst.Buffer)) < base+1 { + offset := ce.popMemoryOffset(op) + if !memoryInst.WriteByte(ctx, offset, val) { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } - memoryInst.Buffer[base] = val frame.pc++ } case wazeroir.OperationKindStore16: { val := uint16(ce.popValue()) - base := op.us[1] + ce.popValue() - if uint64(len(memoryInst.Buffer)) < base+2 { + offset := ce.popMemoryOffset(op) + if !memoryInst.WriteUint16Le(ctx, offset, val) { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } - binary.LittleEndian.PutUint16(memoryInst.Buffer[base:], val) frame.pc++ } case wazeroir.OperationKindStore32: { val := uint32(ce.popValue()) - base := op.us[1] + ce.popValue() - if uint64(len(memoryInst.Buffer)) < base+4 { + offset := ce.popMemoryOffset(op) + if !memoryInst.WriteUint32Le(ctx, offset, val) { panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) } - binary.LittleEndian.PutUint32(memoryInst.Buffer[base:], val) frame.pc++ } case wazeroir.OperationKindMemorySize: { - ce.pushValue(uint64(memoryInst.PageSize())) + ce.pushValue(uint64(memoryInst.PageSize(ctx))) frame.pc++ } case wazeroir.OperationKindMemoryGrow: { n := ce.popValue() - res := memoryInst.Grow(uint32(n)) + res := memoryInst.Grow(ctx, uint32(n)) ce.pushValue(uint64(res)) frame.pc++ } @@ -1670,6 +1676,17 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallCont ce.popFrame() } +// popMemoryOffset takes a memory offset off the stack for use in load and store instructions. +// As the top of stack value is 64-bit, this ensures it is in range before returning it. +func (ce *callEngine) popMemoryOffset(op *interpreterOp) uint32 { + // TODO: Document what 'us' is and why we expect to look at value 1. + offset := op.us[1] + ce.popValue() + if offset > math.MaxUint32 { + panic(wasmruntime.ErrRuntimeOutOfBoundsMemoryAccess) + } + return uint32(offset) +} + func (ce *callEngine) callGoFuncWithStack(ctx context.Context, callCtx *wasm.CallContext, f *function) { params := wasm.PopGoFuncParams(f.source, ce.popValue) results := ce.callGoFunc(ctx, callCtx, f, params) diff --git a/internal/wasm/jit/engine.go b/internal/wasm/jit/engine.go index 1ca3af8992..4aab0c6966 100644 --- a/internal/wasm/jit/engine.go +++ b/internal/wasm/jit/engine.go @@ -690,7 +690,7 @@ jitentry: switch ce.exitContext.builtinFunctionCallIndex { case builtinFunctionIndexMemoryGrow: callercode := ce.callFrameTop().function - ce.builtinFunctionMemoryGrow(callercode.source.Module.Memory) + ce.builtinFunctionMemoryGrow(ctx, callercode.source.Module.Memory) case builtinFunctionIndexGrowValueStack: callercode := ce.callFrameTop().function ce.builtinFunctionGrowValueStack(callercode.stackPointerCeil) @@ -740,10 +740,10 @@ func (ce *callEngine) builtinFunctionGrowCallFrameStack() { ce.globalContext.callFrameStackElementZeroAddress = stackSliceHeader.Data } -func (ce *callEngine) builtinFunctionMemoryGrow(mem *wasm.MemoryInstance) { +func (ce *callEngine) builtinFunctionMemoryGrow(ctx context.Context, mem *wasm.MemoryInstance) { newPages := ce.popValue() - res := mem.Grow(uint32(newPages)) + res := mem.Grow(ctx, uint32(newPages)) ce.pushValue(uint64(res)) // Update the moduleContext fields as they become stale after the update ^^. @@ -782,7 +782,7 @@ func compileWasmFunction(enabledFeatures wasm.Features, ir *wazeroir.Compilation var skip bool for _, op := range ir.Operations { - // Compiler determines whether or not skip the entire label. + // Compiler determines whether skip the entire label. // For example, if the label doesn't have any caller, // we don't need to generate native code at all as we never reach the region. if op.Kind() == wazeroir.OperationKindLabel { diff --git a/internal/wasm/memory.go b/internal/wasm/memory.go index f873d2cbd9..95560be50a 100644 --- a/internal/wasm/memory.go +++ b/internal/wasm/memory.go @@ -2,6 +2,7 @@ package wasm import ( "bytes" + "context" "encoding/binary" "fmt" "math" @@ -35,18 +36,17 @@ type MemoryInstance struct { } // Size implements the same method as documented on api.Memory. -func (m *MemoryInstance) Size() uint32 { - return uint32(len(m.Buffer)) -} +func (m *MemoryInstance) Size(_ context.Context) uint32 { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! -// hasSize returns true if Len is sufficient for sizeInBytes at the given offset. -func (m *MemoryInstance) hasSize(offset uint32, sizeInBytes uint32) bool { - return uint64(offset)+uint64(sizeInBytes) <= uint64(m.Size()) // uint64 prevents overflow on add + return m.size() } // IndexByte implements the same method as documented on api.Memory. -func (m *MemoryInstance) IndexByte(offset uint32, c byte) (uint32, bool) { - if offset >= m.Size() { +func (m *MemoryInstance) IndexByte(_ context.Context, offset uint32, c byte) (uint32, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + if offset >= uint32(len(m.Buffer)) { return 0, false } b := m.Buffer[offset:] @@ -58,24 +58,37 @@ func (m *MemoryInstance) IndexByte(offset uint32, c byte) (uint32, bool) { } // ReadByte implements the same method as documented on api.Memory. -func (m *MemoryInstance) ReadByte(offset uint32) (byte, bool) { - if offset >= m.Size() { +func (m *MemoryInstance) ReadByte(_ context.Context, offset uint32) (byte, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + if offset >= m.size() { return 0, false } return m.Buffer[offset], true } -// ReadUint32Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) ReadUint32Le(offset uint32) (uint32, bool) { - if !m.hasSize(offset, 4) { +// ReadUint16Le implements the same method as documented on api.Memory. +func (m *MemoryInstance) ReadUint16Le(_ context.Context, offset uint32) (uint16, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + if !m.hasSize(offset, 2) { return 0, false } - return binary.LittleEndian.Uint32(m.Buffer[offset : offset+4]), true + return binary.LittleEndian.Uint16(m.Buffer[offset : offset+2]), true +} + +// ReadUint32Le implements the same method as documented on api.Memory. +func (m *MemoryInstance) ReadUint32Le(_ context.Context, offset uint32) (uint32, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + return m.readUint32Le(offset) } // ReadFloat32Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) ReadFloat32Le(offset uint32) (float32, bool) { - v, ok := m.ReadUint32Le(offset) +func (m *MemoryInstance) ReadFloat32Le(_ context.Context, offset uint32) (float32, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + v, ok := m.readUint32Le(offset) if !ok { return 0, false } @@ -83,16 +96,17 @@ func (m *MemoryInstance) ReadFloat32Le(offset uint32) (float32, bool) { } // ReadUint64Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) ReadUint64Le(offset uint32) (uint64, bool) { - if !m.hasSize(offset, 8) { - return 0, false - } - return binary.LittleEndian.Uint64(m.Buffer[offset : offset+8]), true +func (m *MemoryInstance) ReadUint64Le(_ context.Context, offset uint32) (uint64, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + return m.readUint64Le(offset) } // ReadFloat64Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) ReadFloat64Le(offset uint32) (float64, bool) { - v, ok := m.ReadUint64Le(offset) +func (m *MemoryInstance) ReadFloat64Le(_ context.Context, offset uint32) (float64, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + v, ok := m.readUint64Le(offset) if !ok { return 0, false } @@ -100,7 +114,9 @@ func (m *MemoryInstance) ReadFloat64Le(offset uint32) (float64, bool) { } // Read implements the same method as documented on api.Memory. -func (m *MemoryInstance) Read(offset, byteCount uint32) ([]byte, bool) { +func (m *MemoryInstance) Read(_ context.Context, offset, byteCount uint32) ([]byte, bool) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + if !m.hasSize(offset, byteCount) { return nil, false } @@ -108,44 +124,58 @@ func (m *MemoryInstance) Read(offset, byteCount uint32) ([]byte, bool) { } // WriteByte implements the same method as documented on api.Memory. -func (m *MemoryInstance) WriteByte(offset uint32, v byte) bool { - if offset >= m.Size() { +func (m *MemoryInstance) WriteByte(_ context.Context, offset uint32, v byte) bool { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + if offset >= m.size() { return false } m.Buffer[offset] = v return true } -// WriteUint32Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) WriteUint32Le(offset, v uint32) bool { - if !m.hasSize(offset, 4) { +// WriteUint16Le implements the same method as documented on api.Memory. +func (m *MemoryInstance) WriteUint16Le(_ context.Context, offset uint32, v uint16) bool { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + if !m.hasSize(offset, 2) { return false } - binary.LittleEndian.PutUint32(m.Buffer[offset:], v) + binary.LittleEndian.PutUint16(m.Buffer[offset:], v) return true } +// WriteUint32Le implements the same method as documented on api.Memory. +func (m *MemoryInstance) WriteUint32Le(_ context.Context, offset, v uint32) bool { + + return m.writeUint32Le(offset, v) +} + // WriteFloat32Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) WriteFloat32Le(offset uint32, v float32) bool { - return m.WriteUint32Le(offset, math.Float32bits(v)) +func (m *MemoryInstance) WriteFloat32Le(_ context.Context, offset uint32, v float32) bool { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + return m.writeUint32Le(offset, math.Float32bits(v)) } // WriteUint64Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) WriteUint64Le(offset uint32, v uint64) bool { - if !m.hasSize(offset, 8) { - return false - } - binary.LittleEndian.PutUint64(m.Buffer[offset:], v) - return true +func (m *MemoryInstance) WriteUint64Le(_ context.Context, offset uint32, v uint64) bool { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + return m.writeUint64Le(offset, v) } // WriteFloat64Le implements the same method as documented on api.Memory. -func (m *MemoryInstance) WriteFloat64Le(offset uint32, v float64) bool { - return m.WriteUint64Le(offset, math.Float64bits(v)) +func (m *MemoryInstance) WriteFloat64Le(_ context.Context, offset uint32, v float64) bool { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + + return m.writeUint64Le(offset, math.Float64bits(v)) } // Write implements the same method as documented on api.Memory. -func (m *MemoryInstance) Write(offset uint32, val []byte) bool { +func (m *MemoryInstance) Write(_ context.Context, offset uint32, val []byte) bool { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + if !m.hasSize(offset, uint32(len(val))) { return false } @@ -158,17 +188,14 @@ func MemoryPagesToBytesNum(pages uint32) (bytesNum uint64) { return uint64(pages) << MemoryPageSizeInBits } -// memoryBytesNumToPages converts the given number of bytes into the number of pages. -func memoryBytesNumToPages(bytesNum uint64) (pages uint32) { - return uint32(bytesNum >> MemoryPageSizeInBits) -} - // Grow extends the memory buffer by "newPages" * memoryPageSize. // The logic here is described in https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#grow-mem. // // Returns -1 if the operation resulted in exceeding the maximum memory pages. // Otherwise, returns the prior memory size after growing the memory buffer. -func (m *MemoryInstance) Grow(newPages uint32) (result uint32) { +func (m *MemoryInstance) Grow(_ context.Context, newPages uint32) (result uint32) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + currentPages := memoryBytesNumToPages(uint64(len(m.Buffer))) // If exceeds the max of memory size, we push -1 according to the spec. @@ -182,7 +209,9 @@ func (m *MemoryInstance) Grow(newPages uint32) (result uint32) { } // PageSize returns the current memory buffer size in pages. -func (m *MemoryInstance) PageSize() (result uint32) { +func (m *MemoryInstance) PageSize(_ context.Context) (result uint32) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + return memoryBytesNumToPages(uint64(len(m.Buffer))) } @@ -204,3 +233,58 @@ func PagesToUnitOfBytes(pages uint32) string { } return fmt.Sprintf("%d Ti", g/1024) } + +// Below are raw functions used to implement the api.Memory API: + +// memoryBytesNumToPages converts the given number of bytes into the number of pages. +func memoryBytesNumToPages(bytesNum uint64) (pages uint32) { + return uint32(bytesNum >> MemoryPageSizeInBits) +} + +// size returns the size in bytes of the buffer. +func (m *MemoryInstance) size() uint32 { + return uint32(len(m.Buffer)) +} + +// hasSize returns true if Len is sufficient for sizeInBytes at the given offset. +func (m *MemoryInstance) hasSize(offset uint32, sizeInBytes uint32) bool { + return uint64(offset)+uint64(sizeInBytes) <= uint64(len(m.Buffer)) // uint64 prevents overflow on add +} + +// readUint32Le implements ReadUint32Le without using a context. This is extracted as both ints and floats are stored in +// memory as uint32le. +func (m *MemoryInstance) readUint32Le(offset uint32) (uint32, bool) { + if !m.hasSize(offset, 4) { + return 0, false + } + return binary.LittleEndian.Uint32(m.Buffer[offset : offset+4]), true +} + +// readUint64Le implements ReadUint64Le without using a context. This is extracted as both ints and floats are stored in +// memory as uint64le. +func (m *MemoryInstance) readUint64Le(offset uint32) (uint64, bool) { + if !m.hasSize(offset, 8) { + return 0, false + } + return binary.LittleEndian.Uint64(m.Buffer[offset : offset+8]), true +} + +// writeUint32Le implements WriteUint32Le without using a context. This is extracted as both ints and floats are stored +// in memory as uint32le. +func (m *MemoryInstance) writeUint32Le(offset uint32, v uint32) bool { + if !m.hasSize(offset, 4) { + return false + } + binary.LittleEndian.PutUint32(m.Buffer[offset:], v) + return true +} + +// writeUint64Le implements WriteUint64Le without using a context. This is extracted as both ints and floats are stored +// in memory as uint64le. +func (m *MemoryInstance) writeUint64Le(offset uint32, v uint64) bool { + if !m.hasSize(offset, 8) { + return false + } + binary.LittleEndian.PutUint64(m.Buffer[offset:], v) + return true +} diff --git a/internal/wasm/memory_test.go b/internal/wasm/memory_test.go index dab251ee29..ee44d5a4bc 100644 --- a/internal/wasm/memory_test.go +++ b/internal/wasm/memory_test.go @@ -1,6 +1,7 @@ package wasm import ( + "context" "math" "testing" @@ -26,70 +27,84 @@ func Test_MemoryBytesNumToPages(t *testing.T) { } func TestMemoryInstance_Grow_Size(t *testing.T) { - max := uint32(10) - m := &MemoryInstance{Max: max, Buffer: make([]byte, 0)} - require.Equal(t, uint32(0), m.Grow(5)) - require.Equal(t, uint32(5), m.PageSize()) - // Zero page grow is well-defined, should return the current page correctly. - require.Equal(t, uint32(5), m.Grow(0)) - require.Equal(t, uint32(5), m.PageSize()) - require.Equal(t, uint32(5), m.Grow(4)) - require.Equal(t, uint32(9), m.PageSize()) - // At this point, the page size equal 9, - // so trying to grow two pages should result in failure. - require.Equal(t, int32(-1), int32(m.Grow(2))) - require.Equal(t, uint32(9), m.PageSize()) - // But growing one page is still permitted. - require.Equal(t, uint32(9), m.Grow(1)) - // Ensure that the current page size equals the max. - require.Equal(t, max, m.PageSize()) + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + max := uint32(10) + m := &MemoryInstance{Max: max, Buffer: make([]byte, 0)} + require.Equal(t, uint32(0), m.Grow(ctx, 5)) + require.Equal(t, uint32(5), m.PageSize(ctx)) + + // Zero page grow is well-defined, should return the current page correctly. + require.Equal(t, uint32(5), m.Grow(ctx, 0)) + require.Equal(t, uint32(5), m.PageSize(ctx)) + require.Equal(t, uint32(5), m.Grow(ctx, 4)) + require.Equal(t, uint32(9), m.PageSize(ctx)) + + // At this point, the page size equal 9, + // so trying to grow two pages should result in failure. + require.Equal(t, int32(-1), int32(m.Grow(ctx, 2))) + require.Equal(t, uint32(9), m.PageSize(ctx)) + + // But growing one page is still permitted. + require.Equal(t, uint32(9), m.Grow(ctx, 1)) + + // Ensure that the current page size equals the max. + require.Equal(t, max, m.PageSize(ctx)) + } } func TestIndexByte(t *testing.T) { - var mem = &MemoryInstance{Buffer: []byte{0, 0, 0, 0, 16, 0, 0, 0}, Min: 1} - v, ok := mem.IndexByte(4, 16) - require.True(t, ok) - require.Equal(t, uint32(4), v) + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + var mem = &MemoryInstance{Buffer: []byte{0, 0, 0, 0, 16, 0, 0, 0}, Min: 1} + v, ok := mem.IndexByte(ctx, 4, 16) + require.True(t, ok) + require.Equal(t, uint32(4), v) - _, ok = mem.IndexByte(5, 16) - require.False(t, ok) + _, ok = mem.IndexByte(ctx, 5, 16) + require.False(t, ok) - _, ok = mem.IndexByte(9, 16) - require.False(t, ok) + _, ok = mem.IndexByte(ctx, 9, 16) + require.False(t, ok) + } } func TestReadByte(t *testing.T) { - var mem = &MemoryInstance{Buffer: []byte{0, 0, 0, 0, 0, 0, 0, 16}, Min: 1} - v, ok := mem.ReadByte(7) - require.True(t, ok) - require.Equal(t, byte(16), v) + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + var mem = &MemoryInstance{Buffer: []byte{0, 0, 0, 0, 0, 0, 0, 16}, Min: 1} + v, ok := mem.ReadByte(ctx, 7) + require.True(t, ok) + require.Equal(t, byte(16), v) - _, ok = mem.ReadByte(8) - require.False(t, ok) + _, ok = mem.ReadByte(ctx, 8) + require.False(t, ok) - _, ok = mem.ReadByte(9) - require.False(t, ok) + _, ok = mem.ReadByte(ctx, 9) + require.False(t, ok) + } } func TestReadUint32Le(t *testing.T) { - var mem = &MemoryInstance{Buffer: []byte{0, 0, 0, 0, 16, 0, 0, 0}, Min: 1} - v, ok := mem.ReadUint32Le(4) - require.True(t, ok) - require.Equal(t, uint32(16), v) + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + var mem = &MemoryInstance{Buffer: []byte{0, 0, 0, 0, 16, 0, 0, 0}, Min: 1} + v, ok := mem.ReadUint32Le(ctx, 4) + require.True(t, ok) + require.Equal(t, uint32(16), v) - _, ok = mem.ReadUint32Le(5) - require.False(t, ok) + _, ok = mem.ReadUint32Le(ctx, 5) + require.False(t, ok) - _, ok = mem.ReadUint32Le(9) - require.False(t, ok) + _, ok = mem.ReadUint32Le(ctx, 9) + require.False(t, ok) + } } func TestWriteUint32Le(t *testing.T) { - var mem = &MemoryInstance{Buffer: make([]byte, 8), Min: 1} - require.True(t, mem.WriteUint32Le(4, 16)) - require.Equal(t, []byte{0, 0, 0, 0, 16, 0, 0, 0}, mem.Buffer) - require.False(t, mem.WriteUint32Le(5, 16)) - require.False(t, mem.WriteUint32Le(9, 16)) + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + var mem = &MemoryInstance{Buffer: make([]byte, 8), Min: 1} + require.True(t, mem.WriteUint32Le(ctx, 4, 16)) + require.Equal(t, []byte{0, 0, 0, 0, 16, 0, 0, 0}, mem.Buffer) + require.False(t, mem.WriteUint32Le(ctx, 5, 16)) + require.False(t, mem.WriteUint32Le(ctx, 9, 16)) + } } func TestPagesToUnitOfBytes(t *testing.T) { @@ -135,7 +150,7 @@ func TestPagesToUnitOfBytes(t *testing.T) { } func TestMemoryInstance_HasSize(t *testing.T) { - memory := &MemoryInstance{Buffer: make([]byte, 100)} + memory := &MemoryInstance{Buffer: make([]byte, MemoryPageSize)} tests := []struct { name string @@ -151,19 +166,19 @@ func TestMemoryInstance_HasSize(t *testing.T) { }, { name: "maximum valid sizeInBytes", - offset: memory.Size() - 8, + offset: memory.Size(testCtx) - 8, sizeInBytes: 8, expected: true, }, { name: "sizeInBytes exceeds the valid size by 1", offset: 100, // arbitrary valid offset - sizeInBytes: uint64(memory.Size() - 99), + sizeInBytes: uint64(memory.Size(testCtx) - 99), expected: false, }, { name: "offset exceeds the memory size", - offset: memory.Size(), + offset: memory.Size(testCtx), sizeInBytes: 1, // arbitrary size expected: false, }, @@ -173,6 +188,12 @@ func TestMemoryInstance_HasSize(t *testing.T) { sizeInBytes: 4, // if there's overflow, offset + sizeInBytes is 3, and it may pass the check expected: false, }, + { + name: "address.wast:200", + offset: 4294967295, + sizeInBytes: 1, + expected: false, + }, } for _, tt := range tests { @@ -184,6 +205,57 @@ func TestMemoryInstance_HasSize(t *testing.T) { } } +func TestMemoryInstance_ReadUint16Le(t *testing.T) { + tests := []struct { + name string + memory []byte + offset uint32 + expected uint16 + expectedOk bool + }{ + { + name: "valid offset with an endian-insensitive v", + memory: []byte{0xff, 0xff}, + offset: 0, // arbitrary valid offset. + expected: math.MaxUint16, + expectedOk: true, + }, + { + name: "valid offset with an endian-sensitive v", + memory: []byte{0xfe, 0xff}, + offset: 0, // arbitrary valid offset. + expected: math.MaxUint16 - 1, + expectedOk: true, + }, + { + name: "maximum boundary valid offset", + offset: 1, + memory: []byte{0x00, 0x1, 0x00}, + expected: 1, // arbitrary valid v + expectedOk: true, + }, + { + name: "offset exceeds the maximum valid offset by 1", + memory: []byte{0xff, 0xff}, + offset: 1, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + memory := &MemoryInstance{Buffer: tc.memory} + + v, ok := memory.ReadUint16Le(ctx, tc.offset) + require.Equal(t, tc.expectedOk, ok) + require.Equal(t, tc.expected, v) + } + }) + } +} + func TestMemoryInstance_ReadUint32Le(t *testing.T) { tests := []struct { name string @@ -224,11 +296,13 @@ func TestMemoryInstance_ReadUint32Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - memory := &MemoryInstance{Buffer: tc.memory} + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + memory := &MemoryInstance{Buffer: tc.memory} - v, ok := memory.ReadUint32Le(tc.offset) - require.Equal(t, tc.expectedOk, ok) - require.Equal(t, tc.expected, v) + v, ok := memory.ReadUint32Le(ctx, tc.offset) + require.Equal(t, tc.expectedOk, ok) + require.Equal(t, tc.expected, v) + } }) } } @@ -273,11 +347,13 @@ func TestMemoryInstance_ReadUint64Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - memory := &MemoryInstance{Buffer: tc.memory} + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + memory := &MemoryInstance{Buffer: tc.memory} - v, ok := memory.ReadUint64Le(tc.offset) - require.Equal(t, tc.expectedOk, ok) - require.Equal(t, tc.expected, v) + v, ok := memory.ReadUint64Le(ctx, tc.offset) + require.Equal(t, tc.expectedOk, ok) + require.Equal(t, tc.expected, v) + } }) } } @@ -322,11 +398,13 @@ func TestMemoryInstance_ReadFloat32Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - memory := &MemoryInstance{Buffer: tc.memory} + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + memory := &MemoryInstance{Buffer: tc.memory} - v, ok := memory.ReadFloat32Le(tc.offset) - require.Equal(t, tc.expectedOk, ok) - require.Equal(t, tc.expected, v) + v, ok := memory.ReadFloat32Le(ctx, tc.offset) + require.Equal(t, tc.expectedOk, ok) + require.Equal(t, tc.expected, v) + } }) } } @@ -371,11 +449,66 @@ func TestMemoryInstance_ReadFloat64Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - memory := &MemoryInstance{Buffer: tc.memory} + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + memory := &MemoryInstance{Buffer: tc.memory} + + v, ok := memory.ReadFloat64Le(ctx, tc.offset) + require.Equal(t, tc.expectedOk, ok) + require.Equal(t, tc.expected, v) + } + }) + } +} - v, ok := memory.ReadFloat64Le(tc.offset) - require.Equal(t, tc.expectedOk, ok) - require.Equal(t, tc.expected, v) +func TestMemoryInstance_WriteUint16Le(t *testing.T) { + memory := &MemoryInstance{Buffer: make([]byte, 100)} + + tests := []struct { + name string + offset uint32 + v uint16 + expectedOk bool + expectedBytes []byte + }{ + { + name: "valid offset with an endian-insensitive v", + offset: 0, // arbitrary valid offset. + v: math.MaxUint16, + expectedOk: true, + expectedBytes: []byte{0xff, 0xff}, + }, + { + name: "valid offset with an endian-sensitive v", + offset: 0, // arbitrary valid offset. + v: math.MaxUint16 - 1, + expectedOk: true, + expectedBytes: []byte{0xfe, 0xff}, + }, + { + name: "maximum boundary valid offset", + offset: memory.Size(testCtx) - 2, // 2 is the size of uint16 + v: 1, // arbitrary valid v + expectedOk: true, + expectedBytes: []byte{0x1, 0x00}, + }, + { + name: "offset exceeds the maximum valid offset by 1", + offset: memory.Size(testCtx) - 2 + 1, // 2 is the size of uint16 + v: 1, // arbitrary valid v + expectedBytes: []byte{0xff, 0xff}, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + require.Equal(t, tc.expectedOk, memory.WriteUint16Le(ctx, tc.offset, tc.v)) + if tc.expectedOk { + require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+2]) // 2 is the size of uint16 + } + } }) } } @@ -406,15 +539,15 @@ func TestMemoryInstance_WriteUint32Le(t *testing.T) { }, { name: "maximum boundary valid offset", - offset: memory.Size() - 4, // 4 is the size of uint32 - v: 1, // arbitrary valid v + offset: memory.Size(testCtx) - 4, // 4 is the size of uint32 + v: 1, // arbitrary valid v expectedOk: true, expectedBytes: []byte{0x1, 0x00, 0x00, 0x00}, }, { name: "offset exceeds the maximum valid offset by 1", - offset: memory.Size() - 4 + 1, // 4 is the size of uint32 - v: 1, // arbitrary valid v + offset: memory.Size(testCtx) - 4 + 1, // 4 is the size of uint32 + v: 1, // arbitrary valid v expectedBytes: []byte{0xff, 0xff, 0xff, 0xff}, }, } @@ -423,9 +556,11 @@ func TestMemoryInstance_WriteUint32Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expectedOk, memory.WriteUint32Le(tc.offset, tc.v)) - if tc.expectedOk { - require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+4]) // 4 is the size of uint32 + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + require.Equal(t, tc.expectedOk, memory.WriteUint32Le(ctx, tc.offset, tc.v)) + if tc.expectedOk { + require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+4]) // 4 is the size of uint32 + } } }) } @@ -456,15 +591,15 @@ func TestMemoryInstance_WriteUint64Le(t *testing.T) { }, { name: "maximum boundary valid offset", - offset: memory.Size() - 8, // 8 is the size of uint64 - v: 1, // arbitrary valid v + offset: memory.Size(testCtx) - 8, // 8 is the size of uint64 + v: 1, // arbitrary valid v expectedOk: true, expectedBytes: []byte{0x1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, }, { name: "offset exceeds the maximum valid offset by 1", - offset: memory.Size() - 8 + 1, // 8 is the size of uint64 - v: 1, // arbitrary valid v + offset: memory.Size(testCtx) - 8 + 1, // 8 is the size of uint64 + v: 1, // arbitrary valid v expectedOk: false, }, } @@ -473,9 +608,11 @@ func TestMemoryInstance_WriteUint64Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expectedOk, memory.WriteUint64Le(tc.offset, tc.v)) - if tc.expectedOk { - require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+8]) // 8 is the size of uint64 + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + require.Equal(t, tc.expectedOk, memory.WriteUint64Le(ctx, tc.offset, tc.v)) + if tc.expectedOk { + require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+8]) // 8 is the size of uint64 + } } }) } @@ -507,15 +644,15 @@ func TestMemoryInstance_WriteFloat32Le(t *testing.T) { }, { name: "maximum boundary valid offset", - offset: memory.Size() - 4, // 4 is the size of float32 - v: 0.1, // arbitrary valid v + offset: memory.Size(testCtx) - 4, // 4 is the size of float32 + v: 0.1, // arbitrary valid v expectedOk: true, expectedBytes: []byte{0xcd, 0xcc, 0xcc, 0x3d}, }, { name: "offset exceeds the maximum valid offset by 1", - offset: memory.Size() - 4 + 1, // 4 is the size of float32 - v: math.MaxFloat32, // arbitrary valid v + offset: memory.Size(testCtx) - 4 + 1, // 4 is the size of float32 + v: math.MaxFloat32, // arbitrary valid v expectedBytes: []byte{0xff, 0xff, 0xff, 0xff}, }, } @@ -524,9 +661,11 @@ func TestMemoryInstance_WriteFloat32Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expectedOk, memory.WriteFloat32Le(tc.offset, tc.v)) - if tc.expectedOk { - require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+4]) // 4 is the size of float32 + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + require.Equal(t, tc.expectedOk, memory.WriteFloat32Le(ctx, tc.offset, tc.v)) + if tc.expectedOk { + require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+4]) // 4 is the size of float32 + } } }) } @@ -557,15 +696,15 @@ func TestMemoryInstance_WriteFloat64Le(t *testing.T) { }, { name: "maximum boundary valid offset", - offset: memory.Size() - 8, // 8 is the size of float64 - v: math.MaxFloat64, // arbitrary valid v + offset: memory.Size(testCtx) - 8, // 8 is the size of float64 + v: math.MaxFloat64, // arbitrary valid v expectedOk: true, expectedBytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0x7f}, }, { name: "offset exceeds the maximum valid offset by 1", - offset: memory.Size() - 8 + 1, // 8 is the size of float64 - v: math.MaxFloat64, // arbitrary valid v + offset: memory.Size(testCtx) - 8 + 1, // 8 is the size of float64 + v: math.MaxFloat64, // arbitrary valid v expectedOk: false, }, } @@ -574,9 +713,11 @@ func TestMemoryInstance_WriteFloat64Le(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.expectedOk, memory.WriteFloat64Le(tc.offset, tc.v)) - if tc.expectedOk { - require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+8]) // 8 is the size of float64 + for _, ctx := range []context.Context{nil, testCtx} { // Ensure it doesn't crash on nil! + require.Equal(t, tc.expectedOk, memory.WriteFloat64Le(ctx, tc.offset, tc.v)) + if tc.expectedOk { + require.Equal(t, tc.expectedBytes, memory.Buffer[tc.offset:tc.offset+8]) // 8 is the size of float64 + } } }) } diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 31714a0780..05773ce6d0 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -75,12 +75,12 @@ func TestModuleInstance_Memory(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := newStore() - instance, err := s.Instantiate(context.Background(), tc.input, "test", nil) + instance, err := s.Instantiate(testCtx, tc.input, "test", nil) require.NoError(t, err) mem := instance.ExportedMemory("memory") if tc.expected { - require.Equal(t, tc.expectedLen, mem.Size()) + require.Equal(t, tc.expectedLen, mem.Size(testCtx)) } else { require.Nil(t, mem) } @@ -102,7 +102,7 @@ func TestStore_Instantiate(t *testing.T) { sys := &SysContext{} mod, err := s.Instantiate(testCtx, m, "", sys) require.NoError(t, err) - defer mod.Close() + defer mod.Close(testCtx) t.Run("CallContext defaults", func(t *testing.T) { require.Equal(t, s.modules[""], mod.module) @@ -131,14 +131,14 @@ func TestStore_CloseModule(t *testing.T) { Features20191205, ) require.NoError(t, err) - _, err = s.Instantiate(context.Background(), m, importedModuleName, nil) + _, err = s.Instantiate(testCtx, m, importedModuleName, nil) require.NoError(t, err) }, }, { name: "Module imports Module", initializer: func(t *testing.T, s *Store) { - _, err := s.Instantiate(context.Background(), &Module{ + _, err := s.Instantiate(testCtx, &Module{ TypeSection: []*FunctionType{{}}, FunctionSection: []uint32{0}, CodeSection: []*Code{{Body: []byte{OpcodeEnd}}}, @@ -153,7 +153,7 @@ func TestStore_CloseModule(t *testing.T) { s := newStore() tc.initializer(t, s) - _, err := s.Instantiate(context.Background(), &Module{ + _, err := s.Instantiate(testCtx, &Module{ TypeSection: []*FunctionType{{}}, ImportSection: []*Import{{Type: ExternTypeFunc, Module: importedModuleName, Name: "fn", DescFunc: 0}}, MemorySection: &Memory{Min: 1}, @@ -169,14 +169,14 @@ func TestStore_CloseModule(t *testing.T) { require.True(t, ok) // Close the importing module - require.NoError(t, importing.CallCtx.CloseWithExitCode(0)) + require.NoError(t, importing.CallCtx.CloseWithExitCode(testCtx, 0)) require.Nil(t, s.modules[importingModuleName]) // Can re-close the importing module - require.NoError(t, importing.CallCtx.CloseWithExitCode(0)) + require.NoError(t, importing.CallCtx.CloseWithExitCode(testCtx, 0)) // Now we close the imported module. - require.NoError(t, imported.CallCtx.CloseWithExitCode(0)) + require.NoError(t, imported.CallCtx.CloseWithExitCode(testCtx, 0)) require.Nil(t, s.modules[importedModuleName]) }) } @@ -195,7 +195,7 @@ func TestStore_hammer(t *testing.T) { require.NoError(t, err) s := newStore() - imported, err := s.Instantiate(context.Background(), m, importedModuleName, nil) + imported, err := s.Instantiate(testCtx, m, importedModuleName, nil) require.NoError(t, err) _, ok := s.modules[imported.Name()] @@ -222,16 +222,16 @@ func TestStore_hammer(t *testing.T) { N = 100 } hammer.NewHammer(t, P, N).Run(func(name string) { - mod, instantiateErr := s.Instantiate(context.Background(), importingModule, name, DefaultSysContext()) + mod, instantiateErr := s.Instantiate(testCtx, importingModule, name, DefaultSysContext()) require.NoError(t, instantiateErr) - require.NoError(t, mod.CloseWithExitCode(0)) + require.NoError(t, mod.Close(testCtx)) }, nil) if t.Failed() { return // At least one test failed, so return now. } // Close the imported module. - require.NoError(t, imported.CloseWithExitCode(0)) + require.NoError(t, imported.Close(testCtx)) // All instances are freed. require.Zero(t, len(s.modules)) @@ -252,23 +252,23 @@ func TestStore_Instantiate_Errors(t *testing.T) { t.Run("Fails if module name already in use", func(t *testing.T) { s := newStore() - _, err = s.Instantiate(context.Background(), m, importedModuleName, nil) + _, err = s.Instantiate(testCtx, m, importedModuleName, nil) require.NoError(t, err) // Trying to register it again should fail - _, err = s.Instantiate(context.Background(), m, importedModuleName, nil) + _, err = s.Instantiate(testCtx, m, importedModuleName, nil) require.EqualError(t, err, "module imported has already been instantiated") }) t.Run("fail resolve import", func(t *testing.T) { s := newStore() - _, err = s.Instantiate(context.Background(), m, importedModuleName, nil) + _, err = s.Instantiate(testCtx, m, importedModuleName, nil) require.NoError(t, err) hm := s.modules[importedModuleName] require.NotNil(t, hm) - _, err = s.Instantiate(context.Background(), &Module{ + _, err = s.Instantiate(testCtx, &Module{ TypeSection: []*FunctionType{{}}, ImportSection: []*Import{ // The first import resolve succeeds -> increment hm.dependentCount. @@ -283,7 +283,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { t.Run("compilation failed", func(t *testing.T) { s := newStore() - _, err = s.Instantiate(context.Background(), m, importedModuleName, nil) + _, err = s.Instantiate(testCtx, m, importedModuleName, nil) require.NoError(t, err) hm := s.modules[importedModuleName] @@ -292,7 +292,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { engine := s.Engine.(*mockEngine) engine.shouldCompileFail = true - _, err = s.Instantiate(context.Background(), &Module{ + _, err = s.Instantiate(testCtx, &Module{ TypeSection: []*FunctionType{{}}, FunctionSection: []uint32{0, 0}, CodeSection: []*Code{ @@ -311,14 +311,14 @@ func TestStore_Instantiate_Errors(t *testing.T) { engine := s.Engine.(*mockEngine) engine.callFailIndex = 1 - _, err = s.Instantiate(context.Background(), m, importedModuleName, nil) + _, err = s.Instantiate(testCtx, m, importedModuleName, nil) require.NoError(t, err) hm := s.modules[importedModuleName] require.NotNil(t, hm) startFuncIndex := uint32(1) - _, err = s.Instantiate(context.Background(), &Module{ + _, err = s.Instantiate(testCtx, &Module{ TypeSection: []*FunctionType{{}}, FunctionSection: []uint32{0}, CodeSection: []*Code{{Body: []byte{OpcodeEnd}}}, @@ -344,19 +344,19 @@ func TestCallContext_ExportedFunction(t *testing.T) { s := newStore() // Add the host module - imported, err := s.Instantiate(context.Background(), host, host.NameSection.ModuleName, nil) + imported, err := s.Instantiate(testCtx, host, host.NameSection.ModuleName, nil) require.NoError(t, err) - defer imported.Close() + defer imported.Close(testCtx) t.Run("imported function", func(t *testing.T) { - importing, err := s.Instantiate(context.Background(), &Module{ + importing, err := s.Instantiate(testCtx, &Module{ TypeSection: []*FunctionType{{}}, ImportSection: []*Import{{Type: ExternTypeFunc, Module: "host", Name: "host_fn", DescFunc: 0}}, MemorySection: &Memory{Min: 1}, ExportSection: []*Export{{Type: ExternTypeFunc, Name: "host.fn", Index: 0}}, }, "test", nil) require.NoError(t, err) - defer importing.Close() + defer importing.Close(testCtx) fn := importing.ExportedFunction("host.fn") require.NotNil(t, fn) @@ -409,7 +409,7 @@ func (e *mockModuleEngine) Call(ctx context.Context, callCtx *CallContext, f *Fu } // Close implements the same method as documented on wasm.ModuleEngine. -func (e *mockModuleEngine) Close() { +func (e *mockModuleEngine) Close(_ context.Context) { } func TestStore_getFunctionTypeID(t *testing.T) { diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 11a08c3902..45cb0090ae 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -182,6 +182,8 @@ type CompilationResult struct { } func CompileFunctions(_ context.Context, enabledFeatures wasm.Features, module *wasm.Module) ([]*CompilationResult, error) { + // Note: If you use the context.Context param, don't forget to coerce nil to context.Background()! + functions, globals, mem, table, err := module.AllDeclarations() if err != nil { return nil, err diff --git a/wasi/example_test.go b/wasi/example_test.go index 461b2f51d5..e8169095fb 100644 --- a/wasi/example_test.go +++ b/wasi/example_test.go @@ -25,7 +25,7 @@ func Example() { if err != nil { log.Fatal(err) } - defer wm.Close() + defer wm.Close(testCtx) // Override default configuration (which discards stdout). config := wazero.NewModuleConfig().WithStdout(os.Stdout) diff --git a/wasi/usage_test.go b/wasi/usage_test.go index b3048649c8..cf8effd415 100644 --- a/wasi/usage_test.go +++ b/wasi/usage_test.go @@ -22,11 +22,11 @@ func TestInstantiateModuleWithConfig(t *testing.T) { sys := wazero.NewModuleConfig().WithStdout(stdout) wm, err := InstantiateSnapshotPreview1(testCtx, r) require.NoError(t, err) - defer wm.Close() + defer wm.Close(testCtx) compiled, err := r.CompileModule(testCtx, wasiArg) require.NoError(t, err) - defer compiled.Close() + defer compiled.Close(testCtx) // Re-use the same module many times. for _, tc := range []string{"a", "b", "c"} { @@ -37,6 +37,6 @@ func TestInstantiateModuleWithConfig(t *testing.T) { require.Equal(t, append([]byte(tc), 0), stdout.Bytes()) stdout.Reset() - require.NoError(t, mod.Close()) + require.NoError(t, mod.Close(testCtx)) } } diff --git a/wasi/wasi.go b/wasi/wasi.go index 6d9db6ec68..ddc15d6dca 100644 --- a/wasi/wasi.go +++ b/wasi/wasi.go @@ -515,9 +515,9 @@ func snapshotPreview1Functions() (a *snapshotPreview1, nameToGoFunc map[string]i // See ArgsSizesGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#args_get // See https://en.wikipedia.org/wiki/Null-terminated_string -func (a *snapshotPreview1) ArgsGet(m api.Module, argv, argvBuf uint32) Errno { +func (a *snapshotPreview1) ArgsGet(ctx context.Context, m api.Module, argv, argvBuf uint32) Errno { sys := sysCtx(m) - return writeOffsetsAndNullTerminatedValues(m.Memory(), sys.Args(), argv, argvBuf) + return writeOffsetsAndNullTerminatedValues(ctx, m.Memory(), sys.Args(), argv, argvBuf) } // ArgsSizesGet is the WASI function named functionArgsSizesGet that reads command-line argument data (WithArgs) @@ -546,14 +546,14 @@ func (a *snapshotPreview1) ArgsGet(m api.Module, argv, argvBuf uint32) Errno { // See ArgsGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#args_sizes_get // See https://en.wikipedia.org/wiki/Null-terminated_string -func (a *snapshotPreview1) ArgsSizesGet(m api.Module, resultArgc, resultArgvBufSize uint32) Errno { +func (a *snapshotPreview1) ArgsSizesGet(ctx context.Context, m api.Module, resultArgc, resultArgvBufSize uint32) Errno { sys := sysCtx(m) mem := m.Memory() - if !mem.WriteUint32Le(resultArgc, uint32(len(sys.Args()))) { + if !mem.WriteUint32Le(ctx, resultArgc, uint32(len(sys.Args()))) { return ErrnoFault } - if !mem.WriteUint32Le(resultArgvBufSize, sys.ArgsSize()) { + if !mem.WriteUint32Le(ctx, resultArgvBufSize, sys.ArgsSize()) { return ErrnoFault } return ErrnoSuccess @@ -585,9 +585,9 @@ func (a *snapshotPreview1) ArgsSizesGet(m api.Module, resultArgc, resultArgvBufS // See EnvironSizesGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#environ_get // See https://en.wikipedia.org/wiki/Null-terminated_string -func (a *snapshotPreview1) EnvironGet(m api.Module, environ uint32, environBuf uint32) Errno { +func (a *snapshotPreview1) EnvironGet(ctx context.Context, m api.Module, environ uint32, environBuf uint32) Errno { sys := sysCtx(m) - return writeOffsetsAndNullTerminatedValues(m.Memory(), sys.Environ(), environ, environBuf) + return writeOffsetsAndNullTerminatedValues(ctx, m.Memory(), sys.Environ(), environ, environBuf) } // EnvironSizesGet is the WASI function named functionEnvironSizesGet that reads environment variable @@ -617,14 +617,14 @@ func (a *snapshotPreview1) EnvironGet(m api.Module, environ uint32, environBuf u // See EnvironGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#environ_sizes_get // See https://en.wikipedia.org/wiki/Null-terminated_string -func (a *snapshotPreview1) EnvironSizesGet(m api.Module, resultEnvironc uint32, resultEnvironBufSize uint32) Errno { +func (a *snapshotPreview1) EnvironSizesGet(ctx context.Context, m api.Module, resultEnvironc uint32, resultEnvironBufSize uint32) Errno { sys := sysCtx(m) mem := m.Memory() - if !mem.WriteUint32Le(resultEnvironc, uint32(len(sys.Environ()))) { + if !mem.WriteUint32Le(ctx, resultEnvironc, uint32(len(sys.Environ()))) { return ErrnoFault } - if !mem.WriteUint32Le(resultEnvironBufSize, sys.EnvironSize()) { + if !mem.WriteUint32Le(ctx, resultEnvironBufSize, sys.EnvironSize()) { return ErrnoFault } @@ -632,7 +632,7 @@ func (a *snapshotPreview1) EnvironSizesGet(m api.Module, resultEnvironc uint32, } // ClockResGet is the WASI function named functionClockResGet and is stubbed for GrainLang per #271 -func (a *snapshotPreview1) ClockResGet(m api.Module, id uint32, resultResolution uint32) Errno { +func (a *snapshotPreview1) ClockResGet(ctx context.Context, m api.Module, id uint32, resultResolution uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -656,21 +656,21 @@ func (a *snapshotPreview1) ClockResGet(m api.Module, id uint32, resultResolution // Note: This is similar to `clock_gettime` in POSIX. // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-clock_time_getid-clockid-precision-timestamp---errno-timestamp // See https://linux.die.net/man/3/clock_gettime -func (a *snapshotPreview1) ClockTimeGet(m api.Module, id uint32, precision uint64, resultTimestamp uint32) Errno { +func (a *snapshotPreview1) ClockTimeGet(ctx context.Context, m api.Module, id uint32, precision uint64, resultTimestamp uint32) Errno { // TODO: id and precision are currently ignored. - if !m.Memory().WriteUint64Le(resultTimestamp, a.timeNowUnixNano()) { + if !m.Memory().WriteUint64Le(ctx, resultTimestamp, a.timeNowUnixNano()) { return ErrnoFault } return ErrnoSuccess } // FdAdvise is the WASI function named functionFdAdvise and is stubbed for GrainLang per #271 -func (a *snapshotPreview1) FdAdvise(m api.Module, fd uint32, offset, len uint64, resultAdvice uint32) Errno { +func (a *snapshotPreview1) FdAdvise(ctx context.Context, m api.Module, fd uint32, offset, len uint64, resultAdvice uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdAllocate is the WASI function named functionFdAllocate and is stubbed for GrainLang per #271 -func (a *snapshotPreview1) FdAllocate(m api.Module, fd uint32, offset, len uint64) Errno { +func (a *snapshotPreview1) FdAllocate(ctx context.Context, m api.Module, fd uint32, offset, len uint64) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -682,7 +682,7 @@ func (a *snapshotPreview1) FdAllocate(m api.Module, fd uint32, offset, len uint6 // Note: This is similar to `close` in POSIX. // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#fd_close // See https://linux.die.net/man/3/close -func (a *snapshotPreview1) FdClose(m api.Module, fd uint32) Errno { +func (a *snapshotPreview1) FdClose(ctx context.Context, m api.Module, fd uint32) Errno { sys := sysCtx(m) if ok, err := sys.CloseFile(fd); err != nil { @@ -695,7 +695,7 @@ func (a *snapshotPreview1) FdClose(m api.Module, fd uint32) Errno { } // FdDatasync is the WASI function named functionFdDatasync and is stubbed for GrainLang per #271 -func (a *snapshotPreview1) FdDatasync(m api.Module, fd uint32) Errno { +func (a *snapshotPreview1) FdDatasync(ctx context.Context, m api.Module, fd uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -731,7 +731,7 @@ func (a *snapshotPreview1) FdDatasync(m api.Module, fd uint32) Errno { // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fdstat // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#fd_fdstat_get // See https://linux.die.net/man/3/fsync -func (a *snapshotPreview1) FdFdstatGet(m api.Module, fd uint32, resultStat uint32) Errno { +func (a *snapshotPreview1) FdFdstatGet(ctx context.Context, m api.Module, fd uint32, resultStat uint32) Errno { sys := sysCtx(m) if _, ok := sys.OpenedFile(fd); !ok { @@ -767,7 +767,7 @@ func (a *snapshotPreview1) FdFdstatGet(m api.Module, fd uint32, resultStat uint3 // See FdPrestatDirName // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#prestat // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#fd_prestat_get -func (a *snapshotPreview1) FdPrestatGet(m api.Module, fd uint32, resultPrestat uint32) Errno { +func (a *snapshotPreview1) FdPrestatGet(ctx context.Context, m api.Module, fd uint32, resultPrestat uint32) Errno { sys := sysCtx(m) entry, ok := sys.OpenedFile(fd) @@ -776,11 +776,11 @@ func (a *snapshotPreview1) FdPrestatGet(m api.Module, fd uint32, resultPrestat u } // Zero-value 8-bit tag, and 3-byte zero-value paddings, which is uint32le(0) in short. - if !m.Memory().WriteUint32Le(resultPrestat, uint32(0)) { + if !m.Memory().WriteUint32Le(ctx, resultPrestat, uint32(0)) { return ErrnoFault } // Write the length of the directory name at offset 4. - if !m.Memory().WriteUint32Le(resultPrestat+4, uint32(len(entry.Path))) { + if !m.Memory().WriteUint32Le(ctx, resultPrestat+4, uint32(len(entry.Path))) { return ErrnoFault } @@ -788,33 +788,33 @@ func (a *snapshotPreview1) FdPrestatGet(m api.Module, fd uint32, resultPrestat u } // FdFdstatSetFlags is the WASI function named functionFdFdstatSetFlags and is stubbed for GrainLang per #271 -func (a *snapshotPreview1) FdFdstatSetFlags(m api.Module, fd uint32, flags uint32) Errno { +func (a *snapshotPreview1) FdFdstatSetFlags(ctx context.Context, m api.Module, fd uint32, flags uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdFdstatSetRights implements snapshotPreview1.FdFdstatSetRights // Note: This will never be implemented per https://github.com/WebAssembly/WASI/issues/469#issuecomment-1045251844 -func (a *snapshotPreview1) FdFdstatSetRights(m api.Module, fd uint32, fsRightsBase, fsRightsInheriting uint64) Errno { +func (a *snapshotPreview1) FdFdstatSetRights(ctx context.Context, m api.Module, fd uint32, fsRightsBase, fsRightsInheriting uint64) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdFilestatGet is the WASI function named functionFdFilestatGet -func (a *snapshotPreview1) FdFilestatGet(m api.Module, fd uint32, resultBuf uint32) Errno { +func (a *snapshotPreview1) FdFilestatGet(ctx context.Context, m api.Module, fd uint32, resultBuf uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdFilestatSetSize is the WASI function named functionFdFilestatSetSize -func (a *snapshotPreview1) FdFilestatSetSize(m api.Module, fd uint32, size uint64) Errno { +func (a *snapshotPreview1) FdFilestatSetSize(ctx context.Context, m api.Module, fd uint32, size uint64) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdFilestatSetTimes is the WASI function named functionFdFilestatSetTimes -func (a *snapshotPreview1) FdFilestatSetTimes(m api.Module, fd uint32, atim, mtim uint64, fstFlags uint32) Errno { +func (a *snapshotPreview1) FdFilestatSetTimes(ctx context.Context, m api.Module, fd uint32, atim, mtim uint64, fstFlags uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdPread is the WASI function named functionFdPread -func (a *snapshotPreview1) FdPread(m api.Module, fd, iovs, iovsCount uint32, offset uint64, resultNread uint32) Errno { +func (a *snapshotPreview1) FdPread(ctx context.Context, m api.Module, fd, iovs, iovsCount uint32, offset uint64, resultNread uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -842,7 +842,7 @@ func (a *snapshotPreview1) FdPread(m api.Module, fd, iovs, iovsCount uint32, off // Note: importFdPrestatDirName shows this signature in the WebAssembly 1.0 (20191205) Text Format. // See FdPrestatGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_prestat_dir_name -func (a *snapshotPreview1) FdPrestatDirName(m api.Module, fd uint32, pathPtr uint32, pathLen uint32) Errno { +func (a *snapshotPreview1) FdPrestatDirName(ctx context.Context, m api.Module, fd uint32, pathPtr uint32, pathLen uint32) Errno { sys := sysCtx(m) f, ok := sys.OpenedFile(fd) @@ -856,14 +856,14 @@ func (a *snapshotPreview1) FdPrestatDirName(m api.Module, fd uint32, pathPtr uin } // TODO: FdPrestatDirName may have to return ErrnoNotdir if the type of the prestat data of `fd` is not a PrestatDir. - if !m.Memory().Write(pathPtr, []byte(f.Path)[:pathLen]) { + if !m.Memory().Write(ctx, pathPtr, []byte(f.Path)[:pathLen]) { return ErrnoFault } return ErrnoSuccess } // FdPwrite is the WASI function named functionFdPwrite -func (a *snapshotPreview1) FdPwrite(m api.Module, fd, iovs, iovsCount uint32, offset uint64, resultNwritten uint32) Errno { +func (a *snapshotPreview1) FdPwrite(ctx context.Context, m api.Module, fd, iovs, iovsCount uint32, offset uint64, resultNwritten uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -910,7 +910,7 @@ func (a *snapshotPreview1) FdPwrite(m api.Module, fd, iovs, iovsCount uint32, of // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_read // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#iovec // See https://linux.die.net/man/3/readv -func (a *snapshotPreview1) FdRead(m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { +func (a *snapshotPreview1) FdRead(ctx context.Context, m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { sys := sysCtx(m) var reader io.Reader @@ -926,15 +926,15 @@ func (a *snapshotPreview1) FdRead(m api.Module, fd, iovs, iovsCount, resultSize var nread uint32 for i := uint32(0); i < iovsCount; i++ { iovPtr := iovs + i*8 - offset, ok := m.Memory().ReadUint32Le(iovPtr) + offset, ok := m.Memory().ReadUint32Le(ctx, iovPtr) if !ok { return ErrnoFault } - l, ok := m.Memory().ReadUint32Le(iovPtr + 4) + l, ok := m.Memory().ReadUint32Le(ctx, iovPtr+4) if !ok { return ErrnoFault } - b, ok := m.Memory().Read(offset, l) + b, ok := m.Memory().Read(ctx, offset, l) if !ok { return ErrnoFault } @@ -946,19 +946,19 @@ func (a *snapshotPreview1) FdRead(m api.Module, fd, iovs, iovsCount, resultSize return ErrnoIo } } - if !m.Memory().WriteUint32Le(resultSize, nread) { + if !m.Memory().WriteUint32Le(ctx, resultSize, nread) { return ErrnoFault } return ErrnoSuccess } // FdReaddir is the WASI function named functionFdReaddir -func (a *snapshotPreview1) FdReaddir(m api.Module, fd, buf, bufLen uint32, cookie uint64, resultBufused uint32) Errno { +func (a *snapshotPreview1) FdReaddir(ctx context.Context, m api.Module, fd, buf, bufLen uint32, cookie uint64, resultBufused uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdRenumber is the WASI function named functionFdRenumber -func (a *snapshotPreview1) FdRenumber(m api.Module, fd, to uint32) Errno { +func (a *snapshotPreview1) FdRenumber(ctx context.Context, m api.Module, fd, to uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -993,7 +993,7 @@ func (a *snapshotPreview1) FdRenumber(m api.Module, fd, to uint32) Errno { // Note: This is similar to `lseek` in POSIX. // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_seek // See https://linux.die.net/man/3/lseek -func (a *snapshotPreview1) FdSeek(m api.Module, fd uint32, offset uint64, whence uint32, resultNewoffset uint32) Errno { +func (a *snapshotPreview1) FdSeek(ctx context.Context, m api.Module, fd uint32, offset uint64, whence uint32, resultNewoffset uint32) Errno { sys := sysCtx(m) var seeker io.Seeker @@ -1013,7 +1013,7 @@ func (a *snapshotPreview1) FdSeek(m api.Module, fd uint32, offset uint64, whence return ErrnoIo } - if !m.Memory().WriteUint32Le(resultNewoffset, uint32(newOffset)) { + if !m.Memory().WriteUint32Le(ctx, resultNewoffset, uint32(newOffset)) { return ErrnoFault } @@ -1021,12 +1021,12 @@ func (a *snapshotPreview1) FdSeek(m api.Module, fd uint32, offset uint64, whence } // FdSync is the WASI function named functionFdSync -func (a *snapshotPreview1) FdSync(m api.Module, fd uint32) Errno { +func (a *snapshotPreview1) FdSync(ctx context.Context, m api.Module, fd uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // FdTell is the WASI function named functionFdTell -func (a *snapshotPreview1) FdTell(m api.Module, fd, resultOffset uint32) Errno { +func (a *snapshotPreview1) FdTell(ctx context.Context, m api.Module, fd, resultOffset uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -1079,7 +1079,7 @@ func (a *snapshotPreview1) FdTell(m api.Module, fd, resultOffset uint32) Errno { // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#ciovec // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_write // See https://linux.die.net/man/3/writev -func (a *snapshotPreview1) FdWrite(m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { +func (a *snapshotPreview1) FdWrite(ctx context.Context, m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { sys := sysCtx(m) var writer io.Writer @@ -1102,15 +1102,15 @@ func (a *snapshotPreview1) FdWrite(m api.Module, fd, iovs, iovsCount, resultSize var nwritten uint32 for i := uint32(0); i < iovsCount; i++ { iovPtr := iovs + i*8 - offset, ok := m.Memory().ReadUint32Le(iovPtr) + offset, ok := m.Memory().ReadUint32Le(ctx, iovPtr) if !ok { return ErrnoFault } - l, ok := m.Memory().ReadUint32Le(iovPtr + 4) + l, ok := m.Memory().ReadUint32Le(ctx, iovPtr+4) if !ok { return ErrnoFault } - b, ok := m.Memory().Read(offset, l) + b, ok := m.Memory().Read(ctx, offset, l) if !ok { return ErrnoFault } @@ -1120,29 +1120,29 @@ func (a *snapshotPreview1) FdWrite(m api.Module, fd, iovs, iovsCount, resultSize } nwritten += uint32(n) } - if !m.Memory().WriteUint32Le(resultSize, nwritten) { + if !m.Memory().WriteUint32Le(ctx, resultSize, nwritten) { return ErrnoFault } return ErrnoSuccess } // PathCreateDirectory is the WASI function named functionPathCreateDirectory -func (a *snapshotPreview1) PathCreateDirectory(m api.Module, fd, path, pathLen uint32) Errno { +func (a *snapshotPreview1) PathCreateDirectory(ctx context.Context, m api.Module, fd, path, pathLen uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PathFilestatGet is the WASI function named functionPathFilestatGet -func (a *snapshotPreview1) PathFilestatGet(m api.Module, fd, flags, path, pathLen, resultBuf uint32) Errno { +func (a *snapshotPreview1) PathFilestatGet(ctx context.Context, m api.Module, fd, flags, path, pathLen, resultBuf uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PathFilestatSetTimes is the WASI function named functionPathFilestatSetTimes -func (a *snapshotPreview1) PathFilestatSetTimes(m api.Module, fd, flags, path, pathLen uint32, atim, mtime uint64, fstFlags uint32) Errno { +func (a *snapshotPreview1) PathFilestatSetTimes(ctx context.Context, m api.Module, fd, flags, path, pathLen uint32, atim, mtime uint64, fstFlags uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PathLink is the WASI function named functionPathLink -func (a *snapshotPreview1) PathLink(m api.Module, oldFd, oldFlags, oldPath, oldPathLen, newFd, newPath, newPathLen uint32) Errno { +func (a *snapshotPreview1) PathLink(ctx context.Context, m api.Module, oldFd, oldFlags, oldPath, oldPathLen, newFd, newPath, newPathLen uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -1191,7 +1191,7 @@ func (a *snapshotPreview1) PathLink(m api.Module, oldFd, oldFlags, oldPath, oldP // Note: Rights will never be implemented per https://github.com/WebAssembly/WASI/issues/469#issuecomment-1045251844 // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#path_open // See https://linux.die.net/man/3/openat -func (a *snapshotPreview1) PathOpen(m api.Module, fd, dirflags, pathPtr, pathLen, oflags uint32, fsRightsBase, +func (a *snapshotPreview1) PathOpen(ctx context.Context, m api.Module, fd, dirflags, pathPtr, pathLen, oflags uint32, fsRightsBase, fsRightsInheriting uint64, fdflags, resultOpenedFd uint32) (errno Errno) { sys := sysCtx(m) @@ -1200,7 +1200,7 @@ func (a *snapshotPreview1) PathOpen(m api.Module, fd, dirflags, pathPtr, pathLen return ErrnoBadf } - b, ok := m.Memory().Read(pathPtr, pathLen) + b, ok := m.Memory().Read(ctx, pathPtr, pathLen) if !ok { return ErrnoFault } @@ -1216,7 +1216,7 @@ func (a *snapshotPreview1) PathOpen(m api.Module, fd, dirflags, pathPtr, pathLen if newFD, ok := sys.OpenFile(entry); !ok { _ = entry.File.Close() return ErrnoIo - } else if !m.Memory().WriteUint32Le(resultOpenedFd, newFD) { + } else if !m.Memory().WriteUint32Le(ctx, resultOpenedFd, newFD) { _ = entry.File.Close() return ErrnoFault } @@ -1224,32 +1224,32 @@ func (a *snapshotPreview1) PathOpen(m api.Module, fd, dirflags, pathPtr, pathLen } // PathReadlink is the WASI function named functionPathReadlink -func (a *snapshotPreview1) PathReadlink(m api.Module, fd, path, pathLen, buf, bufLen, resultBufused uint32) Errno { +func (a *snapshotPreview1) PathReadlink(ctx context.Context, m api.Module, fd, path, pathLen, buf, bufLen, resultBufused uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PathRemoveDirectory is the WASI function named functionPathRemoveDirectory -func (a *snapshotPreview1) PathRemoveDirectory(m api.Module, fd, path, pathLen uint32) Errno { +func (a *snapshotPreview1) PathRemoveDirectory(ctx context.Context, m api.Module, fd, path, pathLen uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PathRename is the WASI function named functionPathRename -func (a *snapshotPreview1) PathRename(m api.Module, fd, oldPath, oldPathLen, newFd, newPath, newPathLen uint32) Errno { +func (a *snapshotPreview1) PathRename(ctx context.Context, m api.Module, fd, oldPath, oldPathLen, newFd, newPath, newPathLen uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PathSymlink is the WASI function named functionPathSymlink -func (a *snapshotPreview1) PathSymlink(m api.Module, oldPath, oldPathLen, fd, newPath, newPathLen uint32) Errno { +func (a *snapshotPreview1) PathSymlink(ctx context.Context, m api.Module, oldPath, oldPathLen, fd, newPath, newPathLen uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PathUnlinkFile is the WASI function named functionPathUnlinkFile -func (a *snapshotPreview1) PathUnlinkFile(m api.Module, fd, path, pathLen uint32) Errno { +func (a *snapshotPreview1) PathUnlinkFile(ctx context.Context, m api.Module, fd, path, pathLen uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // PollOneoff is the WASI function named functionPollOneoff -func (a *snapshotPreview1) PollOneoff(m api.Module, in, out, nsubscriptions, resultNevents uint32) Errno { +func (a *snapshotPreview1) PollOneoff(ctx context.Context, m api.Module, in, out, nsubscriptions, resultNevents uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -1262,12 +1262,12 @@ func (a *snapshotPreview1) PollOneoff(m api.Module, in, out, nsubscriptions, res // // Note: importProcExit shows this signature in the WebAssembly 1.0 (20191205) Text Format. // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit -func (a *snapshotPreview1) ProcExit(m api.Module, exitCode uint32) { - _ = m.CloseWithExitCode(exitCode) +func (a *snapshotPreview1) ProcExit(ctx context.Context, m api.Module, exitCode uint32) { + _ = m.CloseWithExitCode(ctx, exitCode) } // ProcRaise is the WASI function named functionProcRaise -func (a *snapshotPreview1) ProcRaise(m api.Module, sig uint32) Errno { +func (a *snapshotPreview1) ProcRaise(ctx context.Context, m api.Module, sig uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -1276,7 +1276,7 @@ func (a *snapshotPreview1) SchedYield(m api.Module) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } -// RandomGet is the WASI function named functionRandomGet that write random data in buffer (rand.Read()). +// RandomGet is the WASI function named functionRandomGet that write random data in buffer (rand.Read(ctx, )). // // * buf - is the m.Memory offset to write random values // * bufLen - size of random data in bytes @@ -1291,7 +1291,7 @@ func (a *snapshotPreview1) SchedYield(m api.Module) Errno { // // Note: importRandomGet shows this signature in the WebAssembly 1.0 (20191205) Text Format. // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-random_getbuf-pointeru8-bufLen-size---errno -func (a *snapshotPreview1) RandomGet(m api.Module, buf uint32, bufLen uint32) (errno Errno) { +func (a *snapshotPreview1) RandomGet(ctx context.Context, m api.Module, buf uint32, bufLen uint32) (errno Errno) { randomBytes := make([]byte, bufLen) err := a.randSource(randomBytes) if err != nil { @@ -1299,7 +1299,7 @@ func (a *snapshotPreview1) RandomGet(m api.Module, buf uint32, bufLen uint32) (e return ErrnoIo } - if !m.Memory().Write(buf, randomBytes) { + if !m.Memory().Write(ctx, buf, randomBytes) { return ErrnoFault } @@ -1307,17 +1307,17 @@ func (a *snapshotPreview1) RandomGet(m api.Module, buf uint32, bufLen uint32) (e } // SockRecv is the WASI function named functionSockRecv -func (a *snapshotPreview1) SockRecv(m api.Module, fd, riData, riDataCount, riFlags, resultRoDataLen, resultRoFlags uint32) Errno { +func (a *snapshotPreview1) SockRecv(ctx context.Context, m api.Module, fd, riData, riDataCount, riFlags, resultRoDataLen, resultRoFlags uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // SockSend is the WASI function named functionSockSend -func (a *snapshotPreview1) SockSend(m api.Module, fd, siData, siDataCount, siFlags, resultSoDataLen uint32) Errno { +func (a *snapshotPreview1) SockSend(ctx context.Context, m api.Module, fd, siData, siDataCount, siFlags, resultSoDataLen uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } // SockShutdown is the WASI function named functionSockShutdown -func (a *snapshotPreview1) SockShutdown(m api.Module, fd, how uint32) Errno { +func (a *snapshotPreview1) SockShutdown(ctx context.Context, m api.Module, fd, how uint32) Errno { return ErrnoNosys // stubbed for GrainLang per #271 } @@ -1366,20 +1366,20 @@ func openFileEntry(rootFS fs.FS, pathName string) (*wasm.FileEntry, Errno) { return &wasm.FileEntry{Path: pathName, FS: rootFS, File: f}, ErrnoSuccess } -func writeOffsetsAndNullTerminatedValues(mem api.Memory, values []string, offsets, bytes uint32) Errno { +func writeOffsetsAndNullTerminatedValues(ctx context.Context, mem api.Memory, values []string, offsets, bytes uint32) Errno { for _, value := range values { // Write current offset and advance it. - if !mem.WriteUint32Le(offsets, bytes) { + if !mem.WriteUint32Le(ctx, offsets, bytes) { return ErrnoFault } offsets += 4 // size of uint32 // Write the next value to memory with a NUL terminator - if !mem.Write(bytes, []byte(value)) { + if !mem.Write(ctx, bytes, []byte(value)) { return ErrnoFault } bytes += uint32(len(value)) - if !mem.WriteByte(bytes, 0) { + if !mem.WriteByte(ctx, bytes, 0) { return ErrnoFault } bytes++ diff --git a/wasi/wasi_bench_test.go b/wasi/wasi_bench_test.go index 7820a18729..7932bf18c0 100644 --- a/wasi/wasi_bench_test.go +++ b/wasi/wasi_bench_test.go @@ -24,11 +24,11 @@ func Test_EnvironGet(t *testing.T) { sys, err := newSysContext(nil, []string{"a=b", "b=cd"}, nil) require.NoError(t, err) - testCtx := newCtx(make([]byte, 20), sys) + m := newModule(make([]byte, 20), sys) environGet := newSnapshotPreview1().EnvironGet - require.Equal(t, ErrnoSuccess, environGet(testCtx, 11, 1)) - require.Equal(t, testCtx.Memory(), testMem) + require.Equal(t, ErrnoSuccess, environGet(testCtx, m, 11, 1)) + require.Equal(t, m.Memory(), testMem) } func Benchmark_EnvironGet(b *testing.B) { @@ -37,7 +37,7 @@ func Benchmark_EnvironGet(b *testing.B) { b.Fatal(err) } - testCtx := newCtx([]byte{ + m := newModule([]byte{ 0, // environBuf is after this 'a', '=', 'b', 0, // null terminated "a=b", 'b', '=', 'c', 'd', 0, // null terminated "b=cd" @@ -50,14 +50,14 @@ func Benchmark_EnvironGet(b *testing.B) { environGet := newSnapshotPreview1().EnvironGet b.Run("EnvironGet", func(b *testing.B) { for i := 0; i < b.N; i++ { - if environGet(testCtx, 0, 4) != ErrnoSuccess { + if environGet(testCtx, m, 0, 4) != ErrnoSuccess { b.Fatal() } } }) } -func newCtx(buf []byte, sys *wasm.SysContext) *wasm.CallContext { +func newModule(buf []byte, sys *wasm.SysContext) *wasm.CallContext { return wasm.NewCallContext(nil, &wasm.ModuleInstance{ Memory: &wasm.MemoryInstance{Min: 1, Buffer: buf}, }, sys) diff --git a/wasi/wasi_test.go b/wasi/wasi_test.go index 5c6993efda..a05544ae88 100644 --- a/wasi/wasi_test.go +++ b/wasi/wasi_test.go @@ -41,29 +41,29 @@ func TestSnapshotPreview1_ArgsGet(t *testing.T) { } a, mod, fn := instantiateModule(t, functionArgsGet, importArgsGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.ArgsGet", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) // Invoke ArgsGet directly and check the memory side effects. - errno := a.ArgsGet(mod, argv, argvBuf) + errno := a.ArgsGet(testCtx, mod, argv, argvBuf) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionArgsGet, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, uint64(argv), uint64(argvBuf)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -74,9 +74,9 @@ func TestSnapshotPreview1_ArgsGet_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionArgsGet, importArgsGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) validAddress := uint32(0) // arbitrary valid address as arguments to args_get. We chose 0 here. tests := []struct { @@ -112,7 +112,7 @@ func TestSnapshotPreview1_ArgsGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.ArgsGet(mod, tc.argv, tc.argvBuf) + errno := a.ArgsGet(testCtx, mod, tc.argv, tc.argvBuf) require.NoError(t, err) require.Equal(t, ErrnoFault, errno, ErrnoName(errno)) }) @@ -134,29 +134,29 @@ func TestSnapshotPreview1_ArgsSizesGet(t *testing.T) { } a, mod, fn := instantiateModule(t, functionArgsSizesGet, importArgsSizesGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.ArgsSizesGet", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) // Invoke ArgsSizesGet directly and check the memory side effects. - errno := a.ArgsSizesGet(mod, resultArgc, resultArgvBufSize) + errno := a.ArgsSizesGet(testCtx, mod, resultArgc, resultArgvBufSize) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionArgsSizesGet, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, uint64(resultArgc), uint64(resultArgvBufSize)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -167,9 +167,9 @@ func TestSnapshotPreview1_ArgsSizesGet_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionArgsSizesGet, importArgsSizesGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) validAddress := uint32(0) // arbitrary valid address as arguments to args_sizes_get. We chose 0 here. tests := []struct { @@ -203,7 +203,7 @@ func TestSnapshotPreview1_ArgsSizesGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.ArgsSizesGet(mod, tc.argc, tc.argvBufSize) + errno := a.ArgsSizesGet(testCtx, mod, tc.argc, tc.argvBufSize) require.Equal(t, ErrnoFault, errno, ErrnoName(errno)) }) } @@ -226,29 +226,29 @@ func TestSnapshotPreview1_EnvironGet(t *testing.T) { } a, mod, fn := instantiateModule(t, functionEnvironGet, importEnvironGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.EnvironGet", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) // Invoke EnvironGet directly and check the memory side effects. - errno := a.EnvironGet(mod, resultEnviron, resultEnvironBuf) + errno := a.EnvironGet(testCtx, mod, resultEnviron, resultEnvironBuf) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionEnvironGet, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, uint64(resultEnviron), uint64(resultEnvironBuf)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -259,9 +259,9 @@ func TestSnapshotPreview1_EnvironGet_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionEnvironGet, importEnvironGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) validAddress := uint32(0) // arbitrary valid address as arguments to environ_get. We chose 0 here. tests := []struct { @@ -297,7 +297,7 @@ func TestSnapshotPreview1_EnvironGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.EnvironGet(mod, tc.environ, tc.environBuf) + errno := a.EnvironGet(testCtx, mod, tc.environ, tc.environBuf) require.Equal(t, ErrnoFault, errno, ErrnoName(errno)) }) } @@ -318,29 +318,29 @@ func TestSnapshotPreview1_EnvironSizesGet(t *testing.T) { } a, mod, fn := instantiateModule(t, functionEnvironSizesGet, importEnvironSizesGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.EnvironSizesGet", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) // Invoke EnvironSizesGet directly and check the memory side effects. - errno := a.EnvironSizesGet(mod, resultEnvironc, resultEnvironBufSize) + errno := a.EnvironSizesGet(testCtx, mod, resultEnvironc, resultEnvironBufSize) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionEnvironSizesGet, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, uint64(resultEnvironc), uint64(resultEnvironBufSize)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -351,9 +351,9 @@ func TestSnapshotPreview1_EnvironSizesGet_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionEnvironSizesGet, importEnvironSizesGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) validAddress := uint32(0) // arbitrary valid address as arguments to environ_sizes_get. We chose 0 here. tests := []struct { @@ -387,7 +387,7 @@ func TestSnapshotPreview1_EnvironSizesGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.EnvironSizesGet(mod, tc.environc, tc.environBufSize) + errno := a.EnvironSizesGet(testCtx, mod, tc.environc, tc.environBufSize) require.Equal(t, ErrnoFault, errno, ErrnoName(errno)) }) } @@ -396,10 +396,10 @@ func TestSnapshotPreview1_EnvironSizesGet_Errors(t *testing.T) { // TestSnapshotPreview1_ClockResGet only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_ClockResGet(t *testing.T) { a, mod, fn := instantiateModule(t, functionClockResGet, importClockResGet, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.ClockResGet", func(t *testing.T) { - require.Equal(t, ErrnoNosys, a.ClockResGet(mod, 0, 0)) + require.Equal(t, ErrnoNosys, a.ClockResGet(testCtx, mod, 0, 0)) }) t.Run(functionClockResGet, func(t *testing.T) { @@ -420,31 +420,31 @@ func TestSnapshotPreview1_ClockTimeGet(t *testing.T) { } a, mod, fn := instantiateModule(t, functionClockTimeGet, importClockTimeGet, nil) - defer mod.Close() + defer mod.Close(testCtx) a.timeNowUnixNano = func() uint64 { return epochNanos } t.Run("snapshotPreview1.ClockTimeGet", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) // invoke ClockTimeGet directly and check the memory side effects! - errno := a.ClockTimeGet(mod, 0 /* TODO: id */, 0 /* TODO: precision */, resultTimestamp) + errno := a.ClockTimeGet(testCtx, mod, 0 /* TODO: id */, 0 /* TODO: precision */, resultTimestamp) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionClockTimeGet, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(resultTimestamp)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -454,11 +454,11 @@ func TestSnapshotPreview1_ClockTimeGet_Errors(t *testing.T) { epochNanos := uint64(1640995200000000000) // midnight UTC 2022-01-01 a, mod, fn := instantiateModule(t, functionClockTimeGet, importClockTimeGet, nil) - defer mod.Close() + defer mod.Close(testCtx) a.timeNowUnixNano = func() uint64 { return epochNanos } - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) tests := []struct { name string @@ -491,10 +491,10 @@ func TestSnapshotPreview1_ClockTimeGet_Errors(t *testing.T) { // TestSnapshotPreview1_FdAdvise only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdAdvise(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdAdvise, importFdAdvise, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdAdvise", func(t *testing.T) { - errno := a.FdAdvise(mod, 0, 0, 0, 0) + errno := a.FdAdvise(testCtx, mod, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -509,10 +509,10 @@ func TestSnapshotPreview1_FdAdvise(t *testing.T) { // TestSnapshotPreview1_FdAllocate only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdAllocate(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdAllocate, importFdAllocate, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdAllocate", func(t *testing.T) { - errno := a.FdAllocate(mod, 0, 0, 0) + errno := a.FdAllocate(testCtx, mod, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -559,16 +559,16 @@ func TestSnapshotPreview1_FdClose(t *testing.T) { t.Run("snapshotPreview1.FdClose", func(t *testing.T) { mod, _, api := setupFD() - defer mod.Close() + defer mod.Close(testCtx) - errno := api.FdClose(mod, fdToClose) + errno := api.FdClose(testCtx, mod, fdToClose) require.Zero(t, errno, ErrnoName(errno)) verify(mod) }) t.Run(functionFdClose, func(t *testing.T) { mod, fn, _ := setupFD() - defer mod.Close() + defer mod.Close(testCtx) results, err := fn.Call(testCtx, uint64(fdToClose)) require.NoError(t, err) @@ -579,9 +579,9 @@ func TestSnapshotPreview1_FdClose(t *testing.T) { }) t.Run("ErrnoBadF for an invalid FD", func(t *testing.T) { mod, _, api := setupFD() - defer mod.Close() + defer mod.Close(testCtx) - errno := api.FdClose(mod, 42) // 42 is an arbitrary invalid FD + errno := api.FdClose(testCtx, mod, 42) // 42 is an arbitrary invalid FD require.Equal(t, ErrnoBadf, errno) }) } @@ -589,10 +589,10 @@ func TestSnapshotPreview1_FdClose(t *testing.T) { // TestSnapshotPreview1_FdDatasync only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdDatasync(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdDatasync, importFdDatasync, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdDatasync", func(t *testing.T) { - errno := a.FdDatasync(mod, 0) + errno := a.FdDatasync(testCtx, mod, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -613,10 +613,10 @@ func TestSnapshotPreview1_FdFdstatGet(t *testing.T) { // TestSnapshotPreview1_FdFdstatSetFlags only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdFdstatSetFlags(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdFdstatSetFlags, importFdFdstatSetFlags, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdFdstatSetFlags", func(t *testing.T) { - errno := a.FdFdstatSetFlags(mod, 0, 0) + errno := a.FdFdstatSetFlags(testCtx, mod, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -631,10 +631,10 @@ func TestSnapshotPreview1_FdFdstatSetFlags(t *testing.T) { // TestSnapshotPreview1_FdFdstatSetRights only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdFdstatSetRights(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdFdstatSetRights, importFdFdstatSetRights, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdFdstatSetRights", func(t *testing.T) { - errno := a.FdFdstatSetRights(mod, 0, 0, 0) + errno := a.FdFdstatSetRights(testCtx, mod, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -649,10 +649,10 @@ func TestSnapshotPreview1_FdFdstatSetRights(t *testing.T) { // TestSnapshotPreview1_FdFilestatGet only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdFilestatGet(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdFilestatGet, importFdFilestatGet, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdFilestatGet", func(t *testing.T) { - errno := a.FdFilestatGet(mod, 0, 0) + errno := a.FdFilestatGet(testCtx, mod, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -667,10 +667,10 @@ func TestSnapshotPreview1_FdFilestatGet(t *testing.T) { // TestSnapshotPreview1_FdFilestatSetSize only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdFilestatSetSize(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdFilestatSetSize, importFdFilestatSetSize, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdFilestatSetSize", func(t *testing.T) { - errno := a.FdFilestatSetSize(mod, 0, 0) + errno := a.FdFilestatSetSize(testCtx, mod, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -685,10 +685,10 @@ func TestSnapshotPreview1_FdFilestatSetSize(t *testing.T) { // TestSnapshotPreview1_FdFilestatSetTimes only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdFilestatSetTimes(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdFilestatSetTimes, importFdFilestatSetTimes, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdFilestatSetTimes", func(t *testing.T) { - errno := a.FdFilestatSetTimes(mod, 0, 0, 0, 0) + errno := a.FdFilestatSetTimes(testCtx, mod, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -703,10 +703,10 @@ func TestSnapshotPreview1_FdFilestatSetTimes(t *testing.T) { // TestSnapshotPreview1_FdPread only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdPread(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdPread, importFdPread, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdPread", func(t *testing.T) { - errno := a.FdPread(mod, 0, 0, 0, 0, 0) + errno := a.FdPread(testCtx, mod, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -726,7 +726,7 @@ func TestSnapshotPreview1_FdPrestatGet(t *testing.T) { require.NoError(t, err) a, mod, fn := instantiateModule(t, functionFdPrestatGet, importFdPrestatGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) resultPrestat := uint32(1) // arbitrary offset expectedMemory := []byte{ @@ -739,25 +739,25 @@ func TestSnapshotPreview1_FdPrestatGet(t *testing.T) { } t.Run("snapshotPreview1.FdPrestatGet", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) - errno := a.FdPrestatGet(mod, fd, resultPrestat) + errno := a.FdPrestatGet(testCtx, mod, fd, resultPrestat) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionFdPrestatDirName, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, uint64(fd), uint64(resultPrestat)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -771,9 +771,9 @@ func TestSnapshotPreview1_FdPrestatGet_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionFdPrestatGet, importFdPrestatGet, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) tests := []struct { name string @@ -800,7 +800,7 @@ func TestSnapshotPreview1_FdPrestatGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.FdPrestatGet(mod, tc.fd, tc.resultPrestat) + errno := a.FdPrestatGet(testCtx, mod, tc.fd, tc.resultPrestat) require.Equal(t, tc.expectedErrno, errno, ErrnoName(errno)) }) } @@ -813,7 +813,7 @@ func TestSnapshotPreview1_FdPrestatDirName(t *testing.T) { require.NoError(t, err) a, mod, fn := instantiateModule(t, functionFdPrestatDirName, importFdPrestatDirName, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) path := uint32(1) // arbitrary offset pathLen := uint32(3) // shorter than len("/tmp") to test the path is written for the length of pathLen @@ -824,25 +824,25 @@ func TestSnapshotPreview1_FdPrestatDirName(t *testing.T) { } t.Run("snapshotPreview1.FdPrestatDirName", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) - errno := a.FdPrestatDirName(mod, fd, path, pathLen) + errno := a.FdPrestatDirName(testCtx, mod, fd, path, pathLen) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionFdPrestatDirName, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, uint64(fd), uint64(path), uint64(pathLen)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -854,9 +854,9 @@ func TestSnapshotPreview1_FdPrestatDirName_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionFdPrestatDirName, importFdPrestatDirName, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) validAddress := uint32(0) // Arbitrary valid address as arguments to fd_prestat_dir_name. We chose 0 here. pathLen := uint32(len("/tmp")) @@ -902,7 +902,7 @@ func TestSnapshotPreview1_FdPrestatDirName_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.FdPrestatDirName(mod, tc.fd, tc.path, tc.pathLen) + errno := a.FdPrestatDirName(testCtx, mod, tc.fd, tc.path, tc.pathLen) require.Equal(t, tc.expectedErrno, errno, ErrnoName(errno)) }) } @@ -911,10 +911,10 @@ func TestSnapshotPreview1_FdPrestatDirName_Errors(t *testing.T) { // TestSnapshotPreview1_FdPwrite only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdPwrite(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdPwrite, importFdPwrite, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdPwrite", func(t *testing.T) { - errno := a.FdPwrite(mod, 0, 0, 0, 0, 0) + errno := a.FdPwrite(testCtx, mod, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -950,7 +950,7 @@ func TestSnapshotPreview1_FdRead(t *testing.T) { ) // TestSnapshotPreview1_FdRead uses a matrix because setting up test files is complicated and has to be clean each time. - type fdReadFn func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno + type fdReadFn func(ctx context.Context, m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno tests := []struct { name string fdRead func(*snapshotPreview1, api.Module, api.Function) fdReadFn @@ -959,7 +959,7 @@ func TestSnapshotPreview1_FdRead(t *testing.T) { return a.FdRead }}, {functionFdRead, func(_ *snapshotPreview1, mod api.Module, fn api.Function) fdReadFn { - return func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { + return func(ctx context.Context, m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { results, err := fn.Call(testCtx, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) require.NoError(t, err) return Errno(results[0]) @@ -978,17 +978,17 @@ func TestSnapshotPreview1_FdRead(t *testing.T) { require.NoError(t, err) a, mod, fn := instantiateModule(t, functionFdRead, importFdRead, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) - ok := mod.Memory().Write(0, initialMemory) + ok := mod.Memory().Write(testCtx, 0, initialMemory) require.True(t, ok) - errno := tc.fdRead(a, mod, fn)(mod, fd, iovs, iovsCount, resultSize) + errno := tc.fdRead(a, mod, fn)(testCtx, mod, fd, iovs, iovsCount, resultSize) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -1005,7 +1005,7 @@ func TestSnapshotPreview1_FdRead_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionFdRead, importFdRead, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) tests := []struct { name string @@ -1078,10 +1078,10 @@ func TestSnapshotPreview1_FdRead_Errors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { offset := uint32(wasm.MemoryPagesToBytesNum(testMemoryPageSize) - uint64(len(tc.memory))) - memoryWriteOK := mod.Memory().Write(offset, tc.memory) + memoryWriteOK := mod.Memory().Write(testCtx, offset, tc.memory) require.True(t, memoryWriteOK) - errno := a.FdRead(mod, tc.fd, tc.iovs+offset, tc.iovsCount+offset, tc.resultSize+offset) + errno := a.FdRead(testCtx, mod, tc.fd, tc.iovs+offset, tc.iovsCount+offset, tc.resultSize+offset) require.Equal(t, tc.expectedErrno, errno, ErrnoName(errno)) }) } @@ -1090,10 +1090,10 @@ func TestSnapshotPreview1_FdRead_Errors(t *testing.T) { // TestSnapshotPreview1_FdReaddir only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdReaddir(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdReaddir, importFdReaddir, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdReaddir", func(t *testing.T) { - errno := a.FdReaddir(mod, 0, 0, 0, 0, 0) + errno := a.FdReaddir(testCtx, mod, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1108,10 +1108,10 @@ func TestSnapshotPreview1_FdReaddir(t *testing.T) { // TestSnapshotPreview1_FdRenumber only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdRenumber(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdRenumber, importFdRenumber, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdRenumber", func(t *testing.T) { - errno := a.FdRenumber(mod, 0, 0) + errno := a.FdRenumber(testCtx, mod, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1134,10 +1134,10 @@ func TestSnapshotPreview1_FdSeek(t *testing.T) { require.NoError(t, err) a, mod, fn := instantiateModule(t, functionFdSeek, importFdSeek, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) // TestSnapshotPreview1_FdSeek uses a matrix because setting up test files is complicated and has to be clean each time. - type fdSeekFn func(ctx api.Module, fd uint32, offset uint64, whence, resultNewOffset uint32) Errno + type fdSeekFn func(ctx context.Context, m api.Module, fd uint32, offset uint64, whence, resultNewOffset uint32) Errno seekFns := []struct { name string fdSeek func() fdSeekFn @@ -1146,8 +1146,8 @@ func TestSnapshotPreview1_FdSeek(t *testing.T) { return a.FdSeek }}, {functionFdSeek, func() fdSeekFn { - return func(ctx api.Module, fd uint32, offset uint64, whence, resultNewoffset uint32) Errno { - results, err := fn.Call(testCtx, uint64(fd), offset, uint64(whence), uint64(resultNewoffset)) + return func(ctx context.Context, m api.Module, fd uint32, offset uint64, whence, resultNewoffset uint32) Errno { + results, err := fn.Call(ctx, uint64(fd), offset, uint64(whence), uint64(resultNewoffset)) require.NoError(t, err) return Errno(results[0]) } @@ -1202,7 +1202,7 @@ func TestSnapshotPreview1_FdSeek(t *testing.T) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - maskMemory(t, mod, len(tc.expectedMemory)) + maskMemory(t, testCtx, mod, len(tc.expectedMemory)) // Since we initialized this file, we know it is a seeker (because it is a MapFile) f, ok := sysCtx.OpenedFile(fd) @@ -1214,10 +1214,10 @@ func TestSnapshotPreview1_FdSeek(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(1), offset) - errno := sf.fdSeek()(mod, fd, uint64(tc.offset), uint32(tc.whence), resultNewoffset) + errno := sf.fdSeek()(testCtx, mod, fd, uint64(tc.offset), uint32(tc.whence), resultNewoffset) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(tc.expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(tc.expectedMemory))) require.True(t, ok) require.Equal(t, tc.expectedMemory, actual) @@ -1240,9 +1240,9 @@ func TestSnapshotPreview1_FdSeek_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionFdSeek, importFdSeek, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) tests := []struct { name string @@ -1273,7 +1273,7 @@ func TestSnapshotPreview1_FdSeek_Errors(t *testing.T) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.FdSeek(mod, tc.fd, tc.offset, tc.whence, tc.resultNewoffset) + errno := a.FdSeek(testCtx, mod, tc.fd, tc.offset, tc.whence, tc.resultNewoffset) require.Equal(t, tc.expectedErrno, errno, ErrnoName(errno)) }) } @@ -1283,10 +1283,10 @@ func TestSnapshotPreview1_FdSeek_Errors(t *testing.T) { // TestSnapshotPreview1_FdSync only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdSync(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdSync, importFdSync, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdSync", func(t *testing.T) { - errno := a.FdSync(mod, 0) + errno := a.FdSync(testCtx, mod, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1301,10 +1301,10 @@ func TestSnapshotPreview1_FdSync(t *testing.T) { // TestSnapshotPreview1_FdTell only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_FdTell(t *testing.T) { a, mod, fn := instantiateModule(t, functionFdTell, importFdTell, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.FdTell", func(t *testing.T) { - errno := a.FdTell(mod, 0, 0) + errno := a.FdTell(testCtx, mod, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1340,7 +1340,7 @@ func TestSnapshotPreview1_FdWrite(t *testing.T) { ) // TestSnapshotPreview1_FdWrite uses a matrix because setting up test files is complicated and has to be clean each time. - type fdWriteFn func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno + type fdWriteFn func(ctx context.Context, m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno tests := []struct { name string fdWrite func(*snapshotPreview1, api.Module, api.Function) fdWriteFn @@ -1349,8 +1349,8 @@ func TestSnapshotPreview1_FdWrite(t *testing.T) { return a.FdWrite }}, {functionFdWrite, func(_ *snapshotPreview1, mod api.Module, fn api.Function) fdWriteFn { - return func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { - results, err := fn.Call(testCtx, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) + return func(ctx context.Context, m api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { + results, err := fn.Call(ctx, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) require.NoError(t, err) return Errno(results[0]) } @@ -1371,16 +1371,16 @@ func TestSnapshotPreview1_FdWrite(t *testing.T) { require.NoError(t, err) a, mod, fn := instantiateModule(t, functionFdWrite, importFdWrite, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) - maskMemory(t, mod, len(expectedMemory)) - ok := mod.Memory().Write(0, initialMemory) + maskMemory(t, testCtx, mod, len(expectedMemory)) + ok := mod.Memory().Write(testCtx, 0, initialMemory) require.True(t, ok) - errno := tc.fdWrite(a, mod, fn)(mod, fd, iovs, iovsCount, resultSize) + errno := tc.fdWrite(a, mod, fn)(testCtx, mod, fd, iovs, iovsCount, resultSize) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) @@ -1406,7 +1406,7 @@ func TestSnapshotPreview1_FdWrite_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionFdWrite, importFdWrite, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) // Setup valid test memory iovs, iovsCount := uint32(0), uint32(1) @@ -1465,7 +1465,7 @@ func TestSnapshotPreview1_FdWrite_Errors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { mod.Memory().(*wasm.MemoryInstance).Buffer = tc.memory - errno := a.FdWrite(mod, tc.fd, iovs, iovsCount, tc.resultSize) + errno := a.FdWrite(testCtx, mod, tc.fd, iovs, iovsCount, tc.resultSize) require.Equal(t, tc.expectedErrno, errno, ErrnoName(errno)) }) } @@ -1474,10 +1474,10 @@ func TestSnapshotPreview1_FdWrite_Errors(t *testing.T) { // TestSnapshotPreview1_PathCreateDirectory only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathCreateDirectory(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathCreateDirectory, importPathCreateDirectory, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathCreateDirectory", func(t *testing.T) { - errno := a.PathCreateDirectory(mod, 0, 0, 0) + errno := a.PathCreateDirectory(testCtx, mod, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1492,10 +1492,10 @@ func TestSnapshotPreview1_PathCreateDirectory(t *testing.T) { // TestSnapshotPreview1_PathFilestatGet only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathFilestatGet(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathFilestatGet, importPathFilestatGet, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathFilestatGet", func(t *testing.T) { - errno := a.PathFilestatGet(mod, 0, 0, 0, 0, 0) + errno := a.PathFilestatGet(testCtx, mod, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1510,10 +1510,10 @@ func TestSnapshotPreview1_PathFilestatGet(t *testing.T) { // TestSnapshotPreview1_PathFilestatSetTimes only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathFilestatSetTimes(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathFilestatSetTimes, importPathFilestatSetTimes, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathFilestatSetTimes", func(t *testing.T) { - errno := a.PathFilestatSetTimes(mod, 0, 0, 0, 0, 0, 0, 0) + errno := a.PathFilestatSetTimes(testCtx, mod, 0, 0, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1528,10 +1528,10 @@ func TestSnapshotPreview1_PathFilestatSetTimes(t *testing.T) { // TestSnapshotPreview1_PathLink only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathLink(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathLink, importPathLink, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathLink", func(t *testing.T) { - errno := a.PathLink(mod, 0, 0, 0, 0, 0, 0, 0) + errno := a.PathLink(testCtx, mod, 0, 0, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1574,8 +1574,8 @@ func TestSnapshotPreview1_PathOpen(t *testing.T) { }) require.NoError(t, err) a, mod, fn := instantiateModule(t, functionPathOpen, importPathOpen, sysCtx) - maskMemory(t, mod, len(expectedMemory)) - ok := mod.Memory().Write(0, initialMemory) + maskMemory(t, testCtx, mod, len(expectedMemory)) + ok := mod.Memory().Write(testCtx, 0, initialMemory) require.True(t, ok) return a, mod, fn } @@ -1583,7 +1583,7 @@ func TestSnapshotPreview1_PathOpen(t *testing.T) { verify := func(errno Errno, mod api.Module) { require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory))) + actual, ok := mod.Memory().Read(testCtx, 0, uint32(len(expectedMemory))) require.True(t, ok) require.Equal(t, expectedMemory, actual) @@ -1595,7 +1595,7 @@ func TestSnapshotPreview1_PathOpen(t *testing.T) { t.Run("snapshotPreview1.PathOpen", func(t *testing.T) { a, mod, _ := setup() - errno := a.PathOpen(mod, workdirFD, dirflags, path, pathLen, oflags, fsRightsBase, fsRightsInheriting, fdFlags, resultOpenedFd) + errno := a.PathOpen(testCtx, mod, workdirFD, dirflags, path, pathLen, oflags, fsRightsBase, fsRightsInheriting, fdFlags, resultOpenedFd) verify(errno, mod) }) @@ -1619,11 +1619,11 @@ func TestSnapshotPreview1_PathOpen_Errors(t *testing.T) { require.NoError(t, err) a, mod, _ := instantiateModule(t, functionPathOpen, importPathOpen, sysCtx) - defer mod.Close() + defer mod.Close(testCtx) validPath := uint32(0) // arbitrary offset validPathLen := uint32(6) // the length of "wazero" - mod.Memory().Write(validPath, []byte(pathName)) + mod.Memory().Write(testCtx, validPath, []byte(pathName)) tests := []struct { name string @@ -1638,7 +1638,7 @@ func TestSnapshotPreview1_PathOpen_Errors(t *testing.T) { { name: "out-of-memory reading path", fd: validFD, - path: mod.Memory().Size(), + path: mod.Memory().Size(testCtx), pathLen: validPathLen, expectedErrno: ErrnoFault, }, @@ -1646,7 +1646,7 @@ func TestSnapshotPreview1_PathOpen_Errors(t *testing.T) { name: "out-of-memory reading pathLen", fd: validFD, path: validPath, - pathLen: mod.Memory().Size() + 1, // path is in the valid memory range, but pathLen is out-of-memory for path + pathLen: mod.Memory().Size(testCtx) + 1, // path is in the valid memory range, but pathLen is out-of-memory for path expectedErrno: ErrnoFault, }, { @@ -1661,7 +1661,7 @@ func TestSnapshotPreview1_PathOpen_Errors(t *testing.T) { fd: validFD, path: validPath, pathLen: validPathLen, - resultOpenedFd: mod.Memory().Size(), // path and pathLen correctly point to the right path, but where to write the opened FD is outside memory. + resultOpenedFd: mod.Memory().Size(testCtx), // path and pathLen correctly point to the right path, but where to write the opened FD is outside memory. expectedErrno: ErrnoFault, }, } @@ -1669,7 +1669,7 @@ func TestSnapshotPreview1_PathOpen_Errors(t *testing.T) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.PathOpen(mod, tc.fd, 0, tc.path, tc.pathLen, tc.oflags, 0, 0, 0, tc.resultOpenedFd) + errno := a.PathOpen(testCtx, mod, tc.fd, 0, tc.path, tc.pathLen, tc.oflags, 0, 0, 0, tc.resultOpenedFd) require.Equal(t, tc.expectedErrno, errno, ErrnoName(errno)) }) } @@ -1678,10 +1678,10 @@ func TestSnapshotPreview1_PathOpen_Errors(t *testing.T) { // TestSnapshotPreview1_PathReadlink only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathReadlink(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathReadlink, importPathReadlink, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathLink", func(t *testing.T) { - errno := a.PathReadlink(mod, 0, 0, 0, 0, 0, 0) + errno := a.PathReadlink(testCtx, mod, 0, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1696,10 +1696,10 @@ func TestSnapshotPreview1_PathReadlink(t *testing.T) { // TestSnapshotPreview1_PathRemoveDirectory only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathRemoveDirectory(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathRemoveDirectory, importPathRemoveDirectory, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathRemoveDirectory", func(t *testing.T) { - errno := a.PathRemoveDirectory(mod, 0, 0, 0) + errno := a.PathRemoveDirectory(testCtx, mod, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1714,10 +1714,10 @@ func TestSnapshotPreview1_PathRemoveDirectory(t *testing.T) { // TestSnapshotPreview1_PathRename only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathRename(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathRename, importPathRename, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathRename", func(t *testing.T) { - errno := a.PathRename(mod, 0, 0, 0, 0, 0, 0) + errno := a.PathRename(testCtx, mod, 0, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1732,10 +1732,10 @@ func TestSnapshotPreview1_PathRename(t *testing.T) { // TestSnapshotPreview1_PathSymlink only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathSymlink(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathSymlink, importPathSymlink, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathSymlink", func(t *testing.T) { - errno := a.PathSymlink(mod, 0, 0, 0, 0, 0) + errno := a.PathSymlink(testCtx, mod, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1750,10 +1750,10 @@ func TestSnapshotPreview1_PathSymlink(t *testing.T) { // TestSnapshotPreview1_PathUnlinkFile only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PathUnlinkFile(t *testing.T) { a, mod, fn := instantiateModule(t, functionPathUnlinkFile, importPathUnlinkFile, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PathUnlinkFile", func(t *testing.T) { - errno := a.PathUnlinkFile(mod, 0, 0, 0) + errno := a.PathUnlinkFile(testCtx, mod, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1768,10 +1768,10 @@ func TestSnapshotPreview1_PathUnlinkFile(t *testing.T) { // TestSnapshotPreview1_PollOneoff only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_PollOneoff(t *testing.T) { a, mod, fn := instantiateModule(t, functionPollOneoff, importPollOneoff, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.PollOneoff", func(t *testing.T) { - errno := a.PollOneoff(mod, 0, 0, 0, 0) + errno := a.PollOneoff(testCtx, mod, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1805,7 +1805,7 @@ func TestSnapshotPreview1_ProcExit(t *testing.T) { // Note: Unlike most tests, this uses fn, not the 'a' result parameter. This is because currently, this function // body panics, and we expect Call to unwrap the panic. _, mod, fn := instantiateModule(t, functionProcExit, importProcExit, nil) - defer mod.Close() + defer mod.Close(testCtx) // When ProcExit is called, store.Callfunction returns immediately, returning the exit code as the error. _, err := fn.Call(testCtx, uint64(tc.exitCode)) @@ -1817,10 +1817,10 @@ func TestSnapshotPreview1_ProcExit(t *testing.T) { // TestSnapshotPreview1_ProcRaise only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_ProcRaise(t *testing.T) { a, mod, fn := instantiateModule(t, functionProcRaise, importProcRaise, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.ProcRaise", func(t *testing.T) { - errno := a.ProcRaise(mod, 0) + errno := a.ProcRaise(testCtx, mod, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1835,7 +1835,7 @@ func TestSnapshotPreview1_ProcRaise(t *testing.T) { // TestSnapshotPreview1_SchedYield only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_SchedYield(t *testing.T) { a, mod, fn := instantiateModule(t, functionSchedYield, importSchedYield, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.SchedYield", func(t *testing.T) { errno := a.SchedYield(mod) @@ -1862,7 +1862,7 @@ func TestSnapshotPreview1_RandomGet(t *testing.T) { seed := int64(42) // and seed value a, mod, fn := instantiateModule(t, functionRandomGet, importRandomGet, nil) - defer mod.Close() + defer mod.Close(testCtx) a.randSource = func(p []byte) error { s := rand.NewSource(seed) @@ -1873,26 +1873,26 @@ func TestSnapshotPreview1_RandomGet(t *testing.T) { } t.Run("snapshotPreview1.RandomGet", func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) // Invoke RandomGet directly and check the memory side effects! - errno := a.RandomGet(mod, offset, length) + errno := a.RandomGet(testCtx, mod, offset, length) require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, offset+length+1) + actual, ok := mod.Memory().Read(testCtx, 0, offset+length+1) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) t.Run(functionRandomGet, func(t *testing.T) { - maskMemory(t, mod, len(expectedMemory)) + maskMemory(t, testCtx, mod, len(expectedMemory)) results, err := fn.Call(testCtx, uint64(offset), uint64(length)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) - actual, ok := mod.Memory().Read(0, offset+length+1) + actual, ok := mod.Memory().Read(testCtx, 0, offset+length+1) require.True(t, ok) require.Equal(t, expectedMemory, actual) }) @@ -1902,9 +1902,9 @@ func TestSnapshotPreview1_RandomGet_Errors(t *testing.T) { validAddress := uint32(0) // arbitrary valid address a, mod, _ := instantiateModule(t, functionRandomGet, importRandomGet, nil) - defer mod.Close() + defer mod.Close(testCtx) - memorySize := mod.Memory().Size() + memorySize := mod.Memory().Size(testCtx) tests := []struct { name string @@ -1928,7 +1928,7 @@ func TestSnapshotPreview1_RandomGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - errno := a.RandomGet(mod, tc.offset, tc.length) + errno := a.RandomGet(testCtx, mod, tc.offset, tc.length) require.Equal(t, ErrnoFault, errno, ErrnoName(errno)) }) } @@ -1936,23 +1936,23 @@ func TestSnapshotPreview1_RandomGet_Errors(t *testing.T) { func TestSnapshotPreview1_RandomGet_SourceError(t *testing.T) { a, mod, _ := instantiateModule(t, functionRandomGet, importRandomGet, nil) - defer mod.Close() + defer mod.Close(testCtx) a.randSource = func(p []byte) error { return errors.New("random source error") } - errno := a.RandomGet(mod, uint32(1), uint32(5)) // arbitrary offset and length + errno := a.RandomGet(testCtx, mod, uint32(1), uint32(5)) // arbitrary offset and length require.Equal(t, ErrnoIo, errno, ErrnoName(errno)) } // TestSnapshotPreview1_SockRecv only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_SockRecv(t *testing.T) { a, mod, fn := instantiateModule(t, functionSockRecv, importSockRecv, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.SockRecv", func(t *testing.T) { - errno := a.SockRecv(mod, 0, 0, 0, 0, 0, 0) + errno := a.SockRecv(testCtx, mod, 0, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1967,10 +1967,10 @@ func TestSnapshotPreview1_SockRecv(t *testing.T) { // TestSnapshotPreview1_SockSend only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_SockSend(t *testing.T) { a, mod, fn := instantiateModule(t, functionSockSend, importSockSend, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.SockSend", func(t *testing.T) { - errno := a.SockSend(mod, 0, 0, 0, 0, 0) + errno := a.SockSend(testCtx, mod, 0, 0, 0, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -1985,10 +1985,10 @@ func TestSnapshotPreview1_SockSend(t *testing.T) { // TestSnapshotPreview1_SockShutdown only tests it is stubbed for GrainLang per #271 func TestSnapshotPreview1_SockShutdown(t *testing.T) { a, mod, fn := instantiateModule(t, functionSockShutdown, importSockShutdown, nil) - defer mod.Close() + defer mod.Close(testCtx) t.Run("snapshotPreview1.SockShutdown", func(t *testing.T) { - errno := a.SockShutdown(mod, 0, 0) + errno := a.SockShutdown(testCtx, mod, 0, 0) require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) }) @@ -2003,9 +2003,9 @@ func TestSnapshotPreview1_SockShutdown(t *testing.T) { const testMemoryPageSize = 1 // maskMemory sets the first memory in the store to '?' * size, so tests can see what's written. -func maskMemory(t *testing.T, mod api.Module, size int) { +func maskMemory(t *testing.T, ctx context.Context, mod api.Module, size int) { for i := uint32(0); i < uint32(size); i++ { - require.True(t, mod.Memory().WriteByte(i, '?')) + require.True(t, mod.Memory().WriteByte(ctx, i, '?')) } } @@ -2025,7 +2025,7 @@ func instantiateModule(t *testing.T, wasifunction, wasiimport string, sysCtx *wa (export "%[1]s" (func $wasi.%[1]s)) )`, wasifunction, wasiimport))) require.NoError(t, err) - defer compiled.Close() + defer compiled.Close(testCtx) mod, err := r.InstantiateModuleWithConfig(testCtx, compiled, wazero.NewModuleConfig().WithName(t.Name())) require.NoError(t, err) diff --git a/wasm.go b/wasm.go index 62c9094025..d583af9e66 100644 --- a/wasm.go +++ b/wasm.go @@ -45,7 +45,7 @@ type Runtime interface { // * Improve performance when the same module is instantiated multiple times under different names // * Reduce the amount of errors that can occur during InstantiateModule. // - // Note: when `ctx` is nil, it defaults to context.Background. + // Note: When the context is nil, it defaults to context.Background. // Note: The resulting module name defaults to what was binary from the custom name section. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#name-section%E2%91%A0 CompileModule(ctx context.Context, source []byte) (*CompiledCode, error) @@ -58,7 +58,7 @@ type Runtime interface { // module, _ := wazero.NewRuntime().InstantiateModuleFromCode(ctx, source) // defer module.Close() // - // Note: when `ctx` is nil, it defaults to context.Background. + // Note: When the context is nil, it defaults to context.Background. // Note: This is a convenience utility that chains CompileModule with InstantiateModule. To instantiate the same // source multiple times, use CompileModule as InstantiateModule avoids redundant decoding and/or compilation. InstantiateModuleFromCode(ctx context.Context, source []byte) (api.Module, error) @@ -74,7 +74,7 @@ type Runtime interface { // ) // defer wasm.Close() // - // Note: When `ctx` is nil, it defaults to context.Background. + // Note: When the context is nil, it defaults to context.Background. InstantiateModuleFromCodeWithConfig(ctx context.Context, source []byte, config *ModuleConfig) (api.Module, error) // InstantiateModule instantiates the module namespace or errs if the configuration was invalid. @@ -92,7 +92,7 @@ type Runtime interface { // * The module has a table element initializer that resolves to an index outside the Table minimum size. // * The module has a start function, and it failed to execute. // - // Note: When `ctx` is nil, it defaults to context.Background. + // Note: When the context is nil, it defaults to context.Background. InstantiateModule(ctx context.Context, compiled *CompiledCode) (api.Module, error) // InstantiateModuleWithConfig is like InstantiateModule, except you can override configuration such as the module @@ -111,7 +111,7 @@ type Runtime interface { // // Assign different configuration on each instantiation // module, _ := r.InstantiateModuleWithConfig(ctx, compiled, config.WithName("rotate").WithArgs("rotate", "angle=90", "dir=cw")) // - // Note: when `ctx` is nil, it defaults to context.Background. + // Note: When the context is nil, it defaults to context.Background. // Note: Config is copied during instantiation: Later changes to config do not affect the instantiated result. InstantiateModuleWithConfig(ctx context.Context, compiled *CompiledCode, config *ModuleConfig) (mod api.Module, err error) } @@ -188,8 +188,8 @@ func (r *runtime) InstantiateModuleFromCode(ctx context.Context, source []byte) if compiled, err := r.CompileModule(ctx, source); err != nil { return nil, err } else { - // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. - defer compiled.Close() + // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside this function. + defer compiled.Close(ctx) return r.InstantiateModule(ctx, compiled) } } @@ -199,8 +199,8 @@ func (r *runtime) InstantiateModuleFromCodeWithConfig(ctx context.Context, sourc if compiled, err := r.CompileModule(ctx, source); err != nil { return nil, err } else { - // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. - defer compiled.Close() + // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside this function. + defer compiled.Close(ctx) return r.InstantiateModuleWithConfig(ctx, compiled, config) } } diff --git a/wasm_test.go b/wasm_test.go index 915ac609ac..295fe8e21c 100644 --- a/wasm_test.go +++ b/wasm_test.go @@ -59,7 +59,7 @@ func TestRuntime_DecodeModule(t *testing.T) { t.Run(tc.name, func(t *testing.T) { code, err := r.CompileModule(testCtx, tc.source) require.NoError(t, err) - defer code.Close() + defer code.Close(testCtx) if tc.expectedName != "" { require.Equal(t, tc.expectedName, code.module.NameSection.ModuleName) } @@ -156,11 +156,11 @@ func TestModule_Memory(t *testing.T) { // Instantiate the module and get the export of the above memory module, err := tc.builder(r).Instantiate(testCtx) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) mem := module.ExportedMemory("memory") if tc.expected { - require.Equal(t, tc.expectedLen, mem.Size()) + require.Equal(t, tc.expectedLen, mem.Size(testCtx)) } else { require.Nil(t, mem) } @@ -236,20 +236,20 @@ func TestModule_Global(t *testing.T) { // Instantiate the module and get the export of the above global module, err := r.InstantiateModule(testCtx, code) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) global := module.ExportedGlobal("global") if !tc.expected { require.Nil(t, global) return } - require.Equal(t, uint64(globalVal), global.Get()) + require.Equal(t, uint64(globalVal), global.Get(testCtx)) mutable, ok := global.(api.MutableGlobal) require.Equal(t, tc.expectedMutable, ok) if ok { - mutable.Set(2) - require.Equal(t, uint64(2), global.Get()) + mutable.Set(testCtx, 2) + require.Equal(t, uint64(2), global.Get(testCtx)) } }) } @@ -287,12 +287,12 @@ func TestFunction_Context(t *testing.T) { return expectedResult } source, closer := requireImportAndExportFunction(t, r, hostFn, functionName) - defer closer() // nolint + defer closer(testCtx) // nolint // Instantiate the module and get the export of the above hostFn module, err := r.InstantiateModuleFromCodeWithConfig(tc.ctx, source, NewModuleConfig().WithName(t.Name())) require.NoError(t, err) - defer module.Close() + defer module.Close(testCtx) // This fails if the function wasn't invoked, or had an unexpected context. results, err := module.ExportedFunction(functionName).Call(tc.ctx) @@ -316,19 +316,19 @@ func TestRuntime_InstantiateModule_UsesContext(t *testing.T) { ExportFunction("start", start). Instantiate(testCtx) require.NoError(t, err) - defer env.Close() + defer env.Close(testCtx) code, err := r.CompileModule(testCtx, []byte(`(module $runtime_test.go (import "env" "start" (func $start)) (start $start) )`)) require.NoError(t, err) - defer code.Close() + defer code.Close(testCtx) // Instantiate the module, which calls the start function. This will fail if the context wasn't as intended. m, err := r.InstantiateModule(testCtx, code) require.NoError(t, err) - defer m.Close() + defer m.Close(testCtx) require.True(t, calledStart) } @@ -342,7 +342,7 @@ func TestInstantiateModuleFromCode_DoesntEnforce_Start(t *testing.T) { (export "memory" (memory 0)) )`)) require.NoError(t, err) - require.NoError(t, mod.Close()) + require.NoError(t, mod.Close(testCtx)) } func TestRuntime_InstantiateModuleFromCode_UsesContext(t *testing.T) { @@ -359,7 +359,7 @@ func TestRuntime_InstantiateModuleFromCode_UsesContext(t *testing.T) { ExportFunction("start", start). Instantiate(testCtx) require.NoError(t, err) - defer host.Close() + defer host.Close(testCtx) // Start the module as a WASI command. This will fail if the context wasn't as intended. mod, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $start @@ -369,7 +369,7 @@ func TestRuntime_InstantiateModuleFromCode_UsesContext(t *testing.T) { (export "memory" (memory 0)) )`)) require.NoError(t, err) - defer mod.Close() + defer mod.Close(testCtx) require.True(t, calledStart) } @@ -380,7 +380,7 @@ func TestInstantiateModuleWithConfig_WithName(t *testing.T) { r := NewRuntime() base, err := r.CompileModule(testCtx, []byte(`(module $0 (memory 1))`)) require.NoError(t, err) - defer base.Close() + defer base.Close(testCtx) require.Equal(t, "0", base.module.NameSection.ModuleName) @@ -388,14 +388,14 @@ func TestInstantiateModuleWithConfig_WithName(t *testing.T) { internal := r.(*runtime).store m1, err := r.InstantiateModuleWithConfig(testCtx, base, NewModuleConfig().WithName("1")) require.NoError(t, err) - defer m1.Close() + defer m1.Close(testCtx) require.Nil(t, internal.Module("0")) require.Equal(t, internal.Module("1"), m1) m2, err := r.InstantiateModuleWithConfig(testCtx, base, NewModuleConfig().WithName("2")) require.NoError(t, err) - defer m2.Close() + defer m2.Close(testCtx) require.Nil(t, internal.Module("0")) require.Equal(t, internal.Module("2"), m2) @@ -404,8 +404,8 @@ func TestInstantiateModuleWithConfig_WithName(t *testing.T) { func TestInstantiateModuleWithConfig_ExitError(t *testing.T) { r := NewRuntime() - start := func(m api.Module) { - require.NoError(t, m.CloseWithExitCode(2)) + start := func(ctx context.Context, m api.Module) { + require.NoError(t, m.CloseWithExitCode(ctx, 2)) } _, err := r.NewModuleBuilder("env").ExportFunction("_start", start).Instantiate(testCtx) @@ -415,7 +415,7 @@ func TestInstantiateModuleWithConfig_ExitError(t *testing.T) { } // requireImportAndExportFunction re-exports a host function because only host functions can see the propagated context. -func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx context.Context) uint64, functionName string) ([]byte, func() error) { +func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx context.Context) uint64, functionName string) ([]byte, func(context.Context) error) { mod, err := r.NewModuleBuilder("host").ExportFunction(functionName, hostFn).Instantiate(testCtx) require.NoError(t, err) @@ -424,28 +424,6 @@ func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx con )), mod.Close } -func TestCompiledCode_Close(t *testing.T) { - e := &mockEngine{name: "1", cachedModules: map[*wasm.Module]struct{}{}} - - var cs []*CompiledCode - for i := 0; i < 10; i++ { - m := &wasm.Module{} - err := e.CompileModule(testCtx, m) - require.NoError(t, err) - cs = append(cs, &CompiledCode{module: m, compiledEngine: e}) - } - - // Before Close. - require.Equal(t, 10, len(e.cachedModules)) - - for _, c := range cs { - c.Close() - } - - // After Close. - require.Zero(t, len(e.cachedModules)) -} - type mockEngine struct { name string cachedModules map[*wasm.Module]struct{}