From b45324d5fd3bd20bae0592645d4eec5f5785e8b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yannik=20R=C3=B6del?= Date: Thu, 25 Aug 2022 12:35:36 +0200 Subject: [PATCH] Fix remaining mypy errors --- tests/test_client.py | 1 + tests/test_records.py | 3 +-- tests/test_views.py | 6 +++-- zucker/client/base.py | 52 ++++++++++++++++++++++++------------------- 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 28b5b4c..3b00937 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -175,6 +175,7 @@ def handle_request( ) assert not client.authenticated assert client.request("get", "notaroute")["ping"] == "pong" + client = client assert client.authenticated # The following statement would be unreachable because of the two diff --git a/tests/test_records.py b/tests/test_records.py index 561c64c..205372c 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -1,5 +1,4 @@ -from typing import Sequence # noqa: F401 -from typing import cast +from typing import Callable, Sequence, cast # noqa: F401 from unittest.mock import MagicMock import pytest diff --git a/tests/test_views.py b/tests/test_views.py index 9be9074..f93aa70 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable, List, Optional, Tuple, TypedDict +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypedDict, cast from uuid import uuid4 import pytest @@ -222,7 +222,9 @@ def handle(method: str, url: str, params: JsonMapping) -> Optional[JsonMapping]: offset = int(params["offset"]) assert offset >= 0 - return {"records": record_data[offset : offset + max_num]} + return { + "records": cast(JsonMapping, record_data[offset : offset + max_num]) + } elif (method, url) == ("get", "Demo/count"): return {"record_count": len(record_data)} return None diff --git a/zucker/client/base.py b/zucker/client/base.py index 79d1a47..73d1c6a 100644 --- a/zucker/client/base.py +++ b/zucker/client/base.py @@ -480,21 +480,21 @@ async def request( # https://github.com/python/typeshed/blob/master/stdlib/asyncio/tasks.pyi#L44-L47 @overload - async def bulk(self, action_1: Coroutine[Any, Any, _T1], /) -> Tuple[_T1]: + async def bulk(self, action_1: Awaitable[_T1], /) -> Tuple[_T1]: ... @overload async def bulk( - self, action_1: Coroutine[Any, Any, _T1], action_2: Coroutine[Any, Any, _T2], / + self, action_1: Awaitable[_T1], action_2: Awaitable[_T2], / ) -> Tuple[_T1, _T2]: ... @overload async def bulk( self, - action_1: Coroutine[Any, Any, _T1], - action_2: Coroutine[Any, Any, _T2], - action_3: Coroutine[Any, Any, _T3], + action_1: Awaitable[_T1], + action_2: Awaitable[_T2], + action_3: Awaitable[_T3], /, ) -> Tuple[_T1, _T2, _T3]: ... @@ -502,10 +502,10 @@ async def bulk( @overload async def bulk( self, - action_1: Coroutine[Any, Any, _T1], - action_2: Coroutine[Any, Any, _T2], - action_3: Coroutine[Any, Any, _T3], - action_4: Coroutine[Any, Any, _T4], + action_1: Awaitable[_T1], + action_2: Awaitable[_T2], + action_3: Awaitable[_T3], + action_4: Awaitable[_T4], /, ) -> Tuple[_T1, _T2, _T3, _T4]: ... @@ -513,11 +513,11 @@ async def bulk( @overload async def bulk( self, - action_1: Coroutine[Any, Any, _T1], - action_2: Coroutine[Any, Any, _T2], - action_3: Coroutine[Any, Any, _T3], - action_4: Coroutine[Any, Any, _T4], - action_5: Coroutine[Any, Any, _T5], + action_1: Awaitable[_T1], + action_2: Awaitable[_T2], + action_3: Awaitable[_T3], + action_4: Awaitable[_T4], + action_5: Awaitable[_T5], /, ) -> Tuple[_T1, _T2, _T3, _T4, _T5]: ... @@ -525,21 +525,21 @@ async def bulk( @overload async def bulk( self, - action_1: Coroutine[Any, Any, _T1], - action_2: Coroutine[Any, Any, _T2], - action_3: Coroutine[Any, Any, _T3], - action_4: Coroutine[Any, Any, _T4], - action_5: Coroutine[Any, Any, _T5], - action_6: Coroutine[Any, Any, _T6], + action_1: Awaitable[_T1], + action_2: Awaitable[_T2], + action_3: Awaitable[_T3], + action_4: Awaitable[_T4], + action_5: Awaitable[_T5], + action_6: Awaitable[_T6], /, ) -> Tuple[_T1, _T2, _T3, _T4, _T5, _T6]: ... @overload - async def bulk(self, *actions: Coroutine[Any, Any, Any]) -> Tuple[Any, ...]: + async def bulk(self, *actions: Awaitable[Any]) -> Tuple[Any, ...]: ... - async def bulk(self, *actions: Coroutine[Any, Any, Any]) -> Tuple[Any, ...]: + async def bulk(self, *actions: Awaitable[Any]) -> Tuple[Any, ...]: """Run a sequence of actions that require server communication together. This will use Sugar's `Bulk API`_ to batch all actions together and send them @@ -593,8 +593,14 @@ async def handle_bulk( finally: counting_event.set() + # Actions are wrapped in this method because asyncio.create_task below expects + # a coroutine, but our actions are only awaitables (which are a supertype of + # coroutines). + async def run_action(action: Awaitable[_T]) -> _T: + return await action + self._handle_bulk = handle_bulk - action_tasks = [asyncio.create_task(action) for action in actions] + action_tasks = [asyncio.create_task(run_action(action)) for action in actions] for task in action_tasks: task.add_done_callback(lambda *_: counting_event.set())