Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect task ordering when making worker assignments #4922

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,11 +729,15 @@ def clean(self):
return ws

def __repr__(self):
return "<WorkerState %r, name: %s, memory: %d, processing: %d>" % (
self._address,
self._name,
len(self._has_what),
len(self._processing),
return (
"<WorkerState %r, name: %s, memory: %d, processing: %d, occupancy: %s>"
% (
self._address,
self._name,
len(self._has_what),
len(self._processing),
format_time(self.occupancy),
)
)

def _repr_html_(self):
Expand Down Expand Up @@ -950,6 +954,7 @@ class TaskGroup:
_start: double
_stop: double
_all_durations: object
_last_scheduled_worker: WorkerState

def __init__(self, name: str):
self._name = name
Expand All @@ -964,6 +969,7 @@ def __init__(self, name: str):
self._start = 0.0
self._stop = 0.0
self._all_durations = defaultdict(float)
self._last_scheduled_worker = None

@property
def name(self):
Expand Down Expand Up @@ -2337,12 +2343,18 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
ts.state = "no-worker"
return ws

if ts._dependencies or valid_workers is not None:
if (
ts._dependencies
or valid_workers is not None
or len(ts._group) > self.total_nthreads * 2
):
ws = decide_worker(
ts,
self._workers_dv.values(),
valid_workers,
partial(self.worker_objective, ts),
ts=ts,
all_workers=self._workers_dv.values(),
valid_workers=valid_workers,
objective=partial(self.worker_objective, ts),
nthreads=self.total_nthreads,
unknown_task_duration=self.UNKNOWN_TASK_DURATION,
)
else:
worker_pool = self._idle or self._workers
Expand Down Expand Up @@ -7478,7 +7490,12 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
@cfunc
@exceptval(check=False)
def decide_worker(
ts: TaskState, all_workers, valid_workers: set, objective
ts: TaskState,
all_workers,
valid_workers: set,
objective,
nthreads: int,
unknown_task_duration: float,
) -> WorkerState:
"""
Decide which worker should take task *ts*.
Expand Down Expand Up @@ -7514,7 +7531,14 @@ def decide_worker(
candidates = valid_workers
if not candidates:
if ts._loose_restrictions:
ws = decide_worker(ts, all_workers, None, objective)
ws = decide_worker(
ts=ts,
all_workers=all_workers,
valid_workers=None,
objective=objective,
nthreads=nthreads,
unknown_task_duration=unknown_task_duration,
)
return ws

ncandidates: Py_ssize_t = len(candidates)
Expand All @@ -7525,6 +7549,35 @@ def decide_worker(
break
else:
ws = min(candidates, key=objective)

# If our group is large with few dependencies
# Then assign sequential tasks to similar workers, even if occupancy isn't ideal
if len(ts._group) > nthreads * 2 and sum(map(len, ts._group._dependencies)) < 5:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Isn't the length of all dependencies of a TG potentially very expensive? The length of a group iterates over all TaskStates in a given group. For some topologies, this would require us to iterate over all tasks (-1), wouldn't it?
  • Is there any way to reason about the numeric values here? I think I'm still lacking intuition for TGs to tell how stable this heuristic is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The length of a group iterates over all TaskStates in a given group

The code looks like this

    def __init__(self, ...):
        self._states = {"memory": 0, "processing": 0, ...}

    def __len__(self):
        return sum(self._states.values())

So it's not as bad as it sounds. However, iterating the dict of a few elements could still be concerning. If so we could always keep a _len value around. It would be cheap to maintain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to reason about the numeric values here? I think I'm still lacking intuition for TGs to tell how stable this heuristic is.

The 2 is because we want more than one task per worker to be allocated. If there are more or equal workers as tasks then we're unlikely to co-schedule any tasks on similar workers, so this is a moot point.

The < 5 is really saying "we want there to be almost no dependencies for the tasks in this group, but we're going to accept a common case of all tasks depending on some parameter or something like an abstract zarr file". We're looking for cases where the dependency won't significantly affect the distribution of tasks throughout the cluster. This could be len(dependencies) in (0, 1) but we figured we'd allow a couple of these just in case.

I expect that the distribution here will be bi-modal with tasks either in (0, 1) or in the hundreds or thousands. Five seemed like a good separator value in that distribution. I think that, given the distribution, this choice is stable and defensible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks like this

Right, that's state as in {Running, Memory, Released} and not state as in TaskState and is an aggregated dict. I was already a bit thrown off when I saw that. That's perfectly fine.

I expect that the distribution here will be bi-modal with tasks either in (0, 1) or in the hundreds or thousands.

Thanks for the detailed description. I think I was thrown off by the TaskGroup semantics again. I was thinking about our typical tree reductions where we have usually task splits like 8 or 16. These are the situations where one would want to group all dependencies for the first reduction.
However, for group dependencies this should be a trivial dependency of one, correct?

Then, five is conservative, I agree 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps state should have been called state_counts. Oh well.

Ah, it's not len(ts._group._dependencies) which is what you're describing, I think. It's sum(map(len, ts._group._dependencies)) < 5.

We're counting up all of the dependencies for all of the tasks that are like this task. So in a tree reduction, this number would likely be in the thousands for any non-trivially sized computation. It is non-zero and less than five only in cases like the following:

a1 a2 a3 a4 a5 a6 a7 a8
 \  \  \  \  /  /  /  /
            b

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really we're looking for cases where the number of dependencies, amortized over all similar tasks, is near-zero.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "ish" in "root-ish" tasks that we sometimes talk about here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps state should have been called state_counts. Oh well.

naming is hard

We're counting up all of the dependencies for all of the tasks that are like this task. So in a tree reduction, this number would likely be in the thousands for any non-trivially sized computation. It is non-zero and less than five only in cases like the following:
Really we're looking for cases where the number of dependencies, amortized over all similar tasks, is near-zero.
This is the "ish" in "root-ish" tasks that we sometimes talk about here.

I think I got it now. That's an interesting approach to gauge the local topology. What I'm currently wondering is if this or a closely related metric (e.g. ratio of group dependents/dependencies) could be used to estimate whether a task has the potential to increase/decrease parallelism. that'd be an interesting metric for work stealing.

anyhow, don't want to increase the scope here. this is a discussion we can delay. I'll let the professionals back to work! thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it could be a useful metric for memory consuming/producing tasks.

It's also, yes, a good metric for increasing parallelism. My experience though is that we are always in a state of abundant parallelism, and that scheduling to increase parallelism is not worth considering in our domain.

Instead we should focus our scheduling decisions to reduce memory use and free intermediate tasks quickly.

if ts._group._last_scheduled_worker is None: # First time
if (
ts._group.states["released"] > len(ts._group) / 2
): # many tasks to be scheduled
ts._group._last_scheduled_worker = ws
else:
duration = ts._prefix.duration_average
if duration < 0.0:
duration = unknown_task_duration

alternate = ts._group._last_scheduled_worker
ratio = math.ceil(len(ts._group) / nthreads)

# Allow a few tasks to pile up before moving to the next worker
if (
alternate.occupancy < ws.occupancy + duration * ratio
and alternate in all_workers
):
ws = alternate
else:
ts._group._last_scheduled_worker = ws

if ts._group.states["released"] == 0: # all done, reset
ts._group._last_scheduled_worker = None

return ws


Expand Down
37 changes: 36 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dask
from dask import delayed
from dask.utils import apply
from dask.utils import apply, stringify

from distributed import Client, Nanny, Worker, fire_and_forget, wait
from distributed.comm import Comm
Expand Down Expand Up @@ -2800,3 +2800,38 @@ async def test_rebalance_least_recently_inserted_sender_min(c, s, *_):
a: (large_future.key,),
b: tuple(f.key for f in small_futures),
}


@gen_cluster(
client=True,
ncores=[("127.0.0.1", 1)] * 5,
config={"distributed.scheduler.work-stealing": False},
)
async def test_coschedule_order_neighbors(c, s, *workers):
"""Ensure that similar tasks end up on similar nodes"""

da = pytest.importorskip("dask.array")

x = da.random.random((100, 100), chunks=(10, 10))
xx, xsum = dask.persist(x, x.sum(axis=1, split_every=20))
await xsum

for i, keys in enumerate(x.__dask_keys__()):
# One worker has most of the keys
assert any(sum(stringify(k) in w.data for k in keys) >= 5 for w in workers)
# We might have split between two, but at most two
assert sum(sum(stringify(k) in w.data for k in keys) >= 1 for w in workers) <= 2

x = da.random.random((100, 100), chunks=(10, 10))
xx, xsum = dask.persist(x, x.sum(axis=0, split_every=20))
await xsum

for i, keys in enumerate(zip(*x.__dask_keys__())):

# One worker has most of the keys
assert any(sum(stringify(k) in w.data for k in keys) >= 5 for w in workers)
# We might have split between two, but at most two
assert sum(sum(stringify(k) in w.data for k in keys) >= 1 for w in workers) <= 2

# There were very few transfers
assert sum(len(w.incoming_transfer_log) for w in workers) < 5