Skip to content

Commit

Permalink
Merge pull request #170 from dispatchrun/http-tests
Browse files Browse the repository at this point in the history
test: decouple test components from httpx
  • Loading branch information
chriso authored May 20, 2024
2 parents fd42f89 + b197b74 commit fc85617
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 47 deletions.
9 changes: 4 additions & 5 deletions examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import unittest
from unittest import mock

from fastapi.testclient import TestClient

from dispatch import Client
from dispatch.sdk.v1 import status_pb2 as status_pb
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


class TestAutoRetry(unittest.TestCase):
Expand All @@ -25,14 +24,14 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient(TestClient(app))
app_client = http_client(app)
endpoint_client = EndpointClient(app_client)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app)
response = http_client.get("/")
response = app_client.get("/")
self.assertEqual(response.status_code, 200)

dispatch_service.dispatch_calls()
Expand Down
9 changes: 4 additions & 5 deletions examples/getting_started/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import unittest
from unittest import mock

from fastapi.testclient import TestClient

from dispatch import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


class TestGettingStarted(unittest.TestCase):
Expand All @@ -24,14 +23,14 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient(TestClient(app))
app_client = http_client(app)
endpoint_client = EndpointClient(app_client)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app)
response = http_client.get("/")
response = app_client.get("/")
self.assertEqual(response.status_code, 200)

dispatch_service.dispatch_calls()
Expand Down
9 changes: 4 additions & 5 deletions examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import unittest
from unittest import mock

from fastapi.testclient import TestClient

from dispatch.function import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


class TestGithubStats(unittest.TestCase):
Expand All @@ -24,14 +23,14 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient(TestClient(app))
app_client = http_client(app)
endpoint_client = EndpointClient(app_client)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app)
response = http_client.get("/")
response = app_client.get("/")
self.assertEqual(response.status_code, 200)

while dispatch_service.queue:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ dependencies = [
"grpc-stubs >= 1.53.0.5",
"http-message-signatures >= 0.4.4",
"tblib >= 3.0.0",
"httpx >= 0.27.0",
"typing_extensions >= 4.10"
]

[project.optional-dependencies]
fastapi = ["fastapi", "httpx"]
flask = ["flask"]
lambda = ["awslambdaric"]

dev = [
"httpx >= 0.27.0",
"black >= 24.1.0",
"isort >= 5.13.2",
"mypy >= 1.10.0",
Expand Down
26 changes: 11 additions & 15 deletions src/dispatch/test/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import datetime
from typing import Optional
from typing import Mapping, Optional, Protocol, Union

import grpc
import httpx

from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.sdk.v1 import function_pb2_grpc as function_grpc
Expand All @@ -12,6 +11,7 @@
Request,
sign_request,
)
from dispatch.test.http import HttpClient


class EndpointClient:
Expand All @@ -24,15 +24,15 @@ class EndpointClient:
"""

def __init__(
self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None
self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None
):
"""Initialize the client.
Args:
http_client: Client to use to make HTTP requests.
signing_key: Optional Ed25519 private key to use to sign requests.
"""
channel = _HttpxGrpcChannel(http_client, signing_key=signing_key)
channel = _HttpGrpcChannel(http_client, signing_key=signing_key)
self._stub = function_grpc.FunctionServiceStub(channel)

def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse:
Expand All @@ -46,16 +46,10 @@ def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse:
"""
return self._stub.Run(request)

@classmethod
def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None):
"""Returns an EndpointClient for a Dispatch endpoint URL."""
http_client = httpx.Client(base_url=url)
return EndpointClient(http_client, signing_key)


class _HttpxGrpcChannel(grpc.Channel):
class _HttpGrpcChannel(grpc.Channel):
def __init__(
self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None
self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None
):
self.http_client = http_client
self.signing_key = signing_key
Expand Down Expand Up @@ -120,9 +114,11 @@ def __call__(
wait_for_ready=None,
compression=None,
):
url = self.client.url_for(self.method) # note: method==path in gRPC parlance

request = Request(
method="POST",
url=str(httpx.URL(self.client.base_url).join(self.method)),
url=url,
body=self.request_serializer(request),
headers=CaseInsensitiveDict({"Content-Type": "application/grpc+proto"}),
)
Expand All @@ -131,10 +127,10 @@ def __call__(
sign_request(request, self.signing_key, datetime.now())

response = self.client.post(
request.url, content=request.body, headers=request.headers
request.url, body=request.body, headers=request.headers
)
response.raise_for_status()
return self.response_deserializer(response.content)
return self.response_deserializer(response.body)

def with_call(
self,
Expand Down
10 changes: 10 additions & 0 deletions src/dispatch/test/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient

import dispatch.test.httpx
from dispatch.test.client import HttpClient


def http_client(app: FastAPI) -> HttpClient:
"""Build a client for a FastAPI app."""
return dispatch.test.httpx.Client(TestClient(app))
30 changes: 30 additions & 0 deletions src/dispatch/test/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from dataclasses import dataclass
from typing import Mapping, Protocol


@dataclass
class HttpResponse(Protocol):
status_code: int
body: bytes

def raise_for_status(self):
"""Raise an exception on non-2xx responses."""
...


class HttpClient(Protocol):
"""Protocol for HTTP clients."""

def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse:
"""Make a GET request."""
...

def post(
self, url: str, body: bytes, headers: Mapping[str, str] = {}
) -> HttpResponse:
"""Make a POST request."""
...

def url_for(self, path: str) -> str:
"""Get the fully-qualified URL for a path."""
...
39 changes: 39 additions & 0 deletions src/dispatch/test/httpx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Mapping

import httpx

from dispatch.test.http import HttpClient, HttpResponse


class Client(HttpClient):
def __init__(self, client: httpx.Client):
self.client = client

def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse:
response = self.client.get(url, headers=headers)
return Response(response)

def post(
self, url: str, body: bytes, headers: Mapping[str, str] = {}
) -> HttpResponse:
response = self.client.post(url, content=body, headers=headers)
return Response(response)

def url_for(self, path: str) -> str:
return str(httpx.URL(self.client.base_url).join(path))


class Response(HttpResponse):
def __init__(self, response: httpx.Response):
self.response = response

@property
def status_code(self):
return self.response.status_code

@property
def body(self):
return self.response.content

def raise_for_status(self):
self.response.raise_for_status()
12 changes: 0 additions & 12 deletions src/dispatch/test/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Dict, List, Optional, Set, Tuple

import grpc
import httpx
from typing_extensions import TypeAlias

import dispatch.sdk.v1.call_pb2 as call_pb
Expand Down Expand Up @@ -325,17 +324,6 @@ def _dispatch_continuously(self):

try:
self.dispatch_calls()
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
logger.error(
"error dispatching function call to endpoint (403). Is the endpoint's DISPATCH_VERIFICATION_KEY correct?"
)
else:
logger.exception(e)
except httpx.ConnectError as e:
logger.error(
"error connecting to the endpoint. Is it running and accessible from DISPATCH_ENDPOINT_URL?"
)
except Exception as e:
logger.exception(e)

Expand Down
8 changes: 7 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
import unittest
from unittest import mock

import httpx

import dispatch.test.httpx
from dispatch import Call, Client
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.test import DispatchServer, DispatchService, EndpointClient


class TestClient(unittest.TestCase):
def setUp(self):
endpoint_client = EndpointClient.from_url("http://function-service")
http_client = dispatch.test.httpx.Client(
httpx.Client(base_url="http://function-service")
)
endpoint_client = EndpointClient(http_client)

api_key = "0000000000000000"
self.dispatch_service = DispatchService(endpoint_client, api_key)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from dispatch.status import Status
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.fastapi import http_client


def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str):
Expand All @@ -44,8 +45,7 @@ def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str):
def create_endpoint_client(
app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
):
http_client = TestClient(app)
return EndpointClient(http_client, signing_key)
return EndpointClient(http_client(app), signing_key)


class TestFastAPI(unittest.TestCase):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import httpx
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey

import dispatch.test.httpx
from dispatch.experimental.durable.registry import clear_functions
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.http import Dispatch
Expand Down Expand Up @@ -87,7 +88,8 @@ def my_function(input: Input) -> Output:
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

client = EndpointClient.from_url(self.endpoint)
http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
client = EndpointClient(http_client)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
Expand Down

0 comments on commit fc85617

Please sign in to comment.