diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index c9dddf1ec5..1465c1ddf9 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -2113,7 +2113,7 @@ def patch_updates(self): self.edge_source.patch({"visible": updates}) def __del__(self): - self.scheduler.remove_plugin(self.layout) + self.scheduler.remove_plugin(name=self.layout.name) class TaskGroupGraph(DashboardComponent): diff --git a/distributed/diagnostics/eventstream.py b/distributed/diagnostics/eventstream.py index 1f4c096f7e..4ace48282e 100644 --- a/distributed/diagnostics/eventstream.py +++ b/distributed/diagnostics/eventstream.py @@ -29,7 +29,7 @@ def swap_buffer(scheduler, es): def teardown(scheduler, es): - scheduler.remove_plugin(es) + scheduler.remove_plugin(name=es.name) async def eventstream(address, interval): diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index abdfc2cf96..22ac2cb564 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -240,6 +240,8 @@ def format_time(t): class AllProgress(SchedulerPlugin): """Keep track of all keys, grouped by key_split""" + name = "all-progress" + def __init__(self, scheduler): self.all = defaultdict(set) self.nbytes = defaultdict(lambda: 0) diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 57b0cb3839..a9c02846e1 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -1,4 +1,5 @@ import logging +from functools import partial from tlz import merge, valmap @@ -21,10 +22,10 @@ def counts(scheduler, allprogress): ) -def remove_plugin(*args, **kwargs): +def remove_plugin(**kwargs): # Wrapper function around `Scheduler.remove_plugin` to avoid raising a # `PicklingError` when using a cythonized scheduler - return Scheduler.remove_plugin(*args, **kwargs) + return Scheduler.remove_plugin(**kwargs) async def progress_stream(address, interval): @@ -53,7 +54,7 @@ async def progress_stream(address, interval): "setup": dumps_function(AllProgress), "function": dumps_function(counts), "interval": interval, - "teardown": dumps_function(remove_plugin), + "teardown": dumps_function(partial(remove_plugin, name=AllProgress.name)), } ) return comm diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 95b74f252e..f895f2ae85 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -67,7 +67,7 @@ def remove_worker(self, worker, scheduler): ] events[:] = [] - s.remove_plugin(plugin) + s.remove_plugin(name=plugin.name) a = await Worker(s.address) await a.close() assert events == [] @@ -104,7 +104,7 @@ async def remove_worker(self, worker, scheduler): } events[:] = [] - s.remove_plugin(plugin) + s.remove_plugin(name=plugin.name) async with Worker(s.address): pass assert events == [] @@ -116,8 +116,9 @@ async def start(self, scheduler): plugin = UnnamedPlugin() s.add_plugin(plugin) s.add_plugin(plugin, name="another") - with pytest.raises(ValueError) as excinfo: - s.remove_plugin(plugin) + with pytest.warns(FutureWarning, match="Removing scheduler plugins by value"): + with pytest.raises(ValueError) as excinfo: + s.remove_plugin(plugin) msg = str(excinfo.value) assert "Multiple instances of" in msg diff --git a/distributed/diagnostics/websocket.py b/distributed/diagnostics/websocket.py index 3796f77603..84aed34360 100644 --- a/distributed/diagnostics/websocket.py +++ b/distributed/diagnostics/websocket.py @@ -4,6 +4,9 @@ class WebsocketPlugin(SchedulerPlugin): + + name = "websocket" + def __init__(self, socket, scheduler): self.socket = socket self.scheduler = scheduler diff --git a/distributed/http/scheduler/info.py b/distributed/http/scheduler/info.py index 44197141e9..096180e195 100644 --- a/distributed/http/scheduler/info.py +++ b/distributed/http/scheduler/info.py @@ -204,7 +204,7 @@ def on_message(self, message): self.send("pong", {"timestamp": str(datetime.now())}) def on_close(self): - self.server.remove_plugin(self.plugin) + self.server.remove_plugin(name=self.plugin.name) routes = [ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f89e486356..a4011338c1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -7037,7 +7037,7 @@ def stop_task_metadata(self, comm=None, name=None): ) plugin = plugins[0] - self.remove_plugin(plugin) + self.remove_plugin(name=plugin.name) return {"metadata": plugin.metadata, "state": plugin.state} async def register_worker_plugin(self, comm, plugin, name=None): @@ -8150,6 +8150,8 @@ class WorkerStatusPlugin(SchedulerPlugin): scheduler. """ + name = "worker-status" + def __init__(self, scheduler, comm): self.bcomm = BatchedSend(interval="5ms") self.bcomm.start(comm) @@ -8164,13 +8166,13 @@ def add_worker(self, worker=None, **kwargs): try: self.bcomm.send(["add", {"workers": {worker: ident}}]) except CommClosedError: - self.scheduler.remove_plugin(self) + self.scheduler.remove_plugin(name=self.name) def remove_worker(self, worker=None, **kwargs): try: self.bcomm.send(["remove", worker]) except CommClosedError: - self.scheduler.remove_plugin(self) + self.scheduler.remove_plugin(name=self.name) def teardown(self): self.bcomm.close()