diff --git a/distributed/dashboard/components/worker.py b/distributed/dashboard/components/worker.py index c5ba33bdc3..358e5e0750 100644 --- a/distributed/dashboard/components/worker.py +++ b/distributed/dashboard/components/worker.py @@ -90,7 +90,7 @@ def update(self): "Stored": [len(w.data)], "Executing": ["%d / %d" % (w.state.executing_count, w.state.nthreads)], "Ready": [len(w.state.ready)], - "Waiting": [w.state.waiting_for_data_count], + "Waiting": [len(w.state.waiting)], "Connections": [w.state.transfer_incoming_count], "Serving": [len(w._comms)], } diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 9670e688bd..8e66a14dfb 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -536,12 +536,10 @@ async def test_WorkerTable_custom_metric_overlap_with_core_metric(c, s, a, b): def metric(worker): return -999 - a.metrics["executing"] = metric a.metrics["cpu"] = metric a.metrics["metric"] = metric await asyncio.gather(a.heartbeat(), b.heartbeat()) - assert s.workers[a.address].metrics["executing"] != -999 assert s.workers[a.address].metrics["cpu"] != -999 assert s.workers[a.address].metrics["metric"] == -999 diff --git a/distributed/http/worker/prometheus/core.py b/distributed/http/worker/prometheus/core.py index 2702814ca7..93a0310efe 100644 --- a/distributed/http/worker/prometheus/core.py +++ b/distributed/http/worker/prometheus/core.py @@ -9,6 +9,8 @@ class WorkerMetricCollector(PrometheusCollector): + server: Worker + def __init__(self, server: Worker): super().__init__(server) self.logger = logging.getLogger("distributed.dask_worker") @@ -26,15 +28,15 @@ def __init__(self, server: Worker): def collect(self): from prometheus_client.core import CounterMetricFamily, GaugeMetricFamily + ws = self.server.state + tasks = GaugeMetricFamily( self.build_name("tasks"), "Number of tasks at worker.", labels=["state"], ) - tasks.add_metric(["stored"], len(self.server.data)) - tasks.add_metric(["executing"], self.server.state.executing_count) - tasks.add_metric(["ready"], len(self.server.state.ready)) - tasks.add_metric(["waiting"], self.server.state.waiting_for_data_count) + for k, n in ws.task_counts.items(): + tasks.add_metric([k], n) yield tasks yield GaugeMetricFamily( @@ -43,13 +45,13 @@ def collect(self): "[Deprecated: This metric has been renamed to transfer_incoming_count.] " "Number of open fetch requests to other workers." ), - value=self.server.state.transfer_incoming_count, + value=ws.transfer_incoming_count, ) yield GaugeMetricFamily( self.build_name("threads"), "Number of worker threads.", - value=self.server.state.nthreads, + value=ws.nthreads, ) yield GaugeMetricFamily( @@ -63,7 +65,7 @@ def collect(self): except AttributeError: spilled_memory, spilled_disk = 0, 0 # spilling is disabled process_memory = self.server.monitor.get_process_memory() - managed_memory = min(process_memory, self.server.state.nbytes - spilled_memory) + managed_memory = min(process_memory, ws.nbytes - spilled_memory) memory = GaugeMetricFamily( self.build_name("memory_bytes"), @@ -78,12 +80,12 @@ def collect(self): yield GaugeMetricFamily( self.build_name("transfer_incoming_bytes"), "Total size of open data transfers from other workers.", - value=self.server.state.transfer_incoming_bytes, + value=ws.transfer_incoming_bytes, ) yield GaugeMetricFamily( self.build_name("transfer_incoming_count"), "Number of open data transfers from other workers.", - value=self.server.state.transfer_incoming_count, + value=ws.transfer_incoming_count, ) yield CounterMetricFamily( @@ -92,7 +94,7 @@ def collect(self): "Total number of data transfers from other workers " "since the worker was started." ), - value=self.server.state.transfer_incoming_count_total, + value=ws.transfer_incoming_count_total, ) yield GaugeMetricFamily( diff --git a/distributed/http/worker/tests/test_worker_http.py b/distributed/http/worker/tests/test_worker_http.py index 31fff00314..cef3d2df18 100644 --- a/distributed/http/worker/tests/test_worker_http.py +++ b/distributed/http/worker/tests/test_worker_http.py @@ -71,8 +71,14 @@ async def fetch_state_metrics(): assert not a.state.tasks active_metrics = await fetch_state_metrics() assert active_metrics == { - "stored": 0.0, + "constrained": 0.0, "executing": 0.0, + "fetch": 0.0, + "flight": 0.0, + "long-running": 0.0, + "memory": 0.0, + "missing": 0.0, + "other": 0.0, "ready": 0.0, "waiting": 0.0, } @@ -86,8 +92,14 @@ async def fetch_state_metrics(): active_metrics = await fetch_state_metrics() assert active_metrics == { - "stored": 0.0, + "constrained": 0.0, "executing": 1.0, + "fetch": 0.0, + "flight": 0.0, + "long-running": 0.0, + "memory": 0.0, + "missing": 0.0, + "other": 0.0, "ready": 0.0, "waiting": 0.0, } @@ -102,8 +114,14 @@ async def fetch_state_metrics(): active_metrics = await fetch_state_metrics() assert active_metrics == { - "stored": 0.0, + "constrained": 0.0, "executing": 0.0, + "fetch": 0.0, + "flight": 0.0, + "long-running": 0.0, + "memory": 0.0, + "missing": 0.0, + "other": 0.0, "ready": 0.0, "waiting": 0.0, } diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 8a4696aafd..bd7a0c7c82 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -61,6 +61,7 @@ slowinc, tls_only_security, varying, + wait_for_state, ) from distributed.worker import dumps_function, dumps_task, get_worker @@ -606,26 +607,17 @@ async def test_clear_events_client_removal(c, s, a, b): assert time() < start + 2 -@gen_cluster() -async def test_add_worker(s, a, b): - w = Worker(s.address, nthreads=3) - w.data["x-5"] = 6 - w.data["y"] = 1 - - dsk = {("x-%d" % i): (inc, i) for i in range(10)} - s.update_graph( - tasks=valmap(dumps_task, dsk), - keys=list(dsk), - client="client", - dependencies={k: set() for k in dsk}, - ) - s.validate_state() - await w +@gen_cluster(client=True, nthreads=[]) +async def test_add_worker(c, s): + x = c.submit(inc, 1, key="x") + await wait_for_state("x", ("queued", "no-worker"), s) s.validate_state() - assert w.ip in s.host_info - assert s.host_info[w.ip]["addresses"] == {a.address, b.address, w.address} - await w.close() + async with Worker(s.address) as w: + s.validate_state() + assert w.ip in s.host_info + assert s.host_info[w.ip]["addresses"] == {w.address} + assert await x == 2 @gen_cluster(scheduler_kwargs={"blocked_handlers": ["feed"]}) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 270275a9aa..78f7ae7107 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -987,16 +987,19 @@ async def test_wait_for_state(c, s, a, capsys): await asyncio.gather( wait_for_state("x", "memory", s), - wait_for_state("x", "memory", a), + wait_for_state("x", {"memory", "other"}, a), c.run(wait_for_state, "x", "memory"), ) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(wait_for_state("x", "bad_state", s), timeout=0.1) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(wait_for_state("x", ("this", "that"), s), timeout=0.1) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(wait_for_state("y", "memory", s), timeout=0.1) assert capsys.readouterr().out == ( f"tasks[x].state='memory' on {s.address}; expected state='bad_state'\n" + f"tasks[x].state='memory' on {s.address}; expected state=('this', 'that')\n" f"tasks[y] not found on {s.address}\n" ) diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index ec13d69ec8..399cdd5e9e 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -703,10 +703,9 @@ async def test_override_data_worker(s): async with Worker(s.address, data=UserDict) as w: assert type(w.data) is UserDict - data = UserDict({"x": 1}) + data = UserDict() async with Worker(s.address, data=data) as w: assert w.data is data - assert w.data == {"x": 1} @gen_cluster( diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index b2c983ea9b..6436a6ddc4 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1014,6 +1014,8 @@ async def test_deprecated_worker_attributes(s, a, b): with pytest.warns(FutureWarning, match="attribute has been removed"): assert a.data_needed == set() + with pytest.warns(FutureWarning, match="attribute has been removed"): + assert a.waiting_for_data_count == 0 @pytest.mark.parametrize("n_remote_workers", [1, 2]) @@ -1620,3 +1622,101 @@ def test_worker_nbytes(ws_with_running_task): # memory -> released by RemoveReplicasEvent ws.handle_stimulus(RemoveReplicasEvent(keys=["x", "y", "w"], stimulus_id="s7")) assert ws.nbytes == 0 + + +def test_fetch_count(ws): + ws.transfer_incoming_count_limit = 0 + ws2 = "127.0.0.1:2" + ws3 = "127.0.0.1:3" + assert ws.fetch_count == 0 + # Saturate comms + # released->fetch->flight + ws.handle_stimulus( + AcquireReplicasEvent(who_has={"a": [ws2]}, nbytes={"a": 1}, stimulus_id="s1"), + AcquireReplicasEvent( + who_has={"b": [ws2, ws3]}, nbytes={"b": 1}, stimulus_id="s2" + ), + ) + assert ws.tasks["b"].coming_from == ws3 + assert ws.fetch_count == 0 + + # released->fetch + # d is in two data_needed heaps + ws.handle_stimulus( + AcquireReplicasEvent( + who_has={"c": [ws2], "d": [ws2, ws3]}, + nbytes={"c": 1, "d": 1}, + stimulus_id="s3", + ) + ) + assert ws.fetch_count == 2 + + # fetch->released + ws.handle_stimulus(FreeKeysEvent(keys={"c", "d"}, stimulus_id="s4")) + assert ws.fetch_count == 0 + + # flight->missing + ws.handle_stimulus( + GatherDepSuccessEvent(worker=ws2, data={}, total_nbytes=0, stimulus_id="s5") + ) + assert ws.tasks["a"].state == "missing" + print(ws.tasks) + assert ws.fetch_count == 0 + assert len(ws.missing_dep_flight) == 1 + + # flight->fetch + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "clog", who_has={"clog_dep": [ws2]}, priority=(-1,), stimulus_id="s6" + ), + GatherDepSuccessEvent(worker=ws3, data={}, total_nbytes=0, stimulus_id="s7"), + ) + assert ws.tasks["b"].state == "fetch" + assert ws.fetch_count == 1 + assert len(ws.missing_dep_flight) == 1 + + +def test_task_counts(ws): + assert ws.task_counts == { + "constrained": 0, + "executing": 0, + "fetch": 0, + "flight": 0, + "long-running": 0, + "memory": 0, + "missing": 0, + "other": 0, + "ready": 0, + "waiting": 0, + } + + +def test_task_counts_with_actors(ws): + ws.handle_stimulus(ComputeTaskEvent.dummy("x", actor=True, stimulus_id="s1")) + assert ws.actors == {"x": None} + assert ws.task_counts == { + "constrained": 0, + "executing": 1, + "fetch": 0, + "flight": 0, + "long-running": 0, + "memory": 0, + "missing": 0, + "other": 0, + "ready": 0, + "waiting": 0, + } + ws.handle_stimulus(ExecuteSuccessEvent.dummy("x", value=123, stimulus_id="s2")) + assert ws.actors == {"x": 123} + assert ws.task_counts == { + "constrained": 0, + "executing": 0, + "fetch": 0, + "flight": 0, + "long-running": 0, + "memory": 1, + "missing": 0, + "other": 0, + "ready": 0, + "waiting": 0, + } diff --git a/distributed/utils_test.py b/distributed/utils_test.py index c5415cb2ef..3bc471de19 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -24,7 +24,7 @@ import warnings import weakref from collections import defaultdict -from collections.abc import Callable, Mapping +from collections.abc import Callable, Collection, Mapping from contextlib import contextmanager, nullcontext, suppress from itertools import count from time import sleep @@ -2353,10 +2353,14 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]: async def wait_for_state( - key: str, state: str, dask_worker: Worker | Scheduler, *, interval: float = 0.01 + key: str, + state: str | Collection[str], + dask_worker: Worker | Scheduler, + *, + interval: float = 0.01, ) -> None: """Wait for a task to appear on a Worker or on the Scheduler and to be in a specific - state. + state or one of a set of possible states. """ tasks: Mapping[str, SchedulerTaskState | WorkerTaskState] @@ -2367,14 +2371,18 @@ async def wait_for_state( else: raise TypeError(dask_worker) # pragma: nocover + if isinstance(state, str): + state = (state,) + state_str = repr(next(iter(state))) if len(state) == 1 else str(state) + try: - while key not in tasks or tasks[key].state != state: + while key not in tasks or tasks[key].state not in state: await asyncio.sleep(interval) except (asyncio.CancelledError, asyncio.TimeoutError): if key in tasks: msg = ( f"tasks[{key}].state={tasks[key].state!r} on {dask_worker.address}; " - f"expected {state=}" + f"expected state={state_str}" ) else: msg = f"tasks[{key}] not found on {dask_worker.address}" diff --git a/distributed/worker.py b/distributed/worker.py index 5ba27f95cf..f31d91d4c9 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -902,7 +902,6 @@ def data(self) -> MutableMapping[str, Any]: transition_counter_max = DeprecatedWorkerStateAttribute() validate = DeprecatedWorkerStateAttribute() validate_task = DeprecatedWorkerStateAttribute() - waiting_for_data_count = DeprecatedWorkerStateAttribute() @property def data_needed(self) -> set[TaskState]: @@ -913,11 +912,20 @@ def data_needed(self) -> set[TaskState]: ) return {ts for tss in self.state.data_needed.values() for ts in tss} + @property + def waiting_for_data_count(self) -> int: + warnings.warn( + "The `Worker.waiting_for_data_count` attribute has been removed; " + "use `len(Worker.state.waiting)`", + FutureWarning, + ) + return len(self.state.waiting) + ################## # Administrative # ################## - def __repr__(self): + def __repr__(self) -> str: name = f", name: {self.name}" if self.name != self.address_safe else "" return ( f"<{self.__class__.__name__} {self.address_safe!r}{name}, " @@ -926,7 +934,7 @@ def __repr__(self): f"running: {self.state.executing_count}/{self.state.nthreads}, " f"ready: {len(self.state.ready)}, " f"comm: {self.state.in_flight_tasks_count}, " - f"waiting: {self.state.waiting_for_data_count}>" + f"waiting: {len(self.state.waiting)}>" ) @property @@ -989,10 +997,7 @@ async def get_metrics(self) -> dict: spilled_memory, spilled_disk = 0, 0 out = dict( - executing=self.state.executing_count, - in_memory=len(self.data), - ready=len(self.state.ready), - in_flight=self.state.in_flight_tasks_count, + task_counts=self.state.task_counts, bandwidth={ "total": self.bandwidth, "workers": dict(self.bandwidth_workers), diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 1df4e3bbcd..8eb0f6480e 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1113,8 +1113,8 @@ class WorkerState: #: with the Worker: ``WorkerState.running == (Worker.status is Status.running)``. running: bool - #: A count of how many tasks are currently waiting for data - waiting_for_data_count: int + #: Tasks that are currently waiting for data + waiting: set[TaskState] #: ``{worker address: {ts.key, ...}``. #: The data that we care about that we think a worker has @@ -1126,6 +1126,10 @@ class WorkerState: #: multiple entries in :attr:`~TaskState.who_has` will appear multiple times here. data_needed: defaultdict[str, HeapSet[TaskState]] + #: Total number of tasks in fetch state. If a task is in more than one data_needed + #: heap, it's only counted once. + fetch_count: int + #: Number of bytes to gather from the same worker in a single call to #: :meth:`BaseWorker.gather_dep`. Multiple small tasks that can be gathered from the #: same worker will be batched in a single instruction as long as their combined @@ -1279,11 +1283,12 @@ def __init__( self.validate = validate self.tasks = {} self.running = True - self.waiting_for_data_count = 0 + self.waiting = set() self.has_what = defaultdict(set) self.data_needed = defaultdict( partial(HeapSet[TaskState], key=operator.attrgetter("priority")) ) + self.fetch_count = 0 self.in_flight_workers = {} self.busy_workers = set() self.transfer_incoming_count_limit = transfer_incoming_count_limit @@ -1479,6 +1484,7 @@ def _purge_state(self, ts: TaskState) -> None: self.executing.discard(ts) self.long_running.discard(ts) self.in_flight_tasks.discard(ts) + self.waiting.discard(ts) def _should_throttle_incoming_transfers(self) -> bool: """Decides whether the WorkerState should throttle data transfers from other workers. @@ -1822,7 +1828,6 @@ def _put_key_in_memory( for dep in ts.dependents: dep.waiting_for_data.discard(ts) if not dep.waiting_for_data and dep.state == "waiting": - self.waiting_for_data_count -= 1 recommendations[dep] = "ready" self.log.append((ts.key, "put-in-memory", stimulus_id, time())) @@ -1838,6 +1843,7 @@ def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInst ts.state = "fetch" ts.done = False + self.fetch_count += 1 assert ts.priority for w in ts.who_has: self.data_needed[w].add(ts) @@ -1928,12 +1934,11 @@ def _transition_released_waiting( dep_ts.waiters.add(ts) recommendations[dep_ts] = "fetch" - if ts.waiting_for_data: - self.waiting_for_data_count += 1 - else: + if not ts.waiting_for_data: recommendations[ts] = "ready" ts.state = "waiting" + self.waiting.add(ts) return recommendations, [] def _transition_fetch_flight( @@ -1950,8 +1955,21 @@ def _transition_fetch_flight( ts.state = "flight" ts.coming_from = worker self.in_flight_tasks.add(ts) + self.fetch_count -= 1 return {}, [] + def _transition_fetch_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + self.fetch_count -= 1 + return self._transition_generic_missing(ts, stimulus_id=stimulus_id) + + def _transition_fetch_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + self.fetch_count -= 1 + return self._transition_generic_released(ts, stimulus_id=stimulus_id) + def _transition_memory_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: @@ -1977,6 +1995,7 @@ def _transition_waiting_constrained( assert ts not in self.ready assert ts not in self.constrained ts.state = "constrained" + self.waiting.remove(ts) self.constrained.add(ts) return self._ensure_computing() @@ -2012,6 +2031,7 @@ def _transition_waiting_ready( ts.state = "ready" assert ts.priority is not None + self.waiting.remove(ts) self.ready.add(ts) return self._ensure_computing() @@ -2504,8 +2524,8 @@ def _transition_released_forgotten( ("executing", "released"): _transition_executing_released, ("executing", "rescheduled"): _transition_executing_rescheduled, ("fetch", "flight"): _transition_fetch_flight, - ("fetch", "missing"): _transition_generic_missing, - ("fetch", "released"): _transition_generic_released, + ("fetch", "missing"): _transition_fetch_missing, + ("fetch", "released"): _transition_fetch_released, ("flight", "error"): _transition_generic_error, ("flight", "fetch"): _transition_flight_fetch, ("flight", "memory"): _transition_flight_memory, @@ -3244,6 +3264,33 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: info = {k: v for k, v in info.items() if k not in exclude} return recursive_to_dict(info, exclude=exclude) + @property + def task_counts(self) -> dict[TaskStateState | Literal["other"], int]: + # Actors can be in any state other than {fetch, flight, missing} + n_actors_in_memory = sum( + self.tasks[key].state == "memory" for key in self.actors + ) + + out: dict[TaskStateState | Literal["other"], int] = { + # Key measure for occupancy. + # Also includes cancelled(executing) and resumed(executing->fetch) + "executing": len(self.executing), + # Also includes cancelled(long-running) and resumed(long-running->fetch) + "long-running": len(self.long_running), + "memory": len(self.data) + n_actors_in_memory, + "ready": len(self.ready), + "constrained": len(self.constrained), + "waiting": len(self.waiting), + "fetch": self.fetch_count, + "missing": len(self.missing_dep_flight), + # Also includes cancelled(flight) and resumed(flight->waiting) + "flight": len(self.in_flight_tasks), + } + # released | error + out["other"] = other = len(self.tasks) - sum(out.values()) + assert other >= 0 + return out + ############## # Validation # ############## @@ -3275,6 +3322,10 @@ def _validate_task_executing(self, ts: TaskState) -> None: assert ts.key not in self.data assert not ts.waiting_for_data + for dep in ts.dependents: + assert dep not in self.ready + assert dep not in self.constrained + # FIXME https://github.com/dask/distributed/issues/6893 # This assertion can be false for # - cancelled or resumed tasks @@ -3309,8 +3360,16 @@ def _validate_task_ready(self, ts: TaskState) -> None: def _validate_task_waiting(self, ts: TaskState) -> None: assert ts.key not in self.data assert not ts.done - if ts.dependencies and ts.run_spec: - assert not all(dep.key in self.data for dep in ts.dependencies) + assert ts in self.waiting + assert ts.waiting_for_data + assert ts.waiting_for_data == { + dep + for dep in ts.dependencies + if dep.key not in self.data and dep.key not in self.actors + } + for dep in ts.dependents: + assert dep not in self.ready + assert dep not in self.constrained def _validate_task_flight(self, ts: TaskState) -> None: """Validate tasks: @@ -3320,6 +3379,7 @@ def _validate_task_flight(self, ts: TaskState) -> None: - ts.state == resumed, ts.previous == flight, ts.next == waiting """ assert ts.key not in self.data + assert ts.key not in self.actors assert ts in self.in_flight_tasks for dep in ts.dependents: assert dep not in self.ready @@ -3330,19 +3390,27 @@ def _validate_task_flight(self, ts: TaskState) -> None: def _validate_task_fetch(self, ts: TaskState) -> None: assert ts.key not in self.data + assert ts.key not in self.actors assert self.address not in ts.who_has assert not ts.done assert ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] assert ts in self.data_needed[w] + for dep in ts.dependents: + assert dep not in self.ready + assert dep not in self.constrained def _validate_task_missing(self, ts: TaskState) -> None: assert ts.key not in self.data + assert ts.key not in self.actors assert not ts.who_has assert not ts.done assert not any(ts.key in has_what for has_what in self.has_what.values()) assert ts in self.missing_dep_flight + for dep in ts.dependents: + assert dep not in self.ready + assert dep not in self.constrained def _validate_task_cancelled(self, ts: TaskState) -> None: assert ts.next is None @@ -3360,9 +3428,13 @@ def _validate_task_resumed(self, ts: TaskState) -> None: assert ts.previous == "flight" assert ts.next == "waiting" self._validate_task_flight(ts) + for dep in ts.dependents: + assert dep not in self.ready + assert dep not in self.constrained def _validate_task_released(self, ts: TaskState) -> None: assert ts.key not in self.data + assert ts.key not in self.actors assert not ts.next assert not ts.previous for tss in self.data_needed.values(): @@ -3434,11 +3506,6 @@ def validate_state(self) -> None: assert self.tasks[ts_wait.key] is ts_wait assert ts_wait.state in WAITING_FOR_DATA, ts_wait - # FIXME https://github.com/dask/distributed/issues/6319 - # assert self.waiting_for_data_count == sum( - # bool(ts.waiting_for_data) for ts in self.tasks.values() - # ) - for worker, keys in self.has_what.items(): assert worker != self.address for k in keys: @@ -3446,10 +3513,14 @@ def validate_state(self) -> None: assert worker in self.tasks[k].who_has # Test contents of the various sets of TaskState objects + fetch_tss = set() for worker, tss in self.data_needed.items(): for ts in tss: + fetch_tss.add(ts) assert ts.state == "fetch" assert worker in ts.who_has + assert len(fetch_tss) == self.fetch_count + for ts in self.missing_dep_flight: assert ts.state == "missing" for ts in self.ready: @@ -3468,6 +3539,8 @@ def validate_state(self) -> None: assert ts.state == "flight" or ( ts.state in ("cancelled", "resumed") and ts.previous == "flight" ), ts + for ts in self.waiting: + assert ts.state == "waiting" # Test that there aren't multiple TaskState objects with the same key in any # Set[TaskState]. See note in TaskState.__hash__. @@ -3479,6 +3552,7 @@ def validate_state(self) -> None: self.in_flight_tasks, self.executing, self.long_running, + self.waiting, ): assert self.tasks[ts.key] is ts @@ -3487,6 +3561,11 @@ def validate_state(self) -> None: ) assert self.nbytes == expect_nbytes, f"{self.nbytes=}; expected {expect_nbytes}" + for key in self.data: + assert key in self.tasks, key + for key in self.actors: + assert key in self.tasks, key + for ts in self.tasks.values(): self.validate_task(ts)