Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote endpoints #134

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import dispatch.integrations
from dispatch.coroutine import all, any, call, gather, race
from dispatch.function import DEFAULT_API_URL, Client
from dispatch.function import DEFAULT_API_URL, Client, Registry
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
from dispatch.status import Status
Expand All @@ -23,4 +23,5 @@
"all",
"any",
"race",
"Registry",
]
5 changes: 2 additions & 3 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def __init__(
"request verification is disabled because DISPATCH_VERIFICATION_KEY is not set"
)

self.client = Client(api_key=api_key, api_url=api_url)
super().__init__(endpoint, self.client)
super().__init__(endpoint, api_key=api_key, api_url=api_url)

function_service = _new_app(self, verification_key)
app.mount("/dispatch.sdk.v1.FunctionService", function_service)
Expand Down Expand Up @@ -225,7 +224,7 @@ async def execute(request: fastapi.Request):
raise _ConnectError(400, "invalid_argument", "function is required")

try:
func = function_registry._functions[req.function]
func = function_registry.functions[req.function]
except KeyError:
logger.debug("function '%s' not found", req.function)
raise _ConnectError(
Expand Down
44 changes: 25 additions & 19 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def __init__(
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable,
):
PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)

Expand Down Expand Up @@ -158,21 +157,30 @@ def build_call(


class Registry:
"""Registry of local functions."""
"""Registry of functions."""

__slots__ = ("_functions", "_endpoint", "_client")
__slots__ = ("functions", "endpoint", "client")

def __init__(self, endpoint: str, client: Client):
"""Initialize a local function registry.
def __init__(
self, endpoint: str, api_key: str | None = None, api_url: str | None = None
):
"""Initialize a function registry.

Args:
endpoint: URL of the endpoint that the function is accessible from.
client: Client for the Dispatch API. Used to dispatch calls to
local functions.

api_key: Dispatch API key to use for authentication when
dispatching calls to functions. Uses the value of the
DISPATCH_API_KEY environment variable by default.

api_url: The URL of the Dispatch API to use when dispatching calls
to functions. Uses the value of the DISPATCH_API_URL environment
variable if set, otherwise defaults to the public Dispatch API
(DEFAULT_API_URL).
"""
self._functions: Dict[str, PrimitiveFunction] = {}
self._endpoint = endpoint
self._client = client
self.functions: Dict[str, PrimitiveFunction] = {}
self.endpoint = endpoint
self.client = Client(api_key=api_key, api_url=api_url)

@overload
def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...
Expand Down Expand Up @@ -215,9 +223,7 @@ def primitive_func(input: Input) -> Output:
primitive_func.__qualname__ = f"{name}_primitive"
primitive_func = durable(primitive_func)

wrapped_func = Function[P, T](
self._endpoint, self._client, name, primitive_func, func
)
wrapped_func = Function[P, T](self.endpoint, self.client, name, primitive_func)
self._register(name, wrapped_func)
return wrapped_func

Expand All @@ -228,20 +234,20 @@ def primitive_function(
name = primitive_func.__qualname__
logger.info("registering primitive function: %s", name)
wrapped_func = PrimitiveFunction(
self._endpoint, self._client, name, primitive_func
self.endpoint, self.client, name, primitive_func
)
self._register(name, wrapped_func)
return wrapped_func

def _register(self, name: str, wrapped_func: PrimitiveFunction):
if name in self._functions:
if name in self.functions:
raise ValueError(f"function already registered with name '{name}'")
self._functions[name] = wrapped_func
self.functions[name] = wrapped_func

def set_client(self, client: Client):
"""Set the Client instance used to dispatch calls to local functions."""
self._client = client
for fn in self._functions.values():
"""Set the Client instance used to dispatch calls to registered functions."""
self.client = client
for fn in self.functions.values():
fn._client = client


Expand Down
7 changes: 5 additions & 2 deletions tests/dispatch/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

class TestFunction(unittest.TestCase):
def setUp(self):
self.client = Client(api_url="http://dispatch.com", api_key="foobar")
self.dispatch = Registry(endpoint="http://example.com", client=self.client)
self.dispatch = Registry(
endpoint="http://example.com",
api_url="http://dispatch.com",
api_key="foobar",
)

def test_serializable(self):
@self.dispatch.function
Expand Down