Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typehint the public API #1

Merged
merged 9 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/types.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Run pyright.

on: [push]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install
run: pipx install flit

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"

- name: Install dependencies
run: flit install

- uses: jakebailey/pyright-action@v1
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [{name = "Thomas Grainger", email = "taskgroup@graingert.co.uk"}]
license = {file = "LICENSE"}
classifiers = ["License :: OSI Approved :: MIT License"]
dynamic = ["version", "description"]
dependencies = ["exceptiongroup"]
dependencies = ["exceptiongroup", "typing_extensions>=4.8,<5"]

[project.urls]
Home = "https://github.com/graingert/taskgroup"
4 changes: 0 additions & 4 deletions taskgroup/install.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import contextlib

import asyncio
import collections.abc
import contextlib
import functools
import types

from .tasks import task_factory as _task_factory, Task as _Task
from . import timeouts


UNCANCEL_DONE = object()
Expand Down
Empty file added taskgroup/py.typed
Empty file.
42 changes: 28 additions & 14 deletions taskgroup/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@

__all__ = ('Runner', 'run')

import collections.abc
import contextvars
import enum
import functools
import threading
import signal
from asyncio import coroutines
from asyncio import events
from asyncio import exceptions
from asyncio import tasks
from . tasks import task_factory as _task_factory
import threading
from asyncio import AbstractEventLoop, coroutines, events, exceptions, tasks
from typing import Any, TypeVar, final

from typing_extensions import Self

from .tasks import task_factory as _task_factory


class _State(enum.Enum):
Expand All @@ -23,6 +25,9 @@ class _State(enum.Enum):
CLOSED = "closed"


_T = TypeVar("_T")

@final
class Runner:
"""A context manager that controls event loop life cycle.

Expand Down Expand Up @@ -51,7 +56,12 @@ class Runner:

# Note: the class is final, it is not intended for inheritance.

def __init__(self, *, debug=None, loop_factory=None):
def __init__(
self,
*,
debug: bool | None = None,
loop_factory: collections.abc.Callable[[], AbstractEventLoop] | None = None
) -> None:
self._state = _State.CREATED
self._debug = debug
self._loop_factory = loop_factory
Expand All @@ -60,19 +70,21 @@ def __init__(self, *, debug=None, loop_factory=None):
self._interrupt_count = 0
self._set_event_loop = False

def __enter__(self):
def __enter__(self) -> Self:
self._lazy_init()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def close(self):
def close(self) -> None:
"""Shutdown and close event loop."""
if self._state is not _State.INITIALIZED:
return

loop = self._loop
assert loop is not None
try:
loop = self._loop
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
Expand All @@ -83,12 +95,13 @@ def close(self):
self._loop = None
self._state = _State.CLOSED

def get_loop(self):
def get_loop(self) -> AbstractEventLoop:
"""Return embedded event loop."""
self._lazy_init()
assert self._loop is not None
return self._loop

def run(self, coro, *, context=None):
def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: contextvars.Context | None = None) -> _T:
"""Run a coroutine inside the embedded event loop."""
if not coroutines.iscoroutine(coro):
raise ValueError("a coroutine was expected, got {!r}".format(coro))
Expand All @@ -99,6 +112,7 @@ def run(self, coro, *, context=None):
"Runner.run() cannot be called from a running event loop")

self._lazy_init()
assert self._loop is not None

if context is None:
context = self._context
Expand Down Expand Up @@ -134,7 +148,7 @@ def run(self, coro, *, context=None):
):
signal.signal(signal.SIGINT, signal.default_int_handler)

def _lazy_init(self):
def _lazy_init(self) -> None:
if self._state is _State.CLOSED:
raise RuntimeError("Runner is closed")
if self._state is _State.INITIALIZED:
Expand All @@ -160,7 +174,7 @@ def _on_sigint(self, signum, frame, main_task):
raise KeyboardInterrupt()


def run(main, *, debug=None):
def run(main: collections.abc.Coroutine[Any, Any, _T], *, debug: bool | None = None) -> _T:
"""Execute the coroutine and return the result.

This function runs the passed coroutine, taking care of
Expand Down
28 changes: 18 additions & 10 deletions taskgroup/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,31 @@
# Copyright © 2001-2022 Python Software Foundation; All Rights Reserved
# modified to support working on 3.10

from __future__ import annotations
from contextvars import Context

__all__ = ["TaskGroup"]

import sys
from asyncio import events
from asyncio import exceptions
from asyncio import tasks
from collections.abc import AsyncGenerator, Coroutine
from typing import TYPE_CHECKING, Any, TypeVar

from exceptiongroup import BaseExceptionGroup
import contextlib
from .tasks import task_factory as _task_factory
from .tasks import task_factory as _task_factory, Task
from . import install as _install

from typing_extensions import Self, Literal

_T = TypeVar("_T")


class TaskGroup:

def __init__(self):
def __init__(self) -> None:
self._entered = False
self._exiting = False
self._aborting = False
Expand All @@ -30,7 +39,7 @@ def __init__(self):
self._on_completed_fut = None
self._cmgr = self._cmgr_factory()

def __repr__(self):
def __repr__(self) -> str:
info = ['']
if self._tasks:
info.append(f'tasks={len(self._tasks)}')
Expand All @@ -45,7 +54,7 @@ def __repr__(self):
return f'<TaskGroup{info_str}>'

@contextlib.asynccontextmanager
async def _cmgr_factory(self):
async def _cmgr_factory(self) -> AsyncGenerator[Self, None]:
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
Expand Down Expand Up @@ -144,15 +153,14 @@ async def _cmgr_factory(self):
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
raise me from None

async def __aenter__(self):
async def __aenter__(self) -> Self:
return await self._cmgr.__aenter__()


async def __aexit__(self, *exc_info):
return await self._cmgr.__aexit__(*exc_info)
async def __aexit__(self, *exc_info) -> bool | None:
return await self._cmgr.__aexit__(*exc_info) # type: ignore


def create_task(self, coro, *, name=None, context=None):
def create_task(self, coro: Coroutine[Any, Any, _T], *, name: str | None = None, context: Context | None = None) -> Task[_T]:
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and not self._tasks:
Expand All @@ -176,7 +184,7 @@ def _is_base_error(self, exc: BaseException) -> bool:
assert isinstance(exc, BaseException)
return isinstance(exc, (SystemExit, KeyboardInterrupt))

def _abort(self):
def _abort(self) -> None:
self._aborting = True

for t in self._tasks:
Expand Down
50 changes: 36 additions & 14 deletions taskgroup/tasks.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,70 @@
import asyncio
import collections.abc
import contextvars
from typing import Any, Awaitable, TypeVar, cast

_YieldT = TypeVar("_YieldT")
_SendT = TypeVar("_SendT")
_ReturnT = TypeVar("_ReturnT", covariant=True)

@collections.abc.Coroutine.register
class _Interceptor:
def __init__(self, coro, context):
class _Interceptor(collections.abc.Coroutine[_YieldT, _SendT, _ReturnT]):
def __init__(
self,
coro: (
collections.abc.Coroutine[_YieldT, _SendT, _ReturnT]
| collections.abc.Generator[_YieldT, _SendT, _ReturnT]
),
context: contextvars.Context,
):
self.__coro = coro
self.__context = context

def send(self, v):
def send(self, v: _SendT):
return self.__context.run(self.__coro.send, v)

def throw(self, e):
def throw(self, e: BaseException):
return self.__context.run(self.__coro.throw, e)

def __getattr__(self, name):
return getattr(self.__coro, name)


class Task(asyncio.Task):
def __init__(self, coro, *args, context=None, **kwargs):
class Task(asyncio.Task[_ReturnT]):
def __init__(
self,
coro: (
Awaitable[_ReturnT]
| collections.abc.Coroutine[_YieldT, _SendT, _ReturnT]
| collections.abc.Generator[_YieldT, _SendT, _ReturnT]
),
*args,
context=None,
**kwargs
) -> None:
self._num_cancels_requested = 0
if context is not None:
assert isinstance(coro, (collections.abc.Coroutine, collections.abc.Generator))
coro = _Interceptor(coro, context)
super().__init__(coro, *args, **kwargs)
super().__init__(coro, *args, **kwargs) # type: ignore

def cancel(self, *args, **kwargs):
def cancel(self, *args: Any, **kwargs: Any) -> bool:
if not self.done():
self._num_cancels_requested += 1
return super().cancel(*args, **kwargs)

def cancelling(self):
def cancelling(self) -> int:
return self._num_cancels_requested

def uncancel(self):
def uncancel(self) -> int:
if self._num_cancels_requested > 0:
self._num_cancels_requested -= 1
return self._num_cancels_requested

def get_coro(self):
def get_coro(self) -> collections.abc.Generator[Any, Any, _ReturnT] | collections.abc.Awaitable[_ReturnT]:
coro = super().get_coro()
if isinstance(coro, _Interceptor):
return coro._Interceptor__coro
return coro._Interceptor__coro # type: ignore
return coro

def task_factory(loop, coro, **kwargs):
def task_factory(loop: asyncio.AbstractEventLoop, coro: collections.abc.Coroutine[Any, Any, _ReturnT], **kwargs: Any) -> Task[_ReturnT]:
return Task(coro, loop=loop, **kwargs)
6 changes: 4 additions & 2 deletions taskgroup/timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
# Copyright © 2001-2022 Python Software Foundation; All Rights Reserved
# modified to support working on 3.10 (basically just the imports changed here)

import collections.abc
import contextlib
import enum
import sys
from types import TracebackType
from typing import final, Optional, Type
from typing import final, Optional, Type, TYPE_CHECKING

from asyncio import events
from asyncio import exceptions
from asyncio import tasks
from . import install as _install

from typing_extensions import Self
Gobot1234 marked this conversation as resolved.
Show resolved Hide resolved

__all__ = (
"Timeout",
Expand Down Expand Up @@ -77,7 +79,7 @@ def __repr__(self) -> str:
return f"<Timeout [{self._state.value}]{info_str}>"

@contextlib.asynccontextmanager
async def _cmgr_factory(self):
async def _cmgr_factory(self) -> collections.abc.AsyncGenerator[Self, None]:
self._state = _State.ENTERED
async with _install.install_uncancel():
self._task = tasks.current_task()
Expand Down