Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cyclic garbage that keeps traceback frames alive in taskgroup exceptions #806

Merged
merged 13 commits into from
Oct 13, 2024
202 changes: 115 additions & 87 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,58 +425,62 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
if not self._active:
raise RuntimeError("This cancel scope is not active")
if current_task() is not self._host_task:
raise RuntimeError(
"Attempted to exit cancel scope in a different task than it was "
"entered in"
)

assert self._host_task is not None
host_task_state = _task_states.get(self._host_task)
if host_task_state is None or host_task_state.cancel_scope is not self:
raise RuntimeError(
"Attempted to exit a cancel scope that isn't the current tasks's "
"current cancel scope"
)

self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None

self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)
del exc_tb
try:
if not self._active:
raise RuntimeError("This cancel scope is not active")
if current_task() is not self._host_task:
raise RuntimeError(
"Attempted to exit cancel scope in a different task than it was "
"entered in"
)

host_task_state.cancel_scope = self._parent_scope
assert self._host_task is not None
host_task_state = _task_states.get(self._host_task)
if host_task_state is None or host_task_state.cancel_scope is not self:
raise RuntimeError(
"Attempted to exit a cancel scope that isn't the current tasks's "
"current cancel scope"
)

# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
break
self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None

# We only swallow the exception iff it was an AnyIO CancelledError, either
# directly as exc_val or inside an exception group and there are no cancelled
# parent cancel scopes visible to us here
not_swallowed_exceptions = 0
swallow_exception = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if self._cancel_called and isinstance(exc, CancelledError):
if not (swallow_exception := self._uncancel(exc)):
self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
break

# We only swallow the exception iff it was an AnyIO CancelledError, either
# directly as exc_val or inside an exception group and there are no cancelled
# parent cancel scopes visible to us here
not_swallowed_exceptions = 0
swallow_exception = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if self._cancel_called and isinstance(exc, CancelledError):
if not (swallow_exception := self._uncancel(exc)):
not_swallowed_exceptions += 1
else:
not_swallowed_exceptions += 1
else:
not_swallowed_exceptions += 1

# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()
return swallow_exception and not not_swallowed_exceptions
# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()
return swallow_exception and not not_swallowed_exceptions
finally:
del self._host_task, exc_val

@property
def _effectively_cancelled(self) -> bool:
Expand Down Expand Up @@ -683,6 +687,27 @@ def started(self, value: T_contra | None = None) -> None:
_task_states[task].parent_id = self._parent_id


async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None:
tasks = set(tasks)
loop = get_running_loop()
waiter = loop.create_future()

def on_completion(task: asyncio.Task[object]) -> None:
tasks.discard(task)
if not tasks and not waiter.done():
waiter.set_result(None)

for task in tasks:
task.add_done_callback(on_completion)
del task

try:
await waiter
finally:
while tasks:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tasks.pop().remove_done_callback(on_completion)


class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
Expand All @@ -701,50 +726,53 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)

try:
if self._tasks:
with CancelScope() as wait_scope:
while self._tasks:
try:
await asyncio.wait(self._tasks)
except CancelledError as exc:
# Shield the scope against further cancellation attempts,
# as they're not productive (#695)
wait_scope.shield = True
self.cancel_scope.cancel()

# Set exc_val from the cancellation exception if it was
# previously unset. However, we should not replace a native
# cancellation exception with one raise by a cancel scope.
if exc_val is None or (
isinstance(exc_val, CancelledError)
and not is_anyio_cancellation(exc)
):
exc_val = exc
else:
# If there are no child tasks to wait on, run at least one checkpoint
# anyway
await AsyncIOBackend.cancel_shielded_checkpoint()
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)

self._active = False
if self._exceptions:
raise BaseExceptionGroup(
"unhandled errors in a TaskGroup", self._exceptions
)
elif exc_val:
raise exc_val
except BaseException as exc:
if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
return True
try:
if self._tasks:
with CancelScope() as wait_scope:
while self._tasks:
try:
await _wait(self._tasks)
except CancelledError as exc:
# Shield the scope against further cancellation attempts,
# as they're not productive (#695)
wait_scope.shield = True
self.cancel_scope.cancel()

# Set exc_val from the cancellation exception if it was
# previously unset. However, we should not replace a native
# cancellation exception with one raise by a cancel scope.
if exc_val is None or (
isinstance(exc_val, CancelledError)
and not is_anyio_cancellation(exc)
):
exc_val = exc
else:
# If there are no child tasks to wait on, run at least one checkpoint
# anyway
await AsyncIOBackend.cancel_shielded_checkpoint()

raise
self._active = False
if self._exceptions:
raise BaseExceptionGroup(
"unhandled errors in a TaskGroup", self._exceptions
)
elif exc_val:
raise exc_val
except BaseException as exc:
if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
return True

return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
raise

return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
finally:
del exc_val, exc_tb, self._exceptions

def _spawn(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,17 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
rest = None
try:
return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb)
except BaseExceptionGroup as exc:
_, rest = exc.split(trio.Cancelled)
rest = exc.split(trio.Cancelled)[1]
if not rest:
graingert marked this conversation as resolved.
Show resolved Hide resolved
cancelled_exc = trio.Cancelled._create()
raise cancelled_exc from exc
raise trio.Cancelled._create() from exc

raise
finally:
del exc_val, exc_tb, rest
self._active = False

def start_soon(
Expand Down
108 changes: 107 additions & 1 deletion tests/test_taskgroups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import math
import sys
import time
Expand All @@ -9,7 +10,7 @@
from typing import Any, NoReturn, cast

import pytest
from exceptiongroup import catch
from exceptiongroup import ExceptionGroup, catch
from pytest_mock import MockerFixture

import anyio
Expand Down Expand Up @@ -1548,6 +1549,111 @@ async def in_task_group(task_status: TaskStatus[None]) -> None:
assert not tg.cancel_scope.cancel_called


if sys.version_info <= (3, 11):

def no_other_refs() -> list[object]:
return [sys._getframe(1)]
else:

def no_other_refs() -> list[object]:
return []


async def test_exception_refcycles_direct() -> None:
"""Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

try:
async with tg:
raise _Done
except ExceptionGroup as e:
exc = e

assert exc is not None
assert gc.get_referrers(exc) == no_other_refs()


async def test_exception_refcycles_errors() -> None:
"""Test that TaskGroup deletes self._errors, and __aexit__ args"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

try:
async with tg:
raise _Done
except ExceptionGroup as excs:
exc = excs.exceptions[0]

assert isinstance(exc, _Done)
assert gc.get_referrers(exc) == no_other_refs()


async def test_exception_refcycles_parent_task() -> None:
"""Test that TaskGroup deletes self._parent_task"""
graingert marked this conversation as resolved.
Show resolved Hide resolved
tg = create_task_group()
exc = None

class _Done(Exception):
pass

async def coro_fn() -> None:
async with tg:
raise _Done

try:
async with anyio.create_task_group() as tg2:
tg2.start_soon(coro_fn)
except ExceptionGroup as excs:
exc = excs.exceptions[0].exceptions[0]

assert isinstance(exc, _Done)
assert gc.get_referrers(exc) == no_other_refs()


async def test_exception_refcycles_propagate_cancellation_error() -> None:
"""Test that TaskGroup deletes propagate_cancellation_error"""
tg = anyio.create_task_group()
exc = None

with CancelScope() as cs:
cs.cancel()
try:
async with tg:
await checkpoint()
except get_cancelled_exc_class() as e:
exc = e
raise

assert isinstance(exc, get_cancelled_exc_class())
assert gc.get_referrers(exc) == no_other_refs()


async def test_exception_refcycles_base_error() -> None:
"""Test that TaskGroup deletes self._base_error"""

class MyKeyboardInterrupt(KeyboardInterrupt):
pass

tg = create_task_group()
exc = None

try:
async with tg:
raise MyKeyboardInterrupt
except BaseExceptionGroup as excs:
exc = excs.exceptions[0]

assert isinstance(exc, MyKeyboardInterrupt)
assert gc.get_referrers(exc) == no_other_refs()


class TestTaskStatusTyping:
"""
These tests do not do anything at run time, but since the test suite is also checked
Expand Down
Loading