Skip to content

Commit

Permalink
Refactor worker scheduler messages and TaskState
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 10, 2022
1 parent 936fba5 commit ff4987c
Show file tree
Hide file tree
Showing 8 changed files with 612 additions and 471 deletions.
2 changes: 1 addition & 1 deletion distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .metrics import time
from .utils import import_term, log_errors

if TYPE_CHECKING: # pragma: nocover
if TYPE_CHECKING:
from .client import Client
from .scheduler import Scheduler, TaskState, WorkerState

Expand Down
2 changes: 1 addition & 1 deletion distributed/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _background_send(self):
self.stopped.set()
self.abort()

def send(self, *msgs):
def send(self, *msgs: dict) -> None:
"""Schedule a message for sending to the other side
This completes quickly and synchronously
Expand Down
14 changes: 10 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,8 @@ def _to_dict_no_nest(self, *, exclude: "Container[str]" = ()) -> dict:
class TaskState:
"""
A simple object holding information about a task.
Not to be confused with :class:`distributed.worker_state_machine.TaskState`, which
holds similar information on the Worker side.
.. attribute:: key: str
Expand Down Expand Up @@ -5505,7 +5507,9 @@ def handle_task_finished(self, key=None, worker=None, **msg):
client_msgs: dict
worker_msgs: dict

r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg)
r: tuple = self.stimulus_task_finished(
key=key, worker=worker, status="OK", **msg
)
recommendations, client_msgs, worker_msgs = r
parent._transitions(recommendations, client_msgs, worker_msgs)

Expand All @@ -5516,7 +5520,7 @@ def handle_task_erred(self, key=None, **msg):
recommendations: dict
client_msgs: dict
worker_msgs: dict
r: tuple = self.stimulus_task_erred(key=key, **msg)
r: tuple = self.stimulus_task_erred(key=key, status="error", **msg)
recommendations, client_msgs, worker_msgs = r
parent._transitions(recommendations, client_msgs, worker_msgs)

Expand Down Expand Up @@ -7025,7 +7029,9 @@ async def _track_retire_worker(
logger.info("Retired worker %s", ws._address)
return ws._address, ws.identity()

def add_keys(self, worker=None, keys=(), stimulus_id=None):
def add_keys(
self, worker: str, keys: "Iterable[str]" = (), stimulus_id: "str | None" = None
) -> str:
"""
Learn that a worker has certain keys
Expand All @@ -7038,7 +7044,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None):
ws: WorkerState = parent._workers_dv[worker]
redundant_replicas = []
for key in keys:
ts: TaskState = parent._tasks.get(key)
ts: TaskState = parent._tasks.get(key) # type: ignore
if ts is not None and ts._state == "memory":
if ws not in ts._who_has:
parent.add_replica(ts, ws)
Expand Down
70 changes: 1 addition & 69 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,7 @@
slowinc,
slowsum,
)
from distributed.worker import (
TaskState,
UniqueTaskHeap,
Worker,
error_message,
logger,
parse_memory_limit,
)
from distributed.worker import Worker, error_message, logger, parse_memory_limit

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -3746,67 +3739,6 @@ async def test_Worker__to_dict(c, s, a):
assert d["tasks"]["x"]["key"] == "x"


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_TaskState__to_dict(c, s, a):
"""tasks that are listed as dependencies of other tasks are dumped as a short repr
and always appear in full under Worker.tasks
"""
x = c.submit(inc, 1, key="x")
y = c.submit(inc, x, key="y")
z = c.submit(inc, 2, key="z")
await wait([x, y, z])

tasks = a._to_dict()["tasks"]

assert isinstance(tasks["x"], dict)
assert isinstance(tasks["y"], dict)
assert isinstance(tasks["z"], dict)
assert tasks["x"]["dependents"] == ["<TaskState 'y' memory>"]
assert tasks["y"]["dependencies"] == ["<TaskState 'x' memory>"]


def test_unique_task_heap():
heap = UniqueTaskHeap()

for x in range(10):
ts = TaskState(f"f{x}")
ts.priority = (0, 0, 1, x % 3)
heap.push(ts)

heap_list = list(heap)
# iteration does not empty heap
assert len(heap) == 10
assert heap_list == sorted(heap_list, key=lambda ts: ts.priority)

seen = set()
last_prio = (0, 0, 0, 0)
while heap:
peeked = heap.peek()
ts = heap.pop()
assert peeked == ts
seen.add(ts.key)
assert ts.priority
assert last_prio <= ts.priority
last_prio = last_prio

ts = TaskState("foo")
heap.push(ts)
heap.push(ts)
assert len(heap) == 1

assert repr(heap) == "<UniqueTaskHeap: 1 items>"

assert heap.pop() == ts
assert not heap

# Test that we're cleaning the seen set on pop
heap.push(ts)
assert len(heap) == 1
assert heap.pop() == ts

assert repr(heap) == "<UniqueTaskHeap: 0 items>"


@gen_cluster(nthreads=[])
async def test_do_not_block_event_loop_during_shutdown(s):
loop = asyncio.get_running_loop()
Expand Down
90 changes: 90 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from distributed.utils import recursive_to_dict
from distributed.worker_state_machine import (
ReleaseWorkerDataMsg,
TaskState,
UniqueTaskHeap,
)


def test_TaskState_get_nbytes():
assert TaskState("x", nbytes=123).get_nbytes() == 123
# Default to distributed.scheduler.default-data-size
assert TaskState("y").get_nbytes() == 1024


def test_TaskState__to_dict():
"""Tasks that are listed as dependencies or dependents of other tasks are dumped as
a short repr and always appear in full directly under Worker.tasks. Uninteresting
fields are omitted.
"""
x = TaskState("x", state="memory", done=True)
y = TaskState("y", priority=(0,), dependencies={x})
x.dependents.add(y)
actual = recursive_to_dict([x, y])
assert actual == [
{
"key": "x",
"state": "memory",
"done": True,
"dependents": ["<TaskState 'y' released>"],
},
{
"key": "y",
"state": "released",
"dependencies": ["<TaskState 'x' memory>"],
"priority": [0],
},
]


def test_unique_task_heap():
heap = UniqueTaskHeap()

for x in range(10):
ts = TaskState(f"f{x}", priority=(0,))
ts.priority = (0, 0, 1, x % 3)
heap.push(ts)

heap_list = list(heap)
# iteration does not empty heap
assert len(heap) == 10
assert heap_list == sorted(heap_list, key=lambda ts: ts.priority)

seen = set()
last_prio = (0, 0, 0, 0)
while heap:
peeked = heap.peek()
ts = heap.pop()
assert peeked == ts
seen.add(ts.key)
assert ts.priority
assert last_prio <= ts.priority
last_prio = last_prio

ts = TaskState("foo", priority=(0,))
heap.push(ts)
heap.push(ts)
assert len(heap) == 1

assert repr(heap) == "<UniqueTaskHeap: 1 items>"

assert heap.pop() == ts
assert not heap

# Test that we're cleaning the seen set on pop
heap.push(ts)
assert len(heap) == 1
assert heap.pop() == ts

assert repr(heap) == "<UniqueTaskHeap: 0 items>"


def test_sendmsg_slots():
# Sample test on one of the subclasses
smsg = ReleaseWorkerDataMsg(key="x")
assert not hasattr(smsg, "__dict__")


def test_sendmsg_to_dict():
smsg = ReleaseWorkerDataMsg(key="x")
assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}
Loading

0 comments on commit ff4987c

Please sign in to comment.