diff --git a/tests/test_memory.py b/tests/test_memory.py index e08f4474..bc00d311 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -53,3 +53,54 @@ def test_large(self): MemoryType(Limits(0x100000000, None)) with self.assertRaises(WasmtimeError): MemoryType(Limits(1, 0x100000000)) + + def test_slices(self): + store = Store() + ty = MemoryType(Limits(1, None)) + memory = Memory(store, ty) + memory.grow(store, 2) + data_ptr = memory.data_ptr(store) + ba = bytearray([i for i in range(200)]) + size_bytes = memory.data_len(store) + # happy cases + offset = 2048 + ba_size = len(ba) + # write with start and ommit stop + memory.write(store, ba, offset) + # check write success byte by byte, whole is asserted with read + self.assertEqual(data_ptr[offset], 0) + self.assertEqual(data_ptr[offset + 1], 1) + self.assertEqual(data_ptr[offset + 199], 199) + # read while and assert whole area + out = memory.read(store, offset, offset + ba_size) + self.assertEqual(ba, out) + self.assertEqual(len(memory.read(store, -10)), 10) + # write with start and stop + memory.write(store, ba, offset + ba_size) + out = memory.read(store, offset + ba_size, offset + ba_size + ba_size) + self.assertEqual(ba, out) + # assert old + self.assertEqual(data_ptr[offset], 0) + self.assertEqual(data_ptr[offset + 1], 1) + self.assertEqual(data_ptr[offset + 199], 199) + # assert new + self.assertEqual(data_ptr[offset + ba_size], 0) + self.assertEqual(data_ptr[offset + ba_size + 1], 1) + self.assertEqual(data_ptr[offset + ba_size + 199], 199) + # edge cases + # empty slices + self.assertEqual(len(memory.read(store, 0, 0)), 0) + self.assertEqual(len(memory.read(store, offset, offset)), 0) + self.assertEqual(len(memory.read(store, offset, offset - 1)), 0) + # out of bound access returns empty array similar to list slice + self.assertEqual(len(memory.read(store, size_bytes + 1)), 0) + # write empty + self.assertEqual(memory.write(store, bytearray(0), offset), 0) + self.assertEqual(memory.write(store, bytearray(b""), offset), 0) + with self.assertRaises(IndexError): + memory.write(store, ba, size_bytes) + with self.assertRaises(IndexError): + memory.write(store, ba, size_bytes - ba_size + 1) + self.assertEqual(memory.write(store, ba, -ba_size), ba_size) + out = memory.read(store, -ba_size) + self.assertEqual(ba, out) diff --git a/wasmtime/_memory.py b/wasmtime/_memory.py index 8aef6715..c07ecbf9 100644 --- a/wasmtime/_memory.py +++ b/wasmtime/_memory.py @@ -1,6 +1,7 @@ from . import _ffi as ffi from ctypes import * import ctypes +import typing from wasmtime import MemoryType, WasmtimeError from ._store import Storelike @@ -62,6 +63,78 @@ def data_ptr(self, store: Storelike) -> "ctypes._Pointer[c_ubyte]": """ return ffi.wasmtime_memory_data(store._context, byref(self._memory)) + def get_buffer_ptr(self, store: Storelike, + size: typing.Optional[int] = None, + offset: int = 0) -> ctypes.Array: + """ + return raw pointer to buffer suitable for creating zero-copy writable NumPy Buffer Protocol + this method is also used internally by `read()` and `write()` + + np_mem = np.frombuffer(memory.get_buffer_ptr(store), dtype=np.uint8) + np_mem[start:end] = A # write + B = np_mem[start:end] # read + """ + if size is None: + size = self.data_len(store) + ptr_type = ctypes.c_ubyte * size + return ptr_type.from_address(ctypes.addressof(self.data_ptr(store).contents) + offset) + + def read( + self, + store: Storelike, + start: typing.Optional[int] = 0, + stop: typing.Optional[int] = None) -> bytearray: + """ + Reads this memory starting from `start` and up to `stop` + and returns a copy of the contents as a `bytearray`. + + The indexing behavior of this method is similar to `list[start:stop]` + where negative starts can be used to read from the end, for example. + """ + size = self.data_len(store) + key = slice(start, stop, None) + start, stop, _ = key.indices(size) + val_size = stop - start + if val_size <= 0: + # return bytearray of size zero + return bytearray(0) + src_ptr = self.get_buffer_ptr(store, val_size, start) + return bytearray(src_ptr) + + def write( + self, + store: Storelike, + value: typing.Union[bytearray, bytes], + start: typing.Optional[int] = None) -> int: + """ + write a bytearray value into a possibly large slice of memory + negative start is allowed in a way similat to list slice mylist[-10:] + if value is not bytearray it will be used to construct an intermediate bytearray (copyied twice) + return number of bytes written + raises IndexError when trying to write outside the memory range + this happens when start offset is >= size or when end side of value is >= size + """ + size = self.data_len(store) + key = slice(start, None) + start = key.indices(size)[0] + if start >= size: + raise IndexError("index out of range") + # value must be bytearray ex. cast bytes() to bytearray + if not isinstance(value, bytearray): + value = bytearray(value) + val_size = len(value) + if val_size == 0: + return val_size + # stop is exclusive + stop = start + val_size + if stop > size: + raise IndexError("index out of range") + ptr_type = ctypes.c_ubyte * val_size + src_ptr = ptr_type.from_buffer(value) + dst_ptr = self.get_buffer_ptr(store, val_size, start) + ctypes.memmove(dst_ptr, src_ptr, val_size) + return val_size + def data_len(self, store: Storelike) -> int: """ Returns the raw byte length of this memory.