Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enum for worker TaskState names #5444

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class WorkerPlugin:
>>> client.register_worker_plugin(plugin) # doctest: +SKIP
"""

enum_task_state_names = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a backwards compatibility toggle for the plugin. There is now a class attribute controlling this behaviour and it default to the old behaviour. the old behaviour will then issue a deprecation warning instructing the user how to migrate.


def setup(self, worker):
"""
Run when the plugin is attached to a worker. This happens when the plugin is registered
Expand Down
102 changes: 71 additions & 31 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import contextlib

import pytest

from distributed import Worker, WorkerPlugin
from distributed.utils_test import async_wait_for, gen_cluster, inc
from distributed.worker import WTSS


class MyPlugin(WorkerPlugin):
enum_task_state_names = True
name = "MyPlugin"

def __init__(self, data, expected_notifications=None):
Expand Down Expand Up @@ -106,12 +109,12 @@ async def test_create_on_construction(c, s, a, b):
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_normal_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
{"key": "task", "start": WTSS.released, "finish": WTSS.waiting},
{"key": "task", "start": WTSS.waiting, "finish": WTSS.ready},
{"key": "task", "start": WTSS.ready, "finish": WTSS.executing},
{"key": "task", "start": WTSS.executing, "finish": WTSS.memory},
{"key": "task", "start": WTSS.memory, "finish": WTSS.released},
{"key": "task", "start": WTSS.released, "finish": WTSS.forgotten},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -127,12 +130,12 @@ def failing(x):
raise Exception()

expected_notifications = [
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "start": "error", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
{"key": "task", "start": WTSS.released, "finish": WTSS.waiting},
{"key": "task", "start": WTSS.waiting, "finish": WTSS.ready},
{"key": "task", "start": WTSS.ready, "finish": WTSS.executing},
{"key": "task", "start": WTSS.executing, "finish": WTSS.error},
{"key": "task", "start": WTSS.error, "finish": WTSS.released},
{"key": "task", "start": WTSS.released, "finish": WTSS.forgotten},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -148,12 +151,12 @@ def failing(x):
)
async def test_superseding_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "constrained"},
{"key": "task", "start": "constrained", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
{"key": "task", "start": WTSS.released, "finish": WTSS.waiting},
{"key": "task", "start": WTSS.waiting, "finish": WTSS.constrained},
{"key": "task", "start": WTSS.constrained, "finish": WTSS.executing},
{"key": "task", "start": WTSS.executing, "finish": WTSS.memory},
{"key": "task", "start": WTSS.memory, "finish": WTSS.released},
{"key": "task", "start": WTSS.released, "finish": WTSS.forgotten},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -168,18 +171,18 @@ async def test_dependent_tasks(c, s, w):
dsk = {"dep": 1, "task": (inc, "dep")}

expected_notifications = [
{"key": "dep", "start": "released", "finish": "waiting"},
{"key": "dep", "start": "waiting", "finish": "ready"},
{"key": "dep", "start": "ready", "finish": "executing"},
{"key": "dep", "start": "executing", "finish": "memory"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "dep", "start": "memory", "finish": "released"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
{"key": "dep", "start": "released", "finish": "forgotten"},
{"key": "dep", "start": WTSS.released, "finish": WTSS.waiting},
{"key": "dep", "start": WTSS.waiting, "finish": WTSS.ready},
{"key": "dep", "start": WTSS.ready, "finish": WTSS.executing},
{"key": "dep", "start": WTSS.executing, "finish": WTSS.memory},
{"key": "task", "start": WTSS.released, "finish": WTSS.waiting},
{"key": "task", "start": WTSS.waiting, "finish": WTSS.ready},
{"key": "task", "start": WTSS.ready, "finish": WTSS.executing},
{"key": "task", "start": WTSS.executing, "finish": WTSS.memory},
{"key": "dep", "start": WTSS.memory, "finish": WTSS.released},
{"key": "task", "start": WTSS.memory, "finish": WTSS.released},
{"key": "task", "start": WTSS.released, "finish": WTSS.forgotten},
{"key": "dep", "start": WTSS.released, "finish": WTSS.forgotten},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down Expand Up @@ -215,7 +218,7 @@ def __init__(self):
def release_key(self, key, state, cause, reason, report):
# Ensure that the handler still works
self._called = True
assert state == "memory"
assert state == WTSS.memory
assert key == "task"

def teardown(self, worker):
Expand All @@ -235,6 +238,43 @@ async def test(c, s, a):
test()


@pytest.mark.parametrize("enum", [True, False])
def test_transition_enum_deprecation(enum):
class TransitionEnumStates(WorkerPlugin):
enum_task_state_names = enum

def __init__(self):
self._called = False

def transition(self, key, start, finish, **kwargs):
self._called = True
if enum:
assert isinstance(start, WTSS)
assert isinstance(finish, WTSS)
else:
assert isinstance(start, str)
assert isinstance(finish, str)
return super().transition(key, start, finish, **kwargs)

def teardown(self, worker):
assert self._called
return super().teardown(worker)

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):
await c.register_worker_plugin(TransitionEnumStates())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

if enum:
ctx = contextlib.nullcontext()
else:
ctx = pytest.deprecated_call(match="no longer receive string values for start")

with ctx:
test()


def test_assert_no_warning_no_overload():
"""Assert we do not receive a deprecation warning if we do not overload any
methods
Expand Down
22 changes: 12 additions & 10 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import dask
from dask.utils import parse_timedelta

from distributed.worker import WTSS

from .comm.addressing import get_address_host
from .core import CommClosedError
from .diagnostics.plugin import SchedulerPlugin
Expand All @@ -30,21 +32,21 @@
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")

_WORKER_STATE_CONFIRM = {
"ready",
"constrained",
"waiting",
WTSS.ready,
WTSS.constrained,
WTSS.waiting,
}

_WORKER_STATE_REJECT = {
"memory",
"executing",
"long-running",
"cancelled",
"resumed",
WTSS.memory,
WTSS.executing,
WTSS.long_running,
WTSS.cancelled,
WTSS.resumed,
}
_WORKER_STATE_UNDEFINED = {
"released",
None,
WTSS.released,
WTSS.forgotten,
}


Expand Down
49 changes: 29 additions & 20 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@
import pytest

import distributed
from distributed import Worker
from distributed.core import CommClosedError
from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc
from distributed.utils_test import (
_LockedCommPool,
assert_worker_story,
gen_cluster,
inc,
slowinc,
)
from distributed.worker import WTSS, Worker


async def wait_for_state(key, state, dask_worker):
async def wait_for_state(key: str, state: WTSS, dask_worker: Worker) -> None:
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)


async def wait_for_cancelled(key, dask_worker):
async def wait_for_cancelled(key: str, dask_worker: Worker) -> None:
while key in dask_worker.tasks:
if dask_worker.tasks[key].state == "cancelled":
if dask_worker.tasks[key].state == WTSS.cancelled:
return
await asyncio.sleep(0.005)
assert False
Expand All @@ -25,15 +31,15 @@ async def wait_for_cancelled(key, dask_worker):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_release(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
await wait_for_state(fut.key, WTSS.executing, a)
fut.release()
await wait_for_cancelled(fut.key, a)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_reschedule(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
await wait_for_state(fut.key, WTSS.executing, a)
fut.release()
await wait_for_cancelled(fut.key, a)
fut = c.submit(slowinc, 1, delay=0.1)
Expand All @@ -43,7 +49,7 @@ async def test_abort_execution_reschedule(c, s, a):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_add_as_dependency(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
await wait_for_state(fut.key, WTSS.executing, a)
fut.release()
await wait_for_cancelled(fut.key, a)

Expand All @@ -55,7 +61,7 @@ async def test_abort_execution_add_as_dependency(c, s, a):
@gen_cluster(client=True)
async def test_abort_execution_to_fetch(c, s, a, b):
fut = c.submit(slowinc, 1, delay=2, key="f1", workers=[a.worker_address])
await wait_for_state(fut.key, "executing", a)
await wait_for_state(fut.key, WTSS.executing, a)
fut.release()
await wait_for_cancelled(fut.key, a)

Expand Down Expand Up @@ -120,11 +126,11 @@ async def wait_and_raise(*args, **kwargs):
fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
fut2 = c.submit(inc, fut1, workers=[b.address])

await wait_for_state(fut1.key, "flight", b)
await wait_for_state(fut1.key, WTSS.flight, b)

# Close in scheduler to ensure we transition and reschedule task properly
await s.close_worker(worker=a.address)
await wait_for_state(fut1.key, "resumed", b)
await wait_for_state(fut1.key, WTSS.resumed, b)

lock.release()
assert await fut2 == 3
Expand All @@ -149,10 +155,10 @@ async def wait_and_raise(*args, **kwargs):
raise RuntimeError()

fut = c.submit(wait_and_raise)
await wait_for_state(fut.key, "executing", w)
await wait_for_state(fut.key, WTSS.executing, w)

fut.release()
await wait_for_state(fut.key, "cancelled", w)
await wait_for_state(fut.key, WTSS.cancelled, w)
await lock.release()

# At this point we do not fetch the result of the future since the future
Expand All @@ -169,10 +175,13 @@ async def wait_and_raise(*args, **kwargs):
# refactoring. Below verifies some implementation specific test assumptions

story = w.story(fut.key)
start_finish = [(msg[1], msg[2], msg[3]) for msg in story if len(msg) == 7]
assert ("executing", "released", "cancelled") in start_finish
assert ("cancelled", "error", "error") in start_finish
assert ("error", "released", "released") in start_finish
expected = [
(fut.key, "executing", "released", "cancelled", {}),
(fut.key, "cancelled", "error", "error", {fut.key: "released"}),
(fut.key, "release-key"),
(fut.key, "error", "released", "released", {fut.key: "forgotten"}),
]
assert_worker_story(story, expected)


@gen_cluster(client=True)
Expand All @@ -194,10 +203,10 @@ async def wait_and_raise(*args, **kwargs):
fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
fut2 = c.submit(inc, fut1, workers=[b.address])

await wait_for_state(fut1.key, "flight", b)
await wait_for_state(fut1.key, WTSS.flight, b)
fut2.release()
fut1.release()
await wait_for_state(fut1.key, "cancelled", b)
await wait_for_state(fut1.key, WTSS.cancelled, b)

lock.release()
# At this point we do not fetch the result of the future since the future
Expand Down Expand Up @@ -253,7 +262,7 @@ def produce_evil_data():

fut = c.submit(produce_evil_data)

await wait_for_state(fut.key, "error", w)
await wait_for_state(fut.key, WTSS.error, w)

with pytest.raises(
TypeError,
Expand Down
Loading