Skip to content

Commit

Permalink
Fix test_add_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 20, 2022
1 parent 1d2fca9 commit 900e8b2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ async def test_clear_events_client_removal(c, s, a, b):
@gen_cluster(client=True, nthreads=[])
async def test_add_worker(c, s):
x = c.submit(inc, 1, key="x")
await wait_for_state("x", "no-worker", s)
await wait_for_state("x", ("queued", "no-worker"), s)
s.validate_state()

async with Worker(s.address) as w:
Expand Down
5 changes: 4 additions & 1 deletion distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,16 +987,19 @@ async def test_wait_for_state(c, s, a, capsys):

await asyncio.gather(
wait_for_state("x", "memory", s),
wait_for_state("x", "memory", a),
wait_for_state("x", {"memory", "other"}, a),
c.run(wait_for_state, "x", "memory"),
)

with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(wait_for_state("x", "bad_state", s), timeout=0.1)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(wait_for_state("x", ("this", "that"), s), timeout=0.1)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(wait_for_state("y", "memory", s), timeout=0.1)
assert capsys.readouterr().out == (
f"tasks[x].state='memory' on {s.address}; expected state='bad_state'\n"
f"tasks[x].state='memory' on {s.address}; expected state=('this', 'that')\n"
f"tasks[y] not found on {s.address}\n"
)

Expand Down
18 changes: 13 additions & 5 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings
import weakref
from collections import defaultdict
from collections.abc import Callable, Mapping
from collections.abc import Callable, Collection, Mapping
from contextlib import contextmanager, nullcontext, suppress
from itertools import count
from time import sleep
Expand Down Expand Up @@ -2353,10 +2353,14 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]:


async def wait_for_state(
key: str, state: str, dask_worker: Worker | Scheduler, *, interval: float = 0.01
key: str,
state: str | Collection[str],
dask_worker: Worker | Scheduler,
*,
interval: float = 0.01,
) -> None:
"""Wait for a task to appear on a Worker or on the Scheduler and to be in a specific
state.
state or one of a set of possible states.
"""
tasks: Mapping[str, SchedulerTaskState | WorkerTaskState]

Expand All @@ -2367,14 +2371,18 @@ async def wait_for_state(
else:
raise TypeError(dask_worker) # pragma: nocover

if isinstance(state, str):
state = (state,)
state_str = repr(next(iter(state))) if len(state) == 1 else str(state)

try:
while key not in tasks or tasks[key].state != state:
while key not in tasks or tasks[key].state not in state:
await asyncio.sleep(interval)
except (asyncio.CancelledError, asyncio.TimeoutError):
if key in tasks:
msg = (
f"tasks[{key}].state={tasks[key].state!r} on {dask_worker.address}; "
f"expected {state=}"
f"expected state={state_str}"
)
else:
msg = f"tasks[{key}] not found on {dask_worker.address}"
Expand Down

0 comments on commit 900e8b2

Please sign in to comment.