diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1acf51d1..c5fc2f47 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ['3.10', '3.11', '3.12'] + python: ['3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python }} diff --git a/pyproject.toml b/pyproject.toml index 546014b5..95ab6ed6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "dispatch-py" description = "Develop reliable distributed systems on the Dispatch platform." readme = "README.md" dynamic = ["version"] -requires-python = ">= 3.10" +requires-python = ">= 3.9" dependencies = [ "grpcio >= 1.60.0", "protobuf >= 4.24.0", @@ -17,7 +17,8 @@ dependencies = [ "tblib >= 3.0.0", "docopt >= 0.6.2", "types-docopt >= 0.6.11.4", - "httpx >= 0.27.0" + "httpx >= 0.27.0", + "typing_extensions >= 4.10" ] [project.optional-dependencies] diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index cf9d4c93..79701a73 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -48,17 +48,17 @@ def race(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc] return (yield RaceDirective(awaitables)) -@dataclass(slots=True) +@dataclass class AllDirective: awaitables: tuple[Awaitable[Any], ...] -@dataclass(slots=True) +@dataclass class AnyDirective: awaitables: tuple[Awaitable[Any], ...] -@dataclass(slots=True) +@dataclass class RaceDirective: awaitables: tuple[Awaitable[Any], ...] diff --git a/src/dispatch/experimental/durable/frame.c b/src/dispatch/experimental/durable/frame.c index 1cfb6c6a..b3bb4517 100644 --- a/src/dispatch/experimental/durable/frame.c +++ b/src/dispatch/experimental/durable/frame.c @@ -6,11 +6,12 @@ #define PY_SSIZE_T_CLEAN #include -#if PY_MAJOR_VERSION != 3 || (PY_MINOR_VERSION < 10 || PY_MINOR_VERSION > 13) -# error Python 3.10-3.13 is required +#if PY_MAJOR_VERSION != 3 || (PY_MINOR_VERSION < 9 || PY_MINOR_VERSION > 13) +# error Python 3.9-3.13 is required #endif -// This is a redefinition of the private PyTryBlock from 3.10. +// This is a redefinition of the private PyTryBlock from <= 3.10. +// https://github.com/python/cpython/blob/3.9/Include/cpython/frameobject.h#L11 // https://github.com/python/cpython/blob/3.10/Include/cpython/frameobject.h#L22 typedef struct { int b_type; @@ -18,7 +19,8 @@ typedef struct { int b_level; } PyTryBlock; -// This is a redefinition of the private PyCoroWrapper from 3.10-3.13. +// This is a redefinition of the private PyCoroWrapper from 3.9-3.13. +// https://github.com/python/cpython/blob/3.9/Objects/genobject.c#L830 // https://github.com/python/cpython/blob/3.10/Objects/genobject.c#L884 // https://github.com/python/cpython/blob/3.11/Objects/genobject.c#L1016 // https://github.com/python/cpython/blob/3.12/Objects/genobject.c#L1003 @@ -51,7 +53,9 @@ static int get_frame_iblock(Frame *frame); static void set_frame_iblock(Frame *frame, int iblock); static PyTryBlock *get_frame_blockstack(Frame *frame); -#if PY_MINOR_VERSION == 10 +#if PY_MINOR_VERSION == 9 +#include "frame309.h" +#elif PY_MINOR_VERSION == 10 #include "frame310.h" #elif PY_MINOR_VERSION == 11 #include "frame311.h" @@ -78,7 +82,7 @@ static const char *get_type_name(PyObject *obj) { static PyGenObject *get_generator_like_object(PyObject *obj) { if (PyGen_Check(obj) || PyCoro_CheckExact(obj) || PyAsyncGen_CheckExact(obj)) { - // Note: In Python 3.10-3.13, the PyGenObject, PyCoroObject and PyAsyncGenObject + // Note: In Python 3.9-3.13, the PyGenObject, PyCoroObject and PyAsyncGenObject // have the same layout, they just have different field prefixes (gi_, cr_, ag_). // We cast to PyGenObject here so that the remainder of the code can use the gi_ // prefix for all three cases. @@ -386,7 +390,7 @@ static PyObject *ext_set_frame_stack_at(PyObject *self, PyObject *args) { } PyObject **localsplus = get_frame_localsplus(frame); PyObject *prev = localsplus[index]; - if (Py_IsTrue(unset)) { + if (PyObject_IsTrue(unset)) { localsplus[index] = NULL; } else { Py_INCREF(stack_obj); diff --git a/src/dispatch/experimental/durable/frame.pyi b/src/dispatch/experimental/durable/frame.pyi index e701afd0..ec3e50e0 100644 --- a/src/dispatch/experimental/durable/frame.pyi +++ b/src/dispatch/experimental/durable/frame.pyi @@ -1,31 +1,37 @@ from types import FrameType -from typing import Any, AsyncGenerator, Coroutine, Generator, Tuple +from typing import Any, AsyncGenerator, Coroutine, Generator, Tuple, Union -def get_frame_ip(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int: +def get_frame_ip(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get instruction pointer of a generator or coroutine.""" -def set_frame_ip(frame: FrameType | Coroutine | Generator | AsyncGenerator, ip: int): +def set_frame_ip( + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], ip: int +): """Set instruction pointer of a generator or coroutine.""" -def get_frame_sp(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int: +def get_frame_sp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get stack pointer of a generator or coroutine.""" -def set_frame_sp(frame: FrameType | Coroutine | Generator | AsyncGenerator, sp: int): +def set_frame_sp( + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], sp: int +): """Set stack pointer of a generator or coroutine.""" -def get_frame_bp(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int: +def get_frame_bp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get block pointer of a generator or coroutine.""" -def set_frame_bp(frame: FrameType | Coroutine | Generator | AsyncGenerator, bp: int): +def set_frame_bp( + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], bp: int +): """Set block pointer of a generator or coroutine.""" def get_frame_stack_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, index: int + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int ) -> Tuple[bool, Any]: """Get an object from a generator or coroutine's stack, as an (is_null, obj) tuple.""" def set_frame_stack_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int, unset: bool, value: Any, @@ -33,23 +39,23 @@ def set_frame_stack_at( """Set or unset an object on the stack of a generator or coroutine.""" def get_frame_block_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, index: int + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int ) -> Tuple[int, int, int]: """Get a block from a generator or coroutine.""" def set_frame_block_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int, value: Tuple[int, int, int], ): """Restore a block of a generator or coroutine.""" def get_frame_state( - frame: FrameType | Coroutine | Generator | AsyncGenerator, + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], ) -> int: """Get frame state of a generator or coroutine.""" def set_frame_state( - frame: FrameType | Coroutine | Generator | AsyncGenerator, state: int + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], state: int ): """Set frame state of a generator or coroutine.""" diff --git a/src/dispatch/experimental/durable/frame309.h b/src/dispatch/experimental/durable/frame309.h new file mode 100644 index 00000000..31de5fdf --- /dev/null +++ b/src/dispatch/experimental/durable/frame309.h @@ -0,0 +1,144 @@ +// This is a redefinition of the private/opaque frame object. +// https://github.com/python/cpython/blob/3.9/Include/cpython/frameobject.h#L17 +// +// In Python <= 3.10, `struct _frame` is both the PyFrameObject and +// PyInterpreterFrame. From Python 3.11 onwards, the two were split with the +// PyFrameObject (struct _frame) pointing to struct _PyInterpreterFrame. +struct Frame { + PyObject_VAR_HEAD + struct Frame *f_back; // struct _frame + PyCodeObject *f_code; + PyObject *f_builtins; + PyObject *f_globals; + PyObject *f_locals; + PyObject **f_valuestack; + PyObject **f_stacktop; + PyObject *f_trace; + char f_trace_lines; + char f_trace_opcodes; + PyObject *f_gen; + int f_lasti; + int f_lineno; + int f_iblock; + char f_executing; + PyTryBlock f_blockstack[CO_MAXBLOCKS]; + PyObject *f_localsplus[1]; +}; + +// Python 3.9 and prior didn't have an explicit enum of frame states, +// but we can derive them based on the presence of a frame, and other +// information found on the frame, for compatibility with later versions. +typedef enum _framestate { + FRAME_CREATED = -2, + FRAME_EXECUTING = 0, + FRAME_CLEARED = 4 +} FrameState; + +/* +// This is the definition of PyGenObject for reference to developers +// working on the extension. +// +// Note that PyCoroObject and PyAsyncGenObject have the same layout as +// PyGenObject, however the struct fields have a cr_ and ag_ prefix +// (respectively) rather than a gi_ prefix. In Python <= 3.10, PyCoroObject +// and PyAsyncGenObject have extra fields compared to PyGenObject. In Python +// 3.11 onwards, the three objects are identical (except for field name +// prefixes). The extra fields in Python <= 3.10 are not applicable to the +// extension at this time. +// +// https://github.com/python/cpython/blob/3.9/Include/genobject.h#L15 +typedef struct { + PyObject_HEAD + PyFrameObject *gi_frame; + char gi_running; + PyObject *gi_code; + PyObject *gi_weakreflist; + PyObject *gi_name; + PyObject *gi_qualname; + _PyErr_StackItem gi_exc_state; +} PyGenObject; +*/ + +static Frame *get_frame(PyGenObject *gen_like) { + Frame *frame = (Frame *)(gen_like->gi_frame); + assert(frame); + return frame; +} + +static PyCodeObject *get_frame_code(Frame *frame) { + PyCodeObject *code = frame->f_code; + assert(code); + return code; +} + +static int get_frame_lasti(Frame *frame) { + return frame->f_lasti; +} + +static void set_frame_lasti(Frame *frame, int lasti) { + frame->f_lasti = lasti; +} + +static int get_frame_state(PyGenObject *gen_like) { + // Python 3.9 doesn't have frame states, but we can derive + // some for compatibility with later versions and to simplify + // the extension. + Frame *frame = (Frame *)(gen_like->gi_frame); + if (!frame) { + return FRAME_CLEARED; + } + return frame->f_executing ? FRAME_EXECUTING : FRAME_CREATED; +} + +static void set_frame_state(PyGenObject *gen_like, int fs) { + Frame *frame = get_frame(gen_like); + frame->f_executing = (fs == FRAME_EXECUTING); +} + +static int valid_frame_state(int fs) { + return fs == FRAME_CREATED || fs == FRAME_EXECUTING || fs == FRAME_CLEARED; +} + +static int get_frame_stacktop_limit(Frame *frame) { + PyCodeObject *code = get_frame_code(frame); + return code->co_stacksize + code->co_nlocals; +} + +static int get_frame_stacktop(Frame *frame) { + assert(frame->f_localsplus); + int stacktop = (int)(frame->f_stacktop - frame->f_localsplus); + assert(stacktop >= 0 && stacktop < get_frame_stacktop_limit(frame)); + return stacktop; +} + +static void set_frame_stacktop(Frame *frame, int stacktop) { + assert(stacktop >= 0 && stacktop < get_frame_stacktop_limit(frame)); + assert(frame->f_localsplus); + frame->f_stacktop = frame->f_localsplus + stacktop; +} + +static PyObject **get_frame_localsplus(Frame *frame) { + PyObject **localsplus = frame->f_localsplus; + assert(localsplus); + return localsplus; +} + +static int get_frame_iblock_limit(Frame *frame) { + return CO_MAXBLOCKS; +} + +static int get_frame_iblock(Frame *frame) { + return frame->f_iblock; +} + +static void set_frame_iblock(Frame *frame, int iblock) { + assert(iblock >= 0 && iblock < get_frame_iblock_limit(frame)); + frame->f_iblock = iblock; +} + +static PyTryBlock *get_frame_blockstack(Frame *frame) { + PyTryBlock *blockstack = frame->f_blockstack; + assert(blockstack); + return blockstack; +} + diff --git a/src/dispatch/experimental/durable/function.py b/src/dispatch/experimental/durable/function.py index a740be41..925b0ee5 100644 --- a/src/dispatch/experimental/durable/function.py +++ b/src/dispatch/experimental/durable/function.py @@ -9,7 +9,7 @@ MethodType, TracebackType, ) -from typing import Any, Callable, Coroutine, Generator, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Generator, Optional, TypeVar, Union, cast from . import frame as ext from .registry import RegisteredFunction, lookup_function, register_function @@ -75,7 +75,7 @@ class Serializable: "__qualname__", ) - g: GeneratorType | CoroutineType + g: Union[GeneratorType, CoroutineType] registered_fn: RegisteredFunction wrapped_coroutine: Union["DurableCoroutine", None] args: tuple[Any, ...] @@ -83,7 +83,7 @@ class Serializable: def __init__( self, - g: GeneratorType | CoroutineType, + g: Union[GeneratorType, CoroutineType], registered_fn: RegisteredFunction, wrapped_coroutine: Union["DurableCoroutine", None], *args: Any, @@ -243,7 +243,7 @@ def __await__(self) -> Generator[Any, None, _ReturnT]: def send(self, send: _SendT) -> _YieldT: return self.coroutine.send(send) - def throw(self, typ, val=None, tb: TracebackType | None = None) -> _YieldT: + def throw(self, typ, val=None, tb: Optional[TracebackType] = None) -> _YieldT: return self.coroutine.throw(typ, val, tb) def close(self) -> None: @@ -270,11 +270,11 @@ def cr_frame(self) -> FrameType: return self.coroutine.cr_frame @property - def cr_await(self) -> Any | None: + def cr_await(self) -> Any: return self.coroutine.cr_await @property - def cr_origin(self) -> tuple[tuple[str, int, str], ...] | None: + def cr_origin(self) -> Optional[tuple[tuple[str, int, str], ...]]: return self.coroutine.cr_origin def __repr__(self) -> str: @@ -291,7 +291,7 @@ def __init__( self, generator: GeneratorType, registered_fn: RegisteredFunction, - coroutine: DurableCoroutine | None, + coroutine: Optional[DurableCoroutine], *args: Any, **kwargs: Any, ): @@ -309,7 +309,7 @@ def __next__(self) -> _YieldT: def send(self, send: _SendT) -> _YieldT: return self.generator.send(send) - def throw(self, typ, val=None, tb: TracebackType | None = None) -> _YieldT: + def throw(self, typ, val=None, tb: Optional[TracebackType] = None) -> _YieldT: return self.generator.throw(typ, val, tb) def close(self) -> None: @@ -336,7 +336,7 @@ def gi_frame(self) -> FrameType: return self.generator.gi_frame @property - def gi_yieldfrom(self) -> GeneratorType | None: + def gi_yieldfrom(self) -> Optional[GeneratorType]: return self.generator.gi_yieldfrom def __repr__(self) -> str: diff --git a/src/dispatch/experimental/durable/registry.py b/src/dispatch/experimental/durable/registry.py index 3a5d9765..da8f2c28 100644 --- a/src/dispatch/experimental/durable/registry.py +++ b/src/dispatch/experimental/durable/registry.py @@ -3,7 +3,7 @@ from types import FunctionType -@dataclass(slots=True) +@dataclass class RegisteredFunction: """A function that can be referenced in durable state.""" diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 239c3b36..8f1c8094 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -21,6 +21,7 @@ def read_root(): import logging import os from datetime import timedelta +from typing import Optional, Union from urllib.parse import urlparse import fastapi @@ -51,10 +52,10 @@ class Dispatch(Registry): def __init__( self, app: fastapi.FastAPI, - endpoint: str | None = None, - verification_key: Ed25519PublicKey | str | bytes | None = None, - api_key: str | None = None, - api_url: str | None = None, + endpoint: Optional[str] = None, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + api_key: Optional[str] = None, + api_url: Optional[str] = None, ): """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. @@ -122,8 +123,8 @@ def __init__( def parse_verification_key( - verification_key: Ed25519PublicKey | str | bytes | None, -) -> Ed25519PublicKey | None: + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]], +) -> Optional[Ed25519PublicKey]: if isinstance(verification_key, Ed25519PublicKey): return verification_key @@ -169,7 +170,7 @@ def __init__(self, status, code, message): self.message = message -def _new_app(function_registry: Dispatch, verification_key: Ed25519PublicKey | None): +def _new_app(function_registry: Dispatch, verification_key: Optional[Ed25519PublicKey]): app = fastapi.FastAPI() @app.exception_handler(_ConnectError) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 97f18176..0caeb087 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -12,14 +12,14 @@ Dict, Generic, Iterable, - ParamSpec, - TypeAlias, + Optional, TypeVar, overload, ) from urllib.parse import urlparse import grpc +from typing_extensions import ParamSpec, TypeAlias import dispatch.coroutine import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb @@ -73,7 +73,7 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID: return dispatch_id def _build_primitive_call( - self, input: Any, correlation_id: int | None = None + self, input: Any, correlation_id: Optional[int] = None ) -> Call: return Call( correlation_id=correlation_id, @@ -137,7 +137,7 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: return self._primitive_dispatch(Arguments(args, kwargs)) def build_call( - self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs + self, *args: P.args, correlation_id: Optional[int] = None, **kwargs: P.kwargs ) -> Call: """Create a Call for this function with the provided input. Useful to generate calls when using the Client. @@ -162,7 +162,10 @@ class Registry: __slots__ = ("functions", "endpoint", "client") def __init__( - self, endpoint: str, api_key: str | None = None, api_url: str | None = None + self, + endpoint: str, + api_key: Optional[str] = None, + api_url: Optional[str] = None, ): """Initialize a function registry. @@ -261,7 +264,7 @@ class Client: __slots__ = ("api_url", "api_key", "_stub", "api_key_from") - def __init__(self, api_key: None | str = None, api_url: None | str = None): + def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None): """Create a new Dispatch client. Args: @@ -308,13 +311,12 @@ def __setstate__(self, state): def _init_stub(self): result = urlparse(self.api_url) - match result.scheme: - case "http": - creds = grpc.local_channel_credentials() - case "https": - creds = grpc.ssl_channel_credentials() - case _: - raise ValueError(f"Invalid API scheme: '{result.scheme}'") + if result.scheme == "http": + creds = grpc.local_channel_credentials() + elif result.scheme == "https": + creds = grpc.ssl_channel_credentials() + else: + raise ValueError(f"Invalid API scheme: '{result.scheme}'") call_creds = grpc.access_token_call_credentials(self.api_key) creds = grpc.composite_channel_credentials(creds, call_creds) @@ -344,11 +346,10 @@ def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]: resp = self._stub.Dispatch(req) except grpc.RpcError as e: status_code = e.code() - match status_code: - case grpc.StatusCode.UNAUTHENTICATED: - raise PermissionError( - f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" - ) from e + if status_code == grpc.StatusCode.UNAUTHENTICATED: + raise PermissionError( + f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" + ) from e raise dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] diff --git a/src/dispatch/id.py b/src/dispatch/id.py index ee3cce2a..d5f669be 100644 --- a/src/dispatch/id.py +++ b/src/dispatch/id.py @@ -1,4 +1,4 @@ -from typing import TypeAlias +from typing_extensions import TypeAlias DispatchID: TypeAlias = str """Unique identifier in Dispatch. diff --git a/src/dispatch/integrations/http.py b/src/dispatch/integrations/http.py index 6846deac..19c3c263 100644 --- a/src/dispatch/integrations/http.py +++ b/src/dispatch/integrations/http.py @@ -4,33 +4,31 @@ def http_response_code_status(code: int) -> Status: """Returns a Status that's broadly equivalent to an HTTP response status code.""" - match code: - case 400: # Bad Request - return Status.INVALID_ARGUMENT - case 401: # Unauthorized - return Status.UNAUTHENTICATED - case 403: # Forbidden - return Status.PERMISSION_DENIED - case 404: # Not Found - return Status.NOT_FOUND - case 408: # Request Timeout - return Status.TIMEOUT - case 429: # Too Many Requests - return Status.THROTTLED - case 501: # Not Implemented - return Status.PERMANENT_ERROR + if code == 400: # Bad Request + return Status.INVALID_ARGUMENT + elif code == 401: # Unauthorized + return Status.UNAUTHENTICATED + elif code == 403: # Forbidden + return Status.PERMISSION_DENIED + elif code == 404: # Not Found + return Status.NOT_FOUND + elif code == 408: # Request Timeout + return Status.TIMEOUT + elif code == 429: # Too Many Requests + return Status.THROTTLED + elif code == 501: # Not Implemented + return Status.PERMANENT_ERROR category = code // 100 - match category: - case 1: # 1xx informational - return Status.PERMANENT_ERROR - case 2: # 2xx success - return Status.OK - case 3: # 3xx redirection - return Status.PERMANENT_ERROR - case 4: # 4xx client error - return Status.PERMANENT_ERROR - case 5: # 5xx server error - return Status.TEMPORARY_ERROR + if category == 1: # 1xx informational + return Status.PERMANENT_ERROR + elif category == 2: # 2xx success + return Status.OK + elif category == 3: # 3xx redirection + return Status.PERMANENT_ERROR + elif category == 4: # 4xx client error + return Status.PERMANENT_ERROR + elif category == 5: # 5xx server error + return Status.TEMPORARY_ERROR return Status.UNSPECIFIED diff --git a/src/dispatch/integrations/httpx.py b/src/dispatch/integrations/httpx.py index 3d60a65a..64a07588 100644 --- a/src/dispatch/integrations/httpx.py +++ b/src/dispatch/integrations/httpx.py @@ -6,15 +6,14 @@ def httpx_error_status(error: Exception) -> Status: # See https://www.python-httpx.org/exceptions/ - match error: - case httpx.HTTPStatusError(): - return httpx_response_status(error.response) - case httpx.InvalidURL(): - return Status.INVALID_ARGUMENT - case httpx.UnsupportedProtocol(): - return Status.INVALID_ARGUMENT - case httpx.TimeoutException(): - return Status.TIMEOUT + if isinstance(error, httpx.HTTPStatusError): + return httpx_response_status(error.response) + elif isinstance(error, httpx.InvalidURL): + return Status.INVALID_ARGUMENT + elif isinstance(error, httpx.UnsupportedProtocol): + return Status.INVALID_ARGUMENT + elif isinstance(error, httpx.TimeoutException): + return Status.TIMEOUT return Status.TEMPORARY_ERROR diff --git a/src/dispatch/integrations/openai.py b/src/dispatch/integrations/openai.py index 0d781fe4..533133d4 100644 --- a/src/dispatch/integrations/openai.py +++ b/src/dispatch/integrations/openai.py @@ -6,11 +6,10 @@ def openai_error_status(error: Exception) -> Status: # See https://github.com/openai/openai-python/blob/main/src/openai/_exceptions.py - match error: - case openai.APITimeoutError(): - return Status.TIMEOUT - case openai.APIStatusError(): - return http_response_code_status(error.status_code) + if isinstance(error, openai.APITimeoutError): + return Status.TIMEOUT + elif isinstance(error, openai.APIStatusError): + return http_response_code_status(error.status_code) return Status.TEMPORARY_ERROR diff --git a/src/dispatch/integrations/requests.py b/src/dispatch/integrations/requests.py index b61ed21c..89de804f 100644 --- a/src/dispatch/integrations/requests.py +++ b/src/dispatch/integrations/requests.py @@ -7,14 +7,15 @@ def requests_error_status(error: Exception) -> Status: # See https://requests.readthedocs.io/en/latest/api/#exceptions # and https://requests.readthedocs.io/en/latest/_modules/requests/exceptions/ - match error: - case requests.HTTPError(): - if error.response is not None: - return requests_response_status(error.response) - case requests.Timeout(): - return Status.TIMEOUT - case ValueError(): # base class of things like requests.InvalidURL, etc. - return Status.INVALID_ARGUMENT + if isinstance(error, requests.HTTPError): + if error.response is not None: + return requests_response_status(error.response) + elif isinstance(error, requests.Timeout): + return Status.TIMEOUT + elif isinstance( + error, ValueError + ): # base class of things like requests.InvalidURL, etc. + return Status.INVALID_ARGUMENT return Status.TEMPORARY_ERROR diff --git a/src/dispatch/integrations/slack.py b/src/dispatch/integrations/slack.py index 78040945..28e73718 100644 --- a/src/dispatch/integrations/slack.py +++ b/src/dispatch/integrations/slack.py @@ -8,10 +8,8 @@ def slack_error_status(error: Exception) -> Status: # See https://github.com/slackapi/python-slack-sdk/blob/main/slack/errors.py - match error: - case slack_sdk.errors.SlackApiError(): - if error.response is not None: - return slack_response_status(error.response) + if isinstance(error, slack_sdk.errors.SlackApiError) and error.response is not None: + return slack_response_status(error.response) return Status.TEMPORARY_ERROR diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 3e0d90e0..ef70e84c 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from traceback import format_exception from types import TracebackType -from typing import Any +from typing import Any, Optional import google.protobuf.any_pb2 import google.protobuf.message @@ -97,7 +97,7 @@ def call_results(self) -> list[CallResult]: return self._call_results @property - def poll_error(self) -> Error | None: + def poll_error(self) -> Optional[Error]: self._assert_resume() return self._poll_error @@ -125,7 +125,7 @@ def from_poll_results( function: str, coroutine_state: Any, call_results: list[CallResult], - error: Error | None = None, + error: Optional[Error] = None, ): return Input( req=function_pb.RunRequest( @@ -139,7 +139,7 @@ def from_poll_results( ) -@dataclass(slots=True) +@dataclass class Arguments: """A container for positional and keyword arguments.""" @@ -147,7 +147,7 @@ class Arguments: kwargs: dict[str, Any] -@dataclass(slots=True) +@dataclass class Output: """The output of a primitive function. @@ -163,7 +163,7 @@ def __init__(self, proto: function_pb.RunResponse): self._message = proto @classmethod - def value(cls, value: Any, status: Status | None = None) -> Output: + def value(cls, value: Any, status: Optional[Status] = None) -> Output: """Terminally exit the function with the provided return value.""" if status is None: status = status_for_output(value) @@ -183,8 +183,8 @@ def tail_call(cls, tail_call: Call) -> Output: @classmethod def exit( cls, - result: CallResult | None = None, - tail_call: Call | None = None, + result: Optional[CallResult] = None, + tail_call: Optional[Call] = None, status: Status = Status.OK, ) -> Output: """Terminally exit the function.""" @@ -201,10 +201,10 @@ def exit( def poll( cls, state: Any, - calls: None | list[Call] = None, + calls: Optional[list[Call]] = None, min_results: int = 1, max_results: int = 10, - max_wait_seconds: int | None = None, + max_wait_seconds: Optional[int] = None, ) -> Output: """Suspend the function with a set of Calls, instructing the orchestrator to resume the function with the provided state when @@ -240,7 +240,7 @@ def poll( # the current Python process. -@dataclass(slots=True) +@dataclass class Call: """Instruction to call a function. @@ -249,9 +249,9 @@ class Call: """ function: str - input: Any | None = None - endpoint: str | None = None - correlation_id: int | None = None + input: Optional[Any] = None + endpoint: Optional[str] = None + correlation_id: Optional[int] = None def _as_proto(self) -> call_pb.Call: input_bytes = _pb_any_pickle(self.input) @@ -263,13 +263,13 @@ def _as_proto(self) -> call_pb.Call: ) -@dataclass(slots=True) +@dataclass class CallResult: """Result of a Call.""" - correlation_id: int | None = None - output: Any | None = None - error: Error | None = None + correlation_id: Optional[int] = None + output: Optional[Any] = None + error: Optional[Error] = None def _as_proto(self) -> call_pb.CallResult: output_any = None @@ -297,15 +297,19 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult: ) @classmethod - def from_value(cls, output: Any, correlation_id: int | None = None) -> CallResult: + def from_value( + cls, output: Any, correlation_id: Optional[int] = None + ) -> CallResult: return CallResult(correlation_id=correlation_id, output=output) @classmethod - def from_error(cls, error: Error, correlation_id: int | None = None) -> CallResult: + def from_error( + cls, error: Error, correlation_id: Optional[int] = None + ) -> CallResult: return CallResult(correlation_id=correlation_id, error=error) -@dataclass(slots=True) +@dataclass class Error: """Error when running a function. @@ -316,16 +320,16 @@ class Error: status: Status type: str message: str - value: Exception | None = None - traceback: bytes | None = None + value: Optional[Exception] = None + traceback: Optional[bytes] = None def __init__( self, status: Status, type: str, message: str, - value: Exception | None = None, - traceback: bytes | None = None, + value: Optional[Exception] = None, + traceback: Optional[bytes] = None, ): """Create a new Error. @@ -352,10 +356,12 @@ def __init__( self.value = value self.traceback = traceback if not traceback and value: - self.traceback = "".join(format_exception(value)).encode("utf-8") + self.traceback = "".join( + format_exception(value.__class__, value, value.__traceback__) + ).encode("utf-8") @classmethod - def from_exception(cls, ex: Exception, status: Status | None = None) -> Error: + def from_exception(cls, ex: Exception, status: Optional[Status] = None) -> Error: """Create an Error from a Python exception, using its class qualified named as type. diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 89c4d7af..42915450 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -2,7 +2,9 @@ import pickle import sys from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Protocol, TypeAlias +from typing import Any, Awaitable, Callable, Optional, Protocol, Union + +from typing_extensions import TypeAlias from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective from dispatch.error import IncompatibleStateError @@ -17,44 +19,44 @@ CorrelationID: TypeAlias = int -@dataclass(slots=True) +@dataclass class CoroutineResult: """The result from running a coroutine to completion.""" coroutine_id: CoroutineID - value: Any | None = None - error: Exception | None = None + value: Optional[Any] = None + error: Optional[Exception] = None -@dataclass(slots=True) +@dataclass class CallResult: """The result of an asynchronous function call.""" call_id: CallID - value: Any | None = None - error: Exception | None = None + value: Optional[Any] = None + error: Optional[Exception] = None class Future(Protocol): - def add_result(self, result: CallResult | CoroutineResult): ... + def add_result(self, result: Union[CallResult, CoroutineResult]): ... def add_error(self, error: Exception): ... def ready(self) -> bool: ... - def error(self) -> Exception | None: ... + def error(self) -> Optional[Exception]: ... def value(self) -> Any: ... -@dataclass(slots=True) +@dataclass class CallFuture: """A future result of a dispatch.coroutine.call() operation.""" - result: CallResult | None = None - first_error: Exception | None = None + result: Optional[CallResult] = None + first_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CallResult) if self.result is None: self.result = result @@ -68,7 +70,7 @@ def add_error(self, error: Exception): def ready(self) -> bool: return self.first_error is not None or self.result is not None - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() return self.first_error @@ -78,16 +80,16 @@ def value(self) -> Any: return self.result.value -@dataclass(slots=True) +@dataclass class AllFuture: """A future result of a dispatch.coroutine.all() operation.""" order: list[CoroutineID] = field(default_factory=list) waiting: set[CoroutineID] = field(default_factory=set) results: dict[CoroutineID, CoroutineResult] = field(default_factory=dict) - first_error: Exception | None = None + first_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CoroutineResult) try: @@ -109,7 +111,7 @@ def add_error(self, error: Exception): def ready(self) -> bool: return self.first_error is not None or len(self.waiting) == 0 - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() return self.first_error @@ -120,17 +122,17 @@ def value(self) -> list[Any]: return [self.results[id].value for id in self.order] -@dataclass(slots=True) +@dataclass class AnyFuture: """A future result of a dispatch.coroutine.any() operation.""" order: list[CoroutineID] = field(default_factory=list) waiting: set[CoroutineID] = field(default_factory=set) - first_result: CoroutineResult | None = None + first_result: Optional[CoroutineResult] = None errors: dict[CoroutineID, Exception] = field(default_factory=dict) - generic_error: Exception | None = None + generic_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CoroutineResult) try: @@ -156,19 +158,18 @@ def ready(self) -> bool: or len(self.waiting) == 0 ) - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() if self.generic_error is not None: return self.generic_error if self.first_result is not None or len(self.errors) == 0: return None - match len(self.errors): - case 0: - return None - case 1: - return self.errors[self.order[0]] - case _: - return AnyException([self.errors[id] for id in self.order]) + if len(self.errors) == 0: + return None + elif len(self.errors) == 1: + return self.errors[self.order[0]] + else: + return AnyException([self.errors[id] for id in self.order]) def value(self) -> Any: assert self.ready() @@ -178,15 +179,15 @@ def value(self) -> Any: return self.first_result.value -@dataclass(slots=True) +@dataclass class RaceFuture: """A future result of a dispatch.coroutine.race() operation.""" waiting: set[CoroutineID] = field(default_factory=set) - first_result: CoroutineResult | None = None - first_error: Exception | None = None + first_result: Optional[CoroutineResult] = None + first_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CoroutineResult) if result.error is not None: @@ -209,7 +210,7 @@ def ready(self) -> bool: or len(self.waiting) == 0 ) - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() return self.first_error @@ -218,14 +219,14 @@ def value(self) -> Any: return self.first_result.value if self.first_result else None -@dataclass(slots=True) +@dataclass class Coroutine: """An in-flight coroutine.""" id: CoroutineID - parent_id: CoroutineID | None - coroutine: DurableCoroutine | DurableGenerator - result: Future | None = None + parent_id: Optional[CoroutineID] + coroutine: Union[DurableCoroutine, DurableGenerator] + result: Optional[Future] = None def run(self) -> Any: if self.result is None: @@ -242,7 +243,7 @@ def __repr__(self): return f"Coroutine({self.id}, {self.coroutine.__qualname__})" -@dataclass(slots=True) +@dataclass class State: """State of the scheduler and the coroutines it's managing.""" @@ -279,7 +280,7 @@ def __init__( version: str = sys.version, poll_min_results: int = 1, poll_max_results: int = 10, - poll_max_wait_seconds: int | None = None, + poll_max_wait_seconds: Optional[int] = None, ): """Initialize the scheduler. @@ -423,7 +424,7 @@ def _run(self, input: Input) -> Output: assert coroutine.id not in state.suspended coroutine_yield = None - coroutine_result: CoroutineResult | None = None + coroutine_result: Optional[CoroutineResult] = None try: coroutine_yield = coroutine.run() except StopIteration as e: @@ -470,60 +471,47 @@ def _run(self, input: Input) -> Output: # Handle coroutines that yield. logger.debug("%s yielded %s", coroutine, coroutine_yield) - match coroutine_yield: - case Call(): - call = coroutine_yield - call_id = state.next_call_id - state.next_call_id += 1 - call.correlation_id = correlation_id(coroutine.id, call_id) - logger.debug( - "enqueuing call %d (%s) for %s", - call_id, - call.function, - coroutine, - ) - pending_calls.append(call) - coroutine.result = CallFuture() - state.suspended[coroutine.id] = coroutine - state.prev_callers.append(coroutine) - state.outstanding_calls += 1 - - case AllDirective(): - children = spawn_children( - state, coroutine, coroutine_yield.awaitables - ) - - child_ids = [child.id for child in children] - coroutine.result = AllFuture( - order=child_ids, waiting=set(child_ids) - ) - state.suspended[coroutine.id] = coroutine - - case AnyDirective(): - children = spawn_children( - state, coroutine, coroutine_yield.awaitables - ) - - child_ids = [child.id for child in children] - coroutine.result = AnyFuture( - order=child_ids, waiting=set(child_ids) - ) - state.suspended[coroutine.id] = coroutine - - case RaceDirective(): - children = spawn_children( - state, coroutine, coroutine_yield.awaitables - ) - - coroutine.result = RaceFuture( - waiting={child.id for child in children} - ) - state.suspended[coroutine.id] = coroutine - - case _: - raise RuntimeError( - f"coroutine unexpectedly yielded '{coroutine_yield}'" - ) + if isinstance(coroutine_yield, Call): + call = coroutine_yield + call_id = state.next_call_id + state.next_call_id += 1 + call.correlation_id = correlation_id(coroutine.id, call_id) + logger.debug( + "enqueuing call %d (%s) for %s", + call_id, + call.function, + coroutine, + ) + pending_calls.append(call) + coroutine.result = CallFuture() + state.suspended[coroutine.id] = coroutine + state.prev_callers.append(coroutine) + state.outstanding_calls += 1 + + elif isinstance(coroutine_yield, AllDirective): + children = spawn_children(state, coroutine, coroutine_yield.awaitables) + + child_ids = [child.id for child in children] + coroutine.result = AllFuture(order=child_ids, waiting=set(child_ids)) + state.suspended[coroutine.id] = coroutine + + elif isinstance(coroutine_yield, AnyDirective): + children = spawn_children(state, coroutine, coroutine_yield.awaitables) + + child_ids = [child.id for child in children] + coroutine.result = AnyFuture(order=child_ids, waiting=set(child_ids)) + state.suspended[coroutine.id] = coroutine + + elif isinstance(coroutine_yield, RaceDirective): + children = spawn_children(state, coroutine, coroutine_yield.awaitables) + + coroutine.result = RaceFuture(waiting={child.id for child in children}) + state.suspended[coroutine.id] = coroutine + + else: + raise RuntimeError( + f"coroutine unexpectedly yielded '{coroutine_yield}'" + ) # Serialize coroutines and scheduler state. logger.debug("serializing state") diff --git a/src/dispatch/signature/digest.py b/src/dispatch/signature/digest.py index 4bc264a3..c5602126 100644 --- a/src/dispatch/signature/digest.py +++ b/src/dispatch/signature/digest.py @@ -1,11 +1,12 @@ import hashlib import hmac +from typing import Union import http_sfv from http_message_signatures import InvalidSignature -def generate_content_digest(body: str | bytes) -> str: +def generate_content_digest(body: Union[str, bytes]) -> str: """Returns a SHA-512 Content-Digest header, according to https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-digest-headers-13 """ @@ -16,7 +17,7 @@ def generate_content_digest(body: str | bytes) -> str: return str(http_sfv.Dictionary({"sha-512": digest})) -def verify_content_digest(digest_header: str | bytes, body: str | bytes): +def verify_content_digest(digest_header: Union[str, bytes], body: Union[str, bytes]): """Verify a SHA-256 or SHA-512 Content-Digest header matches a request body.""" if isinstance(body, str): diff --git a/src/dispatch/signature/key.py b/src/dispatch/signature/key.py index 7fc1cee0..5cce28e5 100644 --- a/src/dispatch/signature/key.py +++ b/src/dispatch/signature/key.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional, Union from cryptography.hazmat.primitives.asymmetric.ed25519 import ( Ed25519PrivateKey, @@ -11,7 +12,7 @@ from http_message_signatures import HTTPSignatureKeyResolver -def public_key_from_pem(pem: str | bytes) -> Ed25519PublicKey: +def public_key_from_pem(pem: Union[str, bytes]) -> Ed25519PublicKey: """Returns an Ed25519 public key given a PEM representation.""" if isinstance(pem, str): pem = pem.encode() @@ -28,7 +29,7 @@ def public_key_from_bytes(key: bytes) -> Ed25519PublicKey: def private_key_from_pem( - pem: str | bytes, password: bytes | None = None + pem: Union[str, bytes], password: Optional[bytes] = None ) -> Ed25519PrivateKey: """Returns an Ed25519 private key given a PEM representation and optional password.""" @@ -48,7 +49,7 @@ def private_key_from_bytes(key: bytes) -> Ed25519PrivateKey: return Ed25519PrivateKey.from_private_bytes(key) -@dataclass(slots=True) +@dataclass class KeyResolver(HTTPSignatureKeyResolver): """KeyResolver provides public and private keys. @@ -57,8 +58,8 @@ class KeyResolver(HTTPSignatureKeyResolver): """ key_id: str - public_key: Ed25519PublicKey | None = None - private_key: Ed25519PrivateKey | None = None + public_key: Optional[Ed25519PublicKey] = None + private_key: Optional[Ed25519PrivateKey] = None def resolve_public_key(self, key_id: str): if key_id != self.key_id or self.public_key is None: diff --git a/src/dispatch/signature/request.py b/src/dispatch/signature/request.py index 15ac24c9..ee6f13fc 100644 --- a/src/dispatch/signature/request.py +++ b/src/dispatch/signature/request.py @@ -1,13 +1,14 @@ from dataclasses import dataclass +from typing import Union from http_message_signatures.structures import CaseInsensitiveDict -@dataclass(slots=True) +@dataclass class Request: """A framework-agnostic representation of an HTTP request.""" method: str url: str headers: CaseInsensitiveDict - body: str | bytes + body: Union[str, bytes] diff --git a/src/dispatch/status.py b/src/dispatch/status.py index 5a413802..1a8f34d2 100644 --- a/src/dispatch/status.py +++ b/src/dispatch/status.py @@ -92,27 +92,31 @@ def status_for_error(error: Exception) -> Status: # If not, resort to standard error categorization. # # See https://docs.python.org/3/library/exceptions.html - match error: - case IncompatibleStateError(): - return Status.INCOMPATIBLE_STATE - case TimeoutError(): - return Status.TIMEOUT - case TypeError() | ValueError(): - return Status.INVALID_ARGUMENT - case ConnectionError(): - return Status.TCP_ERROR - case PermissionError(): - return Status.PERMISSION_DENIED - case FileNotFoundError(): - return Status.NOT_FOUND - case EOFError() | InterruptedError() | KeyboardInterrupt() | OSError(): - # For OSError, we might want to categorize the values of errnon - # to determine whether the error is temporary or permanent. - # - # In general, permanent errors from the OS are rare because they - # tend to be caused by invalid use of syscalls, which are - # unlikely at higher abstraction levels. - return Status.TEMPORARY_ERROR + if isinstance(error, IncompatibleStateError): + return Status.INCOMPATIBLE_STATE + elif isinstance(error, TimeoutError): + return Status.TIMEOUT + elif isinstance(error, TypeError) or isinstance(error, ValueError): + return Status.INVALID_ARGUMENT + elif isinstance(error, ConnectionError): + return Status.TCP_ERROR + elif isinstance(error, PermissionError): + return Status.PERMISSION_DENIED + elif isinstance(error, FileNotFoundError): + return Status.NOT_FOUND + elif ( + isinstance(error, EOFError) + or isinstance(error, InterruptedError) + or isinstance(error, KeyboardInterrupt) + or isinstance(error, OSError) + ): + # For OSError, we might want to categorize the values of errnon + # to determine whether the error is temporary or permanent. + # + # In general, permanent errors from the OS are rare because they + # tend to be caused by invalid use of syscalls, which are + # unlikely at higher abstraction levels. + return Status.TEMPORARY_ERROR return Status.PERMANENT_ERROR diff --git a/src/dispatch/test/client.py b/src/dispatch/test/client.py index 0b3f4539..01078aec 100644 --- a/src/dispatch/test/client.py +++ b/src/dispatch/test/client.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional import fastapi import grpc @@ -25,7 +26,7 @@ class EndpointClient: """ def __init__( - self, http_client: httpx.Client, signing_key: Ed25519PrivateKey | None = None + self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None ): """Initialize the client. @@ -48,14 +49,14 @@ def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse: return self._stub.Run(request) @classmethod - def from_url(cls, url: str, signing_key: Ed25519PrivateKey | None = None): + def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None): """Returns an EndpointClient for a Dispatch endpoint URL.""" http_client = httpx.Client(base_url=url) return EndpointClient(http_client, signing_key) @classmethod def from_app( - cls, app: fastapi.FastAPI, signing_key: Ed25519PrivateKey | None = None + cls, app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None ): """Returns an EndpointClient for a Dispatch endpoint bound to a FastAPI app instance.""" @@ -65,7 +66,7 @@ def from_app( class _HttpxGrpcChannel(grpc.Channel): def __init__( - self, http_client: httpx.Client, signing_key: Ed25519PrivateKey | None = None + self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None ): self.http_client = http_client self.signing_key = signing_key @@ -113,7 +114,7 @@ def __init__( method, request_serializer, response_deserializer, - signing_key: Ed25519PrivateKey | None = None, + signing_key: Optional[Ed25519PrivateKey] = None, ): self.client = client self.method = method diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py index 09729396..d711a4ed 100644 --- a/src/dispatch/test/service.py +++ b/src/dispatch/test/service.py @@ -5,10 +5,11 @@ import time from collections import OrderedDict from dataclasses import dataclass -from typing import TypeAlias +from typing import Optional import grpc import httpx +from typing_extensions import TypeAlias import dispatch.sdk.v1.call_pb2 as call_pb import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb @@ -52,8 +53,8 @@ class DispatchService(dispatch_grpc.DispatchServiceServicer): def __init__( self, endpoint_client: EndpointClient, - api_key: str | None = None, - retry_on_status: set[Status] | None = None, + api_key: Optional[str] = None, + retry_on_status: Optional[set[Status]] = None, collect_roundtrips: bool = False, ): """Initialize the Dispatch service. @@ -86,11 +87,11 @@ def __init__( self.pollers: dict[DispatchID, Poller] = {} self.parents: dict[DispatchID, Poller] = {} - self.roundtrips: OrderedDict[DispatchID, list[RoundTrip]] | None = None + self.roundtrips: Optional[OrderedDict[DispatchID, list[RoundTrip]]] = None if collect_roundtrips: self.roundtrips = OrderedDict() - self._thread: threading.Thread | None = None + self._thread: Optional[threading.Thread] = None self._stop_event = threading.Event() self._work_signal = threading.Condition() @@ -143,13 +144,12 @@ def dispatch_calls(self): while self.queue: dispatch_id, request, call_type = self.queue.pop(0) - match call_type: - case CallType.CALL: - logger.info("calling function %s", request.function) - case CallType.RESUME: - logger.info("resuming function %s", request.function) - case CallType.RETRY: - logger.info("retrying function %s", request.function) + if call_type == CallType.CALL: + logger.info("calling function %s", request.function) + elif call_type == CallType.RESUME: + logger.info("resuming function %s", request.function) + elif call_type == CallType.RETRY: + logger.info("retrying function %s", request.function) try: response = self.endpoint_client.run(request) diff --git a/tests/dispatch/experimental/durable/test_frame.py b/tests/dispatch/experimental/durable/test_frame.py index b4e9c2c1..ea2f6287 100644 --- a/tests/dispatch/experimental/durable/test_frame.py +++ b/tests/dispatch/experimental/durable/test_frame.py @@ -28,14 +28,6 @@ async def coroutine(a): await Yields(a) -async def async_generator(a): - await Yields(a) - a += 1 - yield a - a += 1 - await Yields(a) - - class TestFrame(unittest.TestCase): def test_generator_copy(self): # Create an instance and run it to the first yield point. @@ -74,39 +66,6 @@ def test_coroutine_copy(self): assert next(g) == 2 assert next(g) == 3 - def test_async_generator_copy(self): - # Create an instance and run it to the first yield point. - ag = async_generator(1) - next_awaitable = anext(ag) - g = next_awaitable.__await__() - assert next(g) == 1 - - # Copy the async generator. - ag2 = async_generator(1) - self.copy_to(ag, ag2) - next_awaitable2 = anext(ag2) - g2 = next_awaitable2.__await__() - - # The copy should start from where the previous generator was suspended. - try: - next(g2) - raise RuntimeError - except StopIteration as e: - assert e.value == 2 - next_awaitable2 = anext(ag2) - g2 = next_awaitable2.__await__() - assert next(g2) == 3 - - # Original generator is not affected. - try: - next(g) - raise RuntimeError - except StopIteration as e: - assert e.value == 2 - next_awaitable = anext(ag) - g = next_awaitable.__await__() - assert next(g) == 3 - def copy_to(self, from_obj, to_obj): ext.set_frame_state(to_obj, ext.get_frame_state(from_obj)) ext.set_frame_ip(to_obj, ext.get_frame_ip(from_obj)) diff --git a/tests/dispatch/test_error.py b/tests/dispatch/test_error.py index f036df06..df78436b 100644 --- a/tests/dispatch/test_error.py +++ b/tests/dispatch/test_error.py @@ -12,7 +12,13 @@ def test_conversion_between_exception_and_error(self): except Exception as e: original_exception = e error = Error.from_exception(e) - original_traceback = "".join(traceback.format_exception(original_exception)) + original_traceback = "".join( + traceback.format_exception( + original_exception.__class__, + original_exception, + original_exception.__traceback__, + ) + ) # For some reasons traceback.format_exception does not include the caret # (^) in the original traceback, but it does in the reconstructed one, @@ -24,7 +30,13 @@ def strip_caret(s): reconstructed_exception = error.to_exception() reconstructed_traceback = strip_caret( - "".join(traceback.format_exception(reconstructed_exception)) + "".join( + traceback.format_exception( + reconstructed_exception.__class__, + reconstructed_exception, + reconstructed_exception.__traceback__, + ) + ) ) assert type(reconstructed_exception) is type(original_exception) @@ -34,7 +46,13 @@ def strip_caret(s): error2 = Error.from_exception(reconstructed_exception) reconstructed_exception2 = error2.to_exception() reconstructed_traceback2 = strip_caret( - "".join(traceback.format_exception(reconstructed_exception2)) + "".join( + traceback.format_exception( + reconstructed_exception2.__class__, + reconstructed_exception2, + reconstructed_exception2.__traceback__, + ) + ) ) assert type(reconstructed_exception2) is type(original_exception) diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 2bfc079a..c5189de2 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -1,5 +1,5 @@ import unittest -from typing import Any, Callable +from typing import Any, Callable, Optional from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable @@ -414,7 +414,7 @@ def resume( main: Callable, prev_output: Output, call_results: list[CallResult], - poll_error: Exception | None = None, + poll_error: Optional[Exception] = None, ): poll = self.assert_poll(prev_output) input = Input.from_poll_results( @@ -444,7 +444,7 @@ def assert_exit_result_value(self, output: Output, expect: Any): self.assertEqual(expect, any_unpickle(result.output)) def assert_exit_result_error( - self, output: Output, expect: type[Exception], message: str | None = None + self, output: Output, expect: type[Exception], message: Optional[str] = None ): result = self.assert_exit_result(output) self.assertFalse(result.HasField("output"))