Skip to content

Commit

Permalink
Sketch Output API
Browse files Browse the repository at this point in the history
  • Loading branch information
pelletier committed Jan 30, 2024
1 parent d117e28 commit 090bcea
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
25 changes: 25 additions & 0 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""Dispatch coroutine interface.
Coroutines are currently created using the @app.dispatch_coroutine() decorator
in a FastAPI app. See dispatch.fastapi for more details and examples. This
module describes how to write functions that get turned into coroutines.
Coroutines are functions that can yield at any point in their execution to save
progress and coordinate with other coroutines. They take exactly one argument of
type Input, and return an Output value.
"""

from __future__ import annotations
from typing import Any
from dataclasses import dataclass
import pickle
Expand Down Expand Up @@ -42,3 +52,18 @@ def input(self) -> Any:
if not self._has_input:
raise ValueError("This input is for a resumed coroutine")
return self._input


class Output:
"""The output of a coroutine.
This class is meant to be instantiated and returned by authors of coroutines
to indicate the follow up action they need to take.
"""

def __init__(self, value: None | Any = None):
self._value = pickle.dumps(value)

@classmethod
def value(cls, value: Any) -> Output:
return Output(value=value)
10 changes: 7 additions & 3 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,12 @@ async def execute(request: fastapi.Request):
)
output = coroutine(coro_input)

# TODO pack any
output_pb = google.protobuf.wrappers_pb2.StringValue(value=output)
if not isinstance(output, dispatch.coroutine.Output):
raise ValueError(
f"coroutine output should be an instance of {dispatch.coroutine.Output}, not {type(output)}"
)

output_pb = google.protobuf.wrappers_pb2.BytesValue(value=output._value)
output_any = google.protobuf.any_pb2.Any()
output_any.Pack(output_pb)

Expand All @@ -145,6 +149,6 @@ async def execute(request: fastapi.Request):
),
)

return resp.SerializeToString()
return fastapi.Response(content=resp.SerializeToString())

return app
22 changes: 14 additions & 8 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle
import unittest
import dispatch.coroutine
from dispatch.coroutine import Input, Output
import dispatch.fastapi
import fastapi
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -48,8 +49,10 @@ def test_fastapi_simple_request(self):
app = dispatch.fastapi._new_app()

@app.dispatch_coroutine()
def my_cool_coroutine(input: dispatch.coroutine.Input):
return f"You told me: '{input.input}' ({len(input.input)} characters)"
def my_cool_coroutine(input: Input) -> Output:
return Output.value(
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

http_client = TestClient(app)

Expand All @@ -72,9 +75,11 @@ def my_cool_coroutine(input: dispatch.coroutine.Input):
self.assertEqual(resp.coroutine_version, req.coroutine_version)

resp.exit.result.output.Unpack(
output := google.protobuf.wrappers_pb2.StringValue()
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
self.assertEqual(output.value, "You told me: 'Hello World!' (12 characters)")
output = pickle.loads(output_bytes.value)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")


class TestCoroutine(unittest.TestCase):
Expand All @@ -93,14 +98,15 @@ def execute(self, coroutine):

def test_no_input(self):
@self.app.dispatch_coroutine()
def my_cool_coroutine(input: dispatch.coroutine.Input):
return "Hello World!"
def my_cool_coroutine(input: Input) -> Output:
return Output.value("Hello World!")

resp = self.execute(my_cool_coroutine)

self.assertIsInstance(resp, ring.coroutine.v1.coroutine_pb2.ExecuteResponse)

resp.exit.result.output.Unpack(
output := google.protobuf.wrappers_pb2.StringValue()
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
self.assertEqual(output.value, "Hello World!")
output = pickle.loads(output_bytes.value)
self.assertEqual(output, "Hello World!")

0 comments on commit 090bcea

Please sign in to comment.