diff --git a/design/mvp/Async.md b/design/mvp/Async.md index 2a44f8c5..7799c1fb 100644 --- a/design/mvp/Async.md +++ b/design/mvp/Async.md @@ -419,21 +419,25 @@ For now, this remains a [TODO](#todo) and validation will reject `async`-lifted ## TODO -Native async support is being proposed in progressive chunks. The following -features will be added in future chunks to complete "async" in Preview 3: -* `future`/`stream`/`error`: add for use in function types for finer-grained - concurrency -* `subtask.cancel`: allow a supertask to signal to a subtask that its result is - no longer wanted and to please wrap it up promptly -* allow "tail-calling" a subtask so that the current wasm instance can be torn - down eagerly -* `task.index`+`task.wake`: allow tasks in the same instance to wait on and - wake each other (async condvar-style) +Native async support is being proposed incrementally. The following features +will be added in future chunks roughly in the order list to complete the full +"async" story: +* add `future` type +* add `error` type that can be included when closing a stream/future * `nonblocking` function type attribute: allow a function to declare in its type that it will not transitively do anything blocking +* define what `async` means for `start` functions +* `task.index`+`task.wake`: allow tasks in the same instance to wait on and + wake each other +* `subtask.cancel`: allow a supertask to signal to a subtask that its result is + no longer wanted and to please wrap it up promptly +* `stream.lull` built-in that says "no more elements are coming for a while" * `recursive` function type attribute: allow a function to be reentered - recursively (instead of trapping) -* enable `async` `start` functions + recursively (instead of trapping) for the benefit of donut wrapping +* built-in to "tail-call" a subtask so that the current wasm instance can be torn + down eagerly while preserving "structured concurrency" +* allow pipelining multiple `stream.read`/`write` calls +* allow chaining multiple async calls together ("promise pipelining") * integrate with `shared`: define how to lift and lower functions `async` *and* `shared` diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md index a266165c..3a904013 100644 --- a/design/mvp/CanonicalABI.md +++ b/design/mvp/CanonicalABI.md @@ -2139,6 +2139,7 @@ where `$callee` has type `$ft`, validation specifies: * a `memory` is present if required by lifting and is a subtype of `(memory 1)` * a `realloc` is present if required by lifting and has type `(func (param i32 i32 i32 i32) (result i32))` * there is no `post-return` in `$opts` +* if `contains_async($ft)`, then `$opts.async` must be set When instantiating component instance `$inst`: * Define `$f` to be the partially-bound closure: `canon_lower($opts, $ft, $callee)` diff --git a/design/mvp/Explainer.md b/design/mvp/Explainer.md index 46c418a3..49173d6b 100644 --- a/design/mvp/Explainer.md +++ b/design/mvp/Explainer.md @@ -1222,10 +1222,12 @@ freeing memory after lifting and thus `post-return` may be deprecated in the future. 🔀 The `async` option specifies that the component wants to make (for imports) -or support (for exports) multiple concurrent (asynchronous) calls. This option -can be applied to any component-level function type and changes the derived -Canonical ABI significantly. See the [async explainer](Async.md) for more -details. +or support (for exports) multiple concurrent (asynchronous) calls. This +option can be applied to any component-level function type and changes the +derived Canonical ABI significantly. See the [async explainer](Async.md) for +more details. When a function signature contains a `future` or `stream`, +validation requires the `async` option to be set (since a synchronous call to +a function using these types is likely to deadlock). 🔀 The `(callback ...)` option may only be present in `canon lift` when the `async` option has also been set and specifies a core function that is diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py index 8f9be274..bca799a1 100644 --- a/design/mvp/canonical-abi/definitions.py +++ b/design/mvp/canonical-abi/definitions.py @@ -171,6 +171,11 @@ class Own(ValType): class Borrow(ValType): rt: ResourceType +@dataclass +class Stream(ValType): + t: ValType + + ### Context class Context: @@ -199,7 +204,7 @@ class CanonicalOptions: class ComponentInstance: handles: HandleTables - async_subtasks: Table[Subtask] + waitables: Table[Subtask|ReadableStream|WritableStream] num_tasks: int may_leave: bool backpressure: bool @@ -208,7 +213,7 @@ class ComponentInstance: def __init__(self): self.handles = HandleTables() - self.async_subtasks = Table[Subtask]() + self.waitables = Table[Subtask|ReadableStream|WritableStream]() self.num_tasks = 0 self.may_leave = True self.backpressure = False @@ -251,6 +256,8 @@ class Table(Generic[ElemT]): array: list[Optional[ElemT]] free: list[int] + MAX_LENGTH = 2**30 - 1 + def __init__(self): self.array = [None] self.free = [] @@ -267,7 +274,7 @@ def add(self, e): self.array[i] = e else: i = len(self.array) - trap_if(i >= 2**30) + trap_if(i > Table.MAX_LENGTH) self.array.append(e) return i @@ -301,8 +308,10 @@ class EventCode(IntEnum): CALL_RETURNED = CallState.RETURNED CALL_DONE = CallState.DONE YIELDED = 4 + STREAM_READ = 5 + STREAM_WRITE = 6 -EventTuple = tuple[EventCode, int] +EventTuple = tuple[EventCode, int, int] EventCallback = Callable[[], EventTuple] OnBlockCallback = Callable[[Awaitable], any] @@ -339,10 +348,9 @@ class Task(Context): caller: Optional[Task] on_return: Optional[Callable] on_block: OnBlockCallback - borrow_count: int + need_to_drop: int events: list[EventCallback] has_events: asyncio.Event - num_async_subtasks: int def __init__(self, opts, inst, ft, caller, on_return, on_block): super().__init__(opts, inst, self) @@ -350,10 +358,9 @@ def __init__(self, opts, inst, ft, caller, on_return, on_block): self.caller = caller self.on_return = on_return self.on_block = on_block - self.borrow_count = 0 + self.need_to_drop = 0 self.events = [] self.has_events = asyncio.Event() - self.num_async_subtasks = 0 def trap_if_on_the_stack(self, inst): c = self.caller @@ -433,19 +440,6 @@ async def yield_(self): self.maybe_start_pending_task() await self.wait_on(asyncio.sleep(0)) - def add_async_subtask(self, subtask): - assert(subtask.task is self and not subtask.notify_supertask) - subtask.notify_supertask = True - self.num_async_subtasks += 1 - return self.inst.async_subtasks.add(subtask) - - def create_borrow(self): - self.borrow_count += 1 - - def drop_borrow(self): - assert(self.borrow_count > 0) - self.borrow_count -= 1 - def return_(self, flat_results): trap_if(not self.on_return) if self.opts.sync and not self.opts.sync_task_return: @@ -462,8 +456,7 @@ def exit(self): assert(not self.events) assert(self.inst.num_tasks >= 1) trap_if(self.on_return) - trap_if(self.borrow_count != 0) - trap_if(self.num_async_subtasks != 0) + trap_if(self.need_to_drop != 0) self.inst.num_tasks -= 1 if self.opts.sync: assert(not self.inst.interruptible.is_set()) @@ -503,10 +496,10 @@ def maybe_notify_supertask(self): self.enqueued = True def subtask_event(): self.enqueued = False - i = self.inst.async_subtasks.array.index(self) + i = self.inst.waitables.array.index(self) if self.state == CallState.DONE: self.release_lenders() - return (EventCode(self.state), i) + return (EventCode(self.state), i, 0) self.task.notify(subtask_event) def on_start(self): @@ -534,7 +527,189 @@ def finish(self): self.maybe_notify_supertask() return self.flat_results -### Despecialization + def drop(self): + trap_if(self.enqueued) + trap_if(self.state != CallState.DONE) + self.task.need_to_drop -= 1 + +class Buffer: + _cx: Context + _t: ValType + _begin: int + _length: int + _progress: int + + MAX_LENGTH = 2**30 - 1 + + def __init__(self, cx, t, ptr, length): + trap_if(length == 0 or length > Buffer.MAX_LENGTH) + trap_if(ptr != align_to(ptr, alignment(t))) + trap_if(ptr + length * elem_size(t) > len(cx.opts.memory)) + self._cx = cx + self._t = t + self._begin = ptr + self._length = length + self._progress = 0 + + def remain(self): + assert(self._progress <= self._length) + return self._length - self._progress + +class ReadableBuffer(Buffer): + def lift(self, n): + assert(n <= self.remain()) + ptr = self._begin + self._progress * elem_size(self._t) + vs = load_list_from_valid_range(self._cx, ptr, n, self._t) + self._progress += n + return vs + +class WritableBuffer(Buffer): + def lower(self, vs): + assert(len(vs) <= self.remain()) + ptr = self._begin + self._progress * elem_size(self._t) + store_list_into_valid_range(self._cx, vs, ptr, self._t) + self._progress += len(vs) + +class AbstractStream: + def closed(self): + return self._closed + + def close(self): + if not self._closed: + self._closed = True + self._on_close() + + # precondition: !closed + read: Callable[[WritableBuffer, OnBlockCallback], Awaitable] + stop_reading: Callable[[], None] + try_unwrap_writer: Callable[[ComponentInstance], Optional[int]] + + def __init__(self, impl): + self._closed = False + self._on_close = impl.on_close + self.read = impl.read + self.stop_reading = impl.stop_reading + self.try_unwrap_writer = impl.try_unwrap_writer + +class ReadableStream: + stream: AbstractStream + t: ValType + cx: Context + pending_read: bool + + def __init__(self, t, cx, stream): + self.stream = stream + self.t = t + self.cx = cx + self.pending_read = False + + def drop(self): + trap_if(self.pending_read) + self.stream.close() + self.cx.task.need_to_drop -= 1 + +class WritableStream: + stream: AbstractStream + t: ValType + cx: Optional[Context] + pending_write: bool + other_direction: Optional[Literal['read', 'write']] + other_buffer: Optional[Buffer] + other_future: Optional[asyncio.Future] + + def __init__(self, t): + self.stream = AbstractStream(self) + self.t = t + self.cx = None + self.pending_write = False + self.other_direction = None + self.other_buffer = None + self.other_future = None + + async def read(self, dst, on_block): + return await self.rendezvous('read', dst, self.other_buffer, dst, on_block) + + async def write(self, src, on_block): + return await self.rendezvous('write', src, src, self.other_buffer, on_block) + + async def rendezvous(self, this_direction, this_buffer, src, dst, on_block): + assert(not self.stream.closed()) + if self.other_buffer: + ncopy = min(src.remain(), dst.remain()) + assert(ncopy > 0) + dst.lower(src.lift(ncopy)) + if not self.other_buffer.remain(): + self.other_buffer = None + self.finish_rendezvous() + else: + assert(not (self.other_direction or self.other_buffer or self.other_future)) + self.other_direction = this_direction + self.other_buffer = this_buffer + self.other_future = asyncio.Future() + await on_block(self.other_future) + if self.other_buffer is this_buffer: + self.other_buffer = None + + def finish_rendezvous(self): + if self.other_future: + self.other_future.set_result(None) + self.other_future = None + self.other_direction = None + + def stop_reading(self): + assert(not self.stream.closed()) + if self.other_direction == 'read': + self.finish_rendezvous() + + def stop_writing(self): + assert(not self.stream.closed()) + if self.other_direction == 'write': + self.finish_rendezvous() + + def try_unwrap_writer(self, inst): + assert(not self.stream.closed()) + if inst is self.cx.inst: + return self.cx.task.inst.waitables.array.index(self) + return None + + def on_close(self): + assert(self.stream.closed()) + self.finish_rendezvous() + + def drop(self): + trap_if(self.pending_write) + self.stream.close() + if self.cx: + self.cx.task.need_to_drop -= 1 + +### Type utilities + +def contains_async(t): + match t: + case Stream(): + return True + case FuncType(): + return any(contains_async(t) for t in t.param_types()) or \ + any(contains_async(t) for t in t.result_types()) + case PrimValType(): + return False + case List(t): + return contains_async(t) + case Record(fs): + return any(contains_async(f.t) for f in fs) + case Tuple(ts): + return any(contains_async(t) for t in ts) + case Variant(cs): + return any(contains_async(c.t) for c in cs) + case Option(t): + return contains_async(t) + case Result(o,e): + return contains_async(o) or contains_async(e) + case Own(): + return False + case Borrow(): + return False + assert(False) def despecialize(t): match t: @@ -562,6 +737,7 @@ def alignment(t): case Variant(cases) : return alignment_variant(cases) case Flags(labels) : return alignment_flags(labels) case Own(_) | Borrow(_) : return 4 + case Stream(_) : return 4 def alignment_list(elem_type, maybe_length): if maybe_length is not None: @@ -618,6 +794,7 @@ def elem_size(t): case Variant(cases) : return elem_size_variant(cases) case Flags(labels) : return elem_size_flags(labels) case Own(_) | Borrow(_) : return 4 + case Stream(_) : return 4 def elem_size_list(elem_type, maybe_length): if maybe_length is not None: @@ -677,6 +854,7 @@ def load(cx, ptr, t): case Flags(labels) : return load_flags(cx, ptr, labels) case Own() : return lift_own(cx, load_int(cx, ptr, 4), t) case Borrow() : return lift_borrow(cx, load_int(cx, ptr, 4), t) + case Stream(t) : return lift_stream(cx, load_int(cx, ptr, 4), t) def load_int(cx, ptr, nbytes, signed = False): return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed) @@ -830,6 +1008,24 @@ def lift_borrow(cx, i, t): cx.add_lender(h) return h.rep +def lift_stream(cx, i, elem_type): + w = cx.inst.waitables.get(i) + match w: + case ReadableStream(): + trap_if(w.t != elem_type) + trap_if(w.pending_read) + w.cx.task.need_to_drop -= 1 + cx.inst.waitables.remove(i) + case WritableStream(): + trap_if(w.t != elem_type) + assert(not w.pending_write) + trap_if(w.cx is not None) + w.cx = cx + cx.task.need_to_drop += 1 + case _: + trap() + return w.stream + ### Storing def store(cx, v, t, ptr): @@ -855,6 +1051,7 @@ def store(cx, v, t, ptr): case Flags(labels) : store_flags(cx, v, ptr, labels) case Own() : store_int(cx, lower_own(cx, v, t), ptr, 4) case Borrow() : store_int(cx, lower_borrow(cx, v, t), ptr, 4) + case Stream(t) : store_int(cx, lower_stream(cx, v, t), ptr, 4) def store_int(cx, v, ptr, nbytes, signed = False): cx.opts.memory[ptr : ptr+nbytes] = int.to_bytes(v, nbytes, 'little', signed=signed) @@ -1114,9 +1311,22 @@ def lower_borrow(cx, rep, t): if cx.inst is t.rt.impl: return rep h = HandleElem(rep, own=False, scope=cx) - cx.create_borrow() + cx.need_to_drop += 1 return cx.inst.handles.add(t.rt, h) +def lower_stream(cx, stream, elem_type): + assert(isinstance(stream, AbstractStream)) + if (i := stream.try_unwrap_writer(cx.inst)): + ws = cx.inst.waitables.array[i] + ws.cx.task.need_to_drop -= 1 + ws.cx = None + assert(2**31 > Table.MAX_LENGTH) + return i | (2**31) + else: + rs = ReadableStream(elem_type, cx, stream) + cx.task.need_to_drop += 1 + return cx.inst.waitables.add(rs) + ### Flattening MAX_FLAT_PARAMS = 16 @@ -1167,6 +1377,7 @@ def flatten_type(t): case Variant(cases) : return flatten_variant(cases) case Flags(labels) : return ['i32'] case Own(_) | Borrow(_) : return ['i32'] + case Stream(_) : return ['i32'] def flatten_list(elem_type, maybe_length): if maybe_length is not None: @@ -1233,6 +1444,7 @@ def lift_flat(cx, vi, t): case Flags(labels) : return lift_flat_flags(vi, labels) case Own() : return lift_own(cx, vi.next('i32'), t) case Borrow() : return lift_borrow(cx, vi.next('i32'), t) + case Stream(t) : return lift_stream(cx, vi.next('i32'), t) def lift_flat_unsigned(vi, core_width, t_width): i = vi.next('i' + str(core_width)) @@ -1324,6 +1536,7 @@ def lower_flat(cx, v, t): case Flags(labels) : return lower_flat_flags(v, labels) case Own() : return [lower_own(cx, v, t)] case Borrow() : return [lower_borrow(cx, v, t)] + case Stream(t) : return [lower_stream(cx, v, t)] def lower_flat_signed(i, core_bits): if i < 0: @@ -1438,10 +1651,10 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_blo ctx = packed_ctx & ~1 if is_yield: await task.yield_() - event, payload = (EventCode.YIELDED, 0) + event, p1, p2 = (EventCode.YIELDED, 0, 0) else: - event, payload = await task.wait() - [packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, payload]) + event, p1, p2 = await task.wait() + [packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, p1, p2]) task.exit() async def call_and_trap_on_throw(callee, task, args): @@ -1456,6 +1669,7 @@ async def canon_lower(opts, ft, callee, task, flat_args): trap_if(not task.inst.may_leave) subtask = Subtask(opts, ft, task, flat_args) if opts.sync: + assert(not contains_async(ft)) await task.call_sync(callee, task, subtask.on_start, subtask.on_return) flat_results = subtask.finish() else: @@ -1463,7 +1677,9 @@ async def do_call(on_block): await callee(task, subtask.on_start, subtask.on_return, on_block) [] = subtask.finish() if await call_and_handle_blocking(do_call): - i = task.add_async_subtask(subtask) + subtask.notify_supertask = True + task.need_to_drop += 1 + i = task.inst.waitables.add(subtask) flat_results = [pack_async_result(i, subtask.state)] else: flat_results = [0] @@ -1505,7 +1721,7 @@ async def canon_resource_drop(rt, sync, task, i): else: task.trap_if_on_the_stack(rt.impl) else: - h.scope.drop_borrow() + h.scope.need_to_drop -= 1 return flat_results ### `canon resource.rep` @@ -1535,8 +1751,9 @@ async def canon_task_return(task, core_ft, flat_args): async def canon_task_wait(task, ptr): trap_if(not task.inst.may_leave) trap_if(task.opts.callback is not None) - event, payload = await task.wait() - store(task, payload, U32(), ptr) + event, p1, p2 = await task.wait() + store(task, p1, U32(), ptr) + store(task, p2, U32(), ptr + 4) return [event] ### 🔀 `canon task.poll` @@ -1546,7 +1763,7 @@ async def canon_task_poll(task, ptr): ret = task.poll() if ret is None: return [0] - store(task, ret, Tuple([U32(), U32()]), ptr) + store(task, ret, Tuple([U32(), U32(), U32()]), ptr) return [1] ### 🔀 `canon task.yield` @@ -1557,12 +1774,92 @@ async def canon_task_yield(task): await task.yield_() return [] -### 🔀 `canon subtask.drop` +### 🔀 `canon stream.new` + +async def canon_stream_new(elem_type, task): + trap_if(not task.inst.may_leave) + ws = WritableStream(elem_type) + return [ task.inst.waitables.add(ws) ] + +### 🔀 `canon stream.read` and `canon stream.write` + +PENDING = 2**32 - 1 +CLOSED = 2**31 +assert(CLOSED > Buffer.MAX_LENGTH) +assert(PENDING > (Buffer.MAX_LENGTH | CLOSED)) +def pack_progress(n, buffer, handle): + return (n - buffer.remain()) | (CLOSED if handle.stream.closed() else 0) + +async def canon_stream_read(task, i, ptr, n): + trap_if(not task.inst.may_leave) + rs = task.inst.waitables.get(i) + trap_if(not isinstance(rs, ReadableStream)) + trap_if(rs.pending_read) + dst = WritableBuffer(rs.cx, rs.t, ptr, n) + if rs.stream.closed(): + flat_results = [CLOSED] + else: + async def do_read(on_block): + await rs.stream.read(dst, on_block) + if rs.pending_read: + def read_event(): + rs.pending_read = False + packed_progress = pack_progress(n, dst, rs) + return (EventCode.STREAM_READ, i, packed_progress) + rs.cx.task.notify(read_event) + if await call_and_handle_blocking(do_read): + rs.pending_read = True + flat_results = [PENDING] + else: + flat_results = [pack_progress(n, dst, rs)] + return flat_results + +async def canon_stream_write(task, i, ptr, n): + trap_if(not task.inst.may_leave) + ws = task.inst.waitables.get(i) + trap_if(not isinstance(ws, WritableStream)) + trap_if(not ws.cx) + trap_if(ws.pending_write) + src = ReadableBuffer(ws.cx, ws.t, ptr, n) + if ws.stream.closed(): + flat_results = [CLOSED] + else: + async def do_write(on_block): + await ws.write(src, on_block) + if ws.pending_write: + def write_event(): + ws.pending_write = False + packed_progress = pack_progress(n, src, ws) + return (EventCode.STREAM_WRITE, i, packed_progress) + ws.cx.task.notify(write_event) + if await call_and_handle_blocking(do_write): + ws.pending_write = True + flat_results = [PENDING] + else: + flat_results = [pack_progress(n, src, ws)] + return flat_results + +### 🔀 `canon stream.stop-reading` and `canon stream.stop-writing` + +async def canon_stream_stop_reading(task, i): + trap_if(not task.inst.may_leave) + rs = self.inst.waitables.get(i) + trap_if(not isinstance(rs, ReadableStream)) + trap_if(not rs.pending_read) + rs.stop_reading() + return [] + +async def canon_stream_stop_writing(task, i): + trap_if(not task.inst.may_leave) + ws = self.inst.waitables.get(i) + trap_if(not isinstance(ws, WritableStream)) + trap_if(not ws.pending_write) + ws.stop_writing() + return [] + +### 🔀 `canon waitable.drop` -async def canon_subtask_drop(task, i): +async def canon_waitable_drop(task, i): trap_if(not task.inst.may_leave) - subtask = task.inst.async_subtasks.remove(i) - trap_if(subtask.enqueued) - trap_if(subtask.state != CallState.DONE) - subtask.task.num_async_subtasks -= 1 + task.inst.waitables.remove(i).drop() return [] diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py index e1e25f6d..e4e42341 100644 --- a/design/mvp/canonical-abi/run_tests.py +++ b/design/mvp/canonical-abi/run_tests.py @@ -34,13 +34,14 @@ def realloc(self, original_ptr, original_size, alignment, new_size): self.memory[ret : ret + original_size] = self.memory[original_ptr : original_ptr + original_size] return ret -def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None): +def mk_opts(memory = bytearray(), encoding = 'utf8', realloc = None, post_return = None, sync_task_return = False, sync = True): opts = CanonicalOptions() opts.memory = memory opts.string_encoding = encoding opts.realloc = realloc opts.post_return = post_return - opts.sync = True + opts.sync_task_return = sync_task_return + opts.sync = sync opts.callback = None return opts @@ -341,53 +342,56 @@ def test_flatten(t, params, results): test_flatten(FuncType([U8() for _ in range(17)],[]), ['i32' for _ in range(17)], []) test_flatten(FuncType([U8() for _ in range(17)],[Tuple([U8(),U8()])]), ['i32' for _ in range(17)], ['i32','i32']) -def test_roundtrip(t, v): - before = definitions.MAX_FLAT_RESULTS - definitions.MAX_FLAT_RESULTS = 16 - ft = FuncType([t],[t]) - async def callee(task, x): - return x +async def test_roundtrips(): + async def test_roundtrip(t, v): + before = definitions.MAX_FLAT_RESULTS + definitions.MAX_FLAT_RESULTS = 16 - callee_heap = Heap(1000) - callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc) - callee_inst = ComponentInstance() - lifted_callee = partial(canon_lift, callee_opts, callee_inst, ft, callee) + ft = FuncType([t],[t]) + async def callee(task, x): + return x - caller_heap = Heap(1000) - caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc) - caller_inst = ComponentInstance() - caller_task = Task(caller_opts, caller_inst, ft, None, None, None) + callee_heap = Heap(1000) + callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc) + callee_inst = ComponentInstance() + lifted_callee = partial(canon_lift, callee_opts, callee_inst, ft, callee) - flat_args = asyncio.run(caller_task.enter(lambda: [v])) + caller_heap = Heap(1000) + caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc) + caller_inst = ComponentInstance() + caller_task = Task(caller_opts, caller_inst, ft, None, None, None) - return_in_heap = len(flatten_types([t])) > definitions.MAX_FLAT_RESULTS - if return_in_heap: - flat_args += [ caller_heap.realloc(0, 0, alignment(t), elem_size(t)) ] + flat_args = await caller_task.enter(lambda: [v]) - flat_results = asyncio.run(canon_lower(caller_opts, ft, lifted_callee, caller_task, flat_args)) + return_in_heap = len(flatten_types([t])) > definitions.MAX_FLAT_RESULTS + if return_in_heap: + flat_args += [ caller_heap.realloc(0, 0, alignment(t), elem_size(t)) ] - if return_in_heap: - flat_results = [ flat_args[-1] ] + flat_results = await canon_lower(caller_opts, ft, lifted_callee, caller_task, flat_args) - [got] = lift_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t]) - caller_task.exit() + if return_in_heap: + flat_results = [ flat_args[-1] ] - if got != v: - fail("test_roundtrip({},{}) got {}".format(t, v, got)) + [got] = lift_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t]) + caller_task.exit() - definitions.MAX_FLAT_RESULTS = before + if got != v: + fail("test_roundtrip({},{}) got {}".format(t, v, got)) -test_roundtrip(S8(), -1) -test_roundtrip(Tuple([U16(),U16()]), mk_tup(3,4)) -test_roundtrip(List(String()), [mk_str("hello there")]) -test_roundtrip(List(List(String())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) -test_roundtrip(List(Option(Tuple([String(),U16()]))), [{'some':mk_tup(mk_str("answer"),42)}]) -test_roundtrip(Variant([Case('x', Tuple([U32(),U32(),U32(),U32(), U32(),U32(),U32(),U32(), - U32(),U32(),U32(),U32(), U32(),U32(),U32(),U32(), String()]))]), - {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}) + definitions.MAX_FLAT_RESULTS = before -def test_handles(): + await test_roundtrip(S8(), -1) + await test_roundtrip(Tuple([U16(),U16()]), mk_tup(3,4)) + await test_roundtrip(List(String()), [mk_str("hello there")]) + await test_roundtrip(List(List(String())), [[mk_str("one"),mk_str("two")],[mk_str("three")]]) + await test_roundtrip(List(Option(Tuple([String(),U16()]))), [{'some':mk_tup(mk_str("answer"),42)}]) + await test_roundtrip(Variant([Case('x', Tuple([U32(),U32(),U32(),U32(), U32(),U32(),U32(),U32(), + U32(),U32(),U32(),U32(), U32(),U32(),U32(),U32(), String()]))]), + {'x': mk_tup(1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, mk_str("wat"))}) + + +async def test_handles(): before = definitions.MAX_FLAT_RESULTS definitions.MAX_FLAT_RESULTS = 16 @@ -479,7 +483,7 @@ def on_return(results): nonlocal got got = results - asyncio.run(canon_lift(opts, inst, ft, core_wasm, None, on_start, on_return, None)) + await canon_lift(opts, inst, ft, core_wasm, None, on_start, on_return, None) assert(len(got) == 3) assert(got[0] == 46) @@ -490,7 +494,6 @@ def on_return(results): assert(len(inst.handles.table(rt).free) == 4) definitions.MAX_FLAT_RESULTS = before -test_handles() async def test_async_to_async(): producer_heap = Heap(10) @@ -537,40 +540,35 @@ async def consumer(task, args): ptr = consumer_heap.realloc(0, 0, 1, 1) [ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [0, ptr]) assert(ret == 0) - assert(task.num_async_subtasks == 0) u8 = consumer_heap.memory[ptr] assert(u8 == 43) [ret] = await canon_lower(consumer_opts, toggle_ft, toggle_callee, task, []) - assert(ret == (1 | (CallState.STARTED << 30))) - assert(task.num_async_subtasks == 1) + assert((ret >> 30) == CallState.STARTED) + subi = ret & ~(3 << 30) retp = ptr consumer_heap.memory[retp] = 13 [ret] = await canon_lower(consumer_opts, blocking_ft, blocking_callee, task, [83, retp]) assert(ret == (2 | (CallState.STARTING << 30))) - assert(task.num_async_subtasks == 2) assert(consumer_heap.memory[retp] == 13) fut1.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) - [] = await canon_subtask_drop(task, callidx) - assert(task.num_async_subtasks == 1) - event, callidx = await task.wait() + [] = await canon_waitable_drop(task, callidx) + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_STARTED) assert(callidx == 2) assert(consumer_heap.memory[retp] == 13) fut2.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_RETURNED) assert(callidx == 2) assert(consumer_heap.memory[retp] == 44) fut3.set_result(None) - assert(task.num_async_subtasks == 1) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - [] = await canon_subtask_drop(task, callidx) - assert(task.num_async_subtasks == 0) + [] = await canon_waitable_drop(task, callidx) dtor_fut = asyncio.Future() dtor_value = None @@ -586,14 +584,12 @@ async def dtor(task, args): assert(dtor_value is None) [ret] = await canon_resource_drop(rt, False, task, 1) assert(ret == (2 | (CallState.STARTED << 30))) - assert(task.num_async_subtasks == 1) assert(dtor_value is None) dtor_fut.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == CallState.DONE) assert(callidx == 2) - [] = await canon_subtask_drop(task, callidx) - assert(task.num_async_subtasks == 0) + [] = await canon_waitable_drop(task, callidx) [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [42]) return [] @@ -613,7 +609,6 @@ def on_return(results): assert(len(got) == 1) assert(got[0] == 42) -asyncio.run(test_async_to_async()) async def test_async_callback(): producer_inst = ComponentInstance() @@ -647,22 +642,25 @@ async def consumer(task, args): return [42] async def callback(task, args): - assert(len(args) == 3) + assert(len(args) == 4) if args[0] == 42: assert(args[1] == EventCode.CALL_DONE) assert(args[2] == 1) - await canon_subtask_drop(task, 1) + assert(args[3] == 0) + await canon_waitable_drop(task, 1) return [53] elif args[0] == 52: assert(args[1] == EventCode.YIELDED) assert(args[2] == 0) + assert(args[3] == 0) fut2.set_result(None) return [62] else: assert(args[0] == 62) assert(args[1] == EventCode.CALL_DONE) assert(args[2] == 2) - await canon_subtask_drop(task, 2) + assert(args[3] == 0) + await canon_waitable_drop(task, 2) [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [83]) return [0] @@ -681,7 +679,6 @@ def on_return(results): await canon_lift(opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 83) -asyncio.run(test_async_callback()) async def test_async_to_sync(): producer_opts = CanonicalOptions() @@ -725,19 +722,19 @@ async def consumer(task, args): fut.set_result(None) assert(producer1_done == False) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) - await canon_subtask_drop(task, callidx) + await canon_waitable_drop(task, callidx) assert(producer1_done == True) assert(producer2_done == False) await canon_task_yield(task) assert(producer2_done == True) - event, callidx = task.poll() + event, callidx, _ = task.poll() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - await canon_subtask_drop(task, callidx) + await canon_waitable_drop(task, callidx) assert(producer2_done == True) assert(task.poll() is None) @@ -756,7 +753,6 @@ def on_return(results): await canon_lift(consumer_opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 83) -asyncio.run(test_async_to_sync()) async def test_async_backpressure(): producer_opts = CanonicalOptions() @@ -804,18 +800,18 @@ async def consumer(task, args): fut.set_result(None) assert(producer1_done == False) assert(producer2_done == False) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) assert(producer1_done == True) assert(producer2_done == True) - event, callidx = task.poll() + event, callidx, _ = task.poll() assert(event == EventCode.CALL_DONE) assert(callidx == 2) assert(producer2_done == True) - await canon_subtask_drop(task, 1) - await canon_subtask_drop(task, 2) + await canon_waitable_drop(task, 1) + await canon_waitable_drop(task, 2) assert(task.poll() is None) @@ -833,8 +829,6 @@ def on_return(results): await canon_lift(consumer_opts, consumer_inst, consumer_ft, consumer, None, on_start, on_return) assert(got[0] == 84) -if definitions.DETERMINISTIC_PROFILE: - asyncio.run(test_async_backpressure()) async def test_sync_using_wait(): hostcall_opts = mk_opts() @@ -863,16 +857,16 @@ async def core_func(task, args): assert(ret == (2 | (CallState.STARTED << 30))) fut1.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 1) fut2.set_result(None) - event, callidx = await task.wait() + event, callidx, _ = await task.wait() assert(event == EventCode.CALL_DONE) assert(callidx == 2) - await canon_subtask_drop(task, 1) - await canon_subtask_drop(task, 2) + await canon_waitable_drop(task, 1) + await canon_waitable_drop(task, 2) return [] @@ -881,6 +875,570 @@ def on_start(): return [] def on_return(results): pass await canon_lift(mk_opts(), inst, ft, core_func, None, on_start, on_return) -asyncio.run(test_sync_using_wait()) + +class HostSource(AbstractStream): + remaining: list[int] + close_if_empty: bool + chunk: int + waiting: Optional[asyncio.Future] + + def __init__(self, contents, chunk, close_if_empty = True): + self.remaining = contents + self.close_if_empty = close_if_empty + self.chunk = chunk + self.waiting = None + + def closed(self): + return not self.remaining and self.close_if_empty + + def wake_waiting(self): + if self.waiting: + self.waiting.set_result(None) + self.waiting = None + + def close(self): + self.remaining = [] + self.close_if_empty = True + self.wake_waiting() + + def close_once_empty(self): + self.close_if_empty = True + if self.closed(): + self.wake_waiting() + + async def read(self, dst, on_block): + if not self.remaining: + if self.closed(): + return + self.waiting = asyncio.Future() + await on_block(self.waiting) + if not self.remaining: + return + n = min(dst.remain(), len(self.remaining), self.chunk) + dst.lower(self.remaining[:n]) + del self.remaining[:n] + + def write(self, vs): + assert(vs and not self.closed()) + self.remaining += vs + self.wake_waiting() + + def stop_reading(self): + self.wake_waiting() + + def try_unwrap_writer(self, inst): + return None + +class HostSink: + stream: AbstractStream + received: list[int] + chunk: int + write_remain: int + write_event: asyncio.Event + ready_to_consume: asyncio.Event + + def __init__(self, stream, chunk, remain = 2**64): + self.stream = stream + self.received = [] + self.chunk = chunk + self.write_remain = remain + self.write_event = asyncio.Event() + if remain: + self.write_event.set() + self.ready_to_consume = asyncio.Event() + async def read_all(): + while not self.stream.closed(): + async def on_block(f): + return await f + await self.write_event.wait() + await self.stream.read(self, on_block) + asyncio.create_task(read_all()) + + def set_remain(self, n): + self.write_remain = n + if self.write_remain > 0: + self.write_event.set() + + def remain(self): + return self.write_remain + + def lower(self, vs): + self.received += vs + self.ready_to_consume.set() + self.write_remain -= len(vs) + if self.write_remain == 0: + self.write_event.clear() + + async def consume(self, n): + while n > len(self.received): + self.ready_to_consume.clear() + await self.ready_to_consume.wait() + ret = self.received[:n]; + del self.received[:n] + return ret + +async def test_sync_stream_ops(): + ft = FuncType([Stream(U8())], [Stream(U8())]) + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + async def host_import(task, on_start, on_return, on_block): + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], AbstractStream)) + incoming = HostSink(args[0], chunk=4) + outgoing = HostSource([], chunk=4, close_if_empty=False) + on_return([outgoing]) + async def add10(): + while not incoming.stream.closed(): + vs = await incoming.consume(4) + for i in range(len(vs)): + vs[i] += 10 + outgoing.write(vs) + outgoing.close() + asyncio.create_task(add10()) + + src_stream = HostSource([1,2,3,4,5,6,7,8], chunk=4) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = HostSink(results[0], chunk=4) + + async def core_func(task, args): + assert(len(args) == 1) + rsi1 = args[0] + assert(rsi1 == 1) + [wsi1] = await canon_stream_new(U8(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1]) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == 4) + assert(mem[0:4] == b'\x01\x02\x03\x04') + [wsi2] = await canon_stream_new(U8(), task) + retp = 12 + [ret] = await canon_lower(opts, ft, host_import, task, [wsi2, retp]) + assert(ret == 0) + rsi2 = mem[retp] + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == (4 | (2**31))) + assert(mem[0:4] == b'\x05\x06\x07\x08') + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, rsi1) + [] = await canon_waitable_drop(task, rsi2) + [] = await canon_waitable_drop(task, wsi1) + [] = await canon_waitable_drop(task, wsi2) + return [] + + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(dst_stream.received == [11,12,13,14,15,16,17,18]) + + +async def test_async_stream_ops(): + ft = FuncType([Stream(U8())], [Stream(U8())]) + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + host_import_incoming = None + host_import_outgoing = None + async def host_import(task, on_start, on_return, on_block): + nonlocal host_import_incoming, host_import_outgoing + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], AbstractStream)) + host_import_incoming = HostSink(args[0], chunk=4, remain = 0) + host_import_outgoing = HostSource([], chunk=4, close_if_empty=False) + on_return([host_import_outgoing]) + while not host_import_incoming.stream.closed(): + vs = await on_block(host_import_incoming.consume(4)) + for i in range(len(vs)): + vs[i] += 10 + host_import_outgoing.write(vs) + host_import_outgoing.close_once_empty() + + src_stream = HostSource([], chunk=4, close_if_empty = False) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = HostSink(results[0], chunk=4, remain = 0) + + async def core_func(task, args): + [rsi1] = args + assert(rsi1 == 1) + [wsi1] = await canon_stream_new(U8(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1]) + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == definitions.PENDING) + src_stream.write([1,2,3,4]) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi1) + assert(p2 == 4) + assert(mem[0:4] == b'\x01\x02\x03\x04') + [wsi2] = await canon_stream_new(U8(), task) + retp = 16 + [ret] = await canon_lower(opts, ft, host_import, task, [wsi2, retp]) + assert((ret >> 30) == CallState.RETURNED) + subi = ret & ~(3 << 30) + rsi2 = mem[16] + assert(rsi2 == 4) + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == definitions.PENDING) + host_import_incoming.set_remain(100) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi2) + assert(p2 == 4) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == 4) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == definitions.PENDING) + dst_stream.set_remain(100) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi1) + assert(p2 == 4) + src_stream.write([5,6,7,8]) + src_stream.close_once_empty() + [ret] = await canon_stream_read(task, rsi1, 0, 4) + assert(ret == (4 | definitions.CLOSED)) + [] = await canon_waitable_drop(task, rsi1) + assert(mem[0:4] == b'\x05\x06\x07\x08') + [ret] = await canon_stream_write(task, wsi2, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, wsi2) + [ret] = await canon_stream_read(task, rsi2, 0, 4) + assert(ret == definitions.PENDING) + event, p1, p2 = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + assert(p2 == 0) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi2) + assert(p2 == (4 | 2**31)) + [] = await canon_waitable_drop(task, subi) + [] = await canon_waitable_drop(task, rsi2) + [ret] = await canon_stream_write(task, wsi1, 0, 4) + assert(ret == 4) + [] = await canon_waitable_drop(task, wsi1) + return [] + + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(dst_stream.received == [11,12,13,14,15,16,17,18]) + + +async def test_stream_forward(): + src_stream = HostSource([1,2,3,4], chunk=4) + def on_start(): + return [src_stream] + + dst_stream = None + def on_return(results): + assert(len(results) == 1) + nonlocal dst_stream + dst_stream = results[0] + + async def core_func(task, args): + assert(len(args) == 1) + rsi1 = args[0] + assert(rsi1 == 1) + return [rsi1] + + opts = mk_opts() + inst = ComponentInstance() + ft = FuncType([Stream(U8())], [Stream(U8())]) + await canon_lift(opts, inst, ft, core_func, None, on_start, on_return) + assert(src_stream is dst_stream) + + +async def test_receive_own_stream(): + inst = ComponentInstance() + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + host_ft = FuncType([Stream(U8())], [Stream(U8())]) + async def host_import(task, on_start, on_return, on_block): + args = on_start() + assert(len(args) == 1) + assert(isinstance(args[0], AbstractStream)) + on_return(args) + + async def core_func(task, args): + assert(len(args) == 0) + [wsi] = await canon_stream_new(U8(), task) + assert(wsi == 1) + retp = 4 + [ret] = await canon_lower(opts, host_ft, host_import, task, [wsi, retp]) + assert(ret == 0) + result = int.from_bytes(mem[retp : retp+4], 'little', signed=False) + assert(result == (wsi | 2**31)) + [] = await canon_waitable_drop(task, wsi) + return [] + + def on_start(): return [] + def on_return(results): assert(len(results) == 0) + ft = FuncType([],[]) + await canon_lift(mk_opts(), inst, ft, core_func, None, on_start, on_return) + + +async def test_host_partial_reads_writes(): + mem = bytearray(20) + opts = mk_opts(memory=mem, sync=False) + + src = HostSource([1,2,3,4], chunk=2, close_if_empty = False) + source_ft = FuncType([], [Stream(U8())]) + async def host_source(task, on_start, on_return, on_block): + [] = on_start() + on_return([src]) + + dst = None + sink_ft = FuncType([Stream(U8())], []) + async def host_sink(task, on_start, on_return, on_block): + nonlocal dst + [s] = on_start() + dst = HostSink(s, chunk=1, remain=2) + on_return([]) + + async def core_func(task, args): + assert(len(args) == 0) + retp = 4 + [ret] = await canon_lower(opts, source_ft, host_source, task, [retp]) + assert(ret == 0) + rsi = mem[retp] + assert(rsi == 1) + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == 2) + assert(mem[0:2] == b'\x01\x02') + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == 2) + assert(mem[0:2] == b'\x03\x04') + [ret] = await canon_stream_read(task, rsi, 0, 4) + assert(ret == definitions.PENDING) + src.write([5,6]) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 2) + [] = await canon_waitable_drop(task, rsi) + + [wsi] = await canon_stream_new(U8(), task) + assert(wsi == 1) + [ret] = await canon_lower(opts, sink_ft, host_sink, task, [wsi]) + assert(ret == 0) + mem[0:6] = b'\x01\x02\x03\x04\x05\x06' + [ret] = await canon_stream_write(task, wsi, 0, 6) + assert(ret == 2) + [ret] = await canon_stream_write(task, wsi, 2, 6) + assert(ret == definitions.PENDING) + dst.set_remain(4) + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi) + assert(p2 == 4) + assert(dst.received == [1,2,3,4,5,6]) + [] = await canon_waitable_drop(task, wsi) + return [] + + opts2 = mk_opts() + inst = ComponentInstance() + def on_start(): return [] + def on_return(results): assert(len(results) == 0) + ft = FuncType([],[]) + await canon_lift(opts2, inst, ft, core_func, None, on_start, on_return) + + +async def test_wasm_to_wasm_stream(): + fut1, fut2, fut3, fut4 = asyncio.Future(), asyncio.Future(), asyncio.Future(), asyncio.Future() + + inst1 = ComponentInstance() + mem1 = bytearray(10) + opts1 = mk_opts(memory=mem1, sync=False) + ft1 = FuncType([], [Stream(U8())]) + async def core_func1(task, args): + assert(not args) + [wsi] = await canon_stream_new(U8(), task) + [] = await canon_task_return(task, CoreFuncType(['i32'], []), [wsi]) + + await task.wait_on(fut1) + + mem1[0:4] = b'\x01\x02\x03\x04' + [ret] = await canon_stream_write(task, wsi, 0, 2) + assert(ret == 2) + [ret] = await canon_stream_write(task, wsi, 2, 2) + assert(ret == 2) + + await task.wait_on(fut2) + + mem1[0:8] = b'\x05\x06\x07\x08\x09\x0a\x0b\x0c' + [ret] = await canon_stream_write(task, wsi, 0, 8) + assert(ret == definitions.PENDING) + + fut3.set_result(None) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_WRITE) + assert(p1 == wsi) + assert(p2 == 4) + + fut4.set_result(None) + + [] = await canon_waitable_drop(task, wsi) + return [] + + func1 = partial(canon_lift, opts1, inst1, ft1, core_func1) + + inst2 = ComponentInstance() + mem2 = bytearray(10) + opts2 = mk_opts(memory=mem2, sync=False) + ft2 = FuncType([], []) + async def core_func2(task, args): + assert(not args) + [] = await canon_task_return(task, CoreFuncType([], []), []) + + retp = 0 + [ret] = await canon_lower(opts2, ft1, func1, task, [retp]) + assert((ret >> 30) == CallState.RETURNED) + subi = ret & ~(3 << 30) + rsi = mem2[0] + assert(rsi == 1) + + [ret] = await canon_stream_read(task, rsi, 0, 8) + assert(ret == definitions.PENDING) + + fut1.set_result(None) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == 4) + assert(mem2[0:8] == b'\x01\x02\x03\x04\x00\x00\x00\x00') + + fut2.set_result(None) + await task.wait_on(fut3) + + mem2[0:8] = bytes(8) + [ret] = await canon_stream_read(task, rsi, 0, 2) + assert(ret == 2) + assert(mem2[0:6] == b'\x05\x06\x00\x00\x00\x00') + [ret] = await canon_stream_read(task, rsi, 2, 2) + assert(ret == 2) + assert(mem2[0:6] == b'\x05\x06\x07\x08\x00\x00') + + await task.wait_on(fut4) + + [ret] = await canon_stream_read(task, rsi, 0, 2) + assert(ret == definitions.CLOSED) + [] = await canon_waitable_drop(task, rsi) + + event, callidx, _ = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(callidx == subi) + [] = await canon_waitable_drop(task, subi) + return [] + + await canon_lift(opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:()) + + +async def test_borrow_stream(): + rt_inst = ComponentInstance() + rt = ResourceType(rt_inst, None) + + inst1 = ComponentInstance() + mem1 = bytearray(12) + opts1 = mk_opts(memory=mem1) + ft1 = FuncType([Stream(Borrow(rt))], []) + async def core_func1(task, args): + [rsi] = args + + [ret] = await canon_stream_read(task, rsi, 4, 2) + assert(ret == definitions.PENDING) + + event, p1, p2 = await task.wait() + assert(event == EventCode.STREAM_READ) + assert(p1 == rsi) + assert(p2 == (2 | definitions.CLOSED)) + + [] = await canon_waitable_drop(task, rsi) + + h1 = mem1[4] + h2 = mem1[8] + assert(await canon_resource_rep(rt, task, h1) == [42]) + assert(await canon_resource_rep(rt, task, h2) == [43]) + [] = await canon_resource_drop(rt, True, task, h1) + [] = await canon_resource_drop(rt, True, task, h2) + + return [] + + func1 = partial(canon_lift, opts1, inst1, ft1, core_func1) + + inst2 = ComponentInstance() + mem2 = bytearray(10) + sync_opts2 = mk_opts(memory=mem2, sync=True) + async_opts2 = mk_opts(memory=mem2, sync=False) + ft2 = FuncType([], []) + async def core_func2(task, args): + assert(not args) + + [wsi] = await canon_stream_new(Borrow(rt), task) + [ret] = await canon_lower(async_opts2, ft1, func1, task, [wsi]) + assert((ret >> 30) == CallState.STARTED) + subi = ret & ~(3 << 30) + + [h1] = await canon_resource_new(rt, task, 42) + [h2] = await canon_resource_new(rt, task, 43) + mem2[0] = h1 + mem2[4] = h2 + + [ret] = await canon_stream_write(task, wsi, 0, 2) + assert(ret == 2) + [] = await canon_waitable_drop(task, wsi) + + event, p1, _ = await task.wait() + assert(event == EventCode.CALL_DONE) + assert(p1 == subi) + + [] = await canon_waitable_drop(task, subi) + return [] + + await canon_lift(sync_opts2, inst2, ft2, core_func2, None, lambda:[], lambda _:()) + + +async def run_async_tests(): + await test_roundtrips() + await test_handles() + await test_async_to_async() + await test_async_callback() + await test_async_to_sync() + await test_async_backpressure() + await test_sync_using_wait() + await test_sync_stream_ops() + await test_stream_forward() + await test_receive_own_stream() + await test_host_partial_reads_writes() + await test_async_stream_ops() + await test_wasm_to_wasm_stream() + await test_borrow_stream() + +asyncio.run(run_async_tests()) print("All tests passed")