Skip to content

Commit

Permalink
Fix remaining mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
yrd committed Aug 25, 2022
1 parent 9c4e30c commit b45324d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 27 deletions.
1 change: 1 addition & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def handle_request(
)
assert not client.authenticated
assert client.request("get", "notaroute")["ping"] == "pong"
client = client
assert client.authenticated

# The following statement would be unreachable because of the two
Expand Down
3 changes: 1 addition & 2 deletions tests/test_records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Sequence # noqa: F401
from typing import cast
from typing import Callable, Sequence, cast # noqa: F401
from unittest.mock import MagicMock

import pytest
Expand Down
6 changes: 4 additions & 2 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypedDict
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypedDict, cast
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -222,7 +222,9 @@ def handle(method: str, url: str, params: JsonMapping) -> Optional[JsonMapping]:
offset = int(params["offset"])
assert offset >= 0

return {"records": record_data[offset : offset + max_num]}
return {
"records": cast(JsonMapping, record_data[offset : offset + max_num])
}
elif (method, url) == ("get", "Demo/count"):
return {"record_count": len(record_data)}
return None
Expand Down
52 changes: 29 additions & 23 deletions zucker/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,66 +480,66 @@ async def request(
# https://github.com/python/typeshed/blob/master/stdlib/asyncio/tasks.pyi#L44-L47

@overload
async def bulk(self, action_1: Coroutine[Any, Any, _T1], /) -> Tuple[_T1]:
async def bulk(self, action_1: Awaitable[_T1], /) -> Tuple[_T1]:
...

@overload
async def bulk(
self, action_1: Coroutine[Any, Any, _T1], action_2: Coroutine[Any, Any, _T2], /
self, action_1: Awaitable[_T1], action_2: Awaitable[_T2], /
) -> Tuple[_T1, _T2]:
...

@overload
async def bulk(
self,
action_1: Coroutine[Any, Any, _T1],
action_2: Coroutine[Any, Any, _T2],
action_3: Coroutine[Any, Any, _T3],
action_1: Awaitable[_T1],
action_2: Awaitable[_T2],
action_3: Awaitable[_T3],
/,
) -> Tuple[_T1, _T2, _T3]:
...

@overload
async def bulk(
self,
action_1: Coroutine[Any, Any, _T1],
action_2: Coroutine[Any, Any, _T2],
action_3: Coroutine[Any, Any, _T3],
action_4: Coroutine[Any, Any, _T4],
action_1: Awaitable[_T1],
action_2: Awaitable[_T2],
action_3: Awaitable[_T3],
action_4: Awaitable[_T4],
/,
) -> Tuple[_T1, _T2, _T3, _T4]:
...

@overload
async def bulk(
self,
action_1: Coroutine[Any, Any, _T1],
action_2: Coroutine[Any, Any, _T2],
action_3: Coroutine[Any, Any, _T3],
action_4: Coroutine[Any, Any, _T4],
action_5: Coroutine[Any, Any, _T5],
action_1: Awaitable[_T1],
action_2: Awaitable[_T2],
action_3: Awaitable[_T3],
action_4: Awaitable[_T4],
action_5: Awaitable[_T5],
/,
) -> Tuple[_T1, _T2, _T3, _T4, _T5]:
...

@overload
async def bulk(
self,
action_1: Coroutine[Any, Any, _T1],
action_2: Coroutine[Any, Any, _T2],
action_3: Coroutine[Any, Any, _T3],
action_4: Coroutine[Any, Any, _T4],
action_5: Coroutine[Any, Any, _T5],
action_6: Coroutine[Any, Any, _T6],
action_1: Awaitable[_T1],
action_2: Awaitable[_T2],
action_3: Awaitable[_T3],
action_4: Awaitable[_T4],
action_5: Awaitable[_T5],
action_6: Awaitable[_T6],
/,
) -> Tuple[_T1, _T2, _T3, _T4, _T5, _T6]:
...

@overload
async def bulk(self, *actions: Coroutine[Any, Any, Any]) -> Tuple[Any, ...]:
async def bulk(self, *actions: Awaitable[Any]) -> Tuple[Any, ...]:
...

async def bulk(self, *actions: Coroutine[Any, Any, Any]) -> Tuple[Any, ...]:
async def bulk(self, *actions: Awaitable[Any]) -> Tuple[Any, ...]:
"""Run a sequence of actions that require server communication together.
This will use Sugar's `Bulk API`_ to batch all actions together and send them
Expand Down Expand Up @@ -593,8 +593,14 @@ async def handle_bulk(
finally:
counting_event.set()

# Actions are wrapped in this method because asyncio.create_task below expects
# a coroutine, but our actions are only awaitables (which are a supertype of
# coroutines).
async def run_action(action: Awaitable[_T]) -> _T:
return await action

self._handle_bulk = handle_bulk
action_tasks = [asyncio.create_task(action) for action in actions]
action_tasks = [asyncio.create_task(run_action(action)) for action in actions]
for task in action_tasks:
task.add_done_callback(lambda *_: counting_event.set())

Expand Down

0 comments on commit b45324d

Please sign in to comment.