diff --git a/distributed/client.py b/distributed/client.py index d04ba8a679..3e6b1e7367 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4093,10 +4093,11 @@ def register_worker_plugin(self, plugin=None, name=None): on all currently connected workers. It will also be run on any worker that connects in the future. - The plugin may include methods ``setup``, ``teardown``, and - ``transition``. See the ``dask.distributed.WorkerPlugin`` class or the - examples below for the interface and docstrings. It must be - serializable with the pickle or cloudpickle modules. + The plugin may include methods ``setup``, ``teardown``, ``transition``, + ``release_key``, and ``release_dep``. See the + ``dask.distributed.WorkerPlugin`` class or the examples below for the + interface and docstrings. It must be serializable with the pickle or + cloudpickle modules. If the plugin has a ``name`` attribute, or if the ``name=`` keyword is used then that will control idempotency. A a plugin with that name has @@ -4124,6 +4125,10 @@ def register_worker_plugin(self, plugin=None, name=None): ... pass ... def transition(self, key: str, start: str, finish: str, **kwargs): ... pass + ... def release_key(self, key: str, state: str, cause: Optional[str], reason: None, report: bool): + ... pass + ... def release_dep(self, dep: str, state: str, report: bool): + ... pass >>> plugin = MyPlugin(1, 2, 3) >>> client.register_worker_plugin(plugin) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 12e7ad6ec3..fb3b2afe20 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -90,7 +90,8 @@ class WorkerPlugin: """ Interface to extend the Worker A worker plugin enables custom code to run at different stages of the Workers' - lifecycle: at setup, during task state transitions and at teardown. + lifecycle: at setup, during task state transitions, when a task or dependency + is released, and at teardown. A plugin enables custom code to run at each of step of a Workers's life. Whenever such an event happens, the corresponding method on this class will be called. Note that the @@ -147,3 +148,35 @@ def transition(self, key, start, finish, **kwargs): Final state of the transition. kwargs: More options passed when transitioning """ + + def release_key(self, key, state, cause, reason, report): + """ + Called when the worker releases a task. + + Parameters + ---------- + key: string + state: string + State of the released task. + One of waiting, ready, executing, long-running, memory, error. + cause: string or None + Additional information on what triggered the release of the task. + reason: None + Not used. + report: bool + Whether the worker should report the released task to the scheduler. + """ + + def release_dep(self, dep, state, report): + """ + Called when the worker releases a dependency. + + Parameters + ---------- + dep: string + state: string + State of the released dependency. + One of waiting, flight, memory. + report: bool + Whether the worker should report the released dependency to the scheduler. + """ diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 038924853a..c9cdbed784 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -1,15 +1,15 @@ import pytest from distributed import Worker, WorkerPlugin -from distributed.utils_test import gen_cluster +from distributed.utils_test import async_wait_for, gen_cluster, inc class MyPlugin(WorkerPlugin): name = "MyPlugin" - def __init__(self, data, expected_transitions=None): + def __init__(self, data, expected_notifications=None): self.data = data - self.expected_transitions = expected_transitions + self.expected_notifications = expected_notifications def setup(self, worker): assert isinstance(worker, Worker) @@ -17,20 +17,32 @@ def setup(self, worker): self.worker._my_plugin_status = "setup" self.worker._my_plugin_data = self.data - self.observed_transitions = [] + self.observed_notifications = [] def teardown(self, worker): self.worker._my_plugin_status = "teardown" - if self.expected_transitions is not None: - assert len(self.observed_transitions) == len(self.expected_transitions) + if self.expected_notifications is not None: + assert len(self.observed_notifications) == len(self.expected_notifications) for expected, real in zip( - self.expected_transitions, self.observed_transitions + self.expected_notifications, self.observed_notifications ): assert expected == real def transition(self, key, start, finish, **kwargs): - self.observed_transitions.append((key, start, finish)) + self.observed_notifications.append( + {"key": key, "start": start, "finish": finish,} + ) + + def release_key(self, key, state, cause, reason, report): + self.observed_notifications.append( + {"key": key, "state": state,} + ) + + def release_dep(self, dep, state, report): + self.observed_notifications.append( + {"dep": dep, "state": state,} + ) @gen_cluster(client=True, nthreads=[]) @@ -54,16 +66,18 @@ async def test_create_on_construction(c, s, a, b): @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) async def test_normal_task_transitions_called(c, s, w): - expected_transitions = [ - ("task", "waiting", "ready"), - ("task", "ready", "executing"), - ("task", "executing", "memory"), + expected_notifications = [ + {"key": "task", "start": "waiting", "finish": "ready"}, + {"key": "task", "start": "ready", "finish": "executing"}, + {"key": "task", "start": "executing", "finish": "memory"}, + {"key": "task", "state": "memory"}, ] - plugin = MyPlugin(1, expected_transitions=expected_transitions) + plugin = MyPlugin(1, expected_notifications=expected_notifications) await c.register_worker_plugin(plugin) await c.submit(lambda x: x, 1, key="task") + await async_wait_for(lambda: not w.task_state, timeout=10) @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) @@ -71,13 +85,13 @@ async def test_failing_task_transitions_called(c, s, w): def failing(x): raise Exception() - expected_transitions = [ - ("task", "waiting", "ready"), - ("task", "ready", "executing"), - ("task", "executing", "error"), + expected_notifications = [ + {"key": "task", "start": "waiting", "finish": "ready"}, + {"key": "task", "start": "ready", "finish": "executing"}, + {"key": "task", "start": "executing", "finish": "error"}, ] - plugin = MyPlugin(1, expected_transitions=expected_transitions) + plugin = MyPlugin(1, expected_notifications=expected_notifications) await c.register_worker_plugin(plugin) @@ -89,16 +103,44 @@ def failing(x): nthreads=[("127.0.0.1", 1)], client=True, worker_kwargs={"resources": {"X": 1}}, ) async def test_superseding_task_transitions_called(c, s, w): - expected_transitions = [ - ("task", "waiting", "constrained"), - ("task", "constrained", "executing"), - ("task", "executing", "memory"), + expected_notifications = [ + {"key": "task", "start": "waiting", "finish": "constrained"}, + {"key": "task", "start": "constrained", "finish": "executing"}, + {"key": "task", "start": "executing", "finish": "memory"}, + {"key": "task", "state": "memory"}, ] - plugin = MyPlugin(1, expected_transitions=expected_transitions) + plugin = MyPlugin(1, expected_notifications=expected_notifications) await c.register_worker_plugin(plugin) await c.submit(lambda x: x, 1, key="task", resources={"X": 1}) + await async_wait_for(lambda: not w.task_state, timeout=10) + + +@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) +async def test_release_dep_called(c, s, w): + dsk = { + "dep": 1, + "task": (inc, "dep"), + } + + expected_notifications = [ + {"key": "dep", "start": "waiting", "finish": "ready"}, + {"key": "dep", "start": "ready", "finish": "executing"}, + {"key": "dep", "start": "executing", "finish": "memory"}, + {"key": "task", "start": "waiting", "finish": "ready"}, + {"key": "task", "start": "ready", "finish": "executing"}, + {"key": "task", "start": "executing", "finish": "memory"}, + {"key": "dep", "state": "memory"}, + {"dep": "dep", "state": "memory"}, + {"key": "task", "state": "memory"}, + ] + + plugin = MyPlugin(1, expected_notifications=expected_notifications) + + await c.register_worker_plugin(plugin) + await c.get(dsk, "task", sync=False) + await async_wait_for(lambda: not (w.task_state or w.dep_state), timeout=10) @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) diff --git a/distributed/worker.py b/distributed/worker.py index e51cd06473..59cd285d49 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1601,7 +1601,7 @@ def transition(self, key, finish, **kwargs): self.task_state[key] = state or finish if self.validate: self.validate_key(key) - self._notify_transition(key, start, state or finish, **kwargs) + self._notify_plugins("transition", key, start, state or finish, **kwargs) def transition_waiting_ready(self, key): try: @@ -2249,6 +2249,8 @@ def release_key(self, key, cause=None, reason=None, report=True): if report and state in PROCESSING: # not finished self.batched_stream.send({"op": "release", "key": key, "cause": cause}) + + self._notify_plugins("release_key", key, state, cause, reason, report) except CommClosedError: pass except Exception as e: @@ -2292,6 +2294,8 @@ def release_dep(self, dep, report=False): if report and state == "memory": self.batched_stream.send({"op": "release-worker-data", "keys": [dep]}) + + self._notify_plugins("release_dep", dep, state, report) except Exception as e: logger.exception(e) if LOG_PDB: @@ -2833,11 +2837,11 @@ def get_call_stack(self, comm=None, keys=None): result = {k: profile.call_stack(frame) for k, frame in frames.items()} return result - def _notify_transition(self, key, start, finish, **kwargs): + def _notify_plugins(self, method_name, *args, **kwargs): for name, plugin in self.plugins.items(): - if hasattr(plugin, "transition"): + if hasattr(plugin, method_name): try: - plugin.transition(key, start, finish, **kwargs) + getattr(plugin, method_name)(*args, **kwargs) except Exception: logger.info( "Plugin '%s' failed with exception" % name, exc_info=True