From 090bcea6e4921ce2100d514634ad5ac92aeb7e71 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Tue, 30 Jan 2024 11:50:10 -0500 Subject: [PATCH] Sketch Output API --- src/dispatch/coroutine.py | 25 +++++++++++++++++++++++++ src/dispatch/fastapi.py | 10 +++++++--- tests/test_fastapi.py | 22 ++++++++++++++-------- 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index 97f121e9..1309130a 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -1,6 +1,16 @@ """Dispatch coroutine interface. + +Coroutines are currently created using the @app.dispatch_coroutine() decorator +in a FastAPI app. See dispatch.fastapi for more details and examples. This +module describes how to write functions that get turned into coroutines. + +Coroutines are functions that can yield at any point in their execution to save +progress and coordinate with other coroutines. They take exactly one argument of +type Input, and return an Output value. + """ +from __future__ import annotations from typing import Any from dataclasses import dataclass import pickle @@ -42,3 +52,18 @@ def input(self) -> Any: if not self._has_input: raise ValueError("This input is for a resumed coroutine") return self._input + + +class Output: + """The output of a coroutine. + + This class is meant to be instantiated and returned by authors of coroutines + to indicate the follow up action they need to take. + """ + + def __init__(self, value: None | Any = None): + self._value = pickle.dumps(value) + + @classmethod + def value(cls, value: Any) -> Output: + return Output(value=value) diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index f08c2c55..c32cd920 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -132,8 +132,12 @@ async def execute(request: fastapi.Request): ) output = coroutine(coro_input) - # TODO pack any - output_pb = google.protobuf.wrappers_pb2.StringValue(value=output) + if not isinstance(output, dispatch.coroutine.Output): + raise ValueError( + f"coroutine output should be an instance of {dispatch.coroutine.Output}, not {type(output)}" + ) + + output_pb = google.protobuf.wrappers_pb2.BytesValue(value=output._value) output_any = google.protobuf.any_pb2.Any() output_any.Pack(output_pb) @@ -145,6 +149,6 @@ async def execute(request: fastapi.Request): ), ) - return resp.SerializeToString() + return fastapi.Response(content=resp.SerializeToString()) return app diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 2b6bedd8..ce29d909 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,6 +1,7 @@ import pickle import unittest import dispatch.coroutine +from dispatch.coroutine import Input, Output import dispatch.fastapi import fastapi from fastapi.testclient import TestClient @@ -48,8 +49,10 @@ def test_fastapi_simple_request(self): app = dispatch.fastapi._new_app() @app.dispatch_coroutine() - def my_cool_coroutine(input: dispatch.coroutine.Input): - return f"You told me: '{input.input}' ({len(input.input)} characters)" + def my_cool_coroutine(input: Input) -> Output: + return Output.value( + f"You told me: '{input.input}' ({len(input.input)} characters)" + ) http_client = TestClient(app) @@ -72,9 +75,11 @@ def my_cool_coroutine(input: dispatch.coroutine.Input): self.assertEqual(resp.coroutine_version, req.coroutine_version) resp.exit.result.output.Unpack( - output := google.protobuf.wrappers_pb2.StringValue() + output_bytes := google.protobuf.wrappers_pb2.BytesValue() ) - self.assertEqual(output.value, "You told me: 'Hello World!' (12 characters)") + output = pickle.loads(output_bytes.value) + + self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") class TestCoroutine(unittest.TestCase): @@ -93,14 +98,15 @@ def execute(self, coroutine): def test_no_input(self): @self.app.dispatch_coroutine() - def my_cool_coroutine(input: dispatch.coroutine.Input): - return "Hello World!" + def my_cool_coroutine(input: Input) -> Output: + return Output.value("Hello World!") resp = self.execute(my_cool_coroutine) self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse) resp.exit.result.output.Unpack( - output := google.protobuf.wrappers_pb2.StringValue() + output_bytes := google.protobuf.wrappers_pb2.BytesValue() ) - self.assertEqual(output.value, "Hello World!") + output = pickle.loads(output_bytes.value) + self.assertEqual(output, "Hello World!")