diff --git a/design/mvp/Async.md b/design/mvp/Async.md index 2a44f8c5..7799c1fb 100644 --- a/design/mvp/Async.md +++ b/design/mvp/Async.md @@ -419,21 +419,25 @@ For now, this remains a [TODO](#todo) and validation will reject `async`-lifted ## TODO -Native async support is being proposed in progressive chunks. The following -features will be added in future chunks to complete "async" in Preview 3: -* `future`/`stream`/`error`: add for use in function types for finer-grained - concurrency -* `subtask.cancel`: allow a supertask to signal to a subtask that its result is - no longer wanted and to please wrap it up promptly -* allow "tail-calling" a subtask so that the current wasm instance can be torn - down eagerly -* `task.index`+`task.wake`: allow tasks in the same instance to wait on and - wake each other (async condvar-style) +Native async support is being proposed incrementally. The following features +will be added in future chunks roughly in the order list to complete the full +"async" story: +* add `future` type +* add `error` type that can be included when closing a stream/future * `nonblocking` function type attribute: allow a function to declare in its type that it will not transitively do anything blocking +* define what `async` means for `start` functions +* `task.index`+`task.wake`: allow tasks in the same instance to wait on and + wake each other +* `subtask.cancel`: allow a supertask to signal to a subtask that its result is + no longer wanted and to please wrap it up promptly +* `stream.lull` built-in that says "no more elements are coming for a while" * `recursive` function type attribute: allow a function to be reentered - recursively (instead of trapping) -* enable `async` `start` functions + recursively (instead of trapping) for the benefit of donut wrapping +* built-in to "tail-call" a subtask so that the current wasm instance can be torn + down eagerly while preserving "structured concurrency" +* allow pipelining multiple `stream.read`/`write` calls +* allow chaining multiple async calls together ("promise pipelining") * integrate with `shared`: define how to lift and lower functions `async` *and* `shared` diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md index b7497d2f..b575a430 100644 --- a/design/mvp/CanonicalABI.md +++ b/design/mvp/CanonicalABI.md @@ -2126,6 +2126,7 @@ where `$callee` has type `$ft`, validation specifies: * a `memory` is present if required by lifting and is a subtype of `(memory 1)` * a `realloc` is present if required by lifting and has type `(func (param i32 i32 i32 i32) (result i32))` * there is no `post-return` in `$opts` +* if `contains_async($ft)`, then `$opts.async` must be set When instantiating component instance `$inst`: * Define `$f` to be the partially-bound closure: `canon_lower($opts, $ft, $callee)` diff --git a/design/mvp/Explainer.md b/design/mvp/Explainer.md index 46c418a3..49173d6b 100644 --- a/design/mvp/Explainer.md +++ b/design/mvp/Explainer.md @@ -1222,10 +1222,12 @@ freeing memory after lifting and thus `post-return` may be deprecated in the future. 🔀 The `async` option specifies that the component wants to make (for imports) -or support (for exports) multiple concurrent (asynchronous) calls. This option -can be applied to any component-level function type and changes the derived -Canonical ABI significantly. See the [async explainer](Async.md) for more -details. +or support (for exports) multiple concurrent (asynchronous) calls. This +option can be applied to any component-level function type and changes the +derived Canonical ABI significantly. See the [async explainer](Async.md) for +more details. When a function signature contains a `future` or `stream`, +validation requires the `async` option to be set (since a synchronous call to +a function using these types is likely to deadlock). 🔀 The `(callback ...)` option may only be present in `canon lift` when the `async` option has also been set and specifies a core function that is diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py index 2ff543f8..8acfe856 100644 --- a/design/mvp/canonical-abi/definitions.py +++ b/design/mvp/canonical-abi/definitions.py @@ -171,6 +171,10 @@ class OwnType(ValType): class BorrowType(ValType): rt: ResourceType +@dataclass +class StreamType(ValType): + t: ValType + ### CallContext class CallContext: @@ -199,7 +203,7 @@ class CanonicalOptions: class ComponentInstance: resources: ResourceTables - async_subtasks: Table[Subtask] + waitables: Table[Subtask|StreamHandle] num_tasks: int may_leave: bool backpressure: bool @@ -208,7 +212,7 @@ class ComponentInstance: def __init__(self): self.resources = ResourceTables() - self.async_subtasks = Table[Subtask]() + self.waitables = Table[Subtask|StreamHandle]() self.num_tasks = 0 self.may_leave = True self.backpressure = False @@ -251,6 +255,8 @@ class Table(Generic[ElemT]): array: list[Optional[ElemT]] free: list[int] + MAX_LENGTH = 2**30 - 1 + def __init__(self): self.array = [None] self.free = [] @@ -267,7 +273,7 @@ def add(self, e): self.array[i] = e else: i = len(self.array) - trap_if(i >= 2**30) + trap_if(i > Table.MAX_LENGTH) self.array.append(e) return i @@ -301,8 +307,10 @@ class EventCode(IntEnum): CALL_RETURNED = CallState.RETURNED CALL_DONE = CallState.DONE YIELDED = 4 + STREAM_READ = 5 + STREAM_WRITE = 6 -EventTuple = tuple[EventCode, int] +EventTuple = tuple[EventCode, int, int] EventCallback = Callable[[], EventTuple] OnBlockCallback = Callable[[Awaitable], any] @@ -489,10 +497,10 @@ def maybe_notify_supertask(self): self.enqueued = True def subtask_event(): self.enqueued = False - i = self.inst.async_subtasks.array.index(self) + i = self.inst.waitables.array.index(self) if self.state == CallState.DONE: self.release_lenders() - return (EventCode(self.state), i) + return (EventCode(self.state), i, 0) self.task.notify(subtask_event) def on_start(self): @@ -525,7 +533,165 @@ def drop(self): trap_if(self.state != CallState.DONE) self.task.need_to_drop -= 1 -### Despecialization +class Buffer: + MAX_LENGTH = 2**30 - 1 + + def __init__(self, cx, t, ptr, length): + trap_if(length == 0 or length > Buffer.MAX_LENGTH) + trap_if(ptr != align_to(ptr, alignment(t))) + trap_if(ptr + length * elem_size(t) > len(cx.opts.memory)) + self._cx = cx + self._t = t + self._begin = ptr + self._length = length + self._progress = 0 + + def remain(self): + assert(self._progress <= self._length) + return self._length - self._progress + +class ReadableBuffer(Buffer): + def lift(self, n): + assert(n <= self.remain()) + ptr = self._begin + self._progress * elem_size(self._t) + vs = load_list_from_valid_range(self._cx, ptr, n, self._t) + self._progress += n + return vs + +class WritableBuffer(Buffer): + def lower(self, vs): + assert(len(vs) <= self.remain()) + ptr = self._begin + self._progress * elem_size(self._t) + store_list_into_valid_range(self._cx, vs, ptr, self._t) + self._progress += len(vs) + +class Stream: + def closed(self): + return self._closed + + def close(self): + if not self._closed: + self._closed = True + self._on_close() + + # precondition: !closed + read: Callable[[WritableBuffer, OnBlockCallback], Awaitable] + cancel_read: Callable[[WritableBuffer, OnBlockCallback], Awaitable] + maybe_writer_handle_index: Callable[[ComponentInstance], Optional[int]] + + def __init__(self, impl): + self._closed = False + self._on_close = impl.on_close + self.read = impl.read + self.cancel_read = impl.cancel_read + self.maybe_writer_handle_index = impl.maybe_writer_handle_index + +class StreamHandle: + stream: Stream + t: ValType + cx: Optional[CallContext] + copying_buffer: Optional[Buffer] + + def __init__(self, stream, t, cx): + self.stream = stream + self.t = t + self.cx = cx + self.copying_buffer = None + + def drop(self): + trap_if(self.copying_buffer) + self.stream.close() + if self.cx: + self.cx.task.need_to_drop -= 1 + +class ReadableStreamHandle(StreamHandle): + async def copy(self, dst, on_block): + await self.stream.read(dst, on_block) + async def cancel_copy(self, dst, on_block): + await self.stream.cancel_read(dst, on_block) + +class WritableStreamHandle(StreamHandle): + waiting_buffer: Optional[Buffer] + waiting_future: Optional[asyncio.Future[Optional[asyncio.Future]]] + + def __init__(self, t): + super().__init__(Stream(self), t, cx = None) + self.waiting_buffer = None + self.waiting_future = None + + async def copy(self, src, on_block): + await self.rendezvous('write', src, on_block) + async def read(self, dst, on_block): + await self.rendezvous('read', dst, on_block) + async def rendezvous(self, direction, buffer, on_block): + assert(not self.stream.closed()) + if self.waiting_buffer: + ncopy = min(buffer.remain(), self.waiting_buffer.remain()) + assert(ncopy > 0) + match direction: + case 'read': buffer.lower(self.waiting_buffer.lift(ncopy)) + case 'write': self.waiting_buffer.lower(buffer.lift(ncopy)) + if not self.waiting_buffer.remain(): + self.waiting_buffer = None + self.unblock_waiting() + else: + assert(not (self.waiting_buffer or self.waiting_future)) + self.waiting_buffer = buffer + self.waiting_future = asyncio.Future[Optional[asyncio.Future]]() + await on_block(self.waiting_future) + if self.waiting_buffer is buffer: + self.waiting_buffer = None + + def unblock_waiting(self): + if self.waiting_future: + self.waiting_future.set_result(None) + self.waiting_future = None + + async def cancel_copy(self, src, on_block): + await self.cancel_rendezvous('write', src, on_block) + async def cancel_read(self, dst, on_block): + await self.cancel_rendezvous('read', dst, on_block) + async def cancel_rendezvous(self, direction, buffer, on_block): + assert(not self.stream.closed()) + if self.waiting_buffer is buffer: + self.waiting_buffer = None + self.unblock_waiting() + + def on_close(self): + assert(self.stream.closed()) + self.unblock_waiting() + + def maybe_writer_handle_index(self, inst): + assert(not self.stream.closed()) + if inst is self.cx.inst: + return self.cx.task.inst.waitables.array.index(self) + return None + +### Type utilities + +def contains_async(t): + match t: + case StreamType(): + return True + case PrimValType() | OwnType() | BorrowType(): + return False + case FuncType(): + return any(contains_async(t) for t in t.param_types()) or \ + any(contains_async(t) for t in t.result_types()) + case ListType(t): + return contains_async(t) + case RecordType(fs): + return any(contains_async(f.t) for f in fs) + case TupleType(ts): + return any(contains_async(t) for t in ts) + case VariantType(cs): + return any(contains_async(c.t) for c in cs) + case OptionType(t): + return contains_async(t) + case ResultType(o,e): + return contains_async(o) or contains_async(e) + case _: + assert(False) def despecialize(t): match t: @@ -554,6 +720,7 @@ def alignment(t): case FlagsType(labels) : return alignment_flags(labels) case OwnType() : return 4 case BorrowType() : return 4 + case StreamType() : return 4 def alignment_list(elem_type, maybe_length): if maybe_length is not None: @@ -611,6 +778,7 @@ def elem_size(t): case FlagsType(labels) : return elem_size_flags(labels) case OwnType() : return 4 case BorrowType() : return 4 + case StreamType() : return 4 def elem_size_list(elem_type, maybe_length): if maybe_length is not None: @@ -670,6 +838,7 @@ def load(cx, ptr, t): case FlagsType(labels) : return load_flags(cx, ptr, labels) case OwnType() : return lift_own(cx, load_int(cx, ptr, 4), t) case BorrowType() : return lift_borrow(cx, load_int(cx, ptr, 4), t) + case StreamType(t) : return lift_stream(cx, load_int(cx, ptr, 4), t) def load_int(cx, ptr, nbytes, signed = False): return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed) @@ -823,6 +992,22 @@ def lift_borrow(cx, i, t): cx.add_lender(h) return h.rep +def lift_stream(cx, i, elem_type): + h = cx.inst.waitables.get(i) + trap_if(not isinstance(h, StreamHandle)) + trap_if(h.t != elem_type) + match h: + case ReadableStreamHandle(): + trap_if(h.copying_buffer) + h.cx.task.need_to_drop -= 1 + cx.inst.waitables.remove(i) + case WritableStreamHandle(): + trap_if(h.cx is not None) + assert(not h.copying_buffer) + h.cx = cx + h.cx.task.need_to_drop += 1 + return h.stream + ### Storing def store(cx, v, t, ptr): @@ -848,6 +1033,7 @@ def store(cx, v, t, ptr): case FlagsType(labels) : store_flags(cx, v, ptr, labels) case OwnType() : store_int(cx, lower_own(cx, v, t), ptr, 4) case BorrowType() : store_int(cx, lower_borrow(cx, v, t), ptr, 4) + case StreamType(t) : store_int(cx, lower_stream(cx, v, t), ptr, 4) def store_int(cx, v, ptr, nbytes, signed = False): cx.opts.memory[ptr : ptr+nbytes] = int.to_bytes(v, nbytes, 'little', signed=signed) @@ -1110,6 +1296,20 @@ def lower_borrow(cx, rep, t): cx.need_to_drop += 1 return cx.inst.resources.add(t.rt, h) +def lower_stream(cx, stream, elem_type): + assert(isinstance(stream, Stream)) + if (i := stream.maybe_writer_handle_index(cx.inst)): + h = cx.inst.waitables.array[i] + assert(isinstance(h, WritableStreamHandle)) + h.cx.task.need_to_drop -= 1 + h.cx = None + assert(2**31 > Table.MAX_LENGTH >= i) + return i | (2**31) + else: + h = ReadableStreamHandle(stream, elem_type, cx) + cx.task.need_to_drop += 1 + return cx.inst.waitables.add(h) + ### Flattening MAX_FLAT_PARAMS = 16 @@ -1161,6 +1361,7 @@ def flatten_type(t): case FlagsType(labels) : return ['i32'] case OwnType() : return ['i32'] case BorrowType() : return ['i32'] + case StreamType() : return ['i32'] def flatten_list(elem_type, maybe_length): if maybe_length is not None: @@ -1227,6 +1428,7 @@ def lift_flat(cx, vi, t): case FlagsType(labels) : return lift_flat_flags(vi, labels) case OwnType() : return lift_own(cx, vi.next('i32'), t) case BorrowType() : return lift_borrow(cx, vi.next('i32'), t) + case StreamType(t) : return lift_stream(cx, vi.next('i32'), t) def lift_flat_unsigned(vi, core_width, t_width): i = vi.next('i' + str(core_width)) @@ -1318,6 +1520,7 @@ def lower_flat(cx, v, t): case FlagsType(labels) : return lower_flat_flags(v, labels) case OwnType() : return [lower_own(cx, v, t)] case BorrowType() : return [lower_borrow(cx, v, t)] + case StreamType(t) : return [lower_stream(cx, v, t)] def lower_flat_signed(i, core_bits): if i < 0: @@ -1432,10 +1635,10 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_blo ctx = packed_ctx & ~1 if is_yield: await task.yield_() - event, payload = (EventCode.YIELDED, 0) + event, p1, p2 = (EventCode.YIELDED, 0, 0) else: - event, payload = await task.wait() - [packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, payload]) + event, p1, p2 = await task.wait() + [packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, p1, p2]) task.exit() async def call_and_trap_on_throw(callee, task, args): @@ -1450,6 +1653,7 @@ async def canon_lower(opts, ft, callee, task, flat_args): trap_if(not task.inst.may_leave) subtask = Subtask(opts, ft, task, flat_args) if opts.sync: + assert(not contains_async(ft)) await task.call_sync(callee, task, subtask.on_start, subtask.on_return) flat_results = subtask.finish() else: @@ -1460,7 +1664,7 @@ async def do_call(on_block): case Blocked(): subtask.notify_supertask = True task.need_to_drop += 1 - i = task.inst.async_subtasks.add(subtask) + i = task.inst.waitables.add(subtask) flat_results = [pack_async_result(i, subtask.state)] case _: flat_results = [0] @@ -1532,8 +1736,9 @@ async def canon_task_return(task, core_ft, flat_args): async def canon_task_wait(task, ptr): trap_if(not task.inst.may_leave) trap_if(task.opts.callback is not None) - event, payload = await task.wait() - store(task, payload, U32Type(), ptr) + event, p1, p2 = await task.wait() + store(task, p1, U32Type(), ptr) + store(task, p2, U32Type(), ptr + 4) return [event] ### 🔀 `canon task.poll` @@ -1543,7 +1748,7 @@ async def canon_task_poll(task, ptr): ret = task.poll() if ret is None: return [0] - store(task, ret, TupleType([U32Type(), U32Type()]), ptr) + store(task, ret, TupleType([U32Type(), U32Type(), U32Type()]), ptr) return [1] ### 🔀 `canon task.yield` @@ -1554,9 +1759,87 @@ async def canon_task_yield(task): await task.yield_() return [] -### 🔀 `canon subtask.drop` +### 🔀 `canon stream.new` + +async def canon_stream_new(elem_type, task): + trap_if(not task.inst.may_leave) + h = WritableStreamHandle(elem_type) + return [ task.inst.waitables.add(h) ] + +### 🔀 `canon stream.read` and `canon stream.write` + +async def canon_stream_read(task, i, ptr, n): + return await stream_copy(ReadableStreamHandle, WritableBuffer, + task, i, ptr, n, EventCode.STREAM_READ) + +async def canon_stream_write(task, i, ptr, n): + return await stream_copy(WritableStreamHandle, ReadableBuffer, + task, i, ptr, n, EventCode.STREAM_WRITE) + +async def stream_copy(StreamHandleT, BufferT, task, i, ptr, n, event_code): + trap_if(not task.inst.may_leave) + h = task.inst.waitables.get(i) + trap_if(not isinstance(h, StreamHandleT)) + trap_if(not h.cx) + trap_if(h.copying_buffer) + buffer = BufferT(h.cx, h.t, ptr, n) + if h.stream.closed(): + flat_results = [CLOSED] + else: + async def do_copy(on_block): + await h.copy(buffer, on_block) + if h.copying_buffer: + def stream_event(): + h.copying_buffer = None + return (event_code, i, pack_stream_result(buffer, h)) + h.cx.task.notify(stream_event) + match await call_and_handle_blocking(do_copy): + case Blocked(): + h.copying_buffer = buffer + flat_results = [BLOCKED] + case _: + flat_results = [pack_stream_result(buffer, h)] + return flat_results + +def pack_stream_result(buffer, handle): + assert(buffer._progress <= PROGRESS_BITS) + return buffer._progress | (CLOSED if handle.stream.closed() else 0) + +BLOCKED = 0xffff_ffff +CLOSED = 0x8000_0000 +PROGRESS_BITS = 0x3fff_ffff + +assert(BLOCKED != (CLOSED | PROGRESS_BITS)) + +### 🔀 `canon stream.cancel-read` and `canon stream.cancel-writing` + +async def canon_stream_cancel_read(sync, task, i): + return await stream_cancel_copy(ReadableStreamHandle, sync, task, i) + +async def canon_stream_cancel_write(sync, task, i): + return await stream_cancel_copy(WritableStreamHandle, sync, task, i) + +async def stream_cancel_copy(StreamHandleT, sync, task, i): + trap_if(not task.inst.may_leave) + h = self.inst.waitables.get(i) + trap_if(not isinstance(h, StreamHandleT)) + trap_if(not h.copying_buffer) + if sync: + await task.call_sync(h.cancel_copy, h.copying_buffer) + h.copying_buffer = None + flat_results = [pack_stream_result(buffer, h)] + else: + match await call_and_handle_blocking(h.cancel_copy, h.copying_buffer): + case Blocked(): + flat_results = [BLOCKED] + case _: + flat_results = [pack_stream_result(buffer, h)] + h.copying_buffer = None + return flat_results + +### 🔀 `canon waitable.drop` -async def canon_subtask_drop(task, i): +async def canon_waitable_drop(task, i): trap_if(not task.inst.may_leave) - task.inst.async_subtasks.remove(i).drop() + task.inst.waitables.remove(i).drop() return [] diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py index cc0fb688..7b05f3f7 100644 --- a/design/mvp/canonical-abi/run_tests.py +++ b/design/mvp/canonical-abi/run_tests.py @@ -34,13 +34,14 @@ def realloc(self, original_ptr, original_size, alignment, new_size): self.memory[ret : ret + original_size] = self.memory[original_ptr : original_ptr + original_size] return ret -def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None): +def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None, sync_task_return = False, sync = True): opts = CanonicalOptions() opts.memory = memory opts.string_encoding = encoding opts.realloc = realloc opts.post_return = post_return - opts.sync = True + opts.sync_task_return = sync_task_return + opts.sync = sync opts.callback = None return opts @@ -361,56 +362,59 @@ def test_flatten(t, params, results): test_flatten(FuncType([U8Type() for _ in range(17)],[]), ['i32' for _ in range(17)], []) test_flatten(FuncType([U8Type() for _ in range(17)],[TupleType([U8Type(),U8Type()])]), ['i32' for _ in range(17)], ['i32','i32']) -def test_roundtrip(t, v): - before = definitions.MAX_FLAT_RESULTS - definitions.MAX_FLAT_RESULTS = 16 - ft = FuncType([t],[t]) - async def callee(task, x): - return x +async def test_roundtrips(): + async def test_roundtrip(t, v): + before = definitions.MAX_FLAT_RESULTS + definitions.MAX_FLAT_RESULTS = 16 - callee_heap = Heap(1000) - callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc) - callee_inst = ComponentInstance() - lifted_callee = partial(canon_lift, callee_opts, callee_inst, ft, callee) + ft = FuncType([t],[t]) + async def callee(task, x): + return x - caller_heap = Heap(1000) - caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc) - caller_inst = ComponentInstance() - caller_task = Task(caller_opts, caller_inst, ft, None, None, None) + callee_heap = Heap(1000) + callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc) + callee_inst = ComponentInstance() + lifted_callee = partial(canon_lift, callee_opts, callee_inst, ft, callee) - flat_args = asyncio.run(caller_task.enter(lambda: [v])) + caller_heap = Heap(1000) + caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc) + caller_inst = ComponentInstance() + caller_task = Task(caller_opts, caller_inst, ft, None, None, None) - return_in_heap = len(flatten_types([t])) > definitions.MAX_FLAT_RESULTS - if return_in_heap: - flat_args += [ caller_heap.realloc(0, 0, alignment(t), elem_size(t)) ] + flat_args = await caller_task.enter(lambda: [v]) - flat_results = asyncio.run(canon_lower(caller_opts, ft, lifted_callee, caller_task, flat_args)) + return_in_heap = len(flatten_types([t])) > definitions.MAX_FLAT_RESULTS + if return_in_heap: + flat_args += [ caller_heap.realloc(0, 0, alignment(t), elem_size(t)) ] - if return_in_heap: - flat_results = [ flat_args[-1] ] + flat_results = await canon_lower(caller_opts, ft, lifted_callee, caller_task, flat_args) - [got] = lift_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t]) - caller_task.exit() + if return_in_heap: + flat_results = [ flat_args[-1] ] - if got != v: - fail("test_roundtrip({},{}) got {}".format(t, v, got)) + [got] = lift_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t]) + caller_task.exit() - definitions.MAX_FLAT_RESULTS = before + if got != v: + fail("test_roundtrip({},{}) got {}".format(t, v, got)) -test_roundtrip(S8Type(), -1) -test_roundtrip(TupleType([U16Type(),U16Type()]), mk_tup(3,4)) -test_roundtrip(ListType(StringType()), [mk_str("hello there")]) -test_roundtrip(ListType(ListType(StringType())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) -test_roundtrip(ListType(OptionType(TupleType([StringType(),U16Type()]))), [{'some':mk_tup(mk_str("answer"),42)}]) -test_roundtrip(VariantType([CaseType('x', TupleType([U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - U32Type(),U32Type(),U32Type(),U32Type(), - StringType()]))]), - {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}) - -def test_handles(): + definitions.MAX_FLAT_RESULTS = before + + await test_roundtrip(S8Type(), -1) + await test_roundtrip(TupleType([U16Type(),U16Type()]), mk_tup(3,4)) + await test_roundtrip(ListType(StringType()), [mk_str("hello there")]) + await test_roundtrip(ListType(ListType(StringType())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) + await test_roundtrip(ListType(OptionType(TupleType([StringType(),U16Type()]))), [{'some':mk_tup(mk_str("answer"),42)}]) + await test_roundtrip(VariantType([CaseType('x', TupleType([U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + U32Type(),U32Type(),U32Type(),U32Type(), + StringType()]))]), + {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}) + + +async def test_handles(): before = definitions.MAX_FLAT_RESULTS definitions.MAX_FLAT_RESULTS = 16 @@ -502,7 +506,7 @@ def on_return(results): nonlocal got got = results - asyncio.run(canon_lift(opts, inst, ft, core_wasm, None, on_start, on_return, None)) + await canon_lift(opts, inst, ft, core_wasm, None, on_start, on_return, None) assert(len(got) == 3) assert(got[0] == 46) @@ -513,7 +517,6 @@ def on_return(results): assert(len(inst.resources.table(rt).free) == 4) definitions.MAX_FLAT_RESULTS = before -test_handles() async def test_async_to_async(): producer_heap = Heap(10) @@ -563,31 +566,32 @@ async def consumer(task, args): u8 = consumer_heap.memory[ptr] assert(u8 == 43) [ret] = await canon_lower(consumer_opts, toggle_ft, toggle_callee, task, []) - assert(ret == (1 | (CallState.STARTED << 30))) + assert((ret >> 30) == CallState.STARTED) + subi = ret & ~(3 << 30) retp = ptr consumer_heap.memory[retp] = 13 [ret] = await canon_lower(consumer_opts, blocking_ft, blocking_callee, task, [83, retp]) assert(ret == (2 | (CallState.STARTING << 30))) assert(consumer_heap.memory[retp] == 13) fut1.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) - [] = await canon_subtask_drop(task, callidx) - event, callidx = await task.wait() + [] = await canon_waitable_drop(task, callidx) + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_STARTED) assert(callidx == 2) assert(consumer_heap.memory[retp] == 13) fut2.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_RETURNED) assert(callidx == 2) assert(consumer_heap.memory[retp] == 44) fut3.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - [] = await canon_subtask_drop(task, callidx) + [] = await canon_waitable_drop(task, callidx) dtor_fut = asyncio.Future() dtor_value = None @@ -605,10 +609,10 @@ async def dtor(task, args): assert(ret == (2 | (CallState.STARTED << 30))) assert(dtor_value is None) dtor_fut.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == CallState.DONE) assert(callidx == 2) - [] = await canon_subtask_drop(task, callidx) + [] = await canon_waitable_drop(task, callidx) [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [42]) return [] @@ -628,7 +632,6 @@ def on_return(results): assert(len(got) == 1) assert(got[0] == 42) -asyncio.run(test_async_to_async()) async def test_async_callback(): producer_inst = ComponentInstance() @@ -662,22 +665,25 @@ async def consumer(task, args): return [42] async def callback(task, args): - assert(len(args) == 3) + assert(len(args) == 4) if args[0] == 42: assert(args[1] == EventCode.CALL_DONE) assert(args[2] == 1) - await canon_subtask_drop(task, 1) + assert(args[3] == 0) + await canon_waitable_drop(task, 1) return [53] elif args[0] == 52: assert(args[1] == EventCode.YIELDED) assert(args[2] == 0) + assert(args[3] == 0) fut2.set_result(None) return [62] else: assert(args[0] == 62) assert(args[1] == EventCode.CALL_DONE) assert(args[2] == 2) - await canon_subtask_drop(task, 2) + assert(args[3] == 0) + await canon_waitable_drop(task, 2) [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [83]) return [0] @@ -696,7 +702,6 @@ def on_return(results): await canon_lift(opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 83) -asyncio.run(test_async_callback()) async def test_async_to_sync(): producer_opts = CanonicalOptions() @@ -740,19 +745,19 @@ async def consumer(task, args): fut.set_result(None) assert(producer1_done == False) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) - await canon_subtask_drop(task, callidx) + await canon_waitable_drop(task, callidx) assert(producer1_done == True) assert(producer2_done == False) await canon_task_yield(task) assert(producer2_done == True) - event, callidx = task.poll() + event, callidx, _ = task.poll() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - await canon_subtask_drop(task, callidx) + await canon_waitable_drop(task, callidx) assert(producer2_done == True) assert(task.poll() is None) @@ -771,7 +776,6 @@ def on_return(results): await canon_lift(consumer_opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 83) -asyncio.run(test_async_to_sync()) async def test_async_backpressure(): producer_opts = CanonicalOptions() @@ -819,18 +823,18 @@ async def consumer(task, args): fut.set_result(None) assert(producer1_done == False) assert(producer2_done == False) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) assert(producer1_done == True) assert(producer2_done == True) - event, callidx = task.poll() + event, callidx, _ = task.poll() assert(event == EventCode.CALL_DONE) assert(callidx == 2) assert(producer2_done == True) - await canon_subtask_drop(task, 1) - await canon_subtask_drop(task, 2) + await canon_waitable_drop(task, 1) + await canon_waitable_drop(task, 2) assert(task.poll() is None) @@ -848,8 +852,6 @@ def on_return(results): await canon_lift(consumer_opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 84) -if definitions.DETERMINISTIC_PROFILE: - asyncio.run(test_async_backpressure()) async def test_sync_using_wait(): hostcall_opts = mk_opts() @@ -878,16 +880,16 @@ async def core_func(task, args): assert(ret == (2 | (CallState.STARTED << 30))) fut1.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) fut2.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - await canon_subtask_drop(task, 1) - await canon_subtask_drop(task, 2) + await canon_waitable_drop(task, 1) + await canon_waitable_drop(task, 2) return [] @@ -896,6 +898,570 @@ def on_start(): return [] def on_return(results): pass await canon_lift(mk_opts(), inst, ft, core_func, None, on_start, on_return) -asyncio.run(test_sync_using_wait()) + +class HostSource(Stream): + remaining: list[int] + close_if_empty: bool + chunk: int + waiting: Optional[asyncio.Future] + + def __init__(self, contents, chunk, close_if_empty = True): + self.remaining = contents + self.close_if_empty = close_if_empty + self.chunk = chunk + self.waiting = None + + def closed(self): + return not self.remaining and self.close_if_empty + + def wake_waiting(self): + if self.waiting: + self.waiting.set_result(None) + self.waiting = None + + def close(self): + self.remaining = [] + self.close_if_empty = True + self.wake_waiting() + + def close_once_empty(self): + self.close_if_empty = True + if self.closed(): + self.wake_waiting() + + async def read(self, dst, on_block): + if not self.remaining: + if self.closed(): + return + self.waiting = asyncio.Future() + await on_block(self.waiting) + if not self.remaining: + return + n = min(dst.remain(), len(self.remaining), self.chunk) + dst.lower(self.remaining[:n]) + del self.remaining[:n] + + def write(self, vs): + assert(vs and not self.closed()) + self.remaining += vs + self.wake_waiting() + + def stop_reading(self): + self.wake_waiting() + + def maybe_writer_handle_index(self, inst): + return None + +class HostSink: + stream: Stream + received: list[int] + chunk: int + write_remain: int + write_event: asyncio.Event + ready_to_consume: asyncio.Event + + def __init__(self, stream, chunk, remain = 2**64): + self.stream = stream + self.received = [] + self.chunk = chunk + self.write_remain = remain + self.write_event = asyncio.Event() + if remain: + self.write_event.set() + self.ready_to_consume = asyncio.Event() + async def read_all(): + while not self.stream.closed(): + async def on_block(f): + return await f + await self.write_event.wait() + await self.stream.read(self, on_block) + asyncio.create_task(read_all()) + + def set_remain(self, n): + self.write_remain = n + if self.write_remain > 0: + self.write_event.set() + + def remain(self): + return self.write_remain + + def lower(self, vs): + self.received += vs + self.ready_to_consume.set() + self.write_remain -= len(vs) + if self.write_remain == 0: + self.write_event.clear() + + async def consume(self, n): + while n > len(self.received): + self.ready_to_consume.clear() + await self.ready_to_consume.wait() + ret = self.received[:n]; + del self.received[:n] + return ret + +async def test_eager_stream_completion(): + ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + async def host_import(task, on_start, on_return, on_block): + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], Stream)) + incoming = HostSink(args[0], chunk=4) + outgoing = HostSource([], chunk=4, close_if_empty=False) + on_return([outgoing]) + async def add10(): + while not incoming.stream.closed(): + vs = await incoming.consume(4) + for i in range(len(vs)): + vs[i] += 10 + outgoing.write(vs) + outgoing.close() + asyncio.create_task(add10()) + + src_stream = HostSource([1,2,3,4,5,6,7,8], chunk=4) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = HostSink(results[0], chunk=4) + + async def core_func(task, args): + assert(len(args) == 1) + rsi1 = args[0] + assert(rsi1 == 1) + [wsi1] = await canon_stream_new(U8Type(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1]) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == 4) + assert(mem[0:4] == b'\x01\x02\x03\x04') + [wsi2] = await canon_stream_new(U8Type(), task) + retp = 12 + [ret] = await canon_lower(opts, ft, host_import, task, [wsi2, retp]) + assert(ret == 0) + rsi2 = mem[retp] + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == (4 | (2**31))) + assert(mem[0:4] == b'\x05\x06\x07\x08') + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, rsi1) + [] = await canon_waitable_drop(task, rsi2) + [] = await canon_waitable_drop(task, wsi1) + [] = await canon_waitable_drop(task, wsi2) + return [] + + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(dst_stream.received == [11,12,13,14,15,16,17,18]) + + +async def test_async_stream_ops(): + ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + host_import_incoming = None + host_import_outgoing = None + async def host_import(task, on_start, on_return, on_block): + nonlocal host_import_incoming, host_import_outgoing + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], Stream)) + host_import_incoming = HostSink(args[0], chunk=4, remain = 0) + host_import_outgoing = HostSource([], chunk=4, close_if_empty=False) + on_return([host_import_outgoing]) + while not host_import_incoming.stream.closed(): + vs = await on_block(host_import_incoming.consume(4)) + for i in range(len(vs)): + vs[i] += 10 + host_import_outgoing.write(vs) + host_import_outgoing.close_once_empty() + + src_stream = HostSource([], chunk=4, close_if_empty = False) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = HostSink(results[0], chunk=4, remain = 0) + + async def core_func(task, args): + [rsi1] = args + assert(rsi1 == 1) + [wsi1] = await canon_stream_new(U8Type(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1]) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == definitions.BLOCKED) + src_stream.write([1,2,3,4]) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi1) + assert(p2 == 4) + assert(mem[0:4] == b'\x01\x02\x03\x04') + [wsi2] = await canon_stream_new(U8Type(), task) + retp = 16 + [ret] = await canon_lower(opts, ft, host_import, task, [wsi2, retp]) + assert((ret >> 30) == CallState.RETURNED) + subi = ret & ~(3 << 30) + rsi2 = mem[16] + assert(rsi2 == 4) + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == definitions.BLOCKED) + host_import_incoming.set_remain(100) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi2) + assert(p2 == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == definitions.BLOCKED) + dst_stream.set_remain(100) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi1) + assert(p2 == 4) + src_stream.write([5,6,7,8]) + src_stream.close_once_empty() + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == (4 | definitions.CLOSED)) + [] = await canon_waitable_drop(task, rsi1) + assert(mem[0:4] == b'\x05\x06\x07\x08') + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, wsi2) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == definitions.BLOCKED) + event, p1, p2 = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + assert(p2 == 0) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi2) + assert(p2 == (4 | 2**31)) + [] = await canon_waitable_drop(task, subi) + [] = await canon_waitable_drop(task, rsi2) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, wsi1) + return [] + + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(dst_stream.received == [11,12,13,14,15,16,17,18]) + + +async def test_stream_forward(): + src_stream = HostSource([1,2,3,4], chunk=4) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = results[0] + + async def core_func(task, args): + assert(len(args) == 1) + rsi1 = args[0] + assert(rsi1 == 1) + return [rsi1] + + opts = mk_opts() + inst = ComponentInstance() + ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(src_stream is dst_stream) + + +async def test_receive_own_stream(): + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + host_ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())]) + async def host_import(task, on_start, on_return, on_block): + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], Stream)) + on_return(args) + + async def core_func(task, args): + assert(len(args) == 0) + [wsi] = await canon_stream_new(U8Type(), task) + assert(wsi == 1) + retp = 4 + [ret] = await canon_lower(opts, host_ft, host_import, task, [wsi, retp]) + assert(ret == 0) + result = int.from_bytes(mem[retp : retp+4], 'little', signed=False) + assert(result == (wsi | 2**31)) + [] = await canon_waitable_drop(task, wsi) + return [] + + def on_start(): return [] + def on_return(results): assert(len(results) == 0) + ft = FuncType([],[]) + await canon_lift(mk_opts(), inst, ft, core_func, None, on_start, on_return) + + +async def test_host_partial_reads_writes(): + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + src = HostSource([1,2,3,4], chunk=2, close_if_empty = False) + source_ft = FuncType([], [StreamType(U8Type())]) + async def host_source(task, on_start, on_return, on_block): + [] = on_start() + on_return([src]) + + dst = None + sink_ft = FuncType([StreamType(U8Type())], []) + async def host_sink(task, on_start, on_return, on_block): + nonlocal dst + [s] = on_start() + dst = HostSink(s, chunk=1, remain=2) + on_return([]) + + async def core_func(task, args): + assert(len(args) == 0) + retp = 4 + [ret] = await canon_lower(opts, source_ft, host_source, task, [retp]) + assert(ret == 0) + rsi = mem[retp] + assert(rsi == 1) + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == 2) + assert(mem[0:2] == b'\x01\x02') + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == 2) + assert(mem[0:2] == b'\x03\x04') + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == definitions.BLOCKED) + src.write([5,6]) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 2) + [] = await canon_waitable_drop(task, rsi) + + [wsi] = await canon_stream_new(U8Type(), task) + assert(wsi == 1) + [ret] = await canon_lower(opts, sink_ft, host_sink, task, [wsi]) + assert(ret == 0) + mem[0:6] = b'\x01\x02\x03\x04\x05\x06' + [ret] = await canon_stream_write(task, wsi, 0, 6) + assert(ret == 2) + [ret] = await canon_stream_write(task, wsi, 2, 6) + assert(ret == definitions.BLOCKED) + dst.set_remain(4) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi) + assert(p2 == 4) + assert(dst.received == [1,2,3,4,5,6]) + [] = await canon_waitable_drop(task, wsi) + return [] + + opts2 = mk_opts() + inst = ComponentInstance() + def on_start(): return [] + def on_return(results): assert(len(results) == 0) + ft = FuncType([],[]) + await canon_lift(opts2, inst, ft, core_func, None, on_start, on_return) + + +async def test_wasm_to_wasm_stream(): + fut1, fut2, fut3, fut4 = asyncio.Future(), asyncio.Future(), asyncio.Future(), asyncio.Future() + + inst1 = ComponentInstance() + mem1 = bytearray(10) + opts1 = mk_opts(memory=mem1, sync=False) + ft1 = FuncType([], [StreamType(U8Type())]) + async def core_func1(task, args): + assert(not args) + [wsi] = await canon_stream_new(U8Type(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'], []), [wsi]) + + await task.wait_on(fut1) + + mem1[0:4] = b'\x01\x02\x03\x04' + [ret] = await canon_stream_write(task, wsi, 0, 2) + assert(ret == 2) + [ret] = await canon_stream_write(task, wsi, 2, 2) + assert(ret == 2) + + await task.wait_on(fut2) + + mem1[0:8] = b'\x05\x06\x07\x08\x09\x0a\x0b\x0c' + [ret] = await canon_stream_write(task, wsi, 0, 8) + assert(ret == definitions.BLOCKED) + + fut3.set_result(None) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi) + assert(p2 == 4) + + fut4.set_result(None) + + [] = await canon_waitable_drop(task, wsi) + return [] + + func1 = partial(canon_lift, opts1, inst1, ft1, core_func1) + + inst2 = ComponentInstance() + mem2 = bytearray(10) + opts2 = mk_opts(memory=mem2, sync=False) + ft2 = FuncType([], []) + async def core_func2(task, args): + assert(not args) + [] = await canon_task_return(task, CoreFuncType([], []), []) + + retp = 0 + [ret] = await canon_lower(opts2, ft1, func1, task, [retp]) + assert((ret >> 30) == CallState.RETURNED) + subi = ret & ~(3 << 30) + rsi = mem2[0] + assert(rsi == 1) + + [ret] = await canon_stream_read(task, rsi, 0, 8) + assert(ret == definitions.BLOCKED) + + fut1.set_result(None) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 4) + assert(mem2[0:8] == b'\x01\x02\x03\x04\x00\x00\x00\x00') + + fut2.set_result(None) + await task.wait_on(fut3) + + mem2[0:8] = bytes(8) + [ret] = await canon_stream_read(task, rsi, 0, 2) + assert(ret == 2) + assert(mem2[0:6] == b'\x05\x06\x00\x00\x00\x00') + [ret] = await canon_stream_read(task, rsi, 2, 2) + assert(ret == 2) + assert(mem2[0:6] == b'\x05\x06\x07\x08\x00\x00') + + await task.wait_on(fut4) + + [ret] = await canon_stream_read(task, rsi, 0, 2) + assert(ret == definitions.CLOSED) + [] = await canon_waitable_drop(task, rsi) + + event, callidx, _ = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(callidx == subi) + [] = await canon_waitable_drop(task, subi) + return [] + + await canon_lift(opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:()) + + +async def test_borrow_stream(): + rt_inst = ComponentInstance() + rt = ResourceType(rt_inst, None) + + inst1 = ComponentInstance() + mem1 = bytearray(12) + opts1 = mk_opts(memory=mem1) + ft1 = FuncType([StreamType(BorrowType(rt))], []) + async def core_func1(task, args): + [rsi] = args + + [ret] = await canon_stream_read(task, rsi, 4, 2) + assert(ret == definitions.BLOCKED) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == (2 | definitions.CLOSED)) + + [] = await canon_waitable_drop(task, rsi) + + h1 = mem1[4] + h2 = mem1[8] + assert(await canon_resource_rep(rt, task, h1) == [42]) + assert(await canon_resource_rep(rt, task, h2) == [43]) + [] = await canon_resource_drop(rt, True, task, h1) + [] = await canon_resource_drop(rt, True, task, h2) + + return [] + + func1 = partial(canon_lift, opts1, inst1, ft1, core_func1) + + inst2 = ComponentInstance() + mem2 = bytearray(10) + sync_opts2 = mk_opts(memory=mem2, sync=True) + async_opts2 = mk_opts(memory=mem2, sync=False) + ft2 = FuncType([], []) + async def core_func2(task, args): + assert(not args) + + [wsi] = await canon_stream_new(BorrowType(rt), task) + [ret] = await canon_lower(async_opts2, ft1, func1, task, [wsi]) + assert((ret >> 30) == CallState.STARTED) + subi = ret & ~(3 << 30) + + [h1] = await canon_resource_new(rt, task, 42) + [h2] = await canon_resource_new(rt, task, 43) + mem2[0] = h1 + mem2[4] = h2 + + [ret] = await canon_stream_write(task, wsi, 0, 2) + assert(ret == 2) + [] = await canon_waitable_drop(task, wsi) + + event, p1, _ = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + + [] = await canon_waitable_drop(task, subi) + return [] + + await canon_lift(sync_opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:()) + + +async def run_async_tests(): + await test_roundtrips() + await test_handles() + await test_async_to_async() + await test_async_callback() + await test_async_to_sync() + await test_async_backpressure() + await test_sync_using_wait() + await test_eager_stream_completion() + await test_stream_forward() + await test_receive_own_stream() + await test_host_partial_reads_writes() + await test_async_stream_ops() + await test_wasm_to_wasm_stream() + await test_borrow_stream() + +asyncio.run(run_async_tests()) print("All tests passed")