Skip to content

Commit

Permalink
Merge pull request #11 from graingert/fix-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert authored Dec 14, 2024
2 parents 46f8693 + 4f3cd4c commit 51c9b27
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/types.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
- name: Install
run: pip install .

- uses: jakebailey/pyright-action@v1
- uses: jakebailey/pyright-action@v2
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [{name = "Thomas Grainger", email = "taskgroup@graingert.co.uk"}]
license = {file = "LICENSE"}
classifiers = ["License :: OSI Approved :: MIT License"]
dynamic = ["version", "description"]
dependencies = ["exceptiongroup", "typing_extensions>=4.8,<5"]
dependencies = ["exceptiongroup", "typing_extensions>=4.12.2,<5"]

[project.urls]
Home = "https://github.com/graingert/taskgroup"
29 changes: 21 additions & 8 deletions taskgroup/install.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import contextvars
import asyncio
import collections.abc
import contextlib
import types
from typing import Any, cast

from .tasks import task_factory as _task_factory, Task as _Task

from typing_extensions import Self, TypeVar


UNCANCEL_DONE = object()

Expand Down Expand Up @@ -36,33 +40,42 @@ def add_done_callback(self, fn, *, context):
def _async_yield(v):
return (yield v)

_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)


class WrapCoro(collections.abc.Coroutine):
def __init__(self, coro, context):
class WrapCoro(collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]):
def __init__(self, coro: collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], context: contextvars.Context):
self._coro = coro
self._context = context

def __await__(self):
def __await__(self) -> Self:
return self

def __iter__(self):
def __iter__(self) -> Self:
return self

def __next__(self):
return self.send(None)
def __next__(self) -> _YieldT_co:
return self.send(cast(_SendT_contra_nd, None))

def throw(self, *exc_info):
def throw(self, *exc_info) -> _YieldT_co:
result = self._context.run(self._coro.throw, *exc_info)
if result is UNCANCEL_DONE:
raise StopIteration
return result

def send(self, v):
def send(self, v: _SendT_contra_nd) -> _YieldT_co:
result = self._context.run(self._coro.send, v)
if result is UNCANCEL_DONE:
raise StopIteration
return result

def close(self) -> None:
super().close()


@contextlib.asynccontextmanager
async def install_uncancel():
Expand Down
50 changes: 31 additions & 19 deletions taskgroup/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,61 @@
import asyncio
import collections.abc
import contextvars
from typing import Any, Awaitable, TypeVar, cast
from typing import Any, cast, TypeAlias
from typing_extensions import TypeVar
import sys

_YieldT = TypeVar("_YieldT")
_SendT = TypeVar("_SendT")
_ReturnT = TypeVar("_ReturnT", covariant=True)
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_TaskYieldType: TypeAlias = asyncio.Future[object] | None

if sys.version_info >= (3, 12):
_TaskCompatibleCoro: TypeAlias = collections.abc.Coroutine[Any, Any, _T_co]
elif sys.version_info >= (3, 9):
_TaskCompatibleCoro: TypeAlias = collectiona.abc.Generator[_TaskYieldType, None, _T_co] | Coroutine[Any, Any, _T_co]

class _Interceptor(collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]):

class _Interceptor(collections.abc.Coroutine[_YieldT, _SendT, _ReturnT]):
def __init__(
self,
coro: (
collections.abc.Coroutine[_YieldT, _SendT, _ReturnT]
| collections.abc.Generator[_YieldT, _SendT, _ReturnT]
collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
| collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
),
context: contextvars.Context,
):
self.__coro = coro
self.__context = context

def send(self, v: _SendT):
def send(self, v: _SendT_contra_nd) -> _YieldT_co:
return self.__context.run(self.__coro.send, v)

def throw(self, e: BaseException):
return self.__context.run(self.__coro.throw, e)
def throw(self, *exc_info) -> _YieldT_co:
return self.__context.run(self.__coro.throw, *exc_info)

def __getattr__(self, name):
return getattr(self.__coro, name)

def close(self) -> None:
super().close()

class Task(asyncio.Task[_ReturnT]):

class Task(asyncio.Task[_T_co]):
def __init__(
self,
coro: (
Awaitable[_ReturnT]
| collections.abc.Coroutine[_YieldT, _SendT, _ReturnT]
| collections.abc.Generator[_YieldT, _SendT, _ReturnT]
),
coro: _TaskCompatibleCoro[_T_co],
*args,
context=None,
**kwargs
) -> None:
self._num_cancels_requested = 0
if context is not None:
assert isinstance(coro, (collections.abc.Coroutine, collections.abc.Generator))
coro = _Interceptor(coro, context)
super().__init__(coro, *args, **kwargs) # type: ignore

Expand All @@ -62,11 +74,11 @@ def uncancel(self) -> int:
self._num_cancels_requested -= 1
return self._num_cancels_requested

def get_coro(self) -> collections.abc.Generator[Any, Any, _ReturnT] | collections.abc.Awaitable[_ReturnT]:
def get_coro(self) -> _TaskCompatibleCoro[_T_co] | None:
coro = super().get_coro()
if isinstance(coro, _Interceptor):
return coro._Interceptor__coro # type: ignore
return coro

def task_factory(loop: asyncio.AbstractEventLoop, coro: collections.abc.Coroutine[Any, Any, _ReturnT] | collections.abc.Generator[Any, Any, _ReturnT], **kwargs: Any) -> Task[_ReturnT]:
def task_factory(loop: asyncio.AbstractEventLoop, coro: _TaskCompatibleCoro[_T_co], **kwargs: Any) -> Task[_T_co]:
return Task(coro, loop=loop, **kwargs)

0 comments on commit 51c9b27

Please sign in to comment.