diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index c3a7441a7b091d..7aafcea4d885eb 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -69,7 +69,11 @@ async def staggered_race(coro_fns, delay, *, loop=None): exceptions = [] running_tasks = [] - async def run_one_coro(previous_failed) -> None: + async def run_one_coro(ok_to_start, previous_failed) -> None: + # in eager tasks this waits for the calling task to append this task + # to running_tasks, in regular tasks this wait is a no-op that does + # not yield a future. See gh-124309. + await ok_to_start.wait() # Wait for the previous task to finish, or for delay seconds if previous_failed is not None: with contextlib.suppress(exceptions_mod.TimeoutError): @@ -85,8 +89,12 @@ async def run_one_coro(previous_failed) -> None: return # Start task that will run the next coroutine this_failed = locks.Event() - next_task = loop.create_task(run_one_coro(this_failed)) + next_ok_to_start = locks.Event() + next_task = loop.create_task(run_one_coro(next_ok_to_start, this_failed)) running_tasks.append(next_task) + # next_task has been appended to running_tasks so next_task is ok to + # start. + next_ok_to_start.set() assert len(running_tasks) == this_index + 2 # Prepare place to put this coroutine's exceptions if not won exceptions.append(None) @@ -116,8 +124,11 @@ async def run_one_coro(previous_failed) -> None: if i != this_index: t.cancel() - first_task = loop.create_task(run_one_coro(None)) + ok_to_start = locks.Event() + first_task = loop.create_task(run_one_coro(ok_to_start, None)) running_tasks.append(first_task) + # first_task has been appended to running_tasks so first_task is ok to start. + ok_to_start.set() try: # Wait for a growing list of tasks to all finish: poor man's version of # curio's TaskGroup or trio's nursery diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index 58c06287bc3c5d..b06832e02f00d6 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -218,6 +218,52 @@ async def run(): self.run_coro(run()) + def test_staggered_race_with_eager_tasks(self): + # See https://github.com/python/cpython/issues/124309 + + async def fail(): + await asyncio.sleep(0) + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: asyncio.sleep(2, result="sleep2"), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: fail() + ], + delay=0.25 + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIsInstance(excs[2], ValueError) + + self.run_coro(run()) + + def test_staggered_race_with_eager_tasks_no_delay(self): + # See https://github.com/python/cpython/issues/124309 + async def fail(): + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: fail(), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: asyncio.sleep(0, result="sleep0"), + ], + delay=None + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(len(excs), 2) + + self.run_coro(run()) + class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = tasks._PyTask diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index e6e32f7dbbbcba..74941f704c4890 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -95,3 +95,30 @@ async def coro(index): self.assertEqual(len(excs), 2) self.assertIsInstance(excs[0], ValueError) self.assertIsInstance(excs[1], ValueError) + + + async def test_multiple_winners(self): + event = asyncio.Event() + + async def coro(index): + await event.wait() + return index + + async def do_set(): + event.set() + await asyncio.Event().wait() + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + do_set, + ], + delay=0.1, + ) + self.assertIs(winner, 0) + self.assertIs(index, 0) + self.assertEqual(len(excs), 3) + self.assertIsNone(excs[0], None) + self.assertIsInstance(excs[1], asyncio.CancelledError) + self.assertIsInstance(excs[2], asyncio.CancelledError) diff --git a/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst b/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst new file mode 100644 index 00000000000000..89610fa44bf743 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst @@ -0,0 +1 @@ +Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.