diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a8571e8b62..3bcae43b9a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -950,6 +950,9 @@ class TaskGroup: _start: double _stop: double _all_durations: object + _last_worker: WorkerState + _last_worker_tasks_left: int # TODO Py_ssize_t? + _last_worker_priority: tuple # TODO remove (debugging only) def __init__(self, name: str): self._name = name @@ -964,6 +967,9 @@ def __init__(self, name: str): self._start = 0.0 self._stop = 0.0 self._all_durations = defaultdict(float) + self._last_worker = None + self._last_worker_tasks_left = 0 + self._last_worker_priority = () @property def name(self): @@ -1009,6 +1015,26 @@ def start(self): def stop(self): return self._stop + @property + def last_worker(self): + return self._last_worker + + @property + def last_worker_tasks_left(self): + return self._last_worker_tasks_left + + @last_worker_tasks_left.setter + def last_worker_tasks_left(self, n: int): + self._last_worker_tasks_left = n + + @property + def last_worker_priority(self): + return self._last_worker_priority + + @last_worker_priority.setter + def last_worker_priority(self, x: tuple): + self._last_worker_priority = x + @ccall def add(self, o): ts: TaskState = o @@ -2337,14 +2363,20 @@ def decide_worker(self, ts: TaskState) -> WorkerState: ts.state = "no-worker" return ws - if ts._dependencies or valid_workers is not None: + if ( + ts._dependencies + or valid_workers is not None + or ts._group._last_worker is not None + ): ws = decide_worker( ts, self._workers_dv.values(), valid_workers, partial(self.worker_objective, ts), + self._total_nthreads, ) else: + # Fastpath when there are no related tasks or restrictions worker_pool = self._idle or self._workers worker_pool_dv = cast(dict, worker_pool) wp_vals = worker_pool.values() @@ -2366,6 +2398,15 @@ def decide_worker(self, ts: TaskState) -> WorkerState: else: # dumb but fast in large case ws = wp_vals[self._n_tasks % n_workers] + ts._group._last_worker = ws + group_tasks_per_thread = ( + len(ts._group) / self._total_nthreads if self._total_nthreads > 0 else 0 + ) + ts._group._last_worker_tasks_left = ( + math.floor(group_tasks_per_thread * ws._nthreads) - 1 + ) + ts._group._last_worker_priority = ts._priority + if self._validate: assert ws is None or isinstance(ws, WorkerState), ( type(ws), @@ -4671,6 +4712,9 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): recommendations[ts._key] = "released" else: # pure data recommendations[ts._key] = "forgotten" + if ts._group._last_worker is ws: + ts._group._last_worker = None + ts._group._last_worker_tasks_left = 0 ws._has_what.clear() self.transitions(recommendations) @@ -6244,8 +6288,9 @@ async def retire_workers( logger.info("Retire workers %s", workers) # Keys orphaned by retiring those workers - keys = {k for w in workers for k in w.has_what} - keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} + tasks = {ts for w in workers for ts in w.has_what} + keys = {ts._key for ts in tasks if ts._who_has.issubset(workers)} + groups = {ts._group for ts in tasks} if keys: other_workers = set(parent._workers_dv.values()) - workers @@ -6260,6 +6305,11 @@ async def retire_workers( lock=False, ) + for group in groups: + if group._last_worker in workers: + group._last_worker = None + group._last_worker_tasks_left = 0 + worker_keys = {ws._address: ws.identity() for ws in workers} if close_workers: await asyncio.gather( @@ -7471,11 +7521,52 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): @cfunc @exceptval(check=False) def decide_worker( - ts: TaskState, all_workers, valid_workers: set, objective + ts: TaskState, + all_workers, + valid_workers: set, + objective, + total_nthreads: Py_ssize_t, ) -> WorkerState: - """ + r""" Decide which worker should take task *ts*. + There are two modes: root(ish) tasks, and normal tasks. + + Root(ish) tasks + ~~~~~~~~~~~~~~~ + + Root(ish) have no (or very very few) dependencies and fan out widely: + they belong to TaskGroups that contain more tasks than there are workers. + We want neighboring root tasks to run on the same worker, since there's a + good chance those neighbors will be combined in a downstream operation: + + i j + / \ / \ + e f g h + | | | | + a b c d + \ \ / / + X + + In the above case, we want ``a`` and ``b`` to run on the same worker, + and ``c`` and ``d`` to run on the same worker, reducing future + data transfer. We can also ignore the location of ``X``, because + as a common dependency, it will eventually get transferred everywhere. + + Calculaing this directly from the graph would be expensive, so instead + we use task priority as a proxy. We aim to send tasks close in priority + within a `TaskGroup` to the same worker. To do this efficiently, we rely + on the fact that `decide_worker` is generally called in priority order + for root tasks (because `Scheduler.update_graph` creates recommendations + in priority order), and track only the last worker used for a `TaskGroup`, + and how many more tasks can be assigned to it before picking a new one. + + By colocating related root tasks, we ensure that placing thier downstream + normal tasks is set up for success. + + Normal tasks + ~~~~~~~~~~~~ + We choose the worker that has the data on which *ts* depends. If several workers have dependencies then we choose the less-busy worker. @@ -7488,36 +7579,83 @@ def decide_worker( of bytes sent between workers. This is determined by calling the *objective* function. """ - ws: WorkerState = None wws: WorkerState - dts: TaskState + + group: TaskGroup = ts._group + ws: WorkerState = group._last_worker + + if valid_workers is not None: + total_nthreads = sum(wws._nthreads for wws in valid_workers) + + group_tasks_per_thread = (len(group) / total_nthreads) if total_nthreads > 0 else 0 + ignore_deps_while_picking: bool = False + + # Try to schedule sibling root-like tasks on the same workers. + if ( + ws is not None + and group._last_worker_priority is not None + # ^ `decide_worker` hasn't previously been called out of priority order + and group_tasks_per_thread > 1 + and sum(map(len, group._dependencies)) < 5 # TODO what number + ): + if group._last_worker_tasks_left > 0: + group._last_worker_tasks_left -= 1 + if group._last_worker_priority < ts.priority and ( + valid_workers is None or ws in valid_workers + ): + group._last_worker_priority = ts.priority + return ws + + # `decide_worker` called out of priority order, or the last used worker is not valid for this task. + # This is probably not actually a root-ish task; disable root-ish mode in the future. + group._last_worker = None + group._last_worker_tasks_left = 0 + group._last_worker_priority = None + + # Previous worker is fully assigned, so pick a new worker. + ignore_deps_while_picking = True + deps: set = ts._dependencies + dts: TaskState candidates: set assert all([dts._who_has for dts in deps]) - if ts._actor: - candidates = set(all_workers) + if ignore_deps_while_picking: + candidates = valid_workers if valid_workers is not None else set(all_workers) else: - candidates = {wws for dts in deps for wws in dts._who_has} - if valid_workers is None: - if not candidates: + if ts._actor: candidates = set(all_workers) - else: - candidates &= valid_workers - if not candidates: - candidates = valid_workers + else: + candidates = {wws for dts in deps for wws in dts._who_has} + if valid_workers is None: if not candidates: - if ts._loose_restrictions: - ws = decide_worker(ts, all_workers, None, objective) - return ws + candidates = set(all_workers) + else: + candidates &= valid_workers + if not candidates: + candidates = valid_workers + if not candidates: + if ts._loose_restrictions: + ws = decide_worker( + ts, all_workers, None, objective, total_nthreads + ) + return ws ncandidates: Py_ssize_t = len(candidates) if ncandidates == 0: pass elif ncandidates == 1: + # NOTE: this is the ideal case: all the deps are already on the same worker. for ws in candidates: break else: ws = min(candidates, key=objective) + + if group._last_worker_priority is not None: + group._last_worker = ws + group._last_worker_tasks_left = ( + math.floor(group_tasks_per_thread * ws._nthreads) - 1 + ) + group._last_worker_priority = ts.priority return ws diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f1aeef606d..3b40ca5bbe 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -17,7 +17,7 @@ import dask from dask import delayed -from dask.utils import apply +from dask.utils import apply, stringify from distributed import Client, Nanny, Worker, fire_and_forget, wait from distributed.comm import Comm @@ -126,6 +126,114 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): assert x.key in a.data or x.key in b.data +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 3, + config={"distributed.scheduler.work-stealing": False}, +) +async def test_decide_worker_select_candidate_holding_no_deps(client, s, a, b, c): + await client.submit(slowinc, 10, delay=0.1) # learn that slowinc is slow + root = await client.scatter(1) + assert sum(root.key in worker.data for worker in [a, b, c]) == 1 + + start = time() + tasks = client.map(slowinc, [root] * 6, delay=0.1, pure=False) + await wait(tasks) + elapsed = time() - start + + assert elapsed <= 4 + assert all(root.key in worker.data for worker in [a, b, c]), [ + list(worker.data.keys()) for worker in [a, b, c] + ] + + +@pytest.mark.parametrize("ndeps", [0, 1, 4]) +@pytest.mark.parametrize( + "nthreads", + [ + [("127.0.0.1", 1)] * 5, + [("127.0.0.1", 3), ("127.0.0.1", 2), ("127.0.0.1", 1)], + ], +) +def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): + @gen_cluster( + client=True, + nthreads=nthreads, + config={"distributed.scheduler.work-stealing": False}, + ) + async def test(c, s, *workers): + """Ensure that related tasks end up on the same node""" + da = pytest.importorskip("dask.array") + np = pytest.importorskip("numpy") + + if ndeps == 0: + x = da.random.random((100, 100), chunks=(10, 10)) + else: + + def random(**kwargs): + assert len(kwargs) == ndeps + return np.random.random((10, 10)) + + trivial_deps = {f"k{i}": delayed(object()) for i in range(ndeps)} + + # TODO is there a simpler (non-blockwise) way to make this sort of graph? + x = da.blockwise( + random, + "yx", + new_axes={"y": (10,) * 10, "x": (10,) * 10}, + dtype=float, + **trivial_deps, + ) + + xx, xsum = dask.persist(x, x.sum(axis=1, split_every=20)) + await xsum + + # Check that each chunk-row of the array is (mostly) stored on the same worker + primary_worker_key_fractions = [] + secondary_worker_key_fractions = [] + for i, keys in enumerate(x.__dask_keys__()): + # Iterate along rows of the array. + keys = set(stringify(k) for k in keys) + + # No more than 2 workers should have any keys + assert sum(any(k in w.data for k in keys) for w in workers) <= 2 + + # What fraction of the keys for this row does each worker hold? + key_fractions = [ + len(set(w.data).intersection(keys)) / len(keys) for w in workers + ] + key_fractions.sort() + # Primary worker: holds the highest percentage of keys + # Secondary worker: holds the second highest percentage of keys + primary_worker_key_fractions.append(key_fractions[-1]) + secondary_worker_key_fractions.append(key_fractions[-2]) + + # There may be one or two rows that were poorly split across workers, + # but the vast majority of rows should only be on one worker. + assert np.mean(primary_worker_key_fractions) >= 0.9 + assert np.median(primary_worker_key_fractions) == 1.0 + assert np.mean(secondary_worker_key_fractions) <= 0.1 + assert np.median(secondary_worker_key_fractions) == 0.0 + + # Check that there were few transfers + unexpected_transfers = [] + for worker in workers: + for log in worker.incoming_transfer_log: + keys = log["keys"] + # The root-ish tasks should never be transferred + assert not any(k.startswith("random") for k in keys), keys + # `object-` keys (the trivial deps of the root random tasks) should be transferred + if any(not k.startswith("object") for k in keys): + # But not many other things should be + unexpected_transfers.append(list(keys)) + + # A transfer at the very end to move aggregated results is fine (necessary with unbalanced workers in fact), + # but generally there should be very very few transfers. + assert len(unexpected_transfers) <= 2, unexpected_transfers + + test() + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) async def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = await client.scatter([1], workers=b.address)