Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: decouple test components from httpx #170

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import unittest
from unittest import mock

from fastapi.testclient import TestClient

from dispatch import Client
from dispatch.sdk.v1 import status_pb2 as status_pb
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


class TestAutoRetry(unittest.TestCase):
Expand All @@ -25,14 +24,14 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient(TestClient(app))
app_client = http_client(app)
endpoint_client = EndpointClient(app_client)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app)
response = http_client.get("/")
response = app_client.get("/")
self.assertEqual(response.status_code, 200)

dispatch_service.dispatch_calls()
Expand Down
9 changes: 4 additions & 5 deletions examples/getting_started/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import unittest
from unittest import mock

from fastapi.testclient import TestClient

from dispatch import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


class TestGettingStarted(unittest.TestCase):
Expand All @@ -24,14 +23,14 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient(TestClient(app))
app_client = http_client(app)
endpoint_client = EndpointClient(app_client)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app)
response = http_client.get("/")
response = app_client.get("/")
self.assertEqual(response.status_code, 200)

dispatch_service.dispatch_calls()
Expand Down
9 changes: 4 additions & 5 deletions examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import unittest
from unittest import mock

from fastapi.testclient import TestClient

from dispatch.function import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


class TestGithubStats(unittest.TestCase):
Expand All @@ -24,14 +23,14 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient(TestClient(app))
app_client = http_client(app)
endpoint_client = EndpointClient(app_client)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app)
response = http_client.get("/")
response = app_client.get("/")
self.assertEqual(response.status_code, 200)

while dispatch_service.queue:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ dependencies = [
"grpc-stubs >= 1.53.0.5",
"http-message-signatures >= 0.4.4",
"tblib >= 3.0.0",
"httpx >= 0.27.0",
"typing_extensions >= 4.10"
]

[project.optional-dependencies]
fastapi = ["fastapi", "httpx"]
flask = ["flask"]
lambda = ["awslambdaric"]

dev = [
"httpx >= 0.27.0",
"black >= 24.1.0",
"isort >= 5.13.2",
"mypy >= 1.10.0",
Expand Down
26 changes: 11 additions & 15 deletions src/dispatch/test/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import datetime
from typing import Optional
from typing import Mapping, Optional, Protocol, Union

import grpc
import httpx

from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.sdk.v1 import function_pb2_grpc as function_grpc
Expand All @@ -12,6 +11,7 @@
Request,
sign_request,
)
from dispatch.test.http import HttpClient


class EndpointClient:
Expand All @@ -24,15 +24,15 @@ class EndpointClient:
"""

def __init__(
self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None
self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None
):
"""Initialize the client.

Args:
http_client: Client to use to make HTTP requests.
signing_key: Optional Ed25519 private key to use to sign requests.
"""
channel = _HttpxGrpcChannel(http_client, signing_key=signing_key)
channel = _HttpGrpcChannel(http_client, signing_key=signing_key)
self._stub = function_grpc.FunctionServiceStub(channel)

def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse:
Expand All @@ -46,16 +46,10 @@ def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse:
"""
return self._stub.Run(request)

@classmethod
def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None):
"""Returns an EndpointClient for a Dispatch endpoint URL."""
http_client = httpx.Client(base_url=url)
return EndpointClient(http_client, signing_key)


class _HttpxGrpcChannel(grpc.Channel):
class _HttpGrpcChannel(grpc.Channel):
def __init__(
self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None
self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None
):
self.http_client = http_client
self.signing_key = signing_key
Expand Down Expand Up @@ -120,9 +114,11 @@ def __call__(
wait_for_ready=None,
compression=None,
):
url = self.client.url_for(self.method) # note: method==path in gRPC parlance

request = Request(
method="POST",
url=str(httpx.URL(self.client.base_url).join(self.method)),
url=url,
body=self.request_serializer(request),
headers=CaseInsensitiveDict({"Content-Type": "application/grpc+proto"}),
)
Expand All @@ -131,10 +127,10 @@ def __call__(
sign_request(request, self.signing_key, datetime.now())

response = self.client.post(
request.url, content=request.body, headers=request.headers
request.url, body=request.body, headers=request.headers
)
response.raise_for_status()
return self.response_deserializer(response.content)
return self.response_deserializer(response.body)

def with_call(
self,
Expand Down
10 changes: 10 additions & 0 deletions src/dispatch/test/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient

import dispatch.test.httpx
from dispatch.test.client import HttpClient


def http_client(app: FastAPI) -> HttpClient:
"""Build a client for a FastAPI app."""
return dispatch.test.httpx.Client(TestClient(app))
30 changes: 30 additions & 0 deletions src/dispatch/test/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from dataclasses import dataclass
from typing import Mapping, Protocol


@dataclass
class HttpResponse(Protocol):
status_code: int
body: bytes

def raise_for_status(self):
"""Raise an exception on non-2xx responses."""
...


class HttpClient(Protocol):
"""Protocol for HTTP clients."""

def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse:
"""Make a GET request."""
...

def post(
self, url: str, body: bytes, headers: Mapping[str, str] = {}
) -> HttpResponse:
"""Make a POST request."""
...

def url_for(self, path: str) -> str:
"""Get the fully-qualified URL for a path."""
...
39 changes: 39 additions & 0 deletions src/dispatch/test/httpx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Mapping

import httpx

from dispatch.test.http import HttpClient, HttpResponse


class Client(HttpClient):
def __init__(self, client: httpx.Client):
self.client = client

def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse:
response = self.client.get(url, headers=headers)
return Response(response)

def post(
self, url: str, body: bytes, headers: Mapping[str, str] = {}
) -> HttpResponse:
response = self.client.post(url, content=body, headers=headers)
return Response(response)

def url_for(self, path: str) -> str:
return str(httpx.URL(self.client.base_url).join(path))


class Response(HttpResponse):
def __init__(self, response: httpx.Response):
self.response = response

@property
def status_code(self):
return self.response.status_code

@property
def body(self):
return self.response.content

def raise_for_status(self):
self.response.raise_for_status()
12 changes: 0 additions & 12 deletions src/dispatch/test/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Dict, List, Optional, Set, Tuple

import grpc
import httpx
from typing_extensions import TypeAlias

import dispatch.sdk.v1.call_pb2 as call_pb
Expand Down Expand Up @@ -325,17 +324,6 @@ def _dispatch_continuously(self):

try:
self.dispatch_calls()
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
logger.error(
"error dispatching function call to endpoint (403). Is the endpoint's DISPATCH_VERIFICATION_KEY correct?"
)
else:
logger.exception(e)
except httpx.ConnectError as e:
logger.error(
"error connecting to the endpoint. Is it running and accessible from DISPATCH_ENDPOINT_URL?"
)
except Exception as e:
logger.exception(e)

Expand Down
8 changes: 7 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
import unittest
from unittest import mock

import httpx

import dispatch.test.httpx
from dispatch import Call, Client
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.test import DispatchServer, DispatchService, EndpointClient


class TestClient(unittest.TestCase):
def setUp(self):
endpoint_client = EndpointClient.from_url("http://function-service")
http_client = dispatch.test.httpx.Client(
httpx.Client(base_url="http://function-service")
)
endpoint_client = EndpointClient(http_client)

api_key = "0000000000000000"
self.dispatch_service = DispatchService(endpoint_client, api_key)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from dispatch.status import Status
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str):
Expand All @@ -44,8 +45,7 @@ def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str):
def create_endpoint_client(
app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
):
http_client = TestClient(app)
return EndpointClient(http_client, signing_key)
return EndpointClient(http_client(app), signing_key)


class TestFastAPI(unittest.TestCase):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import httpx
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey

import dispatch.test.httpx
from dispatch.experimental.durable.registry import clear_functions
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.http import Dispatch
Expand Down Expand Up @@ -87,7 +88,8 @@ def my_function(input: Input) -> Output:
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

client = EndpointClient.from_url(self.endpoint)
http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
client = EndpointClient(http_client)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
Expand Down