Skip to content

Commit

Permalink
Make future and stream element types optional
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewagner committed Jan 16, 2025
1 parent b87bcf4 commit df84d9a
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 9 deletions.
25 changes: 16 additions & 9 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ class BorrowType(ValType):

@dataclass
class StreamType(ValType):
t: ValType
t: Optional[ValType]

@dataclass
class FutureType(ValType):
t: ValType
t: Optional[ValType]

### Lifting and Lowering Context

Expand Down Expand Up @@ -533,9 +533,10 @@ class BufferGuestImpl(Buffer):
length: int

def __init__(self, cx, t, ptr, length):
trap_if(length == 0 or length > Buffer.MAX_LENGTH)
trap_if(ptr != align_to(ptr, alignment(t)))
trap_if(ptr + length * elem_size(t) > len(cx.opts.memory))
if t:
trap_if(length == 0 or length > Buffer.MAX_LENGTH)
trap_if(ptr != align_to(ptr, alignment(t)))
trap_if(ptr + length * elem_size(t) > len(cx.opts.memory))
self.cx = cx
self.t = t
self.ptr = ptr
Expand All @@ -548,16 +549,22 @@ def remain(self):
class ReadableBufferGuestImpl(BufferGuestImpl):
def lift(self, n):
assert(n <= self.remain())
vs = load_list_from_valid_range(self.cx, self.ptr, n, self.t)
self.ptr += n * elem_size(self.t)
if self.t:
vs = load_list_from_valid_range(self.cx, self.ptr, n, self.t)
self.ptr += n * elem_size(self.t)
else:
vs = n * [()]
self.progress += n
return vs

class WritableBufferGuestImpl(BufferGuestImpl, WritableBuffer):
def lower(self, vs):
assert(len(vs) <= self.remain())
store_list_into_valid_range(self.cx, vs, self.ptr, self.t)
self.ptr += len(vs) * elem_size(self.t)
if self.t:
store_list_into_valid_range(self.cx, vs, self.ptr, self.t)
self.ptr += len(vs) * elem_size(self.t)
else:
assert(all(v == () for v in vs))
self.progress += len(vs)

class ReadableStreamGuestImpl(ReadableStream):
Expand Down
86 changes: 86 additions & 0 deletions design/mvp/canonical-abi/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,91 @@ async def core_func2(task, args):
await canon_lift(opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:())


async def test_wasm_to_wasm_stream_empty():
fut1, fut2, fut3, fut4 = asyncio.Future(), asyncio.Future(), asyncio.Future(), asyncio.Future()

inst1 = ComponentInstance()
opts1 = mk_opts(memory=None, sync=False)
ft1 = FuncType([], [StreamType(None)])
async def core_func1(task, args):
assert(not args)
[wsi] = await canon_stream_new(None, task)
[] = await canon_task_return(task, [StreamType(None)], opts1, [wsi])

await task.on_block(fut1)

[ret] = await canon_stream_write(None, opts1, task, wsi, 10000, 2)
assert(ret == 2)
[ret] = await canon_stream_write(None, opts1, task, wsi, 10000, 2)
assert(ret == 2)

await task.on_block(fut2)

[ret] = await canon_stream_write(None, opts1, task, wsi, 0, 8)
assert(ret == definitions.BLOCKED)

fut3.set_result(None)

event, p1, p2 = await task.wait(sync = False)
assert(event == EventCode.STREAM_WRITE)
assert(p1 == wsi)
assert(p2 == 4)

fut4.set_result(None)

[errctxi] = await canon_error_context_new(opts1, task, 0, 0)
[] = await canon_stream_close_writable(None, task, wsi, errctxi)
[] = await canon_error_context_drop(task, errctxi)
return []

func1 = partial(canon_lift, opts1, inst1, ft1, core_func1)

inst2 = ComponentInstance()
heap2 = Heap(10)
mem2 = heap2.memory
opts2 = mk_opts(memory=heap2.memory, realloc=heap2.realloc, sync=False)
ft2 = FuncType([], [])
async def core_func2(task, args):
assert(not args)
[] = await canon_task_return(task, [], opts2, [])

retp = 0
[ret] = await canon_lower(opts2, ft1, func1, task, [retp])
assert(ret == 0)
rsi = mem2[0]
assert(rsi == 1)

[ret] = await canon_stream_read(None, opts2, task, rsi, 0, 8)
assert(ret == definitions.BLOCKED)

fut1.set_result(None)

event, p1, p2 = await task.wait(sync = False)
assert(event == EventCode.STREAM_READ)
assert(p1 == rsi)
assert(p2 == 4)

fut2.set_result(None)
await task.on_block(fut3)

[ret] = await canon_stream_read(None, opts2, task, rsi, 1000000, 2)
assert(ret == 2)
[ret] = await canon_stream_read(None, opts2, task, rsi, 1000000, 2)
assert(ret == 2)

await task.on_block(fut4)

[ret] = await canon_stream_read(None, opts2, task, rsi, 1000000, 2)
errctxi = 1
assert(ret == (definitions.CLOSED | errctxi))
[] = await canon_stream_close_readable(None, task, rsi)
[] = await canon_error_context_debug_message(opts2, task, errctxi, 0)
[] = await canon_error_context_drop(task, errctxi)
return []

await canon_lift(opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:())


async def test_cancel_copy():
inst = ComponentInstance()
mem = bytearray(10)
Expand Down Expand Up @@ -1612,6 +1697,7 @@ async def run_async_tests():
await test_host_partial_reads_writes()
await test_async_stream_ops()
await test_wasm_to_wasm_stream()
await test_wasm_to_wasm_stream_empty()
await test_cancel_copy()
await test_futures()

Expand Down

0 comments on commit df84d9a

Please sign in to comment.