From a8aa4b331915841ea66e5427b2d9210e207c1d1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 14 Dec 2023 13:31:40 +0200 Subject: [PATCH 01/19] Used TypeVarTuple and ParamSpec in several places to improve type annotation accuracy --- docs/versionhistory.rst | 6 ++++ src/anyio/_backends/_asyncio.py | 42 +++++++++++++++++-------- src/anyio/_backends/_trio.py | 36 ++++++++++++++-------- src/anyio/_core/_eventloop.py | 11 +++++-- src/anyio/_core/_fileio.py | 9 ++---- src/anyio/abc/_tasks.py | 11 +++++-- src/anyio/from_thread.py | 54 ++++++++++++++++++++++++--------- src/anyio/to_process.py | 11 +++++-- src/anyio/to_thread.py | 11 +++++-- tests/test_from_thread.py | 4 +-- 10 files changed, 140 insertions(+), 55 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 94eaf15f..6f2ba5fa 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -8,6 +8,11 @@ This library adheres to `Semantic Versioning 2.0 `_. - Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``, ``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by Lura Skye) +- Improved type annotations of numerous methods and functions including ``anyio.run()``, + ``TaskGroup.start_soon()``, ``anyio.from_thread.run()``, + ``anyio.to_thread.run_sync()`` and ``anyio.to_process.run_sync()`` by making use of + PEP 646 ``TypeVarTuple`` to allow the positional arguments to be validated by static + type checkers - Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help from Egor Blagov) @@ -16,6 +21,7 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed cancellation propagating on asyncio from a task group to child tasks if the task hosting the task group is in a shielded cancel scope (`#642 `_) +- Fixed the type annotation of ``anyio.Path.samefile()`` to match Typeshed **4.1.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 95b8e556..874169f7 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -82,8 +82,14 @@ from ..lowlevel import RunVar from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + if sys.version_info >= (3, 11): from asyncio import Runner + from typing import TypeVarTuple, Unpack else: import contextvars import enum @@ -91,6 +97,7 @@ from asyncio import coroutines, events, exceptions, tasks from exceptiongroup import BaseExceptionGroup + from typing_extensions import TypeVarTuple, Unpack class _State(enum.Enum): CREATED = "created" @@ -271,6 +278,8 @@ def _do_shutdown(future: asyncio.futures.Future) -> None: T_Retval = TypeVar("T_Retval") T_contra = TypeVar("T_contra", contravariant=True) +PosArgsT = TypeVarTuple("PosArgsT") +P = ParamSpec("P") _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") @@ -682,8 +691,8 @@ async def __aexit__( def _spawn( self, - func: Callable[..., Awaitable[Any]], - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + args: tuple[Unpack[PosArgsT]], name: object, task_status_future: asyncio.Future | None = None, ) -> asyncio.Task: @@ -752,7 +761,10 @@ def task_done(_task: asyncio.Task) -> None: return task def start_soon( - self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, ) -> None: self._spawn(func, args, name) @@ -875,11 +887,11 @@ def __init__(self) -> None: def _spawn_task_from_thread( self, - func: Callable, - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: AsyncIOBackend.run_sync_from_thread( partial(self._task_group.start_soon, name=name), @@ -1883,7 +1895,10 @@ async def _run_tests_and_fixtures( future.set_result(retval) async def _call_in_runner_task( - self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object + self, + func: Callable[P, Awaitable[T_Retval]], + *args: P.args, + **kwargs: P.kwargs, ) -> T_Retval: if not self._runner_task: self._send_stream, receive_stream = create_memory_object_stream[ @@ -2062,8 +2077,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: @classmethod async def run_sync_in_worker_thread( cls, - func: Callable[..., T_Retval], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], abandon_on_cancel: bool = False, limiter: abc.CapacityLimiter | None = None, ) -> T_Retval: @@ -2133,8 +2148,8 @@ def check_cancelled(cls) -> None: @classmethod def run_async_from_thread( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: async def task_wrapper(scope: CancelScope) -> T_Retval: @@ -2160,7 +2175,10 @@ async def task_wrapper(scope: CancelScope) -> T_Retval: @classmethod def run_sync_from_thread( - cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, ) -> T_Retval: @wraps(func) def wrapper() -> None: diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index eb891d22..35010900 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -62,12 +62,16 @@ from ..abc._eventloop import AsyncBackend from ..streams.memory import MemoryObjectSendStream -if sys.version_info < (3, 11): +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: from exceptiongroup import BaseExceptionGroup + from typing_extensions import TypeVarTuple, Unpack T = TypeVar("T") T_Retval = TypeVar("T_Retval") T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) +PosArgsT = TypeVarTuple("PosArgsT") # @@ -167,7 +171,12 @@ async def __aexit__( finally: self._active = False - def start_soon(self, func: Callable, *args: object, name: object = None) -> None: + def start_soon( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, + ) -> None: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." @@ -201,11 +210,11 @@ def __init__(self) -> None: def _spawn_task_from_thread( self, - func: Callable, - args: tuple, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: trio.from_thread.run_sync( partial(self._task_group.start_soon, name=name), @@ -806,8 +815,8 @@ class TrioBackend(AsyncBackend): @classmethod def run( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: @@ -866,8 +875,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: @classmethod async def run_sync_in_worker_thread( cls, - func: Callable[..., T_Retval], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], abandon_on_cancel: bool = False, limiter: abc.CapacityLimiter | None = None, ) -> T_Retval: @@ -889,15 +898,18 @@ def check_cancelled(cls) -> None: @classmethod def run_async_from_thread( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: return trio.from_thread.run(func, *args) @classmethod def run_sync_from_thread( - cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, ) -> T_Retval: return trio.from_thread.run_sync(func, *args) diff --git a/src/anyio/_core/_eventloop.py b/src/anyio/_core/_eventloop.py index b74d02b0..a9c6e825 100644 --- a/src/anyio/_core/_eventloop.py +++ b/src/anyio/_core/_eventloop.py @@ -10,6 +10,11 @@ import sniffio +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + if TYPE_CHECKING: from ..abc import AsyncBackend @@ -17,12 +22,14 @@ BACKENDS = "asyncio", "trio" T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + threadlocals = threading.local() def run( - func: Callable[..., Awaitable[T_Retval]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], backend: str = "asyncio", backend_options: dict[str, Any] | None = None, ) -> T_Retval: diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py index f51bf450..53f32339 100644 --- a/src/anyio/_core/_fileio.py +++ b/src/anyio/_core/_fileio.py @@ -15,7 +15,6 @@ AsyncIterator, Final, Generic, - cast, overload, ) @@ -211,7 +210,7 @@ async def __anext__(self) -> Path: if nextval is None: raise StopAsyncIteration from None - return Path(cast("PathLike[str]", nextval)) + return Path(nextval) class Path: @@ -518,7 +517,7 @@ def relative_to(self, *other: str | PathLike[str]) -> Path: async def readlink(self) -> Path: target = await to_thread.run_sync(os.readlink, self._path) - return Path(cast(str, target)) + return Path(target) async def rename(self, target: str | pathlib.PurePath | Path) -> Path: if isinstance(target, Path): @@ -545,9 +544,7 @@ def rglob(self, pattern: str) -> AsyncIterator[Path]: async def rmdir(self) -> None: await to_thread.run_sync(self._path.rmdir) - async def samefile( - self, other_path: str | bytes | int | pathlib.Path | Path - ) -> bool: + async def samefile(self, other_path: str | PathLike[str]) -> bool: if isinstance(other_path, Path): other_path = other_path._path diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 9ea3608e..7ad4938c 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -1,15 +1,22 @@ from __future__ import annotations +import sys from abc import ABCMeta, abstractmethod from collections.abc import Awaitable, Callable from types import TracebackType from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + if TYPE_CHECKING: from .._core._tasks import CancelScope T_Retval = TypeVar("T_Retval") T_contra = TypeVar("T_contra", contravariant=True) +PosArgsT = TypeVarTuple("PosArgsT") class TaskStatus(Protocol[T_contra]): @@ -42,8 +49,8 @@ class TaskGroup(metaclass=ABCMeta): @abstractmethod def start_soon( self, - func: Callable[..., Awaitable[Any]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], name: object = None, ) -> None: """ diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 63716496..31b09de1 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import threading from collections.abc import Awaitable, Callable, Generator from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait @@ -24,11 +25,19 @@ from .abc import AsyncBackend from .abc._tasks import TaskStatus +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") T_co = TypeVar("T_co") +PosArgsT = TypeVarTuple("PosArgsT") -def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: +def run( + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT] +) -> T_Retval: """ Call a coroutine function from a worker thread. @@ -48,7 +57,9 @@ def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: return async_backend.run_async_from_thread(func, args, token=token) -def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: +def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] +) -> T_Retval: """ Call a function in the event loop thread from a worker thread. @@ -182,7 +193,11 @@ async def stop(self, cancel_remaining: bool = False) -> None: self._task_group.cancel_scope.cancel() async def _call_func( - self, func: Callable, args: tuple, kwargs: dict[str, Any], future: Future + self, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + future: Future[T_Retval], ) -> None: def callback(f: Future) -> None: if f.cancelled() and self._event_loop_thread_id not in ( @@ -219,11 +234,11 @@ def callback(f: Future) -> None: def _spawn_task_from_thread( self, - func: Callable, - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: """ Spawn a new task using the given callable. @@ -241,17 +256,23 @@ def _spawn_task_from_thread( raise NotImplementedError @overload - def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: + def call( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], + ) -> T_Retval: ... @overload - def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: + def call( + self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] + ) -> T_Retval: ... def call( self, - func: Callable[..., Awaitable[T_Retval] | T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], ) -> T_Retval: """ Call the given function in the event loop thread. @@ -268,22 +289,25 @@ def call( @overload def start_task_soon( self, - func: Callable[..., Awaitable[T_Retval]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: ... @overload def start_task_soon( - self, func: Callable[..., T_Retval], *args: object, name: object = None + self, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + name: object = None, ) -> Future[T_Retval]: ... def start_task_soon( self, - func: Callable[..., Awaitable[T_Retval] | T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: """ diff --git a/src/anyio/to_process.py b/src/anyio/to_process.py index 2867d42d..1ff06f0b 100644 --- a/src/anyio/to_process.py +++ b/src/anyio/to_process.py @@ -18,9 +18,16 @@ from .lowlevel import RunVar, checkpoint_if_cancelled from .streams.buffered import BufferedByteReceiveStream +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + WORKER_MAX_IDLE_TIME = 300 # 5 minutes T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + _process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers") _process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar( "_process_pool_idle_workers" @@ -29,8 +36,8 @@ async def run_sync( - func: Callable[..., T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], cancellable: bool = False, limiter: CapacityLimiter | None = None, ) -> T_Retval: diff --git a/src/anyio/to_thread.py b/src/anyio/to_thread.py index d9a632e8..5070516e 100644 --- a/src/anyio/to_thread.py +++ b/src/anyio/to_thread.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from collections.abc import Callable from typing import TypeVar from warnings import warn @@ -7,12 +8,18 @@ from ._core._eventloop import get_async_backend from .abc import CapacityLimiter +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") async def run_sync( - func: Callable[..., T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], abandon_on_cancel: bool = False, cancellable: bool | None = None, limiter: CapacityLimiter | None = None, diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index 0e580462..ea041f8f 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -206,8 +206,8 @@ async def test_run_sync_from_thread_exception(self) -> None: exc.match("unsupported operand type") async def test_run_anyio_async_func_from_thread(self) -> None: - def worker(*args: int) -> Literal[True]: - from_thread.run(sleep, *args) + def worker(delay: float) -> Literal[True]: + from_thread.run(sleep, delay) return True assert await to_thread.run_sync(worker, 0) From 2c5c94d1c22fdc12d65106da443181a37c265972 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 14 Dec 2023 13:53:21 +0200 Subject: [PATCH 02/19] Added conditional dependency on typing_extensions --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b167e7e3..d2e22c0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "exceptiongroup >= 1.0.2; python_version < '3.11'", "idna >= 2.8", "sniffio >= 1.1", + "typing_extensions; python_version < '3.11'", ] dynamic = ["version"] From 2f876cbb55cc8ffdf408a4b5a1f63d039db8ac0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 14 Dec 2023 14:48:47 +0200 Subject: [PATCH 03/19] Ignored mypy error caused by a mypy bug Ref: https://github.com/python/mypy/issues/16662 --- src/anyio/from_thread.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 31b09de1..1053986a 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -421,7 +421,7 @@ async def run_portal() -> None: future: Future[BlockingPortal] = Future() with ThreadPoolExecutor(1) as executor: run_future = executor.submit( - _eventloop.run, + _eventloop.run, # type: ignore[arg-type] run_portal, backend=backend, backend_options=backend_options, From 9617c359ea57f3dd547c2f1b10bc80448b917ea7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 15 Dec 2023 13:19:57 +0200 Subject: [PATCH 04/19] Update pyproject.toml Co-authored-by: Ganden Schaffner --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d2e22c0d..93db79ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "exceptiongroup >= 1.0.2; python_version < '3.11'", "idna >= 2.8", "sniffio >= 1.1", - "typing_extensions; python_version < '3.11'", + "typing_extensions >= 4.1; python_version < '3.11'", ] dynamic = ["version"] From a25514fa024ee06fbc47b8c29f98caf4b4b7ac39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 15 Dec 2023 14:40:02 +0200 Subject: [PATCH 05/19] Converted some local functions in tests to coroutine functions --- tests/test_from_thread.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index ea041f8f..f387e755 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -507,29 +507,29 @@ async def run_in_context() -> AsyncGenerator[None, None]: def test_start_no_value( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started() with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, value = portal.start_task(taskfunc) # type: ignore[arg-type] + future, value = portal.start_task(taskfunc) assert value is None assert future.result() is None def test_start_with_value( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started("foo") with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, value = portal.start_task(taskfunc) # type: ignore[arg-type] + future, value = portal.start_task(taskfunc) assert value == "foo" assert future.result() is None def test_start_crash_before_started_call( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: object) -> NoReturn: + async def taskfunc(*, task_status: object) -> NoReturn: raise Exception("foo") with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: @@ -539,7 +539,7 @@ def taskfunc(*, task_status: object) -> NoReturn: def test_start_crash_after_started_call( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> NoReturn: + async def taskfunc(*, task_status: TaskStatus) -> NoReturn: task_status.started(2) raise Exception("foo") @@ -552,24 +552,21 @@ def taskfunc(*, task_status: TaskStatus) -> NoReturn: def test_start_no_started_call( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: pass with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: with pytest.raises(RuntimeError, match="Task exited"): - portal.start_task(taskfunc) # type: ignore[arg-type] + portal.start_task(taskfunc) def test_start_with_name( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started(get_current_task().name) with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, start_value = portal.start_task( - taskfunc, # type: ignore[arg-type] - name="testname", - ) + future, start_value = portal.start_task(taskfunc, name="testname") assert start_value == "testname" def test_contextvar_propagation_sync( From 4cf038bc4e693c55bf3c9988af2fc0baca895baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 15 Dec 2023 14:42:20 +0200 Subject: [PATCH 06/19] Improved annotations in _BlockingAsyncContextManager --- src/anyio/from_thread.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 1053986a..b054b8eb 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -80,8 +80,8 @@ def run_sync( class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): - _enter_future: Future - _exit_future: Future + _enter_future: Future[T_co] + _exit_future: Future[bool | None] _exit_event: Event _exit_exc_info: tuple[ type[BaseException] | None, BaseException | None, TracebackType | None @@ -117,8 +117,7 @@ async def run_async_cm(self) -> bool | None: def __enter__(self) -> T_co: self._enter_future = Future() self._exit_future = self._portal.start_task_soon(self.run_async_cm) - cm = self._enter_future.result() - return cast(T_co, cm) + return self._enter_future.result() def __exit__( self, From 6bfae847e7323d7f5344f799ffdf3ea4fbd8c097 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 15 Dec 2023 14:51:27 +0200 Subject: [PATCH 07/19] Improved annotations in BlockingPortal --- src/anyio/_backends/_asyncio.py | 2 +- src/anyio/_backends/_trio.py | 2 +- src/anyio/from_thread.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 874169f7..0e6fca48 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -887,7 +887,7 @@ def __init__(self) -> None: def _spawn_task_from_thread( self, - func: Callable[[Unpack[PosArgsT]], T_Retval], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index d700eb7b..3f933684 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -210,7 +210,7 @@ def __init__(self) -> None: def _spawn_task_from_thread( self, - func: Callable[[Unpack[PosArgsT]], T_Retval], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index b054b8eb..e9d79c61 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -233,7 +233,7 @@ def callback(f: Future) -> None: def _spawn_task_from_thread( self, - func: Callable[[Unpack[PosArgsT]], T_Retval], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, @@ -328,7 +328,7 @@ def start_task_soon( """ self._check_running() - f: Future = Future() + f: Future[T_Retval] = Future() self._spawn_task_from_thread(func, args, {}, name, f) return f From 83df7e8a3664c1dad155181307550cfbf466fbd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 11:54:14 +0200 Subject: [PATCH 08/19] Added more uses of TypeVarTuple --- src/anyio/_backends/_asyncio.py | 4 ++-- src/anyio/_backends/_trio.py | 5 ++++- src/anyio/streams/tls.py | 9 ++++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 0e6fca48..8d42ead6 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1964,8 +1964,8 @@ class AsyncIOBackend(AsyncBackend): @classmethod def run( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 3f933684..eb4472f0 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -763,7 +763,10 @@ def _main_task_finished(self, outcome: object) -> None: self._send_stream = None def _call_in_runner_task( - self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], + **kwargs: object, ) -> T_Retval: if self._send_stream is None: trio.lowlevel.start_guest_run( diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index 8468f33d..e913eedb 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -3,6 +3,7 @@ import logging import re import ssl +import sys from collections.abc import Callable, Mapping from dataclasses import dataclass from functools import wraps @@ -17,7 +18,13 @@ from .._core._typedattr import TypedAttributeSet, typed_attribute from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") _PCTRTT = Tuple[Tuple[str, str], ...] _PCTRTTT = Tuple[_PCTRTT, ...] @@ -126,7 +133,7 @@ async def wrap( return wrapper async def _call_sslobject_method( - self, func: Callable[..., T_Retval], *args: object + self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] ) -> T_Retval: while True: try: From 5f3e2745096b8cf39a15e7cad82d1c89df3660fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 12:12:04 +0200 Subject: [PATCH 09/19] Used TypeVarTuple in AsyncBackend.run() --- src/anyio/abc/_eventloop.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py index 9f1660c9..c5f8e13e 100644 --- a/src/anyio/abc/_eventloop.py +++ b/src/anyio/abc/_eventloop.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import sys from abc import ABCMeta, abstractmethod from collections.abc import AsyncIterator, Awaitable, Mapping from os import PathLike @@ -17,6 +18,11 @@ overload, ) +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + if TYPE_CHECKING: from typing import Literal @@ -39,6 +45,7 @@ from ._testing import TestRunner T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") class AsyncBackend(metaclass=ABCMeta): @@ -46,8 +53,8 @@ class AsyncBackend(metaclass=ABCMeta): @abstractmethod def run( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: From b21dc7dc54262d46c75db90f3a695064c60eab86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 12:28:30 +0200 Subject: [PATCH 10/19] Improved annotations on BlockingPortal._call_func() --- src/anyio/from_thread.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index e9d79c61..d558b366 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -193,12 +193,12 @@ async def stop(self, cancel_remaining: bool = False) -> None: async def _call_func( self, - func: Callable[[Unpack[PosArgsT]], T_Retval], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], future: Future[T_Retval], ) -> None: - def callback(f: Future) -> None: + def callback(f: Future[T_Retval]) -> None: if f.cancelled() and self._event_loop_thread_id not in ( None, threading.get_ident(), @@ -206,15 +206,17 @@ def callback(f: Future) -> None: self.call(scope.cancel) try: - retval = func(*args, **kwargs) - if isawaitable(retval): + retval_or_awaitable = func(*args, **kwargs) + if isawaitable(retval_or_awaitable): with CancelScope() as scope: if future.cancelled(): scope.cancel() else: future.add_done_callback(callback) - retval = await retval + retval = await retval_or_awaitable + else: + retval = retval_or_awaitable except self._cancelled_exc_class: future.cancel() future.set_running_or_notify_cancel() From 8fbbf44eb528c51568178ff3ccc4dcf2d2015e6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 12:48:12 +0200 Subject: [PATCH 11/19] Used ParamSpec in Trio's _call_in_runner_task() --- src/anyio/_backends/_trio.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index eb4472f0..71e257a2 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -62,6 +62,11 @@ from ..abc._eventloop import AsyncBackend from ..streams.memory import MemoryObjectSendStream +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + if sys.version_info >= (3, 11): from typing import TypeVarTuple, Unpack else: @@ -72,6 +77,7 @@ T_Retval = TypeVar("T_Retval") T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) PosArgsT = TypeVarTuple("PosArgsT") +P = ParamSpec("P") # @@ -764,9 +770,9 @@ def _main_task_finished(self, outcome: object) -> None: def _call_in_runner_task( self, - func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], - *args: Unpack[PosArgsT], - **kwargs: object, + func: Callable[P, Awaitable[T_Retval]], + *args: P.args, + **kwargs: P.kwargs, ) -> T_Retval: if self._send_stream is None: trio.lowlevel.start_guest_run( From ef77d25a5c878418c5c973259939e5742a5300c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 12:48:41 +0200 Subject: [PATCH 12/19] Link to #560 in the changelog --- docs/versionhistory.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 9f2f7d6b..e3752422 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -14,7 +14,7 @@ This library adheres to `Semantic Versioning 2.0 `_. ``TaskGroup.start_soon()``, ``anyio.from_thread.run()``, ``anyio.to_thread.run_sync()`` and ``anyio.to_process.run_sync()`` by making use of PEP 646 ``TypeVarTuple`` to allow the positional arguments to be validated by static - type checkers + type checkers (`#560 `_) - Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help from Egor Blagov) From 72bebdf0c48ef5e32069da1cbed90c3b9f22475a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 12:53:54 +0200 Subject: [PATCH 13/19] Added more missing TypeVarTuple uses to AsyncBackend --- src/anyio/abc/_eventloop.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py index c5f8e13e..4470d83d 100644 --- a/src/anyio/abc/_eventloop.py +++ b/src/anyio/abc/_eventloop.py @@ -176,8 +176,8 @@ def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: @abstractmethod async def run_sync_in_worker_thread( cls, - func: Callable[..., T_Retval], - args: tuple[Any, ...], + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], abandon_on_cancel: bool = False, limiter: CapacityLimiter | None = None, ) -> T_Retval: @@ -192,8 +192,8 @@ def check_cancelled(cls) -> None: @abstractmethod def run_async_from_thread( cls, - func: Callable[..., Awaitable[T_Retval]], - args: tuple[Any], + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: pass @@ -201,7 +201,10 @@ def run_async_from_thread( @classmethod @abstractmethod def run_sync_from_thread( - cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, ) -> T_Retval: pass From 6cad8fcbae81b633f9414500ebf2c3051af8d2ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 13:04:31 +0200 Subject: [PATCH 14/19] Made T_co actually covariant --- src/anyio/from_thread.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index d558b366..5e6e6e69 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -31,7 +31,7 @@ from typing_extensions import TypeVarTuple, Unpack T_Retval = TypeVar("T_Retval") -T_co = TypeVar("T_co") +T_co = TypeVar("T_co", covariant=True) PosArgsT = TypeVarTuple("PosArgsT") From 19766dc4a300958761f612cc06a317bed495f25d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 13:06:44 +0200 Subject: [PATCH 15/19] Fixed return type annotation for TaskGroup.start() --- src/anyio/_backends/_asyncio.py | 2 +- src/anyio/_backends/_trio.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 8d42ead6..e884f564 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -770,7 +770,7 @@ def start_soon( async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None - ) -> None: + ) -> Any: future: asyncio.Future = asyncio.Future() task = self._spawn(func, args, name, future) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 71e257a2..b18586c8 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -192,7 +192,7 @@ def start_soon( async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None - ) -> object: + ) -> Any: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." From c8d703214db5610cac86d3b80487401637861358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 13:13:29 +0200 Subject: [PATCH 16/19] Improved type annotations of `BlockingPortal.start_task()` --- src/anyio/from_thread.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 5e6e6e69..4a987031 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -336,10 +336,10 @@ def start_task_soon( def start_task( self, - func: Callable[..., Awaitable[Any]], + func: Callable[..., Awaitable[T_Retval]], *args: object, name: object = None, - ) -> tuple[Future[Any], Any]: + ) -> tuple[Future[T_Retval], Any]: """ Start a task in the portal's task group and wait until it signals for readiness. @@ -351,13 +351,13 @@ def start_task( :return: a tuple of (future, task_status_value) where the ``task_status_value`` is the value passed to ``task_status.started()`` from within the target function - :rtype: tuple[concurrent.futures.Future[Any], Any] + :rtype: tuple[concurrent.futures.Future[T_Retval], Any] .. versionadded:: 3.0 """ - def task_done(future: Future) -> None: + def task_done(future: Future[T_Retval]) -> None: if not task_status_future.done(): if future.cancelled(): task_status_future.cancel() From 246c7dcedb2c7b22a55c3c5e7fd4eaff65ef4f40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 13:18:52 +0200 Subject: [PATCH 17/19] Narrowed down the type of _call_queue in Trio TestRunner --- src/anyio/_backends/_trio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index b18586c8..a0d14c74 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -739,7 +739,7 @@ class TestRunner(abc.TestRunner): def __init__(self, **options: Any) -> None: from queue import Queue - self._call_queue: Queue[Callable[..., object]] = Queue() + self._call_queue: Queue[Callable[[], object]] = Queue() self._send_stream: MemoryObjectSendStream | None = None self._options = options From d64b2117139992673ad8c4616e6958bee616de53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 13:45:48 +0200 Subject: [PATCH 18/19] Updated the changelog entry to match the scope of the changes --- docs/versionhistory.rst | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index e3752422..ba1efdb3 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -10,11 +10,20 @@ This library adheres to `Semantic Versioning 2.0 `_. Lura Skye) - Enabled the ``Event`` and ``CapacityLimiter`` classes to be instantiated outside an event loop thread -- Improved type annotations of numerous methods and functions including ``anyio.run()``, - ``TaskGroup.start_soon()``, ``anyio.from_thread.run()``, - ``anyio.to_thread.run_sync()`` and ``anyio.to_process.run_sync()`` by making use of - PEP 646 ``TypeVarTuple`` to allow the positional arguments to be validated by static - type checkers (`#560 `_) +- Broadly improved/fixed the type annotations. Among other things, many functions and + methods that take variadic positional arguments now make use of PEP 646 + ``TypeVarTuple`` to allow the positional arguments to be validated by static type + checkers. These changes affected numerous methods and functions, including: + + * ``anyio.run()`` + * ``TaskGroup.start_soon()`` + * ``anyio.from_thread.run()`` + * ``anyio.to_thread.run_sync()`` + * ``anyio.to_process.run_sync()`` + * ``BlockingPortal.start_task_soon()`` + * ``BlockingPortal.start_task()`` + + (`#560 `_) - Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help from Egor Blagov) From 14f150eca9d2e985d6f2efb5ae6fcbdd6563b97f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 16 Dec 2023 13:50:08 +0200 Subject: [PATCH 19/19] Also mentioned BlockingPortal.call() in the changelog --- docs/versionhistory.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index ba1efdb3..f8c0b293 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -20,6 +20,7 @@ This library adheres to `Semantic Versioning 2.0 `_. * ``anyio.from_thread.run()`` * ``anyio.to_thread.run_sync()`` * ``anyio.to_process.run_sync()`` + * ``BlockingPortal.call()`` * ``BlockingPortal.start_task_soon()`` * ``BlockingPortal.start_task()``