Skip to content

Commit

Permalink
Fixed cancellation propagation when task group host is in a shielded …
Browse files Browse the repository at this point in the history
…scope

Fixes #642.
  • Loading branch information
agronholm committed Dec 10, 2023
1 parent 84c1bb0 commit 238a340
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 45 deletions.
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
from Egor Blagov)
- Fixed ``loop_factory`` and ``use_uvloop`` options not being used on the asyncio
backend (`#643 <https://github.com/agronholm/anyio/issues/643>`_)
- Fixed cancellation propagating on asyncio from a task group to child tasks if the task
hosting the task group is in a shielded cancel scope
(`#642 <https://github.com/agronholm/anyio/issues/642>`_)

**4.1.0**

Expand Down
104 changes: 59 additions & 45 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope: CancelScope | None = None
self._child_scopes: set[CancelScope] = set()
self._cancel_called = False
self._cancelled_caught = False
self._active = False
Expand All @@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope:
else:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
if self._parent_scope is not None:
self._parent_scope._child_scopes.add(self)
self._parent_scope._tasks.remove(host_task)

self._timeout()
self._active = True
Expand All @@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope:

# Start cancelling the host task if the scope was cancelled before entering
if self._cancel_called:
self._deliver_cancellation()
self._deliver_cancellation(self)

return self

Expand Down Expand Up @@ -409,13 +413,15 @@ def __exit__(
self._timeout_handle = None

self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Restart the cancellation effort in the farthest directly cancelled parent
# scope if this one was shielded
if self._shield:
self._deliver_cancellation_to_parent()
self._restart_cancellation_in_parent()

if self._cancel_called and exc_val is not None:
for exc in iterate_exceptions(exc_val):
Expand Down Expand Up @@ -451,65 +457,67 @@ def _timeout(self) -> None:
else:
self._timeout_handle = loop.call_at(self._deadline, self._timeout)

def _deliver_cancellation(self) -> None:
def _deliver_cancellation(self, origin: CancelScope) -> bool:
"""
Deliver cancellation to directly contained tasks and nested cancel scopes.
Schedule another run at the end if we still have tasks eligible for
cancellation.
:param origin: the cancel scope that originated the cancellation
:return: ``True`` if the delivery needs to be retried on the next cycle
"""
should_retry = False
current = current_task()
for task in self._tasks:
if task._must_cancel: # type: ignore[attr-defined]
continue

# The task is eligible for cancellation if it has started and is not in a
# cancel scope shielded from this one
cancel_scope = _task_states[task].cancel_scope
while cancel_scope is not self:
if cancel_scope is None or cancel_scope._shield:
break
else:
cancel_scope = cancel_scope._parent_scope
else:
should_retry = True
if task is not current and (
task is self._host_task or _task_started(task)
):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
if sys.version_info >= (3, 9):
task.cancel(f"Cancelled by cancel scope {id(self):x}")
else:
task.cancel()
# The task is eligible for cancellation if it has started
should_retry = True
if task is not current and (task is self._host_task or _task_started(task)):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
if sys.version_info >= (3, 9):
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
else:
task.cancel()

# Deliver cancellation to child scopes that aren't shielded or running their own
# cancellation callbacks
for scope in self._child_scopes:
if not scope._shield and not scope.cancel_called:
should_retry = scope._deliver_cancellation(origin) or should_retry

# Schedule another callback if there are still tasks left
if should_retry:
self._cancel_handle = get_running_loop().call_soon(
self._deliver_cancellation
)
else:
self._cancel_handle = None
if origin is self:
if should_retry:
self._cancel_handle = get_running_loop().call_soon(
self._deliver_cancellation, origin
)
else:
self._cancel_handle = None

return should_retry

def _deliver_cancellation_to_parent(self) -> None:
"""Start cancellation effort in the farthest directly cancelled parent scope"""
def _restart_cancellation_in_parent(self) -> None:
"""Start cancellation effort in the closest directly cancelled parent scope"""
scope = self._parent_scope
scope_to_cancel: CancelScope | None = None
while scope is not None:
if scope._cancel_called and scope._cancel_handle is None:
scope_to_cancel = scope
if scope._cancel_called:
if scope._cancel_handle is None:
scope._deliver_cancellation(scope)

break

# No point in looking beyond any shielded scope
if scope._shield:
break

scope = scope._parent_scope

if scope_to_cancel is not None:
scope_to_cancel._deliver_cancellation()

def _parent_cancelled(self) -> bool:
# Check whether any parent has been cancelled
cancel_scope = self._parent_scope
Expand All @@ -529,7 +537,7 @@ def cancel(self) -> None:

self._cancel_called = True
if self._host_task is not None:
self._deliver_cancellation()
self._deliver_cancellation(self)

@property
def deadline(self) -> float:
Expand Down Expand Up @@ -562,7 +570,7 @@ def shield(self, value: bool) -> None:
if self._shield != value:
self._shield = value
if not value:
self._deliver_cancellation_to_parent()
self._restart_cancellation_in_parent()


#
Expand Down Expand Up @@ -623,6 +631,7 @@ def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
self._active = False
self._exceptions: list[BaseException] = []
self._tasks: set[asyncio.Task] = set()

async def __aenter__(self) -> TaskGroup:
self.cancel_scope.__enter__()
Expand All @@ -642,9 +651,9 @@ async def __aexit__(
self._exceptions.append(exc_val)

cancelled_exc_while_waiting_tasks: CancelledError | None = None
while self.cancel_scope._tasks:
while self._tasks:
try:
await asyncio.wait(self.cancel_scope._tasks)
await asyncio.wait(self._tasks)
except CancelledError as exc:
# This task was cancelled natively; reraise the CancelledError later
# unless this task was already interrupted by another exception
Expand Down Expand Up @@ -676,8 +685,11 @@ def _spawn(
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
assert _task in self.cancel_scope._tasks
self.cancel_scope._tasks.remove(_task)
task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
self._tasks.remove(task)
del _task_states[_task]

try:
Expand All @@ -693,7 +705,8 @@ def task_done(_task: asyncio.Task) -> None:
if not isinstance(exc, CancelledError):
self._exceptions.append(exc)

self.cancel_scope.cancel()
if not self.cancel_scope._parent_cancelled():
self.cancel_scope.cancel()
else:
task_status_future.set_exception(exc)
elif task_status_future is not None and not task_status_future.done():
Expand Down Expand Up @@ -732,6 +745,7 @@ def task_done(_task: asyncio.Task) -> None:
parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)
return task

def start_soon(
Expand Down
23 changes: 23 additions & 0 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,29 @@ def handler(excgrp: BaseExceptionGroup) -> None:
await anyio.sleep_forever()


async def test_cancel_child_task_when_host_is_shielded() -> None:
# Regression test for #642
# Tests that cancellation propagates to a child task even if the host task is within
# a shielded cancel scope.
cancelled = anyio.Event()

async def wait_cancel() -> None:
try:
await anyio.sleep_forever()
except anyio.get_cancelled_exc_class():
cancelled.set()
raise

with CancelScope() as parent_scope:
async with anyio.create_task_group() as task_group:
task_group.start_soon(wait_cancel)
await wait_all_tasks_blocked()

with CancelScope(shield=True), fail_after(1):
parent_scope.cancel()
await cancelled.wait()


class TestTaskStatusTyping:
"""
These tests do not do anything at run time, but since the test suite is also checked
Expand Down

0 comments on commit 238a340

Please sign in to comment.