Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TimerContext not uncancelling the current task #9326

Merged
merged 10 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/9326.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed cancellation leaking upwards on timeout -- by :user:`bdraco`.
23 changes: 20 additions & 3 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: List[asyncio.Task[Any]] = []
self._cancelled = False
self._cancelling = 0

def assert_timeout(self) -> None:
"""Raise TimeoutError if timer has already been cancelled."""
Expand All @@ -707,10 +708,15 @@ def assert_timeout(self) -> None:

def __enter__(self) -> BaseTimerContext:
task = asyncio.current_task(loop=self._loop)

if task is None:
raise RuntimeError("Timeout context manager should be used inside a task")

if sys.version_info >= (3, 11):
# Remember if the task was already cancelling
# so when we __exit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = task.cancelling()

if self._cancelled:
raise asyncio.TimeoutError from None

Expand All @@ -723,11 +729,22 @@ def __exit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
enter_task: Optional[asyncio.Task[Any]] = None
if self._tasks:
self._tasks.pop() # type: ignore[unused-awaitable]
enter_task = self._tasks.pop()

if exc_type is asyncio.CancelledError and self._cancelled:
raise asyncio.TimeoutError from None
assert enter_task is not None
# The timeout was hit, and the task was cancelled
# so we need to uncancel the last task that entered the context manager
# since the cancellation should not leak out of the context manager
if sys.version_info >= (3, 11):
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
if enter_task.uncancel() > self._cancelling:
return None
raise asyncio.TimeoutError from exc_val
return None

def timeout(self) -> None:
Expand Down
54 changes: 54 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,60 @@ def test_timer_context_not_cancelled() -> None:
assert not m_asyncio.current_task.return_value.cancel.called


@pytest.mark.skipif(
sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()"
)
async def test_timer_context_timeout_does_not_leak_upward() -> None:
"""Verify that the TimerContext does not leak cancellation outside the context manager."""
loop = asyncio.get_running_loop()
ctx = helpers.TimerContext(loop)
current_task = asyncio.current_task()
assert current_task is not None
with pytest.raises(asyncio.TimeoutError):
with ctx:
assert current_task.cancelling() == 0
loop.call_soon(ctx.timeout)
await asyncio.sleep(1)

# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0


@pytest.mark.skipif(
sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()"
)
async def test_timer_context_timeout_does_swallow_cancellation() -> None:
"""Verify that the TimerContext does not swallow cancellation."""
loop = asyncio.get_running_loop()
current_task = asyncio.current_task()
assert current_task is not None
ctx = helpers.TimerContext(loop)

async def task_with_timeout() -> None:
nonlocal ctx
new_task = asyncio.current_task()
assert new_task is not None
with pytest.raises(asyncio.TimeoutError):
with ctx:
assert new_task.cancelling() == 0
await asyncio.sleep(1)

task = asyncio.create_task(task_with_timeout())
await asyncio.sleep(0)
task.cancel()
assert task.cancelling() == 1
ctx.timeout()

# Cancellation should not leak into the current task
assert current_task.cancelling() == 0
# Cancellation should not be swallowed if the task is cancelled
# and it also times out
await asyncio.sleep(0)
with pytest.raises(asyncio.CancelledError):
await task
assert task.cancelling() == 1


def test_timer_context_no_task(loop: asyncio.AbstractEventLoop) -> None:
with pytest.raises(RuntimeError):
with helpers.TimerContext(loop):
Expand Down
Loading