From 900e8b2070d9856b5a6c36434e8037a437001979 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 20 Oct 2022 22:48:27 +0100 Subject: [PATCH] Fix test_add_worker --- distributed/tests/test_scheduler.py | 2 +- distributed/tests/test_utils_test.py | 5 ++++- distributed/utils_test.py | 18 +++++++++++++----- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7d5619c380..bd7a0c7c82 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -610,7 +610,7 @@ async def test_clear_events_client_removal(c, s, a, b): @gen_cluster(client=True, nthreads=[]) async def test_add_worker(c, s): x = c.submit(inc, 1, key="x") - await wait_for_state("x", "no-worker", s) + await wait_for_state("x", ("queued", "no-worker"), s) s.validate_state() async with Worker(s.address) as w: diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 270275a9aa..78f7ae7107 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -987,16 +987,19 @@ async def test_wait_for_state(c, s, a, capsys): await asyncio.gather( wait_for_state("x", "memory", s), - wait_for_state("x", "memory", a), + wait_for_state("x", {"memory", "other"}, a), c.run(wait_for_state, "x", "memory"), ) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(wait_for_state("x", "bad_state", s), timeout=0.1) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(wait_for_state("x", ("this", "that"), s), timeout=0.1) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(wait_for_state("y", "memory", s), timeout=0.1) assert capsys.readouterr().out == ( f"tasks[x].state='memory' on {s.address}; expected state='bad_state'\n" + f"tasks[x].state='memory' on {s.address}; expected state=('this', 'that')\n" f"tasks[y] not found on {s.address}\n" ) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index c5415cb2ef..3bc471de19 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -24,7 +24,7 @@ import warnings import weakref from collections import defaultdict -from collections.abc import Callable, Mapping +from collections.abc import Callable, Collection, Mapping from contextlib import contextmanager, nullcontext, suppress from itertools import count from time import sleep @@ -2353,10 +2353,14 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]: async def wait_for_state( - key: str, state: str, dask_worker: Worker | Scheduler, *, interval: float = 0.01 + key: str, + state: str | Collection[str], + dask_worker: Worker | Scheduler, + *, + interval: float = 0.01, ) -> None: """Wait for a task to appear on a Worker or on the Scheduler and to be in a specific - state. + state or one of a set of possible states. """ tasks: Mapping[str, SchedulerTaskState | WorkerTaskState] @@ -2367,14 +2371,18 @@ async def wait_for_state( else: raise TypeError(dask_worker) # pragma: nocover + if isinstance(state, str): + state = (state,) + state_str = repr(next(iter(state))) if len(state) == 1 else str(state) + try: - while key not in tasks or tasks[key].state != state: + while key not in tasks or tasks[key].state not in state: await asyncio.sleep(interval) except (asyncio.CancelledError, asyncio.TimeoutError): if key in tasks: msg = ( f"tasks[{key}].state={tasks[key].state!r} on {dask_worker.address}; " - f"expected {state=}" + f"expected state={state_str}" ) else: msg = f"tasks[{key}] not found on {dask_worker.address}"