diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..e5304e02 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +Dockerfile +__pycache__ +*.md +*.yaml +*.yml +dist/* diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..51bfe11e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.12 +WORKDIR /usr/src/dispatch-py + +COPY pyproject.toml . +RUN python -m pip install -e .[dev] + +COPY . . +RUN python -m pip install -e .[dev] + +ENTRYPOINT ["python"] diff --git a/README.md b/README.md index 3c76203a..e24570ac 100644 --- a/README.md +++ b/README.md @@ -13,17 +13,19 @@ Python package to develop applications with the Dispatch platform. [fastapi]: https://fastapi.tiangolo.com/tutorial/first-steps/ -[ngrok]: https://ngrok.com/ [pypi]: https://pypi.org/project/dispatch-py/ [signup]: https://console.dispatch.run/ - [What is Dispatch?](#what-is-dispatch) - [Installation](#installation) + - [Installing the Dispatch CLI](#installing-the-dispatch-cli) + - [Installing the Dispatch SDK](#installing-the-dispatch-sdk) - [Usage](#usage) - - [Configuration](#configuration) + - [Writing Dispatch Applications](#writing-dispatch-applications) + - [Running Dispatch Applications](#running-dispatch-applications) + - [Writing Transactional Applications with Dispatch](#writing-transactional-applications-with-dispatch) - [Integration with FastAPI](#integration-with-fastapi) - - [Local Testing](#local-testing) - - [Distributed Coroutines for Python](#distributed-coroutines-for-python) + - [Configuration](#configuration) - [Serialization](#serialization) - [Examples](#examples) - [Contributing](#contributing) @@ -32,166 +34,75 @@ Python package to develop applications with the Dispatch platform. Dispatch is a platform for developing scalable & reliable distributed systems. -Dispatch provides a simple programming model based on *Distributed Coroutines*, -allowing complex, dynamic workflows to be expressed with regular code and -control flow. - -Dispatch schedules function calls across a fleet of service instances, -incorporating **fair scheduling**, transparent **retry of failed operations**, -and **durability**. - To get started, follow the instructions to [sign up for Dispatch][signup] 🚀. ## Installation -This package is published on [PyPI][pypi] as **dispatch-py**, to install: -```sh -pip install dispatch-py -``` - -## Usage +### Installing the Dispatch CLI -The SDK allows Python applications to declare functions that Dispatch can -orchestrate: +As a pre-requisite, we recommend installing the Dispatch CLI to simplify the +configuration and execution of applications that use Dispatch. On macOS, this +can be done easily using [Homebrew](https://docs.brew.sh/): -```python -@dispatch.function -def action(msg): - ... +```console +brew tap stealthrocket/dispatch +brew install dispatch ``` -The **@dispatch.function** decorator declares a function that can be run by -Dispatch. The call has durable execution semantics; if the function fails -with a temporary error, it is automatically retried, even if the program is -restarted, or if multiple instances are deployed. +Alternatively, you can download the latest `dispatch` binary from the +[Releases](https://github.com/stealthrocket/dispatch/releases) page. -The SDK adds a method to the `action` object, allowing the program to -dispatch an asynchronous invocation of the function; for example: - -```python -action.dispatch('hello') -``` +*Note that this step is optional, applications that use Dispatch can run without +the CLI, passing configuration through environment variables or directly in the +code. However, the CLI automates the onboarding flow and simplifies the +configuration, so we recommend starting with it.* -### Configuration +### Installing the Dispatch SDK -In order for Dispatch to interact with functions remotely, the SDK needs to be -configured with the address at which the server can be reached. The Dispatch -API Key must also be set, and optionally, a public signing key should be -configured to verify that requests originated from Dispatch. These -configuration options can be passed as arguments to the -the `Dispatch` constructor, but by default they will be loaded from environment -variables: +The Python package is published on [PyPI][pypi] as **dispatch-py**, to install: +```console +pip install dispatch-py +``` -| Environment Variable | Value Example | -| :-------------------------- | :--------------------------------- | -| `DISPATCH_API_KEY` | `d4caSl21a5wdx5AxMjdaMeWehaIyXVnN` | -| `DISPATCH_ENDPOINT_URL` | `https://service.domain.com` | -| `DISPATCH_VERIFICATION_KEY` | `-----BEGIN PUBLIC KEY-----...` | +## Usage -Finally, the `Dispatch` instance needs to mount a route on a HTTP server in to -receive requests from Dispatch. At this time, the SDK integrates with -FastAPI; adapters for other popular Python frameworks will be added in the -future. +### Writing Dispatch Applications -### Integration with FastAPI +The following snippet shows how to write a very simple Dispatch application +that does the following: -The following code snippet is a complete example showing how to install a -`Dispatch` instance on a [FastAPI][fastapi] server: +1. declare a dispatch function named `greet` which can run asynchronously +2. schedule a call to `greet` with the argument `World` +3. run until all dispatched calls have completed ```python -from fastapi import FastAPI -from dispatch.fastapi import Dispatch -import requests - -app = FastAPI() -dispatch = Dispatch(app) +# main.py +import dispatch @dispatch.function -def publish(url, payload): - r = requests.post(url, data=payload) - r.raise_for_status() +def greet(msg: str): + print(f"Hello, ${msg}!") -@app.get('/') -def root(): - publish.dispatch('https://httpstat.us/200', {'hello': 'world'}) - return {'answer': 42} +dispatch.run(lambda: greet.dispatch('World')) ``` -In this example, GET requests on the HTTP server dispatch calls to the -`publish` function. The function runs concurrently to the rest of the -program, driven by the Dispatch SDK. - -The instantiation of the `Dispatch` object on the `FastAPI` application -automatically installs the HTTP route needed for Dispatch to invoke functions. +Obviously, this is just an example, a real application would perform much more +interesting work, but it's a good start to get a sense of how to use Dispatch. -### Local Testing - -#### Mock Dispatch - -The SDK ships with a mock Dispatch server. It can be used to quickly test your -local functions, without requiring internet access. - -Note that the mock Dispatch server has very limited scheduling capabilities. +### Running Dispatch Applications +The simplest way to run a Dispatch application is to use the Dispatch CLI, first +we need to login: ```console -python -m dispatch.test $DISPATCH_ENDPOINT_URL +dispatch login ``` -The command will start a mock Dispatch server and print the configuration -for the SDK. - -For example, if your functions were exposed through a local endpoint -listening on `http://127.0.0.1:8000`, you could run: - +Then we are ready to run the example program we wrote above: ```console -$ python -m dispatch.test http://127.0.0.1:8000 -Spawned a mock Dispatch server on 127.0.0.1:4450 - -Dispatching function calls to the endpoint at http://127.0.0.1:8000 - -The Dispatch SDK can be configured with: - - export DISPATCH_API_URL="http://127.0.0.1:4450" - export DISPATCH_API_KEY="test" - export DISPATCH_ENDPOINT_URL="http://127.0.0.1:8000" - export DISPATCH_VERIFICATION_KEY="Z+nTe2VRcw8t8Ihx++D+nXtbO28nwjWIOTLRgzrelYs=" -``` - -#### Real Dispatch - -To test local functions with the production instance of Dispatch, it needs -to be able to access your local endpoint. - -A common approach consists of using [ngrok][ngrok] to setup a public endpoint -that forwards to the server running on localhost. - -For example, assuming the server is running on port 8000 (which is the default -with FastAPI), the command to create a ngrok tunnel is: -```sh -ngrok http http://localhost:8000 +dispatch run -- python3 main.py ``` -Running this command opens a terminal interface that looks like this: -``` -ngrok -Build better APIs with ngrok. Early access: ngrok.com/early-access - -Session Status online -Account Alice (Plan: Free) -Version 3.6.0 -Region United States (California) (us-cal-1) -Latency - -Web Interface http://127.0.0.1:4040 -Forwarding https://f441-2600-1700-2802-e01f-6861-dbc9-d551-ecfb.ngrok-free.app -> http://localhost:8000 -``` -To configure the Dispatch SDK, set the endpoint URL to the endpoint for the -**Forwarding** parameter; each ngrok instance is unique, so you would have a -different value, but in this example it would be: -```sh -export DISPATCH_ENDPOINT_URL="https://f441-2600-1700-2802-e01f-6861-dbc9-d551-ecfb.ngrok-free.app" -``` - -### Distributed Coroutines for Python +### Writing Transactional Applications with Dispatch The `@dispatch.function` decorator can also be applied to Python coroutines (a.k.a. *async* functions), in which case each `await` point becomes a @@ -243,11 +154,67 @@ async def transform(msg): ``` Dispatch converts Python coroutines to *Distributed Coroutines*, which can be -suspended and resumed on any instance of a service across a fleet. +suspended and resumed on any instance of a service across a fleet. For a deep +dive on these concepts, read our blog post on +[*Distributed Coroutines with a Native Python Extension and Dispatch*](https://stealthrocket.tech/blog/distributed-coroutines-in-python). + +### Integration with FastAPI + +Many web applications written in Python are developed using [FastAPI][fastapi]. +Dispatch can integrate with these applications by instantiating a +`dispatch.fastapi.Dispatch` object. When doing so, the Dispatch functions +declared by the program can be invoked remotely over the same HTTP interface +used for the [FastAPI][fastapi] handlers. + +The following code snippet is a complete example showing how to install a +`Dispatch` instance on a [FastAPI][fastapi] server: + +```python +from fastapi import FastAPI +from dispatch.fastapi import Dispatch +import requests + +app = FastAPI() +dispatch = Dispatch(app) + +@dispatch.function +def publish(url, payload): + r = requests.post(url, data=payload) + r.raise_for_status() + +@app.get('/') +def root(): + publish.dispatch('https://httpstat.us/200', {'hello': 'world'}) + return {'answer': 42} +``` + +In this example, GET requests on the HTTP server dispatch calls to the +`publish` function. The function runs concurrently to the rest of the +program, driven by the Dispatch SDK. + +### Configuration + +The Dispatch CLI automatically configures the SDK, so manual configuration is +usually not required when running Dispatch applications. However, in some +advanced cases, it might be useful to explicitly set configuration options. + +In order for Dispatch to interact with functions remotely, the SDK needs to be +configured with the address at which the server can be reached. The Dispatch +API Key must also be set, and optionally, a public signing key should be +configured to verify that requests originated from Dispatch. These +configuration options can be passed as arguments to the +the `Dispatch` constructor, but by default they will be loaded from environment +variables: + +| Environment Variable | Value Example | +| :-------------------------- | :--------------------------------- | +| `DISPATCH_API_KEY` | `d4caSl21a5wdx5AxMjdaMeWehaIyXVnN` | +| `DISPATCH_ENDPOINT_URL` | `https://service.domain.com` | +| `DISPATCH_VERIFICATION_KEY` | `-----BEGIN PUBLIC KEY-----...` | ### Serialization -Dispatch uses the [pickle] library to serialize coroutines. +Dispatch uses the [pickle][pickle] library to serialize coroutines. [pickle]: https://docs.python.org/3/library/pickle.html @@ -266,7 +233,6 @@ For help with a serialization issues, please submit a [GitHub issue][issues]. [issues]: https://github.com/stealthrocket/dispatch-py/issues - ## Examples Check out the [examples](examples/) directory for code samples to help you get diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 78c4be57..f51106ca 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -2,27 +2,97 @@ from __future__ import annotations +import os +from concurrent import futures +from http.server import ThreadingHTTPServer +from typing import Any, Callable, Coroutine, Optional, TypeVar, overload +from urllib.parse import urlsplit + +from typing_extensions import ParamSpec, TypeAlias + import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race -from dispatch.function import DEFAULT_API_URL, Client, Registry, Reset +from dispatch.function import DEFAULT_API_URL, Client, Function, Registry, Reset +from dispatch.http import Dispatch from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output from dispatch.status import Status __all__ = [ + "Call", "Client", - "DispatchID", "DEFAULT_API_URL", + "DispatchID", + "Error", "Input", "Output", - "Call", - "Error", + "Registry", "Reset", "Status", - "call", - "gather", "all", "any", + "call", + "function", + "gather", "race", - "Registry", + "run", + "serve", ] + + +P = ParamSpec("P") +T = TypeVar("T") + +_registry: Optional[Registry] = None + + +def default_registry(): + global _registry + if not _registry: + _registry = Registry() + return _registry + + +@overload +def function(func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... + + +@overload +def function(func: Callable[P, T]) -> Function[P, T]: ... + + +def function(func): + return default_registry().function(func) + + +def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwargs): + """Run the default dispatch server. The default server uses a function + registry where functions tagged by the `@dispatch.function` decorator are + registered. + + This function is intended to be used with the `dispatch` CLI tool, which + automatically configures environment variables to connect the local server + to the Dispatch bridge API. + + Args: + init: An initialization function called after binding the server address + but before entering the event loop to handle requests. + + args: Positional arguments to pass to the entrypoint. + + kwargs: Keyword arguments to pass to the entrypoint. + + Returns: + The return value of the entrypoint function. + """ + address = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000") + parsed_url = urlsplit("//" + address) + server_address = (parsed_url.hostname or "", parsed_url.port or 0) + server = ThreadingHTTPServer(server_address, Dispatch(default_registry())) + try: + if init is not None: + init(*args, **kwargs) + server.serve_forever() + finally: + server.shutdown() + server.server_close() diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 5d32eaba..1e4a7095 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -1,4 +1,4 @@ -"""Integration of Dispatch programmable endpoints for FastAPI. +"""Integration of Dispatch functions with FastAPI. Example: @@ -18,7 +18,6 @@ def read_root(): """ import asyncio -import base64 import logging import os from datetime import timedelta @@ -36,8 +35,7 @@ def read_root(): CaseInsensitiveDict, Ed25519PublicKey, Request, - public_key_from_bytes, - public_key_from_pem, + parse_verification_key, verify_request, ) from dispatch.status import Status @@ -46,9 +44,7 @@ def read_root(): class Dispatch(Registry): - """A Dispatch programmable endpoint, powered by FastAPI.""" - - __slots__ = ("client",) + """A Dispatch instance, powered by FastAPI.""" def __init__( self, @@ -65,9 +61,9 @@ def __init__( Args: app: The FastAPI app to configure. - endpoint: Full URL of the application the Dispatch programmable - endpoint will be running on. Uses the value of the - DISPATCH_ENDPOINT_URL environment variable by default. + endpoint: Full URL of the application the Dispatch instance will + be running on. Uses the value of the DISPATCH_ENDPOINT_URL + environment variable by default. verification_key: Key to use when verifying signed requests. Uses the value of the DISPATCH_VERIFICATION_KEY environment variable @@ -90,73 +86,12 @@ def __init__( raise ValueError( "missing FastAPI app as first argument of the Dispatch constructor" ) - - endpoint_from = "endpoint argument" - if not endpoint: - endpoint = os.getenv("DISPATCH_ENDPOINT_URL") - endpoint_from = "DISPATCH_ENDPOINT_URL" - if not endpoint: - raise ValueError( - "missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable" - ) - - logger.info("configuring Dispatch endpoint %s", endpoint) - - parsed_url = urlparse(endpoint) - if not parsed_url.netloc or not parsed_url.scheme: - raise ValueError( - f"{endpoint_from} must be a full URL with protocol and domain (e.g., https://example.com)" - ) - - verification_key = parse_verification_key(verification_key) - if verification_key: - base64_key = base64.b64encode(verification_key.public_bytes_raw()).decode() - logger.info("verifying request signatures using key %s", base64_key) - elif parsed_url.scheme != "bridge": - logger.warning( - "request verification is disabled because DISPATCH_VERIFICATION_KEY is not set" - ) - super().__init__(endpoint, api_key=api_key, api_url=api_url) - + verification_key = parse_verification_key(verification_key, endpoint=endpoint) function_service = _new_app(self, verification_key) app.mount("/dispatch.sdk.v1.FunctionService", function_service) -def parse_verification_key( - verification_key: Optional[Union[Ed25519PublicKey, str, bytes]], -) -> Optional[Ed25519PublicKey]: - if isinstance(verification_key, Ed25519PublicKey): - return verification_key - - from_env = False - if not verification_key: - try: - verification_key = os.environ["DISPATCH_VERIFICATION_KEY"] - except KeyError: - return None - from_env = True - - if isinstance(verification_key, bytes): - verification_key = verification_key.decode() - - # Be forgiving when accepting keys in PEM format, which may span - # multiple lines. Users attempting to pass a PEM key via an environment - # variable may accidentally include literal "\n" bytes rather than a - # newline char (0xA). - try: - return public_key_from_pem(verification_key.replace("\\n", "\n")) - except ValueError: - pass - - try: - return public_key_from_bytes(base64.b64decode(verification_key.encode())) - except ValueError: - if from_env: - raise ValueError(f"invalid DISPATCH_VERIFICATION_KEY '{verification_key}'") - raise ValueError(f"invalid verification key '{verification_key}'") - - class _ConnectResponse(fastapi.Response): media_type = "application/grpc+proto" @@ -246,39 +181,39 @@ async def execute(request: fastapi.Request): raise _ConnectError( 500, "internal", f"function '{req.function}' fatal error" ) - else: - response = output._message - status = Status(response.status) - if response.HasField("poll"): - logger.debug( - "function '%s' polling with %d call(s)", - req.function, - len(response.poll.calls), - ) - elif response.HasField("exit"): - exit = response.exit - if not exit.HasField("result"): - logger.debug("function '%s' exiting with no result", req.function) - else: - result = exit.result - if result.HasField("output"): - logger.debug( - "function '%s' exiting with output value", req.function - ) - elif result.HasField("error"): - err = result.error - logger.debug( - "function '%s' exiting with error: %s (%s)", - req.function, - err.message, - err.type, - ) - if exit.HasField("tail_call"): + response = output._message + status = Status(response.status) + + if response.HasField("poll"): + logger.debug( + "function '%s' polling with %d call(s)", + req.function, + len(response.poll.calls), + ) + elif response.HasField("exit"): + exit = response.exit + if not exit.HasField("result"): + logger.debug("function '%s' exiting with no result", req.function) + else: + result = exit.result + if result.HasField("output"): logger.debug( - "function '%s' tail calling function '%s'", - exit.tail_call.function, + "function '%s' exiting with output value", req.function ) + elif result.HasField("error"): + err = result.error + logger.debug( + "function '%s' exiting with error: %s (%s)", + req.function, + err.message, + err.type, + ) + if exit.HasField("tail_call"): + logger.debug( + "function '%s' tail calling function '%s'", + exit.tail_call.function, + ) logger.debug("finished handling run request with status %s", status.name) return fastapi.Response( diff --git a/src/dispatch/function.py b/src/dispatch/function.py index a711880d..44169a8d 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -176,7 +176,7 @@ class Registry: def __init__( self, - endpoint: str, + endpoint: Optional[str] = None, api_key: Optional[str] = None, api_url: Optional[str] = None, ): @@ -184,6 +184,8 @@ def __init__( Args: endpoint: URL of the endpoint that the function is accessible from. + Uses the value of the DISPATCH_ENDPOINT_URL environment variable + by default. api_key: Dispatch API key to use for authentication when dispatching calls to functions. Uses the value of the @@ -193,7 +195,24 @@ def __init__( to functions. Uses the value of the DISPATCH_API_URL environment variable if set, otherwise defaults to the public Dispatch API (DEFAULT_API_URL). + + Raises: + ValueError: If any of the required arguments are missing. """ + endpoint_from = "endpoint argument" + if not endpoint: + endpoint = os.getenv("DISPATCH_ENDPOINT_URL") + endpoint_from = "DISPATCH_ENDPOINT_URL" + if not endpoint: + raise ValueError( + "missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable" + ) + parsed_url = urlparse(endpoint) + if not parsed_url.netloc or not parsed_url.scheme: + raise ValueError( + f"{endpoint_from} must be a full URL with protocol and domain (e.g., https://example.com)" + ) + logger.info("configuring Dispatch endpoint %s", endpoint) self.functions: Dict[str, PrimitiveFunction] = {} self.endpoint = endpoint self.client = Client(api_key=api_key, api_url=api_url) diff --git a/src/dispatch/http.py b/src/dispatch/http.py new file mode 100644 index 00000000..1d5b1bd8 --- /dev/null +++ b/src/dispatch/http.py @@ -0,0 +1,196 @@ +"""Integration of Dispatch functions with http.""" + +import logging +import os +from datetime import timedelta +from http.server import BaseHTTPRequestHandler +from typing import Optional, Union + +from http_message_signatures import InvalidSignature + +from dispatch.function import Registry +from dispatch.proto import Input +from dispatch.sdk.v1 import function_pb2 as function_pb +from dispatch.signature import ( + CaseInsensitiveDict, + Ed25519PublicKey, + Request, + parse_verification_key, + verify_request, +) +from dispatch.status import Status + +logger = logging.getLogger(__name__) + + +class Dispatch: + """A Dispatch instance to be serviced by a http server. The Dispatch class + acts as a factory for DispatchHandler objects, by capturing the variables + that would be shared between all DispatchHandler instances it created.""" + + def __init__( + self, + registry: Registry, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + ): + """Initialize a Dispatch http handler. + + Args: + registry: The registry of functions to be serviced. + """ + self.registry = registry + self.verification_key = parse_verification_key(verification_key) + + def __call__(self, request, client_address, server): + return FunctionService( + request, + client_address, + server, + registry=self.registry, + verification_key=self.verification_key, + ) + + +class FunctionService(BaseHTTPRequestHandler): + + def __init__( + self, + request, + client_address, + server, + registry: Registry, + verification_key: Optional[Ed25519PublicKey] = None, + ): + self.registry = registry + self.verification_key = verification_key + self.error_content_type = "application/json" + super().__init__(request, client_address, server) + + def send_error_response_invalid_argument(self, message: str): + self.send_error_response(400, "invalid_argument", message) + + def send_error_response_not_found(self, message: str): + self.send_error_response(404, "not_found", message) + + def send_error_response_unauthenticated(self, message: str): + self.send_error_response(401, "unauthenticated", message) + + def send_error_response_permission_denied(self, message: str): + self.send_error_response(403, "permission_denied", message) + + def send_error_response_internal(self, message: str): + self.send_error_response(500, "internal", message) + + def send_error_response(self, status: int, code: str, message: str): + body = f'{{"code":"{code}","message":"{message}"}}'.encode() + self.send_response(status) + self.send_header("Content-Type", self.error_content_type) + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def do_POST(self): + if self.path != "/dispatch.sdk.v1.FunctionService/Run": + self.send_error_response_not_found("path not found") + return + + content_length = int(self.headers.get("Content-Length", 0)) + if content_length == 0: + self.send_error_response_invalid_argument("content length is required") + return + if content_length < 0: + self.send_error_response_invalid_argument("content length is negative") + return + if content_length > 16_000_000: + self.send_error_response_invalid_argument("content length is too large") + return + + data: bytes = self.rfile.read(content_length) + logger.debug("handling run request with %d byte body", len(data)) + + if self.verification_key is not None: + signed_request = Request( + method="POST", + url=self.requestline, # TODO: need full URL + headers=CaseInsensitiveDict(self.headers), + body=data, + ) + max_age = timedelta(minutes=5) + try: + verify_request(signed_request, verification_key, max_age) + except ValueError as e: + self.send_error_response_unauthenticated(str(e)) + return + except InvalidSignature as e: + # The http_message_signatures package sometimes wraps does not + # attach a message to the exception, so we set a default to + # have some context about the reason for the error. + message = str(e) or "invalid signature" + self.send_error_response_permission_denied(message) + return + + req = function_pb.RunRequest.FromString(data) + if not req.function: + self.send_error_response_invalid_argument("function is required") + return + + try: + func = self.registry.functions[req.function] + except KeyError: + logger.debug("function '%s' not found", req.function) + self.send_error_response_not_found( + f"function '{req.function}' does not exist" + ) + return + + try: + output = func._primitive_call(Input(req)) + except Exception: + # This indicates that an exception was raised in a primitive + # function. Primitive functions must catch exceptions, categorize + # them in order to derive a Status, and then return a RunResponse + # that carries the Status and the error details. A failure to do + # so indicates a problem, and we return a 500 rather than attempt + # to catch and categorize the error here. + logger.error("function '%s' fatal error", req.function, exc_info=True) + self.send_error_response_internal(f"function '{req.function}' fatal error") + return + + response = output._message + status = Status(response.status) + + if response.HasField("poll"): + logger.debug( + "function '%s' polling with %d call(s)", + req.function, + len(response.poll.calls), + ) + elif response.HasField("exit"): + exit = response.exit + if not exit.HasField("result"): + logger.debug("function '%s' exiting with no result", req.function) + else: + result = exit.result + if result.HasField("output"): + logger.debug( + "function '%s' exiting with output value", req.function + ) + elif result.HasField("error"): + err = result.error + logger.debug( + "function '%s' exiting with error: %s (%s)", + req.function, + err.message, + err.type, + ) + if exit.HasField("tail_call"): + logger.debug( + "function '%s' tail calling function '%s'", + exit.tail_call.function, + ) + + logger.debug("finished handling run request with status %s", status.name) + self.send_response(200) + self.send_header("Content-Type", "application/proto") + self.end_headers() + self.wfile.write(response.SerializeToString()) diff --git a/src/dispatch/signature/__init__.py b/src/dispatch/signature/__init__.py index a36173b8..f8075ade 100644 --- a/src/dispatch/signature/__init__.py +++ b/src/dispatch/signature/__init__.py @@ -1,6 +1,9 @@ +import base64 import logging +import os from datetime import datetime, timedelta -from typing import Sequence, Set, cast +from typing import Optional, Sequence, Set, Union, cast +from urllib.parse import urlparse import http_sfv from cryptography.hazmat.primitives.asymmetric.ed25519 import ( @@ -123,3 +126,69 @@ def extract_covered_components(result: VerifyResult) -> Set[str]: covered_components.add(item.value) return covered_components + + +def parse_verification_key( + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]], + endpoint: Optional[str] = None, +) -> Optional[Ed25519PublicKey]: + # This function depends a lot on global context like enviornment variables + # and logging configuration. It's not ideal for testing, but it's useful to + # unify the behavior of the Dispatch class everywhere the signature module + # is used. + if isinstance(verification_key, Ed25519PublicKey): + return verification_key + + # Keep track of whether the key was obtained from the environment, so that + # we can tweak the error messages accordingly. + from_env = False + if not verification_key: + try: + verification_key = os.environ["DISPATCH_VERIFICATION_KEY"] + except KeyError: + return None + from_env = verification_key is not None + + if isinstance(verification_key, bytes): + verification_key = verification_key.decode() + + # Be forgiving when accepting keys in PEM format, which may span + # multiple lines. Users attempting to pass a PEM key via an environment + # variable may accidentally include literal "\n" bytes rather than a + # newline char (0xA). + public_key: Optional[Ed25519PublicKey] = None + try: + public_key = public_key_from_pem(verification_key.replace("\\n", "\n")) + except ValueError: + pass + + # If the key is not in PEM format, try to decode it as base64 string. + if not public_key: + try: + public_key = public_key_from_bytes( + base64.b64decode(verification_key.encode()) + ) + except ValueError: + if from_env: + raise ValueError( + f"invalid DISPATCH_VERIFICATION_KEY '{verification_key}'" + ) + raise ValueError(f"invalid verification key '{verification_key}'") + + # Print diagostic information about the key, this is useful for debugging. + url_scheme = "" + if endpoint: + try: + parsed_url = urlparse(endpoint) + url_scheme = parsed_url.scheme + except: + pass + + if public_key: + base64_key = base64.b64encode(public_key.public_bytes_raw()).decode() + logger.info("verifying request signatures using key %s", base64_key) + elif url_scheme != "bridge": + logger.warning( + "request verification is disabled because DISPATCH_VERIFICATION_KEY is not set" + ) + return public_key diff --git a/tests/dispatch/signature/test_signature.py b/tests/dispatch/signature/test_signature.py index fc64b8df..12d98298 100644 --- a/tests/dispatch/signature/test_signature.py +++ b/tests/dispatch/signature/test_signature.py @@ -1,13 +1,18 @@ +import base64 +import os import unittest from datetime import datetime, timedelta +from unittest import mock from http_message_signatures import HTTPMessageSigner from http_message_signatures._algorithms import ED25519 from dispatch.signature import ( CaseInsensitiveDict, + Ed25519PublicKey, InvalidSignature, Request, + parse_verification_key, sign_request, verify_request, ) @@ -33,6 +38,18 @@ """ ) +public_key2_pem = """-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs= +-----END PUBLIC KEY----- +""" +public_key2_pem2 = """-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs= +-----END PUBLIC KEY----- +""" +public_key2 = public_key_from_pem(public_key2_pem) +public_key2_bytes = public_key2.public_bytes_raw() +public_key2_b64 = base64.b64encode(public_key2_bytes) + class TestSignature(unittest.TestCase): def setUp(self): @@ -125,3 +142,70 @@ def test_known_signature(self): ValueError, "public key 'test-key-ed25519' not available" ): verify_request(request, public_key, max_age=timedelta(weeks=9000)) + + @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem}) + def test_parse_verification_key_env_pem_str(self): + verification_key = parse_verification_key(None) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem2}) + def test_parse_verification_key_env_pem_escaped_newline_str(self): + verification_key = parse_verification_key(None) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + @mock.patch.dict( + os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_b64.decode()} + ) + def test_parse_verification_key_env_b64_str(self): + verification_key = parse_verification_key(None) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + def test_parse_verification_key_none(self): + # The verification key is optional. Both Dispatch(verification_key=...) and + # DISPATCH_VERIFICATION_KEY may be omitted/None. + verification_key = parse_verification_key(None) + self.assertIsNone(verification_key) + + def test_parse_verification_key_ed25519publickey(self): + verification_key = parse_verification_key(public_key2) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + def test_parse_verification_key_pem_str(self): + verification_key = parse_verification_key(public_key2_pem) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + def test_parse_verification_key_pem_escaped_newline_str(self): + verification_key = parse_verification_key(public_key2_pem2) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + def test_parse_verification_key_pem_bytes(self): + verification_key = parse_verification_key(public_key2_pem.encode()) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + def test_parse_verification_key_b64_str(self): + verification_key = parse_verification_key(public_key2_b64.decode()) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + def test_parse_verification_key_b64_bytes(self): + verification_key = parse_verification_key(public_key2_b64) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes) + + def test_parse_verification_key_invalid(self): + with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"): + parse_verification_key("foo") + + @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"}) + def test_parse_verification_key_invalid_env(self): + with self.assertRaisesRegex( + ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'" + ): + parse_verification_key(None) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 1bbac55c..c4fd58a9 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -14,21 +14,15 @@ from fastapi.testclient import TestClient from dispatch.experimental.durable.registry import clear_functions -from dispatch.fastapi import Dispatch, parse_verification_key +from dispatch.fastapi import Dispatch from dispatch.function import Arguments, Error, Function, Input, Output from dispatch.proto import _any_unpickle as any_unpickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.signature import public_key_from_pem +from dispatch.signature import parse_verification_key, public_key_from_pem from dispatch.status import Status from dispatch.test import EndpointClient -public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----" -public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----" -public_key = public_key_from_pem(public_key_pem) -public_key_bytes = public_key.public_bytes_raw() -public_key_b64 = base64.b64encode(public_key_bytes) - def create_dispatch_instance(app, endpoint): return Dispatch( @@ -107,71 +101,6 @@ def my_function(input: Input) -> Output: self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") - @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem}) - def test_parse_verification_key_env_pem_str(self): - verification_key = parse_verification_key(None) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem2}) - def test_parse_verification_key_env_pem_escaped_newline_str(self): - verification_key = parse_verification_key(None) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_b64.decode()}) - def test_parse_verification_key_env_b64_str(self): - verification_key = parse_verification_key(None) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - def test_parse_verification_key_none(self): - # The verification key is optional. Both Dispatch(verification_key=...) and - # DISPATCH_VERIFICATION_KEY may be omitted/None. - verification_key = parse_verification_key(None) - self.assertIsNone(verification_key) - - def test_parse_verification_key_ed25519publickey(self): - verification_key = parse_verification_key(public_key) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - def test_parse_verification_key_pem_str(self): - verification_key = parse_verification_key(public_key_pem) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - def test_parse_verification_key_pem_escaped_newline_str(self): - verification_key = parse_verification_key(public_key_pem2) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - def test_parse_verification_key_pem_bytes(self): - verification_key = parse_verification_key(public_key_pem.encode()) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - def test_parse_verification_key_b64_str(self): - verification_key = parse_verification_key(public_key_b64.decode()) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - def test_parse_verification_key_b64_bytes(self): - verification_key = parse_verification_key(public_key_b64) - self.assertIsInstance(verification_key, Ed25519PublicKey) - self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) - - def test_parse_verification_key_invalid(self): - with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"): - parse_verification_key("foo") - - @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"}) - def test_parse_verification_key_invalid_env(self): - with self.assertRaisesRegex( - ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'" - ): - parse_verification_key(None) - def response_output(resp: function_pb.RunResponse) -> Any: return any_unpickle(resp.exit.result.output) diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 00000000..f79c8d72 --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,110 @@ +import base64 +import os +import pickle +import struct +import threading +import unittest +from http.server import HTTPServer +from typing import Any +from unittest import mock + +import fastapi +import google.protobuf.any_pb2 +import google.protobuf.wrappers_pb2 +import httpx +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey + +from dispatch.experimental.durable.registry import clear_functions +from dispatch.function import Arguments, Error, Function, Input, Output, Registry +from dispatch.http import Dispatch +from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.sdk.v1 import call_pb2 as call_pb +from dispatch.sdk.v1 import function_pb2 as function_pb +from dispatch.signature import parse_verification_key, public_key_from_pem +from dispatch.status import Status +from dispatch.test import EndpointClient + +public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----" +public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----" +public_key = public_key_from_pem(public_key_pem) +public_key_bytes = public_key.public_bytes_raw() +public_key_b64 = base64.b64encode(public_key_bytes) + +from datetime import datetime + + +def create_dispatch_instance(endpoint: str): + return Dispatch( + Registry( + endpoint=endpoint, + api_key="0000000000000000", + api_url="http://127.0.0.1:10000", + ), + ) + + +class TestHTTP(unittest.TestCase): + def setUp(self): + self.server_address = ("127.0.0.1", 9999) + self.endpoint = f"http://{self.server_address[0]}:{self.server_address[1]}" + self.dispatch = create_dispatch_instance(self.endpoint) + self.client = httpx.Client(timeout=1.0) + self.server = HTTPServer(self.server_address, self.dispatch) + self.thread = threading.Thread( + target=lambda: self.server.serve_forever(poll_interval=0.05) + ) + self.thread.start() + + def tearDown(self): + self.server.shutdown() + self.thread.join(timeout=1.0) + self.client.close() + self.server.server_close() + + def test_content_length_missing(self): + resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run") + body = resp.read() + self.assertEqual(resp.status_code, 400) + self.assertEqual( + body, b'{"code":"invalid_argument","message":"content length is required"}' + ) + + def test_content_length_too_large(self): + resp = self.client.post( + f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run", + data=b"a" * 16_000_001, + ) + body = resp.read() + self.assertEqual(resp.status_code, 400) + self.assertEqual( + body, b'{"code":"invalid_argument","message":"content length is too large"}' + ) + + def test_simple_request(self): + @self.dispatch.registry.primitive_function + def my_function(input: Input) -> Output: + return Output.value( + f"You told me: '{input.input}' ({len(input.input)} characters)" + ) + + client = EndpointClient.from_url(self.endpoint) + + pickled = pickle.dumps("Hello World!") + input_any = google.protobuf.any_pb2.Any() + input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) + + req = function_pb.RunRequest( + function=my_function.name, + input=input_any, + ) + + resp = client.run(req) + + self.assertIsInstance(resp, function_pb.RunResponse) + + resp.exit.result.output.Unpack( + output_bytes := google.protobuf.wrappers_pb2.BytesValue() + ) + output = pickle.loads(output_bytes.value) + + self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")