From 4582d6ca4852ef128d627512268f38dca33b6f17 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 12:18:46 -0700 Subject: [PATCH 01/12] Select queued tasks in stimuli, not transitions TODO: clean up implementation. This feels ugly, mostly because the stimuli are ugly and inconsistent. --- distributed/scheduler.py | 58 +++++++++++++++++---------- distributed/tests/test_scheduler.py | 61 +++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dbaa7cfa1c..ef6eaa9995 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2317,18 +2317,13 @@ def transition_processing_memory( ############################ # Update State Information # ############################ - recommendations: Recs = {} - client_msgs: Msgs = {} - if nbytes is not None: ts.set_nbytes(nbytes) - # NOTE: recommendations for queued tasks are added first, so they'll be popped - # last, allowing higher-priority downstream tasks to be transitioned first. - # FIXME: this would be incorrect if queued tasks are user-annotated as higher - # priority. - self._exit_processing_common(ts, recommendations) + self._exit_processing_common(ts) + recommendations: Recs = {} + client_msgs: Msgs = {} self._add_to_memory( ts, ws, recommendations, client_msgs, type=type, typename=typename ) @@ -2507,7 +2502,7 @@ def transition_processing_released(self, key: str, stimulus_id: str) -> RecsMsgs assert not ts.waiting_on assert ts.state == "processing" - ws = self._exit_processing_common(ts, recommendations) + ws = self._exit_processing_common(ts) if ws: worker_msgs[ws.address] = [ { @@ -2574,7 +2569,7 @@ def transition_processing_erred( assert ws ws.actors.remove(ts) - self._exit_processing_common(ts, recommendations) + self._exit_processing_common(ts) ts.erred_on.add(worker) if exception is not None: @@ -3108,9 +3103,7 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs: return {ws.address: [self._task_to_msg(ts)]} - def _exit_processing_common( - self, ts: TaskState, recommendations: Recs - ) -> WorkerState | None: + def _exit_processing_common(self, ts: TaskState) -> WorkerState | None: """Remove *ts* from the set of processing tasks. Returns @@ -3133,11 +3126,6 @@ def _exit_processing_common( self.check_idle_saturated(ws) self.release_resources(ts, ws) - for qts in self._next_queued_tasks_for_worker(ws): - if self.validate: - assert qts.key not in recommendations, recommendations[qts.key] - recommendations[qts.key] = "processing" - return ws def _next_queued_tasks_for_worker(self, ws: WorkerState) -> Iterator[TaskState]: @@ -4944,6 +4932,17 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None): recommendations: Recs = {} self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) + potential_open_workers = { + ws for k in recommendations.keys() if (ws := self.tasks[k].processing_on) + } + + self.transitions(recommendations, stimulus_id) + + recommendations: Recs = {} + for ws in potential_open_workers: + for qts in self._next_queued_tasks_for_worker(ws): + recommendations[qts.key] = "processing" + self.transitions(recommendations, stimulus_id) def client_heartbeat(self, client=None): @@ -5291,12 +5290,31 @@ def handle_task_finished( recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) + ws = self.workers[worker] + recommendations = { + qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws) + } + if self.validate: + assert len(recommendations) <= 1, (ws, recommendations) + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) + self.send_all(client_msgs, worker_msgs) - def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None: - r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg) + def handle_task_erred(self, key: str, worker: str, stimulus_id: str, **msg) -> None: + r: tuple = self.stimulus_task_erred( + key=key, worker=worker, stimulus_id=stimulus_id, **msg + ) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) + + ws = self.workers[worker] + recommendations = { + qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws) + } + if self.validate: + assert len(recommendations) <= 1, (ws, recommendations) + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) + self.send_all(client_msgs, worker_msgs) def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7dc53b822c..fff169b59b 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import itertools import json import logging import math @@ -86,14 +87,14 @@ async def test_administration(s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) async def test_respect_data_in_memory(c, s, a): - x = delayed(inc)(1) - y = delayed(inc)(x) + x = delayed(inc)(1, dask_key_name="x") + y = delayed(inc)(x, dask_key_name="y") f = c.persist(y) await wait([f]) assert s.tasks[y.key].who_has == {s.workers[a.address]} - z = delayed(operator.add)(x, y) + z = delayed(operator.add)(x, y, dask_key_name="z") f2 = c.persist(z) while f2.key not in s.tasks or not s.tasks[f2.key]: assert s.tasks[y.key].who_has @@ -371,6 +372,58 @@ def __del__(self): assert max(Refcount.log) <= s.total_nthreads +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_forget_tasks_while_processing(c, s, a, b): + events = [Event() for _ in range(10)] + + futures = c.map(Event.wait, events) + await events[0].set() + await futures[0] + await c.close() + assert not s.tasks + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + config={"distributed.scheduler.worker-saturation": 1.0}, +) +async def test_queued_release(c, s, a): + event = Event() + + rootish_threshold = s.total_nthreads * 2 + 1 + + first_batch = c.map( + lambda i: event.wait(), + range(rootish_threshold), + key=[f"first-{i}" for i in range(rootish_threshold)], + ) + await async_wait_for(lambda: s.queued, 5) + + second_batch = c.map( + lambda i: event.wait(), + range(rootish_threshold), + key=[f"second-{i}" for i in range(rootish_threshold)], + fifo_timeout=0, + ) + await async_wait_for(lambda: second_batch[0].key in s.tasks, 5) + + # All of the second batch should be queued after the first batch + assert [ts.key for ts in s.queued.sorted()] == [ + f.key for f in itertools.chain(first_batch[1:], second_batch) + ] + + # Cancel the first batch + del first_batch + await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5) + + # Second batch should move up the queue and start processing + assert len(s.queued) == len(second_batch) - 1, list(s.queued.sorted()) + + await event.set() + await c.gather(second_batch) + + @gen_cluster( client=True, nthreads=[("", 2)] * 2, @@ -4237,7 +4290,7 @@ def assert_rootish(): await asyncio.sleep(0.005) assert_rootish() if rootish: - assert all(s.tasks[k] in s.queued for k in keys) + assert all(s.tasks[k] in s.queued for k in keys), [s.tasks[k] for k in keys] await block.set() # At this point we need/want to wait for the task-finished message to # arrive on the scheduler. There is no proper hook to wait, therefore we From 93f3c9e81b9773818d7f1c60bd2252c9391037c3 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 13:05:20 -0700 Subject: [PATCH 02/12] update client cancel test for multiple workers --- distributed/tests/test_scheduler.py | 62 ++++++++++++++++------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index fff169b59b..8343917f03 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -385,43 +385,49 @@ async def test_forget_tasks_while_processing(c, s, a, b): @gen_cluster( client=True, - nthreads=[("", 1)], + nthreads=[("", 1)] * 3, config={"distributed.scheduler.worker-saturation": 1.0}, ) -async def test_queued_release(c, s, a): - event = Event() +async def test_queued_release_multiple_workers(c, s, *workers): + async with Client(s.address, asynchronous=True) as c2: + event = Event(client=c2) - rootish_threshold = s.total_nthreads * 2 + 1 + rootish_threshold = s.total_nthreads * 2 + 1 - first_batch = c.map( - lambda i: event.wait(), - range(rootish_threshold), - key=[f"first-{i}" for i in range(rootish_threshold)], - ) - await async_wait_for(lambda: s.queued, 5) + first_batch = c.map( + lambda i: event.wait(), + range(rootish_threshold), + key=[f"first-{i}" for i in range(rootish_threshold)], + ) + await async_wait_for(lambda: s.queued, 5) - second_batch = c.map( - lambda i: event.wait(), - range(rootish_threshold), - key=[f"second-{i}" for i in range(rootish_threshold)], - fifo_timeout=0, - ) - await async_wait_for(lambda: second_batch[0].key in s.tasks, 5) + second_batch = c2.map( + lambda i: event.wait(), + range(rootish_threshold), + key=[f"second-{i}" for i in range(rootish_threshold)], + fifo_timeout=0, + ) + await async_wait_for(lambda: second_batch[0].key in s.tasks, 5) - # All of the second batch should be queued after the first batch - assert [ts.key for ts in s.queued.sorted()] == [ - f.key for f in itertools.chain(first_batch[1:], second_batch) - ] + # All of the second batch should be queued after the first batch + assert [ts.key for ts in s.queued.sorted()] == [ + f.key for f in itertools.chain(first_batch[3:], second_batch) + ] - # Cancel the first batch - del first_batch - await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5) + # Cancel the first batch. + # Use `Client.close` instead of `del first_batch` because deleting futures sends cancellation + # messages one at a time. We're testing here that when multiple workers have open slots, we don't + # recommend the same queued tasks for every worker, so we need a bulk cancellation operation. + await c.close() + del c, first_batch - # Second batch should move up the queue and start processing - assert len(s.queued) == len(second_batch) - 1, list(s.queued.sorted()) + await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5) - await event.set() - await c.gather(second_batch) + # Second batch should move up the queue and start processing + assert len(s.queued) == len(second_batch) - 3, list(s.queued.sorted()) + + await event.set() + await c2.gather(second_batch) @gen_cluster( From 55021bcba7de548abadd65cdee1f98df9704882f Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 13:07:35 -0700 Subject: [PATCH 03/12] WIP `stimulus_queue_slots_maybe_opened` Bit tidier. Broken for `client_releases_keys` (which is also still untidy). --- distributed/scheduler.py | 66 ++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ef6eaa9995..1567af95dd 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3139,8 +3139,9 @@ def _next_queued_tasks_for_worker(self, ws: WorkerState) -> Iterator[TaskState]: for qts in self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION)): if self.validate: assert qts.state == "queued", qts.state - assert not qts.processing_on - assert not qts.waiting_on + assert not qts.processing_on, (qts, qts.processing_on) + assert not qts.waiting_on, (qts, qts.processing_on) + assert qts.who_wants or qts.waiters, qts yield qts def _add_to_memory( @@ -4608,6 +4609,24 @@ def update_graph( # TODO: balance workers + def stimulus_queue_slots_maybe_opened( + self, *workers: WorkerState, stimulus_id: str + ) -> None: + """Respond to an event which may have opened a spot on the threadpool of a worker + + Selects the appropriate number of tasks from the front of the queue (potentially + 0), and transitions them to ``processing``. + """ + # FIXME we'll pick the same tasks for each worker!!! (because this doesn't pop off the queue) + recommendations: Recs = { + qts.key: "processing" + for ws in workers + for qts in self._next_queued_tasks_for_worker(ws) + } + # TODO we already know the worker, pass `ws` as an argument to + # `transition_queued_procssing` and bypass `decide_worker` + self.transitions(recommendations, stimulus_id) + def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s, %s", key, worker) @@ -4938,12 +4957,9 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None): self.transitions(recommendations, stimulus_id) - recommendations: Recs = {} - for ws in potential_open_workers: - for qts in self._next_queued_tasks_for_worker(ws): - recommendations[qts.key] = "processing" - - self.transitions(recommendations, stimulus_id) + self.stimulus_queue_slots_maybe_opened( + *potential_open_workers, stimulus_id=stimulus_id + ) def client_heartbeat(self, client=None): """Handle heartbeats from Client""" @@ -5289,34 +5305,24 @@ def handle_task_finished( ) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) - - ws = self.workers[worker] - recommendations = { - qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws) - } - if self.validate: - assert len(recommendations) <= 1, (ws, recommendations) - self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) - self.send_all(client_msgs, worker_msgs) + self.stimulus_queue_slots_maybe_opened( + self.workers[worker], stimulus_id=stimulus_id + ) + def handle_task_erred(self, key: str, worker: str, stimulus_id: str, **msg) -> None: r: tuple = self.stimulus_task_erred( key=key, worker=worker, stimulus_id=stimulus_id, **msg ) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) - - ws = self.workers[worker] - recommendations = { - qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws) - } - if self.validate: - assert len(recommendations) <= 1, (ws, recommendations) - self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) - self.send_all(client_msgs, worker_msgs) + self.stimulus_queue_slots_maybe_opened( + self.workers[worker], stimulus_id=stimulus_id + ) + def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: ts = self.tasks.get(key) ws = self.workers.get(worker) @@ -5358,13 +5364,7 @@ def handle_long_running( ws.add_to_long_running(ts) self.check_idle_saturated(ws) - recommendations: Recs = { - qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws) - } - if self.validate: - assert len(recommendations) <= 1, (ws, recommendations) - - self.transitions(recommendations, stimulus_id) + self.stimulus_queue_slots_maybe_opened(ws, stimulus_id=stimulus_id) def handle_worker_status_change( self, status: str | Status, worker: str | WorkerState, stimulus_id: str From 013588211fcd59e1eed69343d20e37efa2dad0e4 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 13:11:31 -0700 Subject: [PATCH 04/12] don't take workers, just use `idle_task_count` set --- distributed/scheduler.py | 41 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1567af95dd..5d43367d3c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4609,23 +4609,22 @@ def update_graph( # TODO: balance workers - def stimulus_queue_slots_maybe_opened( - self, *workers: WorkerState, stimulus_id: str - ) -> None: + def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened a spot on the threadpool of a worker Selects the appropriate number of tasks from the front of the queue (potentially 0), and transitions them to ``processing``. """ - # FIXME we'll pick the same tasks for each worker!!! (because this doesn't pop off the queue) - recommendations: Recs = { - qts.key: "processing" - for ws in workers - for qts in self._next_queued_tasks_for_worker(ws) - } - # TODO we already know the worker, pass `ws` as an argument to - # `transition_queued_procssing` and bypass `decide_worker` - self.transitions(recommendations, stimulus_id) + if self.idle_task_count: + # FIXME we'll pick the same tasks for each worker!!! (because this doesn't pop off the queue) + recommendations: Recs = { + qts.key: "processing" + for ws in self.idle_task_count + for qts in self._next_queued_tasks_for_worker(ws) + } + # TODO we already know the worker, pass `ws` as an argument to + # `transition_queued_procssing` and bypass `decide_worker` + self.transitions(recommendations, stimulus_id) def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" @@ -4951,15 +4950,9 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None): recommendations: Recs = {} self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) - potential_open_workers = { - ws for k in recommendations.keys() if (ws := self.tasks[k].processing_on) - } - self.transitions(recommendations, stimulus_id) - self.stimulus_queue_slots_maybe_opened( - *potential_open_workers, stimulus_id=stimulus_id - ) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) def client_heartbeat(self, client=None): """Handle heartbeats from Client""" @@ -5307,9 +5300,7 @@ def handle_task_finished( self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - self.stimulus_queue_slots_maybe_opened( - self.workers[worker], stimulus_id=stimulus_id - ) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) def handle_task_erred(self, key: str, worker: str, stimulus_id: str, **msg) -> None: r: tuple = self.stimulus_task_erred( @@ -5319,9 +5310,7 @@ def handle_task_erred(self, key: str, worker: str, stimulus_id: str, **msg) -> N self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - self.stimulus_queue_slots_maybe_opened( - self.workers[worker], stimulus_id=stimulus_id - ) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: ts = self.tasks.get(key) @@ -5364,7 +5353,7 @@ def handle_long_running( ws.add_to_long_running(ts) self.check_idle_saturated(ws) - self.stimulus_queue_slots_maybe_opened(ws, stimulus_id=stimulus_id) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) def handle_worker_status_change( self, status: str | Status, worker: str | WorkerState, stimulus_id: str From e98f416e63f3f81703cc4e42eaabd4c0f9cc42cb Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 13:24:11 -0700 Subject: [PATCH 05/12] `peekn(1)` should not error on empty `HeapSet` `peen(2)` would give you an empty iterator, so the behavior is inconsistent if you specify 1 --- distributed/collections.py | 2 +- distributed/tests/test_collections.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/distributed/collections.py b/distributed/collections.py index 4b67807ed4..4ce0fcefa4 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -121,7 +121,7 @@ def peekn(self, n: int) -> Iterator[T]: """Iterate over the n smallest elements without removing them. This is O(1) for n == 1; O(n*logn) otherwise. """ - if n <= 0: + if n <= 0 or not self: return # empty iterator if n == 1: yield self.peek() diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 066cf147a3..6db1811072 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -148,8 +148,14 @@ def test_heapset(): assert list(heap.peekn(1)) == [cx] heap.remove(cw) assert list(heap.peekn(1)) == [cx] + heap.remove(cx) + assert list(heap.peekn(-1)) == [] + assert list(heap.peekn(0)) == [] + assert list(heap.peekn(1)) == [] + assert list(heap.peekn(2)) == [] # Test resilience to failure in key() + heap.add(cx) bad_key = C("bad_key", 0) del bad_key.i with pytest.raises(AttributeError): From 231dbcbfc07930a1a23023f0a985fd7ca2ed304d Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 13:52:05 -0700 Subject: [PATCH 06/12] use `stimulus_queue_slots_maybe_opened` everywhere now fully split from `bulk_schedule`. this is feeling cleaner. --- distributed/scheduler.py | 90 ++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5d43367d3c..ba5c4a5fa9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3054,24 +3054,20 @@ def remove_all_replicas(self, ts: TaskState) -> None: self.replicated_tasks.remove(ts) ts.who_has.clear() - def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs: - """Send ``queued`` or ``no-worker`` tasks to ``processing`` that this worker can - handle. + def bulk_schedule_unrunnable_after_adding_worker(self, ws: WorkerState) -> Recs: + """Send ``no-worker`` tasks to ``processing`` that this worker can handle. Returns priority-ordered recommendations. """ - maybe_runnable = list(self._next_queued_tasks_for_worker(ws))[::-1] - - # Schedule any restricted tasks onto the new worker, if the worker can run them + runnable: list[TaskState] = [] for ts in self.unrunnable: valid = self.valid_workers(ts) if valid is None or ws in valid: - maybe_runnable.append(ts) + runnable.append(ts) # Recommendations are processed LIFO, hence the reversed order - maybe_runnable.sort(key=operator.attrgetter("priority"), reverse=True) - # Note not all will necessarily be run; transition->processing will decide - return {ts.key: "processing" for ts in maybe_runnable} + runnable.sort(key=operator.attrgetter("priority"), reverse=True) + return {ts.key: "processing" for ts in runnable} def _validate_ready(self, ts: TaskState) -> None: """Validation for ready states (processing, queued, no-worker)""" @@ -3128,22 +3124,6 @@ def _exit_processing_common(self, ts: TaskState) -> WorkerState | None: return ws - def _next_queued_tasks_for_worker(self, ws: WorkerState) -> Iterator[TaskState]: - """Queued tasks to run, in priority order, on all open slots on a worker""" - if not self.queued or ws.status != Status.running: - return - - # NOTE: this is called most frequently because a single task has completed, so - # there are <= 1 task slots available on the worker. - # `peekn` has fast paths for the cases N<=0 and N==1. - for qts in self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION)): - if self.validate: - assert qts.state == "queued", qts.state - assert not qts.processing_on, (qts, qts.processing_on) - assert not qts.waiting_on, (qts, qts.processing_on) - assert qts.who_wants or qts.waiters, qts - yield qts - def _add_to_memory( self, ts: TaskState, @@ -4229,7 +4209,10 @@ async def add_worker( logger.exception(e) if ws.status == Status.running: - self.transitions(self.bulk_schedule_after_adding_worker(ws), stimulus_id) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) logger.info("Register worker %s", ws) @@ -4612,19 +4595,38 @@ def update_graph( def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened a spot on the threadpool of a worker - Selects the appropriate number of tasks from the front of the queue (potentially - 0), and transitions them to ``processing``. + Selects the appropriate number of tasks from the front of the queue according to + the total number of task slots available on workers (potentially 0), and + transitions them to ``processing``. + + Notes + ----- + Other transitions related to this stimulus should be fully processed beforehand, + so any tasks that became runnable are already in ``processing``. Otherwise, + overproduction can occur if queued tasks get scheduled before downstream tasks. + + Must be called after `check_idle_saturated`; i.e. `idle_task_count` must be up + to date. """ - if self.idle_task_count: - # FIXME we'll pick the same tasks for each worker!!! (because this doesn't pop off the queue) - recommendations: Recs = { - qts.key: "processing" - for ws in self.idle_task_count - for qts in self._next_queued_tasks_for_worker(ws) - } - # TODO we already know the worker, pass `ws` as an argument to - # `transition_queued_procssing` and bypass `decide_worker` - self.transitions(recommendations, stimulus_id) + if not self.queued: + return + slots_available = sum( + _task_slots_available(ws, self.WORKER_SATURATION) + for ws in self.idle_task_count + ) + if slots_available == 0: + return + + recommendations: Recs = {} + for qts in self.queued.peekn(slots_available): + if self.validate: + assert qts.state == "queued", qts.state + assert not qts.processing_on, (qts, qts.processing_on) + assert not qts.waiting_on, (qts, qts.processing_on) + assert qts.who_wants or qts.waiters, qts + recommendations[qts.key] = "processing" + + self.transitions(recommendations, stimulus_id) def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" @@ -5379,12 +5381,10 @@ def handle_worker_status_change( if ws.status == Status.running: self.running.add(ws) self.check_idle_saturated(ws) - recs = self.bulk_schedule_after_adding_worker(ws) - if recs: - client_msgs: Msgs = {} - worker_msgs: Msgs = {} - self._transitions(recs, client_msgs, worker_msgs, stimulus_id) - self.send_all(client_msgs, worker_msgs) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) self.idle.pop(ws.address, None) From afc4ada14347eaea7901d19454c34ee4bb56747e Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 15:35:41 -0700 Subject: [PATCH 07/12] remove unnecessary diff --- distributed/scheduler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ba5c4a5fa9..0508f04196 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5304,10 +5304,8 @@ def handle_task_finished( self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) - def handle_task_erred(self, key: str, worker: str, stimulus_id: str, **msg) -> None: - r: tuple = self.stimulus_task_erred( - key=key, worker=worker, stimulus_id=stimulus_id, **msg - ) + def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None: + r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) From 89da28ecf9e73a12429a9d482d002f4f148e678d Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 15:36:15 -0700 Subject: [PATCH 08/12] less hardcoded test --- distributed/tests/test_scheduler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 8343917f03..e797eaa6ea 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -411,7 +411,8 @@ async def test_queued_release_multiple_workers(c, s, *workers): # All of the second batch should be queued after the first batch assert [ts.key for ts in s.queued.sorted()] == [ - f.key for f in itertools.chain(first_batch[3:], second_batch) + f.key + for f in itertools.chain(first_batch[s.total_nthreads :], second_batch) ] # Cancel the first batch. @@ -424,7 +425,9 @@ async def test_queued_release_multiple_workers(c, s, *workers): await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5) # Second batch should move up the queue and start processing - assert len(s.queued) == len(second_batch) - 3, list(s.queued.sorted()) + assert len(s.queued) == len(second_batch) - s.total_nthreads, list( + s.queued.sorted() + ) await event.set() await c2.gather(second_batch) From 79525791d3d710533571cd49aff450539c949ffe Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 13 Dec 2022 15:41:36 -0700 Subject: [PATCH 09/12] pluralize docstring --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0508f04196..089da240f1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4593,7 +4593,7 @@ def update_graph( # TODO: balance workers def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: - """Respond to an event which may have opened a spot on the threadpool of a worker + """Respond to an event which may have opened spots on worker threadpools Selects the appropriate number of tasks from the front of the queue according to the total number of task slots available on workers (potentially 0), and From 31975e5b44552688d586c12c694d9a4716378cf1 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 14 Dec 2022 09:09:55 -0700 Subject: [PATCH 10/12] add `test_restart_while_processing` fails on main --- distributed/tests/test_scheduler.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e797eaa6ea..71ec122f62 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -383,6 +383,17 @@ async def test_forget_tasks_while_processing(c, s, a, b): assert not s.tasks +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_restart_while_processing(c, s, a, b): + events = [Event() for _ in range(10)] + + futures = c.map(Event.wait, events) + await events[0].set() + await futures[0] + await c.restart() + assert not s.tasks + + @gen_cluster( client=True, nthreads=[("", 1)] * 3, From 6de7c4538561988a7fff49fb25f23e525fe35431 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 14 Dec 2022 09:35:34 -0700 Subject: [PATCH 11/12] workers must be Nannies to restart --- distributed/tests/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 71ec122f62..8711e96fed 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -383,7 +383,7 @@ async def test_forget_tasks_while_processing(c, s, a, b): assert not s.tasks -@gen_cluster(client=True, nthreads=[("", 1)]) +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_restart_while_processing(c, s, a, b): events = [Event() for _ in range(10)] From 1450522638967af0265a5baa01f065d762f0698c Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 14 Dec 2022 09:36:55 -0700 Subject: [PATCH 12/12] slow --- distributed/tests/test_scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 8711e96fed..8a5374e515 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -383,6 +383,7 @@ async def test_forget_tasks_while_processing(c, s, a, b): assert not s.tasks +@pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_restart_while_processing(c, s, a, b): events = [Event() for _ in range(10)] @@ -390,6 +391,7 @@ async def test_restart_while_processing(c, s, a, b): futures = c.map(Event.wait, events) await events[0].set() await futures[0] + # TODO slow because worker waits a while for the task to finish await c.restart() assert not s.tasks