Skip to content

Commit

Permalink
Expose reference types in the Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzgen committed Jul 17, 2020
1 parent 6396ddc commit b8532a4
Show file tree
Hide file tree
Showing 10 changed files with 439 additions and 101 deletions.
14 changes: 3 additions & 11 deletions bindgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def __init__(self):
self.ret += '\n'
self.ret += 'from ctypes import *\n'
self.ret += 'from typing import Any\n'
self.ret += 'from ._ffi import dll, wasm_val_t\n'
self.generated_wasm_ref_t = False
self.ret += 'from ._ffi import dll, wasm_val_t, wasm_ref_t\n'

# Skip all function definitions, we don't bind those
def visit_FuncDef(self, node):
Expand All @@ -31,16 +30,9 @@ def visit_Struct(self, node):
return

# This is hand-generated since it has an anonymous union in it
if node.name == 'wasm_val_t':
if node.name == 'wasm_val_t' or node.name == 'wasm_ref_t':
return

# This is defined twice in the header file, but we only want to insert
# one definition.
if node.name == 'wasm_ref_t':
if self.generated_wasm_ref_t:
return
self.generated_wasm_ref_t = True

self.ret += "\n"
self.ret += "class {}(Structure):\n".format(node.name)
if node.decls:
Expand Down Expand Up @@ -84,7 +76,7 @@ def visit_FuncDecl(self, node):
return
if name == 'wasm_module_deserialize':
return
if 'ref_as_' in name:
if '_ref_as_' in name:
return
if 'extern_const' in name:
return
Expand Down
4 changes: 2 additions & 2 deletions tests/test_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def test_errors(self):
store = Store()
ty = GlobalType(ValType.i32(), True)
with self.assertRaises(TypeError):
Global(store, ty, store) # type: ignore
Global(store, ty, store)
with self.assertRaises(TypeError):
Global(store, 1, Val.i32(1)) # type: ignore
with self.assertRaises(TypeError):
Global(1, ty, Val.i32(1)) # type: ignore

g = Global(store, ty, Val.i32(1))
with self.assertRaises(TypeError):
g.value = g # type: ignore
g.value = g

ty = GlobalType(ValType.i32(), False)
g = Global(store, ty, Val.i32(1))
Expand Down
124 changes: 124 additions & 0 deletions tests/test_refs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import unittest

from wasmtime import *


def ref_types_store():
config = Config()
config.wasm_reference_types = True
engine = Engine(config)
return Store(engine)


def compile_and_instantiate(wat):
store = ref_types_store()
module = Module(store.engine, wat)
return (Instance(store, module, []), store)


class TestExternRef(unittest.TestCase):
def test_smoke(self):
(instance, store) = compile_and_instantiate(
"""
(module
(func (export "f") (param externref) (result externref)
local.get 0
)
(func (export "null_externref") (result externref)
ref.null extern
)
)
"""
)

null_externref = instance.exports.get("null_externref")
self.assertEqual(null_externref(), None)

f = instance.exports.get("f")
externs = [42, True, False, None, "Hello", {"x": 1}, [12, 13, 14], Config()]

for extern in externs:
# We can create an externref for the given extern data.
ref = Val.externref(extern)

# And the externref's value is our extern data.
self.assertEqual(ref.value, extern)

# And we can round trip the externref through Wasm and still get our
# extern data.
result = f(ref)
self.assertEqual(result, extern)

def test_externref_tables(self):
store = ref_types_store()
ty = TableType(ValType.externref(), Limits(10, None))
table = Table(store, ty, "init")

for i in range(0, 10):
self.assertEqual(table[i], "init")

table.grow(2, "grown")

for i in range(0, 10):
self.assertEqual(table[i], "init")
for i in range(10, 12):
self.assertEqual(table[i], "grown")

table[7] = "lucky"

for i in range(0, 7):
self.assertEqual(table[i], "init")
self.assertEqual(table[7], "lucky")
for i in range(8, 10):
self.assertEqual(table[i], "init")
for i in range(10, 12):
self.assertEqual(table[i], "grown")

def test_externref_in_global(self):
store = ref_types_store()
ty = GlobalType(ValType.externref(), True)
g = Global(store, ty, Val.externref("hello"))
self.assertEqual(g.value, "hello")
g.value = "goodbye"
self.assertEqual(g.value, "goodbye")


class TestFuncRef(unittest.TestCase):
def test_smoke(self):
(instance, store) = compile_and_instantiate(
"""
(module
(func (export \"f\") (param funcref) (result funcref)
local.get 0
)
(func (export "null_funcref") (result funcref)
ref.null func
)
)
"""
)

null_funcref = instance.exports.get("null_funcref")
self.assertEqual(null_funcref(), None)

f = instance.exports.get("f")

ty = FuncType([], [ValType.i32()])
g = Func(store, ty, lambda: 42)

# We can create a funcref.
ref_g_val = Val.funcref(g)

# And the funcref's points to `g`.
g2 = ref_g_val.as_funcref()
if isinstance(g2, Func):
self.assertEqual(g2(), 42)
else:
self.fail("g2 is not a funcref: g2 = %r" % g2)

# And we can round trip the funcref through Wasm.
g3 = f(ref_g_val)
if isinstance(g3, Func):
self.assertEqual(g3(), 42)
else:
self.fail("g3 is not a funcref: g3 = %r" % g3)
4 changes: 2 additions & 2 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_new(self):

ty = TableType(ValType.i32(), Limits(1, 2))
store = Store()
with self.assertRaises(WasmtimeError):
with self.assertRaises(TypeError):
Table(store, ty, None)

ty = TableType(ValType.funcref(), Limits(1, 2))
Expand All @@ -35,7 +35,7 @@ def test_grow(self):
with self.assertRaises(TypeError):
table.grow('x', None) # type: ignore
with self.assertRaises(TypeError):
table.grow(2, 'x') # type: ignore
table.grow(2, 'x')

# growth works
table.grow(1, None)
Expand Down
37 changes: 33 additions & 4 deletions wasmtime/_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ctypes import *
from typing import Any
from ._ffi import dll, wasm_val_t
from ._ffi import dll, wasm_val_t, wasm_ref_t

wasm_byte_t = c_ubyte

Expand Down Expand Up @@ -740,9 +740,6 @@ def wasm_exporttype_name(arg0: Any) -> pointer:
def wasm_exporttype_type(arg0: Any) -> pointer:
return _wasm_exporttype_type(arg0) # type: ignore

class wasm_ref_t(Structure):
pass

_wasm_val_delete = dll.wasm_val_delete
_wasm_val_delete.restype = None
_wasm_val_delete.argtypes = [POINTER(wasm_val_t)]
Expand Down Expand Up @@ -1951,6 +1948,18 @@ def wasmtime_func_new(arg0: Any, arg1: Any, callback: Any) -> pointer:
def wasmtime_func_new_with_env(store: Any, type: Any, callback: Any, env: Any, finalizer: Any) -> pointer:
return _wasmtime_func_new_with_env(store, type, callback, env, finalizer) # type: ignore

_wasmtime_func_as_funcref = dll.wasmtime_func_as_funcref
_wasmtime_func_as_funcref.restype = None
_wasmtime_func_as_funcref.argtypes = [POINTER(wasm_func_t), POINTER(wasm_val_t)]
def wasmtime_func_as_funcref(func: Any, funcrefp: Any) -> None:
return _wasmtime_func_as_funcref(func, funcrefp) # type: ignore

_wasmtime_funcref_as_func = dll.wasmtime_funcref_as_func
_wasmtime_funcref_as_func.restype = POINTER(wasm_func_t)
_wasmtime_funcref_as_func.argtypes = [POINTER(wasm_val_t)]
def wasmtime_funcref_as_func(val: Any) -> pointer:
return _wasmtime_funcref_as_func(val) # type: ignore

_wasmtime_caller_export_get = dll.wasmtime_caller_export_get
_wasmtime_caller_export_get.restype = POINTER(wasm_extern_t)
_wasmtime_caller_export_get.argtypes = [POINTER(wasmtime_caller_t), POINTER(wasm_name_t)]
Expand Down Expand Up @@ -2055,3 +2064,23 @@ def wasmtime_funcref_table_set(table: Any, index: Any, value: Any) -> pointer:
_wasmtime_funcref_table_grow.argtypes = [POINTER(wasm_table_t), wasm_table_size_t, POINTER(wasm_func_t), POINTER(wasm_table_size_t)]
def wasmtime_funcref_table_grow(table: Any, delta: Any, init: Any, prev_size: Any) -> pointer:
return _wasmtime_funcref_table_grow(table, delta, init, prev_size) # type: ignore

_wasmtime_externref_new = dll.wasmtime_externref_new
_wasmtime_externref_new.restype = None
_wasmtime_externref_new.argtypes = [c_void_p, POINTER(wasm_val_t)]
def wasmtime_externref_new(data: Any, valp: Any) -> None:
return _wasmtime_externref_new(data, valp) # type: ignore

wasmtime_externref_finalizer_t = CFUNCTYPE(None, c_void_p)

_wasmtime_externref_new_with_finalizer = dll.wasmtime_externref_new_with_finalizer
_wasmtime_externref_new_with_finalizer.restype = None
_wasmtime_externref_new_with_finalizer.argtypes = [c_void_p, wasmtime_externref_finalizer_t, POINTER(wasm_val_t)]
def wasmtime_externref_new_with_finalizer(data: Any, finalizer: Any, valp: Any) -> None:
return _wasmtime_externref_new_with_finalizer(data, finalizer, valp) # type: ignore

_wasmtime_externref_data = dll.wasmtime_externref_data
_wasmtime_externref_data.restype = c_bool
_wasmtime_externref_data.argtypes = [POINTER(wasm_val_t), POINTER(c_void_p)]
def wasmtime_externref_data(val: Any, datap: Any) -> c_bool:
return _wasmtime_externref_data(val, datap) # type: ignore
7 changes: 7 additions & 0 deletions wasmtime/_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import platform
import typing

from wasmtime import WasmtimeError

Expand Down Expand Up @@ -39,18 +40,24 @@
WASM_VAR = c_uint8(1)


class wasm_ref_t(Structure):
pass


class wasm_val_union(Union):
_fields_ = [
("i32", c_int32),
("i64", c_int64),
("f32", c_float),
("f64", c_double),
("ref", POINTER(wasm_ref_t)),
]

i32: int
i64: int
f32: float
f64: float
ref: "typing.Union[pointer[wasm_ref_t], None]"


class wasm_val_t(Structure):
Expand Down
37 changes: 26 additions & 11 deletions wasmtime/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,17 @@ def __call__(self, *params: IntoVal) -> Union[IntoVal, Sequence[IntoVal], None]:

ty = self.type
param_tys = ty.params
if len(params) > len(param_tys):
raise WasmtimeError("too many parameters provided: given %s, expected %s" %
(len(params), len(param_tys)))
if len(params) < len(param_tys):
raise WasmtimeError("too few parameters provided: given %s, expected %s" %
(len(params), len(param_tys)))

param_vals = [Val._convert(ty, params[i]) for i, ty in enumerate(param_tys)]
params_ptr = (ffi.wasm_val_t * len(params))()
for i, param in enumerate(params):
if i >= len(param_tys):
raise WasmtimeError("too many parameters provided")
val = Val._convert(param_tys[i], param)
params_ptr[i] = val._raw
for i, val in enumerate(param_vals):
params_ptr[i] = val._unwrap_raw()

result_tys = ty.results
results_ptr = (ffi.wasm_val_t * len(result_tys))()
Expand All @@ -116,7 +121,7 @@ def __call__(self, *params: IntoVal) -> Union[IntoVal, Sequence[IntoVal], None]:

results = []
for i in range(0, len(result_tys)):
results.append(extract_val(Val(results_ptr[i])))
results.append(Val(results_ptr[i]).value)
if len(results) == 0:
return None
elif len(results) == 1:
Expand Down Expand Up @@ -201,21 +206,31 @@ def invoke(idx, params_ptr, results_ptr, params): # type: ignore

try:
for i in range(0, len(param_tys)):
params.append(extract_val(Val(params_ptr[i])))
params.append(Val._value(params_ptr[i]))
results = func(*params)
if len(result_tys) == 0:
if results is not None:
raise WasmtimeError(
"callback produced results when it shouldn't")
elif len(result_tys) == 1:
val = Val._convert(result_tys[0], results)
results_ptr[0] = val._raw
if isinstance(results, Val):
# Because we are taking the inner value with `_into_raw`, we
# need to ensure that we have a unique `Val`.
val = results._clone()
else:
val = Val._convert(result_tys[0], results)
results_ptr[0] = val._into_raw()
else:
if len(results) != len(result_tys):
raise WasmtimeError("callback produced wrong number of results")
for i, result in enumerate(results):
val = Val._convert(result_tys[i], result)
results_ptr[i] = val._raw
# Because we are taking the inner value with `_into_raw`, we
# need to ensure that we have a unique `Val`.
if isinstance(result, Val):
val = result._clone()
else:
val = Val._convert(result_tys[i], result)
results_ptr[i] = val._into_raw()
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
fmt = traceback.format_exception(exc_type, exc_value, exc_traceback)
Expand Down
4 changes: 2 additions & 2 deletions wasmtime/_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, store: Store, ty: GlobalType, val: IntoVal):
error = ffi.wasmtime_global_new(
store._ptr,
ty._ptr,
byref(val._raw),
byref(val._unwrap_raw()),
byref(ptr))
if error:
raise WasmtimeError._from_ptr(error)
Expand Down Expand Up @@ -61,7 +61,7 @@ def value(self, val: IntoVal) -> None:
Sets the value of this global to a new value
"""
val = Val._convert(self.type.content, val)
error = ffi.wasmtime_global_set(self._ptr, byref(val._raw))
error = ffi.wasmtime_global_set(self._ptr, byref(val._unwrap_raw()))
if error:
raise WasmtimeError._from_ptr(error)

Expand Down
Loading

0 comments on commit b8532a4

Please sign in to comment.