diff --git a/CHANGES/9326.bugfix.rst b/CHANGES/9326.bugfix.rst new file mode 100644 index 00000000000..4689941708f --- /dev/null +++ b/CHANGES/9326.bugfix.rst @@ -0,0 +1 @@ +Fixed cancellation leaking upwards on timeout -- by :user:`bdraco`. diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 40705b16d71..ee2a91cec46 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -686,6 +686,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._tasks: List[asyncio.Task[Any]] = [] self._cancelled = False + self._cancelling = 0 def assert_timeout(self) -> None: """Raise TimeoutError if timer has already been cancelled.""" @@ -694,12 +695,17 @@ def assert_timeout(self) -> None: def __enter__(self) -> BaseTimerContext: task = asyncio.current_task(loop=self._loop) - if task is None: raise RuntimeError( "Timeout context manager should be used " "inside a task" ) + if sys.version_info >= (3, 11): + # Remember if the task was already cancelling + # so when we __exit__ we can decide if we should + # raise asyncio.TimeoutError or let the cancellation propagate + self._cancelling = task.cancelling() + if self._cancelled: raise asyncio.TimeoutError from None @@ -712,11 +718,22 @@ def __exit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> Optional[bool]: + enter_task: Optional[asyncio.Task[Any]] = None if self._tasks: - self._tasks.pop() + enter_task = self._tasks.pop() if exc_type is asyncio.CancelledError and self._cancelled: - raise asyncio.TimeoutError from None + assert enter_task is not None + # The timeout was hit, and the task was cancelled + # so we need to uncancel the last task that entered the context manager + # since the cancellation should not leak out of the context manager + if sys.version_info >= (3, 11): + # If the task was already cancelling don't raise + # asyncio.TimeoutError and instead return None + # to allow the cancellation to propagate + if enter_task.uncancel() > self._cancelling: + return None + raise asyncio.TimeoutError from exc_val return None def timeout(self) -> None: diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 2d6e098aae5..f79f9bebe09 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -397,7 +397,61 @@ def test_timer_context_not_cancelled() -> None: assert not m_asyncio.current_task.return_value.cancel.called -def test_timer_context_no_task(loop) -> None: +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()" +) +async def test_timer_context_timeout_does_not_leak_upward() -> None: + """Verify that the TimerContext does not leak cancellation outside the context manager.""" + loop = asyncio.get_running_loop() + ctx = helpers.TimerContext(loop) + current_task = asyncio.current_task() + assert current_task is not None + with pytest.raises(asyncio.TimeoutError): + with ctx: + assert current_task.cancelling() == 0 + loop.call_soon(ctx.timeout) + await asyncio.sleep(1) + + # After the context manager exits, the task should no longer be cancelling + assert current_task.cancelling() == 0 + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()" +) +async def test_timer_context_timeout_does_swallow_cancellation() -> None: + """Verify that the TimerContext does not swallow cancellation.""" + loop = asyncio.get_running_loop() + current_task = asyncio.current_task() + assert current_task is not None + ctx = helpers.TimerContext(loop) + + async def task_with_timeout() -> None: + nonlocal ctx + new_task = asyncio.current_task() + assert new_task is not None + with pytest.raises(asyncio.TimeoutError): + with ctx: + assert new_task.cancelling() == 0 + await asyncio.sleep(1) + + task = asyncio.create_task(task_with_timeout()) + await asyncio.sleep(0) + task.cancel() + assert task.cancelling() == 1 + ctx.timeout() + + # Cancellation should not leak into the current task + assert current_task.cancelling() == 0 + # Cancellation should not be swallowed if the task is cancelled + # and it also times out + await asyncio.sleep(0) + with pytest.raises(asyncio.CancelledError): + await task + assert task.cancelling() == 1 + + +def test_timer_context_no_task(loop: asyncio.AbstractEventLoop) -> None: with pytest.raises(RuntimeError): with helpers.TimerContext(loop): pass