Skip to content

Commit

Permalink
Notify worker plugins when a task is released (#3817)
Browse files Browse the repository at this point in the history
  • Loading branch information
nre authored Jun 18, 2020
1 parent 920af0f commit 5172678
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 32 deletions.
13 changes: 9 additions & 4 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4093,10 +4093,11 @@ def register_worker_plugin(self, plugin=None, name=None):
on all currently connected workers. It will also be run on any worker
that connects in the future.
The plugin may include methods ``setup``, ``teardown``, and
``transition``. See the ``dask.distributed.WorkerPlugin`` class or the
examples below for the interface and docstrings. It must be
serializable with the pickle or cloudpickle modules.
The plugin may include methods ``setup``, ``teardown``, ``transition``,
``release_key``, and ``release_dep``. See the
``dask.distributed.WorkerPlugin`` class or the examples below for the
interface and docstrings. It must be serializable with the pickle or
cloudpickle modules.
If the plugin has a ``name`` attribute, or if the ``name=`` keyword is
used then that will control idempotency. A a plugin with that name has
Expand Down Expand Up @@ -4124,6 +4125,10 @@ def register_worker_plugin(self, plugin=None, name=None):
... pass
... def transition(self, key: str, start: str, finish: str, **kwargs):
... pass
... def release_key(self, key: str, state: str, cause: Optional[str], reason: None, report: bool):
... pass
... def release_dep(self, dep: str, state: str, report: bool):
... pass
>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_worker_plugin(plugin)
Expand Down
35 changes: 34 additions & 1 deletion distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class WorkerPlugin:
""" Interface to extend the Worker
A worker plugin enables custom code to run at different stages of the Workers'
lifecycle: at setup, during task state transitions and at teardown.
lifecycle: at setup, during task state transitions, when a task or dependency
is released, and at teardown.
A plugin enables custom code to run at each of step of a Workers's life. Whenever such
an event happens, the corresponding method on this class will be called. Note that the
Expand Down Expand Up @@ -147,3 +148,35 @@ def transition(self, key, start, finish, **kwargs):
Final state of the transition.
kwargs: More options passed when transitioning
"""

def release_key(self, key, state, cause, reason, report):
"""
Called when the worker releases a task.
Parameters
----------
key: string
state: string
State of the released task.
One of waiting, ready, executing, long-running, memory, error.
cause: string or None
Additional information on what triggered the release of the task.
reason: None
Not used.
report: bool
Whether the worker should report the released task to the scheduler.
"""

def release_dep(self, dep, state, report):
"""
Called when the worker releases a dependency.
Parameters
----------
dep: string
state: string
State of the released dependency.
One of waiting, flight, memory.
report: bool
Whether the worker should report the released dependency to the scheduler.
"""
88 changes: 65 additions & 23 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,48 @@
import pytest

from distributed import Worker, WorkerPlugin
from distributed.utils_test import gen_cluster
from distributed.utils_test import async_wait_for, gen_cluster, inc


class MyPlugin(WorkerPlugin):
name = "MyPlugin"

def __init__(self, data, expected_transitions=None):
def __init__(self, data, expected_notifications=None):
self.data = data
self.expected_transitions = expected_transitions
self.expected_notifications = expected_notifications

def setup(self, worker):
assert isinstance(worker, Worker)
self.worker = worker
self.worker._my_plugin_status = "setup"
self.worker._my_plugin_data = self.data

self.observed_transitions = []
self.observed_notifications = []

def teardown(self, worker):
self.worker._my_plugin_status = "teardown"

if self.expected_transitions is not None:
assert len(self.observed_transitions) == len(self.expected_transitions)
if self.expected_notifications is not None:
assert len(self.observed_notifications) == len(self.expected_notifications)
for expected, real in zip(
self.expected_transitions, self.observed_transitions
self.expected_notifications, self.observed_notifications
):
assert expected == real

def transition(self, key, start, finish, **kwargs):
self.observed_transitions.append((key, start, finish))
self.observed_notifications.append(
{"key": key, "start": start, "finish": finish,}
)

def release_key(self, key, state, cause, reason, report):
self.observed_notifications.append(
{"key": key, "state": state,}
)

def release_dep(self, dep, state, report):
self.observed_notifications.append(
{"dep": dep, "state": state,}
)


@gen_cluster(client=True, nthreads=[])
Expand All @@ -54,30 +66,32 @@ async def test_create_on_construction(c, s, a, b):

@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_normal_task_transitions_called(c, s, w):
expected_transitions = [
("task", "waiting", "ready"),
("task", "ready", "executing"),
("task", "executing", "memory"),
expected_notifications = [
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
]

plugin = MyPlugin(1, expected_transitions=expected_transitions)
plugin = MyPlugin(1, expected_notifications=expected_notifications)

await c.register_worker_plugin(plugin)
await c.submit(lambda x: x, 1, key="task")
await async_wait_for(lambda: not w.task_state, timeout=10)


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_failing_task_transitions_called(c, s, w):
def failing(x):
raise Exception()

expected_transitions = [
("task", "waiting", "ready"),
("task", "ready", "executing"),
("task", "executing", "error"),
expected_notifications = [
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
]

plugin = MyPlugin(1, expected_transitions=expected_transitions)
plugin = MyPlugin(1, expected_notifications=expected_notifications)

await c.register_worker_plugin(plugin)

Expand All @@ -89,16 +103,44 @@ def failing(x):
nthreads=[("127.0.0.1", 1)], client=True, worker_kwargs={"resources": {"X": 1}},
)
async def test_superseding_task_transitions_called(c, s, w):
expected_transitions = [
("task", "waiting", "constrained"),
("task", "constrained", "executing"),
("task", "executing", "memory"),
expected_notifications = [
{"key": "task", "start": "waiting", "finish": "constrained"},
{"key": "task", "start": "constrained", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
]

plugin = MyPlugin(1, expected_transitions=expected_transitions)
plugin = MyPlugin(1, expected_notifications=expected_notifications)

await c.register_worker_plugin(plugin)
await c.submit(lambda x: x, 1, key="task", resources={"X": 1})
await async_wait_for(lambda: not w.task_state, timeout=10)


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_release_dep_called(c, s, w):
dsk = {
"dep": 1,
"task": (inc, "dep"),
}

expected_notifications = [
{"key": "dep", "start": "waiting", "finish": "ready"},
{"key": "dep", "start": "ready", "finish": "executing"},
{"key": "dep", "start": "executing", "finish": "memory"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "dep", "state": "memory"},
{"dep": "dep", "state": "memory"},
{"key": "task", "state": "memory"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)

await c.register_worker_plugin(plugin)
await c.get(dsk, "task", sync=False)
await async_wait_for(lambda: not (w.task_state or w.dep_state), timeout=10)


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
Expand Down
12 changes: 8 additions & 4 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,7 @@ def transition(self, key, finish, **kwargs):
self.task_state[key] = state or finish
if self.validate:
self.validate_key(key)
self._notify_transition(key, start, state or finish, **kwargs)
self._notify_plugins("transition", key, start, state or finish, **kwargs)

def transition_waiting_ready(self, key):
try:
Expand Down Expand Up @@ -2249,6 +2249,8 @@ def release_key(self, key, cause=None, reason=None, report=True):

if report and state in PROCESSING: # not finished
self.batched_stream.send({"op": "release", "key": key, "cause": cause})

self._notify_plugins("release_key", key, state, cause, reason, report)
except CommClosedError:
pass
except Exception as e:
Expand Down Expand Up @@ -2292,6 +2294,8 @@ def release_dep(self, dep, report=False):

if report and state == "memory":
self.batched_stream.send({"op": "release-worker-data", "keys": [dep]})

self._notify_plugins("release_dep", dep, state, report)
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -2833,11 +2837,11 @@ def get_call_stack(self, comm=None, keys=None):
result = {k: profile.call_stack(frame) for k, frame in frames.items()}
return result

def _notify_transition(self, key, start, finish, **kwargs):
def _notify_plugins(self, method_name, *args, **kwargs):
for name, plugin in self.plugins.items():
if hasattr(plugin, "transition"):
if hasattr(plugin, method_name):
try:
plugin.transition(key, start, finish, **kwargs)
getattr(plugin, method_name)(*args, **kwargs)
except Exception:
logger.info(
"Plugin '%s' failed with exception" % name, exc_info=True
Expand Down

0 comments on commit 5172678

Please sign in to comment.