Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed cancellation propagation when task group host is in a shielded scope #648

Merged
merged 8 commits into from
Dec 14, 2023
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
from Egor Blagov)
- Fixed ``loop_factory`` and ``use_uvloop`` options not being used on the asyncio
backend (`#643 <https://github.com/agronholm/anyio/issues/643>`_)
- Fixed cancellation propagating on asyncio from a task group to child tasks if the task
hosting the task group is in a shielded cancel scope
(`#642 <https://github.com/agronholm/anyio/issues/642>`_)

**4.1.0**

Expand Down
110 changes: 57 additions & 53 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope: CancelScope | None = None
self._child_scopes: set[CancelScope] = set()
self._cancel_called = False
self._cancelled_caught = False
self._active = False
Expand All @@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope:
else:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
if self._parent_scope is not None:
self._parent_scope._child_scopes.add(self)
self._parent_scope._tasks.remove(host_task)

self._timeout()
self._active = True
Expand All @@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope:

# Start cancelling the host task if the scope was cancelled before entering
if self._cancel_called:
self._deliver_cancellation()
self._deliver_cancellation(self)

return self

Expand Down Expand Up @@ -409,13 +413,15 @@ def __exit__(
self._timeout_handle = None

self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Restart the cancellation effort in the farthest directly cancelled parent
# Restart the cancellation effort in the closest directly cancelled parent
# scope if this one was shielded
agronholm marked this conversation as resolved.
Show resolved Hide resolved
if self._shield:
self._deliver_cancellation_to_parent()
self._restart_cancellation_in_parent()

if self._cancel_called and exc_val is not None:
for exc in iterate_exceptions(exc_val):
Expand Down Expand Up @@ -451,64 +457,56 @@ def _timeout(self) -> None:
else:
self._timeout_handle = loop.call_at(self._deadline, self._timeout)

def _deliver_cancellation(self) -> None:
def _deliver_cancellation(self, origin: CancelScope) -> bool:
"""
Deliver cancellation to directly contained tasks and nested cancel scopes.

Schedule another run at the end if we still have tasks eligible for
cancellation.

:param origin: the cancel scope that originated the cancellation
:return: ``True`` if the delivery needs to be retried on the next cycle

"""
should_retry = False
current = current_task()
for task in self._tasks:
if task._must_cancel: # type: ignore[attr-defined]
continue

# The task is eligible for cancellation if it has started and is not in a
# cancel scope shielded from this one
cancel_scope = _task_states[task].cancel_scope
while cancel_scope is not self:
if cancel_scope is None or cancel_scope._shield:
break
else:
cancel_scope = cancel_scope._parent_scope
else:
should_retry = True
if task is not current and (
task is self._host_task or _task_started(task)
):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
if sys.version_info >= (3, 9):
task.cancel(f"Cancelled by cancel scope {id(self):x}")
else:
task.cancel()

# Schedule another callback if there are still tasks left
if should_retry:
self._cancel_handle = get_running_loop().call_soon(
self._deliver_cancellation
)
else:
self._cancel_handle = None
# The task is eligible for cancellation if it has started
should_retry = True
if task is not current and (task is self._host_task or _task_started(task)):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
if sys.version_info >= (3, 9):
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
else:
task.cancel()

def _deliver_cancellation_to_parent(self) -> None:
"""Start cancellation effort in the farthest directly cancelled parent scope"""
scope = self._parent_scope
scope_to_cancel: CancelScope | None = None
while scope is not None:
if scope._cancel_called and scope._cancel_handle is None:
scope_to_cancel = scope
# Deliver cancellation to child scopes that aren't shielded or running their own
# cancellation callbacks
for scope in self._child_scopes:
if not scope._shield and not scope.cancel_called:
should_retry = scope._deliver_cancellation(origin) or should_retry

# No point in looking beyond any shielded scope
if scope._shield:
break
# Schedule another callback if there are still tasks left
if origin is self:
if should_retry:
self._cancel_handle = get_running_loop().call_soon(
self._deliver_cancellation, origin
)
else:
self._cancel_handle = None

scope = scope._parent_scope
return should_retry

if scope_to_cancel is not None:
scope_to_cancel._deliver_cancellation()
def _restart_cancellation_in_parent(self) -> None:
"""Start cancellation effort in the closest directly cancelled parent scope"""
agronholm marked this conversation as resolved.
Show resolved Hide resolved
agronholm marked this conversation as resolved.
Show resolved Hide resolved
scope = self._parent_scope
if scope is not None and scope._cancel_called and scope._cancel_handle is None:
scope._deliver_cancellation(scope)

def _parent_cancelled(self) -> bool:
# Check whether any parent has been cancelled
Expand All @@ -529,7 +527,7 @@ def cancel(self) -> None:

self._cancel_called = True
if self._host_task is not None:
self._deliver_cancellation()
self._deliver_cancellation(self)

@property
def deadline(self) -> float:
Expand Down Expand Up @@ -562,7 +560,7 @@ def shield(self, value: bool) -> None:
if self._shield != value:
self._shield = value
if not value:
self._deliver_cancellation_to_parent()
self._restart_cancellation_in_parent()


#
Expand Down Expand Up @@ -623,6 +621,7 @@ def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
self._active = False
self._exceptions: list[BaseException] = []
self._tasks: set[asyncio.Task] = set()

async def __aenter__(self) -> TaskGroup:
self.cancel_scope.__enter__()
Expand All @@ -642,9 +641,9 @@ async def __aexit__(
self._exceptions.append(exc_val)

cancelled_exc_while_waiting_tasks: CancelledError | None = None
while self.cancel_scope._tasks:
while self._tasks:
try:
await asyncio.wait(self.cancel_scope._tasks)
await asyncio.wait(self._tasks)
except CancelledError as exc:
# This task was cancelled natively; reraise the CancelledError later
# unless this task was already interrupted by another exception
Expand Down Expand Up @@ -676,8 +675,11 @@ def _spawn(
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
assert _task in self.cancel_scope._tasks
self.cancel_scope._tasks.remove(_task)
task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
self._tasks.remove(task)
del _task_states[_task]

try:
Expand All @@ -693,7 +695,8 @@ def task_done(_task: asyncio.Task) -> None:
if not isinstance(exc, CancelledError):
self._exceptions.append(exc)

self.cancel_scope.cancel()
if not self.cancel_scope._parent_cancelled():
self.cancel_scope.cancel()
else:
task_status_future.set_exception(exc)
elif task_status_future is not None and not task_status_future.done():
Expand Down Expand Up @@ -732,6 +735,7 @@ def task_done(_task: asyncio.Task) -> None:
parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)
return task

def start_soon(
Expand Down
23 changes: 23 additions & 0 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,29 @@ def handler(excgrp: BaseExceptionGroup) -> None:
await anyio.sleep_forever()


async def test_cancel_child_task_when_host_is_shielded() -> None:
# Regression test for #642
# Tests that cancellation propagates to a child task even if the host task is within
# a shielded cancel scope.
cancelled = anyio.Event()

async def wait_cancel() -> None:
try:
await anyio.sleep_forever()
except anyio.get_cancelled_exc_class():
cancelled.set()
raise

with CancelScope() as parent_scope:
async with anyio.create_task_group() as task_group:
task_group.start_soon(wait_cancel)
await wait_all_tasks_blocked()

with CancelScope(shield=True), fail_after(1):
parent_scope.cancel()
await cancelled.wait()


class TestTaskStatusTyping:
"""
These tests do not do anything at run time, but since the test suite is also checked
Expand Down