diff --git a/distributed/stealing.py b/distributed/stealing.py index cdbcce30c4..2d0917710a 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -4,6 +4,7 @@ import logging from collections import defaultdict, deque from collections.abc import Container +from functools import partial from math import log2 from time import time from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast @@ -522,12 +523,12 @@ def _get_thief( ) -> WorkerState | None: valid_workers = scheduler.valid_workers(ts) if valid_workers is not None: - subset = potential_thieves & valid_workers - if subset: - return next(iter(subset)) + valid_thieves = potential_thieves & valid_workers + if valid_thieves: + potential_thieves = valid_thieves elif not ts.loose_restrictions: return None - return next(iter(potential_thieves)) + return min(potential_thieves, key=partial(scheduler.worker_objective, ts)) fast_tasks = {"split-shuffle"} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index e10bd3f021..fd42f5bae8 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -7,17 +7,30 @@ import math import random import weakref +from collections import defaultdict from operator import mul from time import sleep +from typing import Callable, Iterable, Mapping, Sequence import numpy as np import pytest -from tlz import sliding_window +from tlz import merge, sliding_window import dask from dask.utils import key_split -from distributed import Event, Lock, Nanny, Worker, profile, wait, worker_client +from distributed import ( + Client, + Event, + Lock, + Nanny, + Scheduler, + Worker, + profile, + wait, + worker_client, +) +from distributed.client import Future from distributed.compatibility import LINUX from distributed.core import Status from distributed.metrics import time @@ -25,7 +38,6 @@ from distributed.utils_test import ( NO_AMM, BlockedGetData, - SizeOf, captured_logger, freeze_batched_send, gen_cluster, @@ -50,6 +62,11 @@ teardown_module = nodebug_teardown_module +@pytest.fixture(params=[True, False]) +def recompute_saturation(request): + yield request.param + + @gen_cluster(client=True, nthreads=[("", 2), ("", 2)]) async def test_work_stealing(c, s, a, b): [x] = await c._scatter([1], workers=a.address) @@ -664,7 +681,7 @@ def block(*args, event, **kwargs): for t in sorted(ts, reverse=True): if t: [dat] = await c.scatter( - [SizeOf(int(t * s.bandwidth))], workers=w.address + [gen_nbytes(int(t * s.bandwidth))], workers=w.address ) else: dat = 123 @@ -710,24 +727,37 @@ def block(*args, event, **kwargs): raise Exception(f"Expected: {expected2}; got: {result2}") -@pytest.mark.parametrize("recompute_saturation", [True, False]) @pytest.mark.parametrize( "inp,expected", [ - ([[1], []], [[1], []]), # don't move unnecessarily - ([[0, 0], []], [[0], [0]]), # balance - ([[0.1, 0.1], []], [[0], [0]]), # balance even if results in even - ([[0, 0, 0], []], [[0, 0], [0]]), # don't over balance - ([[0, 0], [0, 0, 0], []], [[0, 0], [0, 0], [0]]), # move from larger - ([[0, 0, 0], [0], []], [[0, 0], [0], [0]]), # move to smaller - ([[0, 1], []], [[1], [0]]), # choose easier first - ([[0, 0, 0, 0], [], []], [[0, 0], [0], [0]]), # spread evenly - ([[1, 0, 2, 0], [], []], [[2, 1], [0], [0]]), # move easier - ([[1, 1, 1], []], [[1, 1], [1]]), # be willing to move costly items - ([[1, 1, 1, 1], []], [[1, 1, 1], [1]]), # but don't move too many - ( - [[0, 0], [0, 0], [0, 0], []], # no one clearly saturated + pytest.param([[1], []], [[1], []], id="don't move unnecessarily"), + pytest.param([[0, 0], []], [[0], [0]], id="balance"), + pytest.param( + [[0, 0, 0, 0, 0, 0, 0, 0], []], + [[0, 0, 0, 0, 0, 0], [0, 0]], + id="balance until none idle", + ), + pytest.param( + [[0.1, 0.1], []], [[0], [0]], id="balance even if results in even" + ), + pytest.param([[0, 0, 0], []], [[0, 0], [0]], id="don't over balance"), + pytest.param( + [[0, 0], [0, 0, 0], []], [[0, 0], [0, 0], [0]], id="move from larger" + ), + pytest.param([[0, 0, 0], [0], []], [[0, 0], [0], [0]], id="move to smaller"), + pytest.param([[0, 1], []], [[1], [0]], id="choose easier first"), + pytest.param([[0, 0, 0, 0], [], []], [[0, 0], [0], [0]], id="spread evenly"), + pytest.param([[1, 0, 2, 0], [], []], [[2, 1], [0], [0]], id="move easier"), + pytest.param( + [[1, 1, 1], []], [[1, 1], [1]], id="be willing to move costly items" + ), + pytest.param( + [[1, 1, 1, 1], []], [[1, 1, 1], [1]], id="but don't move too many" + ), + pytest.param( + [[0, 0], [0, 0], [0, 0], []], [[0, 0], [0, 0], [0], [0]], + id="no one clearly saturated", ), # NOTE: There is a timing issue that workers may already start executing # tasks before we call balance, i.e. the workers will reject the @@ -735,9 +765,10 @@ def block(*args, event, **kwargs): # Particularly tests with many input tasks are more likely to fail since # the test setup takes longer and allows the workers more time to # schedule a task on the threadpool - ( + pytest.param( [[4, 2, 2, 2, 2, 1, 1], [4, 2, 1, 1], [], [], []], [[4, 2, 2, 2], [4, 2, 1, 1], [2], [1], [1]], + id="balance multiple saturated workers", ), ], ) @@ -1418,3 +1449,428 @@ def func(*args): ideal = ntasks / len(workers) assert (ntasks_per_worker > ideal * 0.5).all(), (ideal, ntasks_per_worker) assert (ntasks_per_worker < ideal * 1.5).all(), (ideal, ntasks_per_worker) + + +def test_balance_even_with_replica(recompute_saturation): + dependencies = {"a": 1} + dependency_placement = [["a"], ["a"]] + task_placement = [[["a"], ["a"]], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + return actual_task_counts == [ + 1, + 1, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_to_replica(recompute_saturation): + dependencies = {"a": 2} + dependency_placement = [["a"], ["a"], []] + task_placement = [[["a"], ["a"]], [], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + return actual_task_counts == [ + 1, + 1, + 0, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_multiple_to_replica(recompute_saturation): + dependencies = {"a": 6} + dependency_placement = [["a"], ["a"], []] + task_placement = [[["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], [], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + # FIXME: A better task placement would be even but the current balancing + # logic aborts as soon as a worker is no longer classified as idle + # return actual_task_counts == [ + # 4, + # 4, + # 0, + # ] + return actual_task_counts == [ + 6, + 2, + 0, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_to_larger_dependency(recompute_saturation): + dependencies = {"a": 2, "b": 1} + dependency_placement = [["a", "b"], ["a"], ["b"]] + task_placement = [[["a", "b"], ["a", "b"], ["a", "b"]], [], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + return actual_task_counts == [ + 2, + 1, + 0, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_prefers_busier_with_dependency(): + recompute_saturation = True + dependencies = {"a": 5, "b": 1} + dependency_placement = [["a"], ["a", "b"], []] + task_placement = [ + [["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], + [["b"]], + [], + ] + + def _correct_placement(actual): + actual_task_placements = [sorted(placed) for placed in actual] + # FIXME: A better task placement would be even but the current balancing + # logic aborts as soon as a worker is no longer classified as idle + # return actual_task_placements == [ + # [["a"], ["a"], ["a"], ["a"]], + # [["a"], ["a"], ["b"]], + # [], + # ] + return actual_task_placements == [ + [["a"], ["a"], ["a"], ["a"], ["a"]], + [["a"], ["b"]], + [], + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + # This test relies on disabling queueing to flag workers as idle + config={ + "distributed.scheduler.worker-saturation": float("inf"), + }, + ) + + +def _run_dependency_balance_test( + dependencies: Mapping[str, int], + dependency_placement: list[list[str]], + task_placement: list[list[list[str]]], + correct_placement_fn: Callable[[list[list[list[str]]]], bool], + recompute_saturation: bool, + config: dict | None = None, +) -> None: + """Run a test for balancing with task dependencies according to the provided + specifications. + + This method executes the test logic for all permutations of worker placements + and generates a new cluster for each one. + + Parameters + ---------- + dependencies + Mapping of task dependencies to their weight. + dependency_placement + List of list of dependencies to be placed on the worker corresponding + to the index of the outer list. + task_placement + List of list of tasks to be placed on the worker corresponding to the + index of the outer list. Each task is a list of names of dependencies. + correct_placement_fn + Callable used to determine if stealing placed the tasks as expected. + recompute_saturation + Whether to recompute worker saturation before stealing. + config + Optional configuration to apply to the test. + See Also + -------- + _dependency_balance_test_permutation + """ + nworkers = len(task_placement) + for permutation in itertools.permutations(range(nworkers)): + + async def _run( + *args, + permutation=permutation, + **kwargs, + ): + await _dependency_balance_test_permutation( + dependencies, + dependency_placement, + task_placement, + correct_placement_fn, + recompute_saturation, + permutation, + *args, + **kwargs, + ) + + gen_cluster( + client=True, + nthreads=[("", 1)] * len(task_placement), + config=merge( + config or {}, + { + "distributed.scheduler.unknown-task-duration": "1s", + }, + ), + )(_run)() + + +async def _dependency_balance_test_permutation( + dependencies: Mapping[str, int], + dependency_placement: list[list[str]], + task_placement: list[list[list[str]]], + correct_placement_fn: Callable[[list[list[list[str]]]], bool], + recompute_saturation: bool, + permutation: list[int], + c: Client, + s: Scheduler, + *workers: Worker, +) -> None: + """Run a test for balancing with task dependencies according to the provided + specifications and worker permutations. + + Parameters + ---------- + dependencies + Mapping of task dependencies to their weight. + dependency_placement + List of list of dependencies to be placed on the worker corresponding + to the index of the outer list. + task_placement + List of list of tasks to be placed on the worker corresponding to the + index of the outer list. Each task is a list of names of dependencies. + correct_placement_fn + Callable used to determine if stealing placed the tasks as expected. + recompute_saturation + Whether to recompute worker saturation before stealing. + permutation + Permutation of workers to use for this run. + + See Also + -------- + _run_dependency_balance_test + """ + steal = s.extensions["stealing"] + await steal.stop() + + inverse = [permutation.index(i) for i in range(len(permutation))] + permutated_dependency_placement = [dependency_placement[i] for i in permutation] + permutated_task_placement = [task_placement[i] for i in permutation] + + dependency_futures = await _place_dependencies( + dependencies, permutated_dependency_placement, c, s, workers + ) + + ev, futures = await _place_tasks( + permutated_task_placement, + permutated_dependency_placement, + dependency_futures, + c, + s, + workers, + ) + + if recompute_saturation: + for ws in s.workers.values(): + s._reevaluate_occupancy_worker(ws) + try: + for _ in range(20): + steal.balance() + await steal.stop() + + permutated_actual_placement = _get_task_placement(s, workers) + actual_placement = [permutated_actual_placement[i] for i in inverse] + + if correct_placement_fn(actual_placement): + return + finally: + # Release the threadpools + await ev.set() + await c.gather(futures) + + raise AssertionError(actual_placement, permutation) + + +async def _place_dependencies( + dependencies: Mapping[str, int], + placement: list[list[str]], + c: Client, + s: Scheduler, + workers: Sequence[Worker], +) -> dict[str, Future]: + """Places the dependencies on the workers as specified. + + Parameters + ---------- + dependencies + Mapping of task dependencies to their weight. + placement + List of list of dependencies to be placed on the worker corresponding to the + index of the outer list. + + Returns + ------- + Dictionary of futures matching the input dependencies. + + See Also + -------- + _run_dependency_balance_test + """ + dependencies_to_workers = defaultdict(set) + for worker_idx, placed in enumerate(placement): + for dependency in placed: + dependencies_to_workers[dependency].add(workers[worker_idx].address) + + futures = {} + for name, multiplier in dependencies.items(): + worker_addresses = dependencies_to_workers[name] + futs = await c.scatter( + {name: gen_nbytes(int(multiplier * s.bandwidth))}, + workers=worker_addresses, + broadcast=True, + ) + futures[name] = futs[name] + + await c.gather(futures.values()) + + _assert_dependency_placement(placement, workers) + + return futures + + +def _assert_dependency_placement(expected, workers): + """Assert that dependencies are placed on the workers as expected.""" + actual = [] + for worker in workers: + actual.append(list(worker.state.tasks.keys())) + + assert actual == expected + + +async def _place_tasks( + placement: list[list[list[str]]], + dependency_placement: list[list[str]], + dependency_futures: Mapping[str, Future], + c: Client, + s: Scheduler, + workers: Sequence[Worker], +) -> tuple[Event, list[Future]]: + """Places the tasks on the workers as specified. + + Parameters + ---------- + placement + List of list of tasks to be placed on the worker corresponding to the + index of the outer list. Each task is a list of names of dependencies. + dependency_placement + List of list of dependencies to be placed on the worker corresponding to the + index of the outer list. + dependency_futures + Mapping of dependency names to their corresponding futures. + + Returns + ------- + Tuple of the event blocking the placed tasks and list of futures matching + the input task placement. + + See Also + -------- + _run_dependency_balance_test + """ + ev = Event() + + def block(*args, event, **kwargs): + event.wait() + + counter = itertools.count() + futures = [] + for worker_idx, tasks in enumerate(placement): + for dependencies in tasks: + i = next(counter) + dep_key = "".join(sorted(dependencies)) + key = f"{dep_key}-{i}" + f = c.submit( + block, + [dependency_futures[dependency] for dependency in dependencies], + event=ev, + key=key, + workers=workers[worker_idx].address, + allow_other_workers=True, + pure=False, + priority=-i, + ) + futures.append(f) + + while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): + await asyncio.sleep(0.001) + + while any( + len(w.state.tasks) < (len(tasks) + len(dependencies)) + for w, dependencies, tasks in zip(workers, dependency_placement, placement) + ): + await asyncio.sleep(0.001) + + assert_task_placement(placement, s, workers) + + return ev, futures + + +def _get_task_placement( + s: Scheduler, workers: Iterable[Worker] +) -> list[list[list[str]]]: + """Return the placement of tasks on this worker""" + actual = [] + for w in workers: + actual.append( + [list(key_split(ts.key)) for ts in s.workers[w.address].processing] + ) + return _deterministic_placement(actual) + + +def _equal_placement(left, right): + """Return True IFF the two input placements are equal.""" + return _deterministic_placement(left) == _deterministic_placement(right) + + +def _deterministic_placement(placement): + """Return a deterministic ordering of the tasks or dependencies on each worker.""" + return [sorted(placed) for placed in placement] + + +def assert_task_placement(expected, s, workers): + """Assert that tasks are placed on the workers as expected.""" + actual = _get_task_placement(s, workers) + assert _equal_placement(actual, expected)