diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 7d10fbcdb15..4956974fa00 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -20,7 +20,7 @@ from .metrics import time from .utils import import_term, log_errors -if TYPE_CHECKING: # pragma: nocover +if TYPE_CHECKING: from .client import Client from .scheduler import Scheduler, TaskState, WorkerState diff --git a/distributed/batched.py b/distributed/batched.py index 3e6cbcfd30b..0b1fc1da0f5 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -128,7 +128,7 @@ def _background_send(self): self.stopped.set() self.abort() - def send(self, *msgs): + def send(self, *msgs: dict) -> None: """Schedule a message for sending to the other side This completes quickly and synchronously diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 28cd996fe23..b19010448a7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1209,6 +1209,8 @@ def _to_dict_no_nest(self, *, exclude: "Container[str]" = ()) -> dict: class TaskState: """ A simple object holding information about a task. + Not to be confused with :class:`distributed.worker_state_machine.TaskState`, which + holds similar information on the Worker side. .. attribute:: key: str diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index e8dc492aab1..9912f77aa1c 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -25,6 +25,7 @@ slowadd, slowinc, ) +from distributed.worker_state_machine import TaskState pytestmark = pytest.mark.ci1 @@ -494,19 +495,13 @@ async def test_worker_time_to_live(c, s, a, b): @gen_cluster() async def test_forget_data_not_supposed_to_have(s, a, b): - """ - If a depednecy fetch finishes on a worker after the scheduler already - released everything, the worker might be stuck with a redundant replica - which is never cleaned up. + """If a dependency fetch finishes on a worker after the scheduler already released + everything, the worker might be stuck with a redundant replica which is never + cleaned up. """ # FIXME: Replace with "blackbox test" which shows an actual example where - # this situation is provoked if this is even possible. - # If this cannot be constructed, the entire superfuous_data handler and its - # corresponding pieces on the scheduler side may be removed - from distributed.worker import TaskState - - ts = TaskState("key") - ts.state = "flight" + # this situation is provoked if this is even possible. + ts = TaskState("key", state="flight") a.tasks["key"] = ts recommendations = {ts: ("memory", 123)} a.transitions(recommendations, stimulus_id="test") diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 75cf9161119..7eee89a3c0e 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -57,7 +57,7 @@ slowinc, slowsum, ) -from distributed.worker import TaskState, UniqueTaskHeap, Worker, error_message, logger +from distributed.worker import Worker, error_message, logger pytestmark = pytest.mark.ci1 @@ -3308,67 +3308,6 @@ async def test_Worker__to_dict(c, s, a): assert d["data"] == ["x"] -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_TaskState__to_dict(c, s, a): - """tasks that are listed as dependencies of other tasks are dumped as a short repr - and always appear in full under Worker.tasks - """ - x = c.submit(inc, 1, key="x") - y = c.submit(inc, x, key="y") - z = c.submit(inc, 2, key="z") - await wait([x, y, z]) - - tasks = a._to_dict()["tasks"] - - assert isinstance(tasks["x"], dict) - assert isinstance(tasks["y"], dict) - assert isinstance(tasks["z"], dict) - assert tasks["x"]["dependents"] == [""] - assert tasks["y"]["dependencies"] == [""] - - -def test_unique_task_heap(): - heap = UniqueTaskHeap() - - for x in range(10): - ts = TaskState(f"f{x}") - ts.priority = (0, 0, 1, x % 3) - heap.push(ts) - - heap_list = list(heap) - # iteration does not empty heap - assert len(heap) == 10 - assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) - - seen = set() - last_prio = (0, 0, 0, 0) - while heap: - peeked = heap.peek() - ts = heap.pop() - assert peeked == ts - seen.add(ts.key) - assert ts.priority - assert last_prio <= ts.priority - last_prio = last_prio - - ts = TaskState("foo") - heap.push(ts) - heap.push(ts) - assert len(heap) == 1 - - assert repr(heap) == "" - - assert heap.pop() == ts - assert not heap - - # Test that we're cleaning the seen set on pop - heap.push(ts) - assert len(heap) == 1 - assert heap.pop() == ts - - assert repr(heap) == "" - - @gen_cluster(nthreads=[]) async def test_do_not_block_event_loop_during_shutdown(s): loop = asyncio.get_running_loop() diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py new file mode 100644 index 00000000000..2c97b217aab --- /dev/null +++ b/distributed/tests/test_worker_state_machine.py @@ -0,0 +1,94 @@ +import pytest + +from distributed.utils import recursive_to_dict +from distributed.worker_state_machine import ( + ReleaseWorkerDataMsg, + SendMessageToScheduler, + TaskState, + UniqueTaskHeap, +) + + +def test_TaskState_get_nbytes(): + assert TaskState("x", nbytes=123).get_nbytes() == 123 + # Default to distributed.scheduler.default-data-size + assert TaskState("y").get_nbytes() == 1024 + + +def test_TaskState__to_dict(): + """Tasks that are listed as dependencies or dependents of other tasks are dumped as + a short repr and always appear in full directly under Worker.tasks. Uninteresting + fields are omitted. + """ + x = TaskState("x", state="memory", done=True) + y = TaskState("y", priority=(0,), dependencies={x}) + x.dependents.add(y) + actual = recursive_to_dict([x, y]) + assert actual == [ + { + "key": "x", + "state": "memory", + "done": True, + "dependents": [""], + }, + { + "key": "y", + "state": "released", + "dependencies": [""], + "priority": [0], + }, + ] + + +def test_unique_task_heap(): + heap = UniqueTaskHeap() + + for x in range(10): + ts = TaskState(f"f{x}", priority=(0,)) + ts.priority = (0, 0, 1, x % 3) + heap.push(ts) + + heap_list = list(heap) + # iteration does not empty heap + assert len(heap) == 10 + assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) + + seen = set() + last_prio = (0, 0, 0, 0) + while heap: + peeked = heap.peek() + ts = heap.pop() + assert peeked == ts + seen.add(ts.key) + assert ts.priority + assert last_prio <= ts.priority + last_prio = last_prio + + ts = TaskState("foo", priority=(0,)) + heap.push(ts) + heap.push(ts) + assert len(heap) == 1 + + assert repr(heap) == "" + + assert heap.pop() == ts + assert not heap + + # Test that we're cleaning the seen set on pop + heap.push(ts) + assert len(heap) == 1 + assert heap.pop() == ts + + assert repr(heap) == "" + + +@pytest.mark.parametrize("cls", SendMessageToScheduler.__subclasses__()) +def test_sendmsg_slots(cls): + smsg = cls(**dict.fromkeys(cls.__annotations__)) + assert not hasattr(smsg, "__dict__") + + +def test_sendmsg_to_dict(): + # Arbitrary sample class + smsg = ReleaseWorkerDataMsg(key="x") + assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} diff --git a/distributed/worker.py b/distributed/worker.py index 85464a9e50d..e12b83e21b4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -18,7 +18,6 @@ Collection, Container, Iterable, - Iterator, Mapping, MutableMapping, ) @@ -27,7 +26,7 @@ from datetime import timedelta from inspect import isawaitable from pickle import PicklingError -from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from tlz import first, keymap, merge, pluck # noqa: F401 from tornado.ioloop import IOLoop, PeriodicCallback @@ -100,18 +99,36 @@ DeprecatedMemoryMonitor, WorkerMemoryManager, ) +from .worker_state_machine import Instruction # noqa: F401 +from .worker_state_machine import ( + PROCESSING, + READY, + AddKeysMsg, + InvalidTransition, + LongRunningMsg, + ReleaseWorkerDataMsg, + RescheduleMsg, + SendMessageToScheduler, + SerializedTask, + TaskErredMsg, + TaskFinishedMsg, + TaskState, + UniqueTaskHeap, +) if TYPE_CHECKING: + # TODO move to typing (requires Python >=3.10) from typing_extensions import TypeAlias from .actor import Actor from .client import Client from .diagnostics.plugin import WorkerPlugin from .nanny import Nanny + from .worker_state_machine import TaskStateState - # {TaskState -> finish: str | (finish: str, *args)} - Recs: TypeAlias = "dict[TaskState, str | tuple]" - Smsgs: TypeAlias = "list[dict[str, Any]]" + # {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)} + Recs: TypeAlias = "dict[TaskState, TaskStateState | tuple]" + Instructions: TypeAlias = "list[Instruction]" logger = logging.getLogger(__name__) @@ -120,241 +137,12 @@ no_value = "--no-value-sentinel--" -# TaskState.state subsets -PROCESSING = { - "waiting", - "ready", - "constrained", - "executing", - "long-running", - "cancelled", - "resumed", -} -READY = {"ready", "constrained"} - DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension, ShuffleWorkerExtension] DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {} DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {} -DEFAULT_DATA_SIZE = parse_bytes( - dask.config.get("distributed.scheduler.default-data-size") -) - - -class SerializedTask(NamedTuple): - function: Callable - args: tuple - kwargs: dict[str, Any] - task: object # distributed.scheduler.TaskState.run_spec - - -class StartStop(TypedDict, total=False): - action: str - start: float - stop: float - source: str # optional - - -class InvalidTransition(Exception): - pass - - -class TaskState: - """Holds volatile state relating to an individual Dask task - - - * **dependencies**: ``set(TaskState instances)`` - The data needed by this key to run - * **dependents**: ``set(TaskState instances)`` - The keys that use this dependency. - * **duration**: ``float`` - Expected duration the a task - * **priority**: ``tuple`` - The priority this task given by the scheduler. Determines run order. - * **state**: ``str`` - The current state of the task. One of ["waiting", "ready", "executing", - "fetch", "memory", "flight", "long-running", "rescheduled", "error"] - * **who_has**: ``set(worker)`` - Workers that we believe have this data - * **coming_from**: ``str`` - The worker that current task data is coming from if task is in flight - * **waiting_for_data**: ``set(keys of dependencies)`` - A dynamic version of dependencies. All dependencies that we still don't - have for a particular key. - * **resource_restrictions**: ``{str: number}`` - Abstract resources required to run a task - * **exception**: ``str`` - The exception caused by running a task if it erred - * **traceback**: ``str`` - The exception caused by running a task if it erred - * **type**: ``type`` - The type of a particular piece of data - * **suspicious_count**: ``int`` - The number of times a dependency has not been where we expected it - * **startstops**: ``[{startstop}]`` - Log of transfer, load, and compute times for a task - * **start_time**: ``float`` - Time at which task begins running - * **stop_time**: ``float`` - Time at which task finishes running - * **metadata**: ``dict`` - Metadata related to task. Stored metadata should be msgpack - serializable (e.g. int, string, list, dict). - * **nbytes**: ``int`` - The size of a particular piece of data - * **annotations**: ``dict`` - Task annotations - - Parameters - ---------- - key: str - run_spec: SerializedTask - A named tuple containing the ``function``, ``args``, ``kwargs`` and - ``task`` associated with this `TaskState` instance. This defaults to - ``None`` and can remain empty if it is a dependency that this worker - will receive from another worker. - - """ - - key: str - run_spec: SerializedTask | None - dependencies: set[TaskState] - dependents: set[TaskState] - duration: float | None - priority: tuple[int, ...] | None - state: str - who_has: set[str] - coming_from: str | None - waiting_for_data: set[TaskState] - waiters: set[TaskState] - resource_restrictions: dict[str, float] - exception: Exception | None - exception_text: str | None - traceback: object | None - traceback_text: str | None - type: type | None - suspicious_count: int - startstops: list[StartStop] - start_time: float | None - stop_time: float | None - metadata: dict - nbytes: float | None - annotations: dict | None - done: bool - _previous: str | None - _next: str | None - - def __init__(self, key: str, run_spec: SerializedTask | None = None): - assert key is not None - self.key = key - self.run_spec = run_spec - self.dependencies = set() - self.dependents = set() - self.duration = None - self.priority = None - self.state = "released" - self.who_has = set() - self.coming_from = None - self.waiting_for_data = set() - self.waiters = set() - self.resource_restrictions = {} - self.exception = None - self.exception_text = "" - self.traceback = None - self.traceback_text = "" - self.type = None - self.suspicious_count = 0 - self.startstops = [] - self.start_time = None - self.stop_time = None - self.metadata = {} - self.nbytes = None - self.annotations = None - self.done = False - self._previous = None - self._next = None - - def __repr__(self) -> str: - return f"" - - def get_nbytes(self) -> int: - nbytes = self.nbytes - return nbytes if nbytes is not None else DEFAULT_DATA_SIZE - - def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict: - """Dictionary representation for debugging purposes. - Not type stable and not intended for roundtrips. - - See also - -------- - Client.dump_cluster_state - distributed.utils.recursive_to_dict - - Notes - ----- - This class uses ``_to_dict_no_nest`` instead of ``_to_dict``. - When a task references another task, just print the task repr. All tasks - should neatly appear under Worker.tasks. This also prevents a RecursionError - during particularly heavy loads, which have been observed to happen whenever - there's an acyclic dependency chain of ~200+ tasks. - """ - return recursive_to_dict(self, exclude=exclude, members=True) - - def is_protected(self) -> bool: - return self.state in PROCESSING or any( - dep_ts.state in PROCESSING for dep_ts in self.dependents - ) - - -class UniqueTaskHeap(Collection): - """A heap of TaskState objects ordered by TaskState.priority - Ties are broken by string comparison of the key. Keys are guaranteed to be - unique. Iterating over this object returns the elements in priority order. - """ - - def __init__(self, collection: Collection[TaskState] = ()): - self._known = {ts.key for ts in collection} - self._heap = [(ts.priority, ts.key, ts) for ts in collection] - heapq.heapify(self._heap) - - def push(self, ts: TaskState) -> None: - """Add a new TaskState instance to the heap. If the key is already - known, no object is added. - - Note: This does not update the priority / heap order in case priority - changes. - """ - assert isinstance(ts, TaskState) - if ts.key not in self._known: - heapq.heappush(self._heap, (ts.priority, ts.key, ts)) - self._known.add(ts.key) - - def pop(self) -> TaskState: - """Pop the task with highest priority from the heap.""" - _, key, ts = heapq.heappop(self._heap) - self._known.remove(key) - return ts - - def peek(self) -> TaskState: - """Get the highest priority TaskState without removing it from the heap""" - return self._heap[0][2] - - def __contains__(self, x: object) -> bool: - if isinstance(x, TaskState): - x = x.key - return x in self._known - - def __iter__(self) -> Iterator[TaskState]: - return (ts for _, _, ts in sorted(self._heap)) - - def __len__(self) -> int: - return len(self._known) - - def __repr__(self) -> str: - return f"<{type(self).__name__}: {len(self)} items>" - class Worker(ServerNode): """Worker node in a Dask distributed cluster @@ -1870,7 +1658,7 @@ def update_data( if stimulus_id is None: stimulus_id = f"update-data-{time()}" recommendations: Recs = {} - scheduler_messages = [] + instructions: Instructions = [] for key, value in data.items(): try: ts = self.tasks[key] @@ -1889,13 +1677,10 @@ def update_data( self.log.append((key, "receive-from-scatter", stimulus_id, time())) if report: - scheduler_messages.append( - {"op": "add-keys", "keys": list(data), "stimulus_id": stimulus_id} - ) + instructions.append(AddKeysMsg(keys=list(data), stimulus_id=stimulus_id)) self.transitions(recommendations, stimulus_id=stimulus_id) - for msg in scheduler_messages: - self.batched_stream.send(msg) + self._handle_instructions(instructions) return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: @@ -1955,9 +1740,8 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: if rejected: self.log.append(("remove-replica-rejected", rejected, stimulus_id, time())) - self.batched_stream.send( - {"op": "add-keys", "keys": rejected, "stimulus_id": stimulus_id} - ) + smsg = AddKeysMsg(keys=rejected, stimulus_id=stimulus_id) + self._handle_instructions([smsg]) self.transitions(recommendations, stimulus_id=stimulus_id) @@ -2082,7 +1866,7 @@ def handle_compute_task( ts.annotations = annotations recommendations: Recs = {} - scheduler_msgs: Smsgs = [] + instructions: Instructions = [] for dependency in who_has: dep_ts = self.ensure_task_exists( key=dependency, @@ -2098,7 +1882,7 @@ def handle_compute_task( pass elif ts.state == "memory": recommendations[ts] = "memory" - scheduler_msgs.append(self._get_task_finished_msg(ts)) + instructions.append(self._get_task_finished_msg(ts)) elif ts.state in { "released", "fetch", @@ -2111,9 +1895,7 @@ def handle_compute_task( else: # pragma: no cover raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") - for msg in scheduler_msgs: - self.batched_stream.send(msg) - + self._handle_instructions(instructions) self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) @@ -2123,7 +1905,7 @@ def handle_compute_task( def transition_missing_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "missing" assert ts.priority is not None @@ -2136,15 +1918,17 @@ def transition_missing_fetch( def transition_missing_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self._missing_dep_flight.discard(ts) - recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) + recs, instructions = self.transition_generic_released( + ts, stimulus_id=stimulus_id + ) assert ts.key in self.tasks - return recs, smsgs + return recs, instructions def transition_flight_missing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: assert ts.done ts.state = "missing" self._missing_dep_flight.add(ts) @@ -2153,7 +1937,7 @@ def transition_flight_missing( def transition_released_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "released" assert ts.priority is not None @@ -2166,7 +1950,7 @@ def transition_released_fetch( def transition_generic_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self.release_key(ts.key, stimulus_id=stimulus_id) recs: Recs = {} for dependency in ts.dependencies: @@ -2183,7 +1967,7 @@ def transition_generic_released( def transition_released_waiting( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "released" assert all(d.key in self.tasks for d in ts.dependencies) @@ -2191,7 +1975,7 @@ def transition_released_waiting( recommendations: Recs = {} ts.waiting_for_data.clear() for dep_ts in ts.dependencies: - if not dep_ts.state == "memory": + if dep_ts.state != "memory": ts.waiting_for_data.add(dep_ts) dep_ts.waiters.add(ts) if dep_ts.state not in {"fetch", "flight"}: @@ -2209,7 +1993,7 @@ def transition_released_waiting( def transition_fetch_flight( self, ts: TaskState, worker, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "fetch" assert ts.who_has @@ -2222,14 +2006,16 @@ def transition_fetch_flight( def transition_memory_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: - recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) - smsgs.append({"op": "release-worker-data", "key": ts.key}) - return recs, smsgs + ) -> tuple[Recs, Instructions]: + recs, instructions = self.transition_generic_released( + ts, stimulus_id=stimulus_id + ) + instructions.append(ReleaseWorkerDataMsg(ts.key)) + return recs, instructions def transition_waiting_constrained( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "waiting" assert not ts.waiting_for_data @@ -2245,25 +2031,25 @@ def transition_waiting_constrained( def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: recs: Recs = {ts: "released"} - smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] - return recs, smsgs + smsg = RescheduleMsg(key=ts.key, worker=self.address) + return recs, [smsg] def transition_executing_rescheduled( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) recs: Recs = {ts: "released"} - smsgs: Smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] - return recs, smsgs + smsg = RescheduleMsg(key=ts.key, worker=self.address) + return recs, [smsg] def transition_waiting_ready( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "waiting" assert ts.key not in self.ready @@ -2287,11 +2073,11 @@ def transition_cancelled_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: recs: Recs = {} - smsgs: Smsgs = [] + instructions: Instructions = [] if ts._previous == "executing": - recs, smsgs = self.transition_executing_error( + recs, instructions = self.transition_executing_error( ts, exception, traceback, @@ -2300,7 +2086,7 @@ def transition_cancelled_error( stimulus_id=stimulus_id, ) elif ts._previous == "flight": - recs, smsgs = self.transition_flight_error( + recs, instructions = self.transition_flight_error( ts, exception, traceback, @@ -2310,36 +2096,32 @@ def transition_cancelled_error( ) if ts._next: recs[ts] = ts._next - return recs, smsgs + return recs, instructions def transition_generic_error( self, ts: TaskState, - exception, - traceback, - exception_text, - traceback_text, + exception: Exception, + traceback: object, + exception_text: str, + traceback_text: str, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: ts.exception = exception ts.traceback = traceback ts.exception_text = exception_text ts.traceback_text = traceback_text ts.state = "error" - smsg = { - "op": "task-erred", - "status": "error", - "key": ts.key, - "thread": self.threads.get(ts.key), - "exception": ts.exception, - "traceback": ts.traceback, - "exception_text": ts.exception_text, - "traceback_text": ts.traceback_text, - } - - if ts.startstops: - smsg["startstops"] = ts.startstops + smsg = TaskErredMsg( + key=ts.key, + exception=ts.exception, + traceback=ts.traceback, + exception_text=ts.exception_text, + traceback_text=ts.traceback_text, + thread=self.threads.get(ts.key), + startstops=ts.startstops, + ) return {}, [smsg] @@ -2352,7 +2134,7 @@ def transition_executing_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) @@ -2366,8 +2148,8 @@ def transition_executing_error( ) def _transition_from_resumed( - self, ts: TaskState, finish: str, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + self, ts: TaskState, finish: TaskStateState, *, stimulus_id: str + ) -> tuple[Recs, Instructions]: """`resumed` is an intermediate degenerate state which splits further up into two states depending on what the last signal / next state is intended to be. There are only two viable choices depending on whether @@ -2388,24 +2170,24 @@ def _transition_from_resumed( See also `transition_resumed_waiting` """ recs: Recs = {} - smsgs: Smsgs = [] + instructions: Instructions = [] if ts.done: next_state = ts._next # if the next state is already intended to be waiting or if the # coro/thread is still running (ts.done==False), this is a noop if ts._next != finish: - recs, smsgs = self.transition_generic_released( + recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) assert next_state recs[ts] = next_state else: ts._next = finish - return recs, smsgs + return recs, instructions def transition_resumed_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: """ See Worker._transition_from_resumed """ @@ -2413,7 +2195,7 @@ def transition_resumed_fetch( def transition_resumed_missing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: """ See Worker._transition_from_resumed """ @@ -2427,7 +2209,7 @@ def transition_resumed_waiting(self, ts: TaskState, *, stimulus_id: str): def transition_cancelled_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if ts.done: return {ts: "released"}, [] elif ts._previous == "flight": @@ -2438,15 +2220,15 @@ def transition_cancelled_fetch( return {ts: ("resumed", "fetch")}, [] def transition_cancelled_resumed( - self, ts: TaskState, next: str, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + self, ts: TaskState, next: TaskStateState, *, stimulus_id: str + ) -> tuple[Recs, Instructions]: ts._next = next ts.state = "resumed" return {}, [] def transition_cancelled_waiting( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if ts.done: return {ts: "released"}, [] elif ts._previous == "executing": @@ -2458,7 +2240,7 @@ def transition_cancelled_waiting( def transition_cancelled_forgotten( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: ts._next = "forgotten" if not ts.done: return {}, [] @@ -2466,7 +2248,7 @@ def transition_cancelled_forgotten( def transition_cancelled_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if not ts.done: ts._next = "released" return {}, [] @@ -2477,14 +2259,16 @@ def transition_cancelled_released( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) + recs, instructions = self.transition_generic_released( + ts, stimulus_id=stimulus_id + ) if next_state != "released": recs[ts] = next_state - return recs, smsgs + return recs, instructions def transition_executing_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: ts._previous = ts.state ts._next = "released" # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 @@ -2494,13 +2278,13 @@ def transition_executing_released( def transition_long_running_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self.executed_count += 1 return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) def transition_generic_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if value is no_value and ts.key not in self.data: raise RuntimeError( f"Tried to transition task {ts} to `memory` without data available" @@ -2519,13 +2303,14 @@ def transition_generic_memory( msg = error_message(e) recs = {ts: tuple(msg.values())} return recs, [] - assert ts.key in self.data or ts.key in self.actors - smsgs = [self._get_task_finished_msg(ts)] - return recs, smsgs + if self.validate: + assert ts.key in self.data or ts.key in self.actors + smsg = self._get_task_finished_msg(ts) + return recs, [smsg] def transition_executing_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "executing" or ts.key in self.long_running assert not ts.waiting_for_data @@ -2537,7 +2322,7 @@ def transition_executing_memory( def transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2555,7 +2340,7 @@ def transition_constrained_executing( def transition_ready_executing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2573,7 +2358,7 @@ def transition_ready_executing( def transition_flight_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: # If this transition is called after the flight coroutine has finished, # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op @@ -2601,7 +2386,7 @@ def transition_flight_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self._in_flight_tasks.discard(ts) ts.coming_from = None return self.transition_generic_error( @@ -2615,7 +2400,7 @@ def transition_flight_error( def transition_flight_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if ts.done: # FIXME: Is this even possible? Would an assert instead be more # sensible? @@ -2629,70 +2414,49 @@ def transition_flight_released( def transition_cancelled_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: assert ts._next return {ts: ts._next}, [] def transition_executing_long_running( - self, ts: TaskState, compute_duration, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + self, ts: TaskState, compute_duration: float, *, stimulus_id: str + ) -> tuple[Recs, Instructions]: ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - smsgs = [ - { - "op": "long-running", - "key": ts.key, - "compute_duration": compute_duration, - } - ] - + smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration) self.io_loop.add_callback(self.ensure_computing) - return {}, smsgs + return {}, [smsg] def transition_released_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: - recs: Recs = {} + ) -> tuple[Recs, Instructions]: try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) - recs[ts] = ( - "error", - msg["exception"], - msg["traceback"], - msg["exception_text"], - msg["traceback_text"], - ) + recs = {ts: tuple(msg.values())} return recs, [] - smsgs = [{"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}] - return recs, smsgs + smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + return recs, [smsg] def transition_flight_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self._in_flight_tasks.discard(ts) ts.coming_from = None - recs: Recs = {} try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) - recs[ts] = ( - "error", - msg["exception"], - msg["traceback"], - msg["exception_text"], - msg["traceback_text"], - ) + recs = {ts: tuple(msg.values())} return recs, [] - smsgs = [{"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}] - return recs, smsgs + smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + return recs, [smsg] def transition_released_forgotten( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: recommendations: Recs = {} # Dependents _should_ be released by the scheduler before this if self.validate: @@ -2709,7 +2473,7 @@ def transition_released_forgotten( def _transition( self, ts: TaskState, finish: str | tuple, *args, stimulus_id: str, **kwargs - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple assert not args @@ -2723,13 +2487,15 @@ def _transition( if func is not None: self._transition_counter += 1 - recs, smsgs = func(ts, *args, stimulus_id=stimulus_id, **kwargs) + recs, instructions = func(ts, *args, stimulus_id=stimulus_id, **kwargs) self._notify_plugins("transition", ts.key, start, finish, **kwargs) elif "released" not in (start, finish): # start -> "released" -> finish try: - recs, smsgs = self._transition(ts, "released", stimulus_id=stimulus_id) + recs, instructions = self._transition( + ts, "released", stimulus_id=stimulus_id + ) v = recs.get(ts, (finish, *args)) v_state: str v_args: list | tuple @@ -2737,11 +2503,11 @@ def _transition( v_state, *v_args = v else: v_state, v_args = v, () - b_recs, b_smsgs = self._transition( + b_recs, b_instructions = self._transition( ts, v_state, *v_args, stimulus_id=stimulus_id ) recs.update(b_recs) - smsgs += b_smsgs + instructions += b_instructions except InvalidTransition: raise InvalidTransition( f"Impossible transition from {start} to {finish} for {ts.key}" @@ -2768,7 +2534,7 @@ def _transition( time(), ) ) - return recs, smsgs + return recs, instructions def transition( self, ts: TaskState, finish: str, *, stimulus_id: str, **kwargs @@ -2788,9 +2554,10 @@ def transition( -------- Scheduler.transitions: transitive version of this function """ - recs, smsgs = self._transition(ts, finish, stimulus_id=stimulus_id, **kwargs) - for msg in smsgs: - self.batched_stream.send(msg) + recs, instructions = self._transition( + ts, finish, stimulus_id=stimulus_id, **kwargs + ) + self._handle_instructions(instructions) self.transitions(recs, stimulus_id=stimulus_id) def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: @@ -2799,34 +2566,44 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: This includes feedback from previous transitions and continues until we reach a steady state """ - smsgs = [] + instructions = [] remaining_recs = recommendations.copy() tasks = set() while remaining_recs: ts, finish = remaining_recs.popitem() tasks.add(ts) - a_recs, a_smsgs = self._transition(ts, finish, stimulus_id=stimulus_id) + a_recs, a_instructions = self._transition( + ts, finish, stimulus_id=stimulus_id + ) remaining_recs.update(a_recs) - smsgs += a_smsgs + instructions += a_instructions if self.validate: # Full state validation is very expensive for ts in tasks: self.validate_task(ts) - if not self.batched_stream.closed(): - for msg in smsgs: - self.batched_stream.send(msg) - else: + if self.batched_stream.closed(): logger.debug( "BatchedSend closed while transitioning tasks. %d tasks not sent.", - len(smsgs), + len(instructions), ) + else: + self._handle_instructions(instructions) + + def _handle_instructions(self, instructions: list[Instruction]) -> None: + # TODO this method is temporary. + # See final design: https://github.com/dask/distributed/issues/5894 + for inst in instructions: + if isinstance(inst, SendMessageToScheduler): + self.batched_stream.send(inst.to_dict()) + else: + raise TypeError(inst) # pragma: nocover def maybe_transition_long_running( - self, ts: TaskState, *, stimulus_id: str, compute_duration=None + self, ts: TaskState, *, compute_duration: float, stimulus_id: str ): if ts.state == "executing": self.transition( @@ -2918,7 +2695,7 @@ def ensure_communicating(self) -> None: for el in skipped_worker_in_flight: self.data_needed.push(el) - def _get_task_finished_msg(self, ts: TaskState) -> dict[str, Any]: + def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2936,19 +2713,15 @@ def _get_task_finished_msg(self, ts: TaskState) -> dict[str, Any]: # Some types fail pickling (example: _thread.lock objects), # send their name as a best effort. typ_serialized = pickle.dumps(typ.__name__, protocol=4) - d = { - "op": "task-finished", - "status": "OK", - "key": ts.key, - "nbytes": ts.nbytes, - "thread": self.threads.get(ts.key), - "type": typ_serialized, - "typename": typename(typ), - "metadata": ts.metadata, - } - if ts.startstops: - d["startstops"] = ts.startstops - return d + return TaskFinishedMsg( + key=ts.key, + nbytes=ts.nbytes, + type=typ_serialized, + typename=typename(typ), + metadata=ts.metadata, + thread=self.threads.get(ts.key), + startstops=ts.startstops, + ) def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: """ @@ -2962,13 +2735,14 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: Raises ------ - TypeError: - In case the data is put into the in memory buffer and an exception - occurs during spilling, this raises an exception. This has to be - handled by the caller since most callers generate scheduler messages - on success (see comment above) but we need to signal that this was - not successful. - Can only trigger if spill to disk is enabled and the task is not an + Exception: + In case the data is put into the in memory buffer and a serialization error + occurs during spilling, this raises that error. This has to be handled by + the caller since most callers generate scheduler messages on success (see + comment above) but we need to signal that this was not successful. + + Can only trigger if distributed.worker.memory.target is enabled, the value + is individually larger than target * memory_limit, and the task is not an actor. """ if ts.key in self.data: diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py new file mode 100644 index 00000000000..0133afab965 --- /dev/null +++ b/distributed/worker_state_machine.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import heapq +import sys +from collections.abc import Callable, Container, Iterator +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Collection # TODO move to collections.abc (requires Python >=3.9) +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict + +import dask +from dask.utils import parse_bytes + +from .utils import recursive_to_dict + +if TYPE_CHECKING: + # TODO move to typing (requires Python >=3.10) + from typing_extensions import TypeAlias + + TaskStateState: TypeAlias = Literal[ + "cancelled", + "constrained", + "error", + "executing", + "fetch", + "flight", + "forgotten", + "long-running", + "memory", + "missing", + "ready", + "released", + "rescheduled", + "resumed", + "waiting", + ] + + +# TaskState.state subsets +PROCESSING: set[TaskStateState] = { + "waiting", + "ready", + "constrained", + "executing", + "long-running", + "cancelled", + "resumed", +} +READY: set[TaskStateState] = {"ready", "constrained"} + + +class SerializedTask(NamedTuple): + function: Callable + args: tuple + kwargs: dict[str, Any] + task: object # distributed.scheduler.TaskState.run_spec + + +class StartStop(TypedDict, total=False): + action: str + start: float + stop: float + source: str # optional + + +class InvalidTransition(Exception): + pass + + +@lru_cache +def _default_data_size() -> int: + return parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) + + +# Note: can't specify __slots__ manually to enable slots in Python <3.10 in a @dataclass +# that defines any default values +dc_slots = {"slots": True} if sys.version_info >= (3, 10) else {} + + +@dataclass(repr=False, eq=False, **dc_slots) +class TaskState: + """Holds volatile state relating to an individual Dask task. + + Not to be confused with :class:`distributed.scheduler.TaskState`, which holds + similar information on the scheduler side. + """ + + #: Task key. Mandatory. + key: str + #: A named tuple containing the ``function``, ``args``, ``kwargs`` and ``task`` + #: associated with this `TaskState` instance. This defaults to ``None`` and can + #: remain empty if it is a dependency that this worker will receive from another + #: worker. + run_spec: SerializedTask | None = None + + #: The data needed by this key to run + dependencies: set[TaskState] = field(default_factory=set) + #: The keys that use this dependency + dependents: set[TaskState] = field(default_factory=set) + #: Subset of dependencies that are not in memory + waiting_for_data: set[TaskState] = field(default_factory=set) + #: Subset of dependents that are not in memory + waiters: set[TaskState] = field(default_factory=set) + + #: The current state of the task + state: TaskStateState = "released" + #: The previous state of the task. This is a state machine implementation detail. + _previous: TaskStateState | None = None + #: The next state of the task. This is a state machine implementation detail. + _next: TaskStateState | None = None + + #: Expected duration of the task + duration: float | None = None + #: The priority this task given by the scheduler. Determines run order. + priority: tuple[int, ...] | None = None + #: Addresses of workers that we believe have this data + who_has: set[str] = field(default_factory=set) + #: The worker that current task data is coming from if task is in flight + coming_from: str | None = None + #: Abstract resources required to run a task + resource_restrictions: dict[str, float] = field(default_factory=dict) + #: The exception caused by running a task if it erred + exception: Exception | None = None + #: string representation of exception + exception_text: str = "" + #: The traceback caused by running a task if it erred + traceback: object | None = None + #: string representation of traceback + traceback_text: str = "" + #: The type of a particular piece of data + type: type | None = None + #: The number of times a dependency has not been where we expected it + suspicious_count: int = 0 + #: Log of transfer, load, and compute times for a task + startstops: list[StartStop] = field(default_factory=list) + #: Time at which task begins running + start_time: float | None = None + #: Time at which task finishes running + stop_time: float | None = None + #: Metadata related to the task. + #: Stored metadata should be msgpack serializable (e.g. int, string, list, dict). + metadata: dict = field(default_factory=dict) + #: The size of the value of the task, if in memory + nbytes: int | None = None + #: Arbitrary task annotations + annotations: dict | None = None + #: True if the task is in memory or erred; False otherwise + done: bool = False + + # Support for weakrefs to a class with __slots__ + __weakref__: Any = field(init=False) + + def __repr__(self) -> str: + return f"" + + def get_nbytes(self) -> int: + nbytes = self.nbytes + return nbytes if nbytes is not None else _default_data_size() + + def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict: + """Dictionary representation for debugging purposes. + Not type stable and not intended for roundtrips. + + See also + -------- + Client.dump_cluster_state + distributed.utils.recursive_to_dict + + Notes + ----- + This class uses ``_to_dict_no_nest`` instead of ``_to_dict``. + When a task references another task, just print the task repr. All tasks + should neatly appear under Worker.tasks. This also prevents a RecursionError + during particularly heavy loads, which have been observed to happen whenever + there's an acyclic dependency chain of ~200+ tasks. + """ + out = recursive_to_dict(self, exclude=exclude, members=True) + # Remove all Nones and empty containers + return {k: v for k, v in out.items() if v} + + def is_protected(self) -> bool: + return self.state in PROCESSING or any( + dep_ts.state in PROCESSING for dep_ts in self.dependents + ) + + +class UniqueTaskHeap(Collection[TaskState]): + """A heap of TaskState objects ordered by TaskState.priority. + Ties are broken by string comparison of the key. Keys are guaranteed to be + unique. Iterating over this object returns the elements in priority order. + """ + + __slots__ = ("_known", "_heap") + _known: set[str] + _heap: list[tuple[tuple[int, ...], str, TaskState]] + + def __init__(self): + self._known = set() + self._heap = [] + + def push(self, ts: TaskState) -> None: + """Add a new TaskState instance to the heap. If the key is already + known, no object is added. + + Note: This does not update the priority / heap order in case priority + changes. + """ + assert isinstance(ts, TaskState) + if ts.key not in self._known: + assert ts.priority + heapq.heappush(self._heap, (ts.priority, ts.key, ts)) + self._known.add(ts.key) + + def pop(self) -> TaskState: + """Pop the task with highest priority from the heap.""" + _, key, ts = heapq.heappop(self._heap) + self._known.remove(key) + return ts + + def peek(self) -> TaskState: + """Get the highest priority TaskState without removing it from the heap""" + return self._heap[0][2] + + def __contains__(self, x: object) -> bool: + if isinstance(x, TaskState): + x = x.key + return x in self._known + + def __iter__(self) -> Iterator[TaskState]: + return (ts for _, _, ts in sorted(self._heap)) + + def __len__(self) -> int: + return len(self._known) + + def __repr__(self) -> str: + return f"<{type(self).__name__}: {len(self)} items>" + + +class Instruction: + """Command from the worker state machine to the Worker, in response to an event""" + + __slots__ = () + + +# TODO https://github.com/dask/distributed/issues/5736 + +# @dataclass +# class GatherDep(Instruction): +# __slots__ = ("worker", "to_gather") +# worker: str +# to_gather: set[str] + + +# @dataclass +# class FindMissing(Instruction): +# __slots__ = () + + +# @dataclass +# class Execute(Instruction): +# __slots__ = ("key", "stimulus_id") +# key: str +# stimulus_id: str + + +class SendMessageToScheduler(Instruction): + __slots__ = () + #: Matches a key in Scheduler.stream_handlers + op: ClassVar[str] + + def __init_subclass__(cls, op: str): + cls.op = op + + def to_dict(self) -> dict[str, Any]: + """Convert object to dict so that it can be serialized with msgpack""" + d = {k: getattr(self, k) for k in self.__annotations__} + d["op"] = self.op + return d + + +# Note: as of Python 3.10.2, @dataclass(slots=True) doesn't work with __init__subclass__ +# https://bugs.python.org/issue46970 +@dataclass +class TaskFinishedMsg(SendMessageToScheduler, op="task-finished"): + key: str + nbytes: int | None + type: bytes # serialized class + typename: str + metadata: dict + thread: int | None + startstops: list[StartStop] + __slots__ = tuple(__annotations__) # type: ignore + + def to_dict(self) -> dict[str, Any]: + d = super().to_dict() + d["status"] = "OK" + return d + + +@dataclass +class TaskErredMsg(SendMessageToScheduler, op="task-erred"): + key: str + exception: Exception + exception_text: str + traceback: object + traceback_text: str + thread: int | None + startstops: list[StartStop] + __slots__ = tuple(__annotations__) # type: ignore + + def to_dict(self) -> dict[str, Any]: + d = super().to_dict() + d["status"] = "error" + return d + + +@dataclass +class ReleaseWorkerDataMsg(SendMessageToScheduler, op="release-worker-data"): + __slots__ = ("key",) + key: str + + +@dataclass +class RescheduleMsg(SendMessageToScheduler, op="reschedule"): + # Not to be confused with the distributed.Reschedule Exception + __slots__ = ("key", "worker") + key: str + worker: str + + +@dataclass +class LongRunningMsg(SendMessageToScheduler, op="long-running"): + __slots__ = ("key", "compute_duration") + key: str + compute_duration: float + + +@dataclass +class AddKeysMsg(SendMessageToScheduler, op="add-keys"): + __slots__ = ("keys", "stimulus_id") + keys: list[str] + stimulus_id: str diff --git a/docs/source/worker.rst b/docs/source/worker.rst index bcc71783dcd..91cac09947a 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -94,19 +94,19 @@ more details on the command line options, please have a look at the Internal Scheduling ------------------- -Internally tasks that come to the scheduler proceed through the following -pipeline as :py:class:`distributed.worker.TaskState` objects. Tasks which -follow this path have a :py:attr:`distributed.worker.TaskState.runspec` defined -which instructs the worker how to execute them. +Internally tasks that come to the scheduler proceed through the following pipeline as +:class:`distributed.worker_state_machine.TaskState` objects. Tasks which follow this +path have a :attr:`~distributed.worker_state_machine.TaskState.runspec` defined which +instructs the worker how to execute them. .. image:: images/worker-task-state.svg :alt: Dask worker task states Data dependencies are also represented as -:py:class:`distributed.worker.TaskState` objects and follow a simpler path -through the execution pipeline. These tasks do not have a -:py:attr:`distributed.worker.TaskState.runspec` defined and instead contain a -listing of workers to collect their result from. +:class:`~distributed.worker_state_machine.TaskState` objects and follow a simpler path +through the execution pipeline. These tasks do not have a +:attr:`~distributed.worker_state_machine.TaskState.runspec` defined and instead contain +a listing of workers to collect their result from. .. image:: images/worker-dep-state.svg @@ -156,10 +156,14 @@ Dask workers are by default launched, monitored, and managed by a small Nanny process. .. autoclass:: distributed.nanny.Nanny + :members: API Documentation ----------------- -.. autoclass:: distributed.worker.TaskState +.. autoclass:: distributed.worker_state_machine.TaskState + :members: + .. autoclass:: distributed.worker.Worker + :members: