Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental support for snapshot/restore #1808

Merged
merged 13 commits into from
Jan 24, 2024
25 changes: 25 additions & 0 deletions experimental/checkpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package experimental

// Snapshot holds the execution state at the time of a Snapshotter.Snapshot call.
type Snapshot interface {
// Restore sets the Wasm execution state to the capture. Because a host function
// calling this is resetting the pointer to the executation stack, the host function
// will not be able to return values in the normal way. ret is a slice of values the
// host function intends to return from the restored function.
Restore(ret []uint64)
}

// Snapshotter allows host functions to snapshot the WebAssembly execution environment.
type Snapshotter interface {
// Snapshot captures the current execution state.
Snapshot() Snapshot
}

// EnableSnapshotterKey is a context key to indicate that snapshotting should be enabled.
// The context.Context passed to a exported function invocation should have this key set
// to a non-nil value, and host functions will be able to retrieve it using SnapshotterKey.
type EnableSnapshotterKey struct{}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this API should be more akin to:
https://pkg.go.dev/github.com/tetratelabs/wazero@v1.6.0/experimental/sock#WithConfig

So a WithSnapshotter that adds something like an internalEnableSnapshotter to a context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the pattern of function listener when writing this, I could change it to the sock pattern instead though. Any suggestion on what package the internal context key would go in?


// SnapshotterKey is a context key to access a Snapshotter from a host function.
// It is only present if EnableSnapshotter was set in the function invocation context.
type SnapshotterKey struct{}
100 changes: 100 additions & 0 deletions experimental/checkpoint_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package experimental_test

import (
"context"
_ "embed"
"fmt"
"log"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
)

// snapshotWasm was generated by the following:
//
// cd testdata; wat2wasm snapshot.wat
//
//go:embed testdata/snapshot.wasm
var snapshotWasm []byte

type snapshotsKey struct{}

func Example_enableSnapshotterKey() {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx) // This closes everything this Runtime created.

// Enable experimental snapshotting functionality by setting it to context. We use this
// context when invoking functions, indicating to wazero to enable it.
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

// Also place a mutable holder of snapshots to be referenced during restore.
var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)

// Register host functions using snapshot and restore. Generally snapshot is saved
// into a mutable location in context to be referenced during restore.
_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
// Because we set EnableSnapshotterKey to context, this is non-nil.
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()

// Get our mutable snapshots holder to be able to add to it. Our example only calls snapshot
// and restore once but real programs will often call them at multiple layers within a call
// stack with various e.g., try/catch statements.
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)

// Write a value to be passed back to restore. This is meant to be opaque to the guest
// and used to re-reference the snapshot.
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
if !ok {
log.Panicln("failed to write snapshot index")
}

return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
// Read the value written by snapshot to re-reference the snapshot.
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
if !ok {
log.Panicln("failed to read snapshot index")
}

// Get the snapshot
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

// Restore! The invocation of this function will end as soon as we invoke
// Restore, so we also pass in our return value. The guest function run
// will finish with this return value.
snapshot.Restore([]uint64{5})
}).
Export("restore").
Instantiate(ctx)
if err != nil {
log.Panicln(err)
}

mod, err := rt.Instantiate(ctx, snapshotWasm) // Instantiate the actual code
if err != nil {
log.Panicln(err)
}

// Call the guest entrypoint.
res, err := mod.ExportedFunction("run").Call(ctx)
if err != nil {
log.Panicln(err)
}
// We restored and returned the restore value, so it's our result. If restore
// was instead a no-op, we would have returned 10 from normal code flow.
fmt.Println(res[0])
// Output:
// 5
}
121 changes: 121 additions & 0 deletions experimental/checkpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package experimental_test

import (
"context"
"testing"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/internal/testing/require"
)

func TestSnapshotNestedWasmInvocation(t *testing.T) {
mathetake marked this conversation as resolved.
Show resolved Hide resolved
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

sidechannel := 0

_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
defer func() {
sidechannel = 10
}()
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)

_, err := mod.ExportedFunction("restore").Call(ctx, uint64(snapshotPtr))
require.NoError(t, err)

return 2
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)

mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)

var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// return value from restore
require.Equal(t, uint64(12), res[0])
// Host function defers within the call stack work fine
require.Equal(t, 10, sidechannel)
}

func TestSnapshotMultipleWasmInvocations(t *testing.T) {
ctx := context.Background()

rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)

_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)

return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]

snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)

mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)

var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})

snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// snapshot returned zero
require.Equal(t, uint64(0), res[0])

// Fails, snapshot and restore are called from different wasm invocations. Currently, this
// results in a panic.
err = require.CapturePanic(func() {
_, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr)
})
require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming ctx is propagated through a nested invocation, I think it'd be possible to handle this as an error, but I don't know if it's worth it since it would be an error either way but add some tricky handling

"exported function invocation than snapshot")
}
2 changes: 1 addition & 1 deletion experimental/listener_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

// listenerWasm was generated by the following:
//
// cd testdata; wat2wasm --debug-names listener.wat
// cd logging/testdata; wat2wasm --debug-names listener.wat
//
//go:embed logging/testdata/listener.wasm
var listenerWasm []byte
Expand Down
Binary file added experimental/testdata/snapshot.wasm
Binary file not shown.
34 changes: 34 additions & 0 deletions experimental/testdata/snapshot.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
(module
(import "example" "snapshot" (func $snapshot (param i32) (result i32)))
(import "example" "restore" (func $restore (param i32)))

(func $helper (result i32)
(call $restore (i32.const 0))
;; Not executed
i32.const 10
)

(func (export "run") (result i32) (local i32)
(call $snapshot (i32.const 0))
local.set 0
local.get 0
(if (result i32)
(then ;; restore return, finish with the value returned by it
local.get 0
)
(else ;; snapshot return, call heloer
(call $helper)
)
)
)

(func (export "snapshot") (param i32) (result i32)
(call $snapshot (local.get 0))
)

(func (export "restore") (param i32)
(call $restore (local.get 0))
)

(memory (export "memory") 1 1)
)
Loading
Loading