Skip to content

Commit

Permalink
Merge pull request #156 from stealthrocket/mypy-1.10.0
Browse files Browse the repository at this point in the history
mypy: upgrade to 1.10.0
  • Loading branch information
achille-roussel authored Apr 25, 2024
2 parents bfdd64b + 8a92dbf commit a47ae12
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 168 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fmt-check:
exit $$((isort_status + black_status))

typecheck:
$(PYTHON) -m mypy src tests
$(PYTHON) -m mypy --check-untyped-defs src tests examples

unittest:
$(PYTHON) -m pytest tests
Expand Down
2 changes: 1 addition & 1 deletion examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
endpoint_client = EndpointClient(TestClient(app))
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
Expand Down
2 changes: 1 addition & 1 deletion examples/getting_started/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
endpoint_client = EndpointClient(TestClient(app))
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
Expand Down
2 changes: 1 addition & 1 deletion examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
endpoint_client = EndpointClient(TestClient(app))
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
Expand Down
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lambda = ["awslambdaric"]
dev = [
"black >= 24.1.0",
"isort >= 5.13.2",
"mypy >= 1.8.0",
"mypy >= 1.10.0",
"pytest==8.0.0",
"fastapi >= 0.109.0",
"coverage >= 7.4.1",
Expand Down Expand Up @@ -58,7 +58,15 @@ src_paths = ["src"]
omit = ["*_pb2_grpc.py", "*_pb2.py", "tests/*", "examples/*", "src/buf/*"]

[tool.mypy]
exclude = ['^src/buf', '^tests/examples']
exclude = [
'^src/buf',
'^tests/examples',
# mypy 1.10.0 reports false positives for these two files:
# src/dispatch/sdk/v1/function_pb2_grpc.py:74: error: Module has no attribute "experimental" [attr-defined]
# src/dispatch/sdk/v1/dispatch_pb2_grpc.py:80: error: Module has no attribute "experimental" [attr-defined]
'^src/dispatch/sdk/v1/function_pb2_grpc.py',
'^src/dispatch/sdk/v1/dispatch_pb2_grpc.py',
]

[tool.pytest.ini_options]
testpaths = ['tests']
6 changes: 4 additions & 2 deletions src/dispatch/experimental/durable/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,15 @@ def __getstate__(self):
if frame_state < FRAME_CLEARED:
print(f"IP = {ip}")
print(f"SP = {sp}")
for i, (is_null, value) in enumerate(stack):
for i, (is_null, value) in enumerate(
stack if stack is not None else []
):
if is_null:
print(f"stack[{i}] = NULL")
else:
print(f"stack[{i}] = {value}")
print(f"BP = {bp}")
for i, block in enumerate(blocks):
for i, block in enumerate(blocks if blocks is not None else []):
print(f"block[{i}] = {block}")
print()

Expand Down
4 changes: 3 additions & 1 deletion src/dispatch/experimental/durable/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __setstate__(self, state):
f"hash mismatch for function {key}: {code_hash} vs. expected {rfn.hash}"
)

self.fn = rfn.fn
# mypy 1.10.0 seems to report a false positive here:
# error: Incompatible types in assignment (expression has type "FunctionType", variable has type "MethodType") [assignment]
self.fn = rfn.fn # type: ignore
self.key = key
self.filename = filename
self.lineno = lineno
Expand Down
6 changes: 4 additions & 2 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,11 @@ def primitive_func(input: Input) -> Output:
return OneShotScheduler(func).run(input)

primitive_func.__qualname__ = f"{name}_primitive"
primitive_func = durable(primitive_func)
durable_primitive_func = durable(primitive_func)

wrapped_func = Function[P, T](self.endpoint, self.client, name, primitive_func)
wrapped_func = Function[P, T](
self.endpoint, self.client, name, durable_primitive_func
)
self._register(name, wrapped_func)
return wrapped_func

Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def do_POST(self):
)
max_age = timedelta(minutes=5)
try:
verify_request(signed_request, verification_key, max_age)
verify_request(signed_request, self.verification_key, max_age)
except ValueError as e:
self.send_error_response_unauthenticated(str(e))
return
Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __str__(self):
_OUTPUT_TYPES: Dict[Type[Any], Callable[[Any], Status]] = {}


def status_for_error(error: Exception) -> Status:
def status_for_error(error: BaseException) -> Status:
"""Returns a Status that corresponds to the specified error."""
# See if the error matches one of the registered types.
handler = _find_handler(error, _ERROR_TYPES)
Expand Down
13 changes: 1 addition & 12 deletions src/dispatch/test/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from datetime import datetime
from typing import Optional

import fastapi
import grpc
import httpx
from fastapi.testclient import TestClient

from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.sdk.v1 import function_pb2_grpc as function_grpc
Expand All @@ -22,7 +20,7 @@ class EndpointClient:
Note that this is different from dispatch.Client, which is a client
for the Dispatch API. The EndpointClient is a client similar to the one
that Dispatch itself would use to interact with an endpoint that provides
functions, for example a FastAPI app.
functions.
"""

def __init__(
Expand Down Expand Up @@ -54,15 +52,6 @@ def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None):
http_client = httpx.Client(base_url=url)
return EndpointClient(http_client, signing_key)

@classmethod
def from_app(
cls, app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
):
"""Returns an EndpointClient for a Dispatch endpoint bound to a
FastAPI app instance."""
http_client = TestClient(app)
return EndpointClient(http_client, signing_key)


class _HttpxGrpcChannel(grpc.Channel):
def __init__(
Expand Down
9 changes: 4 additions & 5 deletions src/dispatch/test/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ def __init__(
self.pollers: Dict[DispatchID, Poller] = {}
self.parents: Dict[DispatchID, Poller] = {}

self.roundtrips: Optional[OrderedDict[DispatchID, List[RoundTrip]]] = None
if collect_roundtrips:
self.roundtrips = OrderedDict()
self.roundtrips: OrderedDict[DispatchID, List[RoundTrip]] = OrderedDict()
self.collect_roundtrips = collect_roundtrips

self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
Expand Down Expand Up @@ -142,7 +141,7 @@ def _make_dispatch_id(self) -> DispatchID:
def dispatch_calls(self):
"""Synchronously dispatch pending function calls to the
configured endpoint."""
_next_queue = []
_next_queue: List[Tuple[DispatchID, function_pb.RunRequest, CallType]] = []
while self.queue:
dispatch_id, request, call_type = self.queue.pop(0)

Expand All @@ -161,7 +160,7 @@ def dispatch_calls(self):
self.queue.append((dispatch_id, request, CallType.RETRY))
raise

if self.roundtrips is not None:
if self.collect_roundtrips:
try:
roundtrips = self.roundtrips[dispatch_id]
except KeyError:
Expand Down
38 changes: 19 additions & 19 deletions tests/dispatch/signature/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,58 +146,58 @@ def test_known_signature(self):
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem})
def test_parse_verification_key_env_pem_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw(), public_key2_bytes

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem2})
def test_parse_verification_key_env_pem_escaped_newline_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw(), public_key2_bytes

@mock.patch.dict(
os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_b64.decode()}
)
def test_parse_verification_key_env_b64_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw() == public_key2_bytes

def test_parse_verification_key_none(self):
# The verification key is optional. Both Dispatch(verification_key=...) and
# DISPATCH_VERIFICATION_KEY may be omitted/None.
verification_key = parse_verification_key(None)
self.assertIsNone(verification_key)
assert verification_key is None

def test_parse_verification_key_ed25519publickey(self):
verification_key = parse_verification_key(public_key2)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw() == public_key2_bytes

def test_parse_verification_key_pem_str(self):
verification_key = parse_verification_key(public_key2_pem)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw() == public_key2_bytes

def test_parse_verification_key_pem_escaped_newline_str(self):
verification_key = parse_verification_key(public_key2_pem2)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw() == public_key2_bytes

def test_parse_verification_key_pem_bytes(self):
verification_key = parse_verification_key(public_key2_pem.encode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw() == public_key2_bytes

def test_parse_verification_key_b64_str(self):
verification_key = parse_verification_key(public_key2_b64.decode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw() == public_key2_bytes

def test_parse_verification_key_b64_bytes(self):
verification_key = parse_verification_key(public_key2_b64)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
assert isinstance(verification_key, Ed25519PublicKey)
assert verification_key.public_bytes_raw() == public_key2_bytes

def test_parse_verification_key_invalid(self):
with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"):
Expand Down
2 changes: 1 addition & 1 deletion tests/dispatch/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_conversion_without_traceback(self):
original_exception = e

error = Error.from_exception(original_exception)
error.traceback = ""
error.traceback = b""

reconstructed_exception = error.to_exception()
assert type(reconstructed_exception) is type(original_exception)
Expand Down
4 changes: 2 additions & 2 deletions tests/dispatch/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from typing import Any, Callable, List, Optional, Type
from typing import Any, Callable, List, Optional, Set, Type

from dispatch.coroutine import AnyException, any, call, gather, race
from dispatch.experimental.durable import durable
Expand Down Expand Up @@ -255,7 +255,7 @@ async def main():
result3 = await call_concurrently("g", "h")
return [result1, result2, result3]

correlation_ids = set()
correlation_ids: Set[int] = set()

output = self.start(main)
# a, b, c, d are called first. e is not because it depends on a.
Expand Down
Loading

0 comments on commit a47ae12

Please sign in to comment.