Skip to content

Commit

Permalink
fix running examples as tests
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
  • Loading branch information
achille-roussel committed Jun 18, 2024
1 parent 0918918 commit d016221
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 32 deletions.
6 changes: 4 additions & 2 deletions examples/auto_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ def third_party_api_call(x):
# Simulate a third-party API call that fails.
print(f"Simulating third-party API call with {x}")
if x < 3:
print("RAISE EXCEPTION")
raise requests.RequestException("Simulated failure")
else:
return "SUCCESS"


# Use the `dispatch.function` decorator to declare a stateful function.
@dispatch.function
def application():
def auto_retry():
x = rng.randint(0, 5)
return third_party_api_call(x)


dispatch.run(application())
if __name__ == "__main__":
print(dispatch.run(auto_retry()))
3 changes: 2 additions & 1 deletion examples/fanout.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ async def fanout():
return await reduce_stargazers(repos)


print(dispatch.run(fanout()))
if __name__ == "__main__":
print(dispatch.run(fanout()))
12 changes: 7 additions & 5 deletions examples/getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import dispatch


# Use the `dispatch.function` decorator declare a stateful function.
@dispatch.function
def publish(url, payload):
r = requests.post(url, data=payload)
r.raise_for_status()
return r.text


# Use the `dispatch.run` function to run the function with automatic error
# handling and retries.
res = dispatch.run(publish("https://httpstat.us/200", {"hello": "world"}))
print(res)
@dispatch.function
async def getting_started():
return await publish("https://httpstat.us/200", {"hello": "world"})


if __name__ == "__main__":
print(dispatch.run(getting_started()))
24 changes: 4 additions & 20 deletions examples/github_stats.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
"""Github repository stats example.
This example demonstrates how to use async functions orchestrated by Dispatch.
Make sure to follow the setup instructions at
https://docs.dispatch.run/dispatch/stateful-functions/getting-started/
Run with:
uvicorn app:app
Logs will show a pipeline of functions being called and their results.
"""

import httpx

import dispatch
Expand All @@ -31,21 +15,21 @@ def get_gh_api(url):


@dispatch.function
async def get_repo_info(repo_owner, repo_name):
def get_repo_info(repo_owner, repo_name):
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}"
repo_info = get_gh_api(url)
return repo_info


@dispatch.function
async def get_contributors(repo_info):
def get_contributors(repo_info):
url = repo_info["contributors_url"]
contributors = get_gh_api(url)
return contributors


@dispatch.function
async def main():
async def github_stats():
repo_info = await get_repo_info("dispatchrun", "coroutine")
print(
f"""Repository: {repo_info['full_name']}
Expand All @@ -57,5 +41,5 @@ async def main():


if __name__ == "__main__":
contributors = dispatch.run(main())
contributors = dispatch.run(github_stats())
print(f"Contributors: {len(contributors)}")
29 changes: 29 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import dispatch.test

from .auto_retry import auto_retry
from .fanout import fanout
from .getting_started import getting_started
from .github_stats import github_stats


@dispatch.test.function
async def test_auto_retry():
assert await auto_retry() == "SUCCESS"


@dispatch.test.function
async def test_fanout():
contributors = await fanout()
assert len(contributors) >= 15
assert "achille-roussel" in contributors


@dispatch.test.function
async def test_getting_started():
assert await getting_started() == "200 OK"


@dispatch.test.function
async def test_github_stats():
contributors = await github_stats()
assert len(contributors) >= 6
11 changes: 9 additions & 2 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
CoroutineID: TypeAlias = int
CorrelationID: TypeAlias = int

_in_function_call = contextvars.ContextVar("dispatch.scheduler.in_function_call", default=False)
_in_function_call = contextvars.ContextVar(
"dispatch.scheduler.in_function_call", default=False
)


def in_function_call() -> bool:
return bool(_in_function_call.get())
Expand Down Expand Up @@ -523,7 +526,11 @@ def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]
if isinstance(coroutine_yield, RaceDirective):
return set_coroutine_race(state, coroutine, coroutine_yield.awaitables)

yield coroutine_yield
try:
yield coroutine_yield
except Exception as e:
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)
return set_coroutine_result(state, coroutine, coroutine_result)


def set_coroutine_result(
Expand Down
30 changes: 28 additions & 2 deletions src/dispatch/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,17 @@
from dispatch.sdk.v1.error_pb2 import Error
from dispatch.sdk.v1.function_pb2 import RunRequest, RunResponse
from dispatch.sdk.v1.poll_pb2 import PollResult
from dispatch.sdk.v1.status_pb2 import STATUS_OK
from dispatch.sdk.v1.status_pb2 import (
STATUS_DNS_ERROR,
STATUS_HTTP_ERROR,
STATUS_INCOMPATIBLE_STATE,
STATUS_OK,
STATUS_TCP_ERROR,
STATUS_TEMPORARY_ERROR,
STATUS_THROTTLED,
STATUS_TIMEOUT,
STATUS_TLS_ERROR,
)

from .client import EndpointClient
from .server import DispatchServer
Expand Down Expand Up @@ -183,7 +193,18 @@ def make_request(call: Call) -> RunRequest:
res = await self.run(call.endpoint, req)

if res.status != STATUS_OK:
# TODO: emulate retries etc...
if res.status in (
STATUS_TIMEOUT,
STATUS_THROTTLED,
STATUS_TEMPORARY_ERROR,
STATUS_INCOMPATIBLE_STATE,
STATUS_DNS_ERROR,
STATUS_TCP_ERROR,
STATUS_TLS_ERROR,
STATUS_HTTP_ERROR,
):
continue # emulate retries, without backoff for now

if (
res.HasField("exit")
and res.exit.HasField("result")
Expand Down Expand Up @@ -263,14 +284,19 @@ async def main(coro: Coroutine[Any, Any, None]) -> None:
api = Service()
app = Dispatch(reg)
try:
print("Starting bakend")
async with Server(api) as backend:
print("Starting server")
async with Server(app) as server:
# Here we break through the abstraction layers a bit, it's not
# ideal but it works for now.
reg.client.api_url.value = backend.url
reg.endpoint = server.url
print("BACKEND:", backend.url)
print("SERVER:", server.url)
await coro
finally:
print("DONE!")
await api.close()
# TODO: let's figure out how to get rid of this global registry
# state at some point, which forces tests to be run sequentially.
Expand Down

0 comments on commit d016221

Please sign in to comment.