Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select queued tasks in stimuli, not transitions #7402

Merged
merged 12 commits into from
Dec 14, 2022
2 changes: 1 addition & 1 deletion distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def peekn(self, n: int) -> Iterator[T]:
"""Iterate over the n smallest elements without removing them.
This is O(1) for n == 1; O(n*logn) otherwise.
"""
if n <= 0:
if n <= 0 or not self:
return # empty iterator
if n == 1:
yield self.peek()
Expand Down
121 changes: 63 additions & 58 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,18 +2317,13 @@ def transition_processing_memory(
############################
# Update State Information #
############################
recommendations: Recs = {}
client_msgs: Msgs = {}

if nbytes is not None:
ts.set_nbytes(nbytes)

# NOTE: recommendations for queued tasks are added first, so they'll be popped
# last, allowing higher-priority downstream tasks to be transitioned first.
# FIXME: this would be incorrect if queued tasks are user-annotated as higher
# priority.
self._exit_processing_common(ts, recommendations)
self._exit_processing_common(ts)

recommendations: Recs = {}
client_msgs: Msgs = {}
self._add_to_memory(
ts, ws, recommendations, client_msgs, type=type, typename=typename
)
Expand Down Expand Up @@ -2507,7 +2502,7 @@ def transition_processing_released(self, key: str, stimulus_id: str) -> RecsMsgs
assert not ts.waiting_on
assert ts.state == "processing"

ws = self._exit_processing_common(ts, recommendations)
ws = self._exit_processing_common(ts)
if ws:
worker_msgs[ws.address] = [
{
Expand Down Expand Up @@ -2574,7 +2569,7 @@ def transition_processing_erred(
assert ws
ws.actors.remove(ts)

self._exit_processing_common(ts, recommendations)
self._exit_processing_common(ts)

ts.erred_on.add(worker)
if exception is not None:
Expand Down Expand Up @@ -3059,24 +3054,20 @@ def remove_all_replicas(self, ts: TaskState) -> None:
self.replicated_tasks.remove(ts)
ts.who_has.clear()

def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs:
"""Send ``queued`` or ``no-worker`` tasks to ``processing`` that this worker can
handle.
def bulk_schedule_unrunnable_after_adding_worker(self, ws: WorkerState) -> Recs:
"""Send ``no-worker`` tasks to ``processing`` that this worker can handle.

Returns priority-ordered recommendations.
"""
maybe_runnable = list(self._next_queued_tasks_for_worker(ws))[::-1]

# Schedule any restricted tasks onto the new worker, if the worker can run them
runnable: list[TaskState] = []
for ts in self.unrunnable:
valid = self.valid_workers(ts)
if valid is None or ws in valid:
maybe_runnable.append(ts)
runnable.append(ts)

# Recommendations are processed LIFO, hence the reversed order
maybe_runnable.sort(key=operator.attrgetter("priority"), reverse=True)
# Note not all will necessarily be run; transition->processing will decide
return {ts.key: "processing" for ts in maybe_runnable}
runnable.sort(key=operator.attrgetter("priority"), reverse=True)
return {ts.key: "processing" for ts in runnable}

def _validate_ready(self, ts: TaskState) -> None:
"""Validation for ready states (processing, queued, no-worker)"""
Expand Down Expand Up @@ -3108,9 +3099,7 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs:

return {ws.address: [self._task_to_msg(ts)]}

def _exit_processing_common(
self, ts: TaskState, recommendations: Recs
) -> WorkerState | None:
def _exit_processing_common(self, ts: TaskState) -> WorkerState | None:
"""Remove *ts* from the set of processing tasks.

Returns
Expand All @@ -3133,28 +3122,8 @@ def _exit_processing_common(
self.check_idle_saturated(ws)
self.release_resources(ts, ws)

for qts in self._next_queued_tasks_for_worker(ws):
if self.validate:
assert qts.key not in recommendations, recommendations[qts.key]
recommendations[qts.key] = "processing"

return ws

def _next_queued_tasks_for_worker(self, ws: WorkerState) -> Iterator[TaskState]:
"""Queued tasks to run, in priority order, on all open slots on a worker"""
if not self.queued or ws.status != Status.running:
return

# NOTE: this is called most frequently because a single task has completed, so
# there are <= 1 task slots available on the worker.
# `peekn` has fast paths for the cases N<=0 and N==1.
for qts in self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION)):
if self.validate:
assert qts.state == "queued", qts.state
assert not qts.processing_on
assert not qts.waiting_on
yield qts

def _add_to_memory(
self,
ts: TaskState,
Expand Down Expand Up @@ -4240,7 +4209,10 @@ async def add_worker(
logger.exception(e)

if ws.status == Status.running:
self.transitions(self.bulk_schedule_after_adding_worker(ws), stimulus_id)
self.transitions(
self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id
)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

logger.info("Register worker %s", ws)

Expand Down Expand Up @@ -4620,6 +4592,42 @@ def update_graph(

# TODO: balance workers

def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None:
"""Respond to an event which may have opened spots on worker threadpools

Selects the appropriate number of tasks from the front of the queue according to
the total number of task slots available on workers (potentially 0), and
transitions them to ``processing``.

Notes
-----
Other transitions related to this stimulus should be fully processed beforehand,
so any tasks that became runnable are already in ``processing``. Otherwise,
overproduction can occur if queued tasks get scheduled before downstream tasks.

Must be called after `check_idle_saturated`; i.e. `idle_task_count` must be up
to date.
"""
if not self.queued:
return
slots_available = sum(
_task_slots_available(ws, self.WORKER_SATURATION)
for ws in self.idle_task_count
)
Comment on lines +4613 to +4616
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hardly worth mentioning, but potential for a negligible scheduler performance change here (which may not actually be possible):

Now, every time a task completes, we run _task_slots_available on all idle workers. Before, we only ran it on the one worker that just completed a task. In most of those cases, len(self.idle_task_count) should be 1, so no difference.

I mention this because _task_slots_available does already show up in py-spy profiles of the scheduler (usually around 0.5-1%), because it's already called frequently. This would maybe allow it to be called even more frequently.

But I don't think it can actually run more than it needs to:

  • When there are more threads than root tasks, so many workers are idle, queued would be empty, so this wouldn't run.

  • When there are more root tasks than threads, and we're queuing, at most one worker should ever be idle: as soon as it becomes idle, it gets another queued task and is no longer idle.

    The one exception is client.close() releasing many processing tasks at once, while a different client has tasks on the queue. In that case, there's a lot of rescheduling to do, and we do need to look at all the workers that just became idle, so no unnecessary work there either.

tl;dr I don't see a theoretical way for this to be a problem, but I haven't benchmarked or profiled to be sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for highlighting.

Apart from the theoretical analysis, this is also something we can easily optimize if it becomes a problem, e.g. saturation * nthreads is not something we need to compute every time and even the processing - long running could be replaced with a counter we inc/dec during transitions.

TLDR Not concerned and "fixing" it right now feels premature

if slots_available == 0:
return

recommendations: Recs = {}
for qts in self.queued.peekn(slots_available):
if self.validate:
assert qts.state == "queued", qts.state
assert not qts.processing_on, (qts, qts.processing_on)
assert not qts.waiting_on, (qts, qts.processing_on)
assert qts.who_wants or qts.waiters, qts
recommendations[qts.key] = "processing"

self.transitions(recommendations, stimulus_id)

def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
logger.debug("Stimulus task finished %s, %s", key, worker)
Expand Down Expand Up @@ -4946,6 +4954,8 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None):
self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations)
self.transitions(recommendations, stimulus_id)

self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def client_heartbeat(self, client=None):
"""Handle heartbeats from Client"""
cs: ClientState = self.clients[client]
Expand Down Expand Up @@ -5290,15 +5300,18 @@ def handle_task_finished(
)
recommendations, client_msgs, worker_msgs = r
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)

self.send_all(client_msgs, worker_msgs)

self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None:
r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg)
recommendations, client_msgs, worker_msgs = r
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)

self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
ts = self.tasks.get(key)
ws = self.workers.get(worker)
Expand Down Expand Up @@ -5340,13 +5353,7 @@ def handle_long_running(
ws.add_to_long_running(ts)
self.check_idle_saturated(ws)

recommendations: Recs = {
qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws)
}
if self.validate:
assert len(recommendations) <= 1, (ws, recommendations)

self.transitions(recommendations, stimulus_id)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def handle_worker_status_change(
self, status: str | Status, worker: str | WorkerState, stimulus_id: str
Expand All @@ -5372,12 +5379,10 @@ def handle_worker_status_change(
if ws.status == Status.running:
self.running.add(ws)
self.check_idle_saturated(ws)
recs = self.bulk_schedule_after_adding_worker(ws)
if recs:
client_msgs: Msgs = {}
worker_msgs: Msgs = {}
self._transitions(recs, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
self.transitions(
self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id
)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)
else:
self.running.discard(ws)
self.idle.pop(ws.address, None)
Expand Down
6 changes: 6 additions & 0 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,14 @@ def test_heapset():
assert list(heap.peekn(1)) == [cx]
heap.remove(cw)
assert list(heap.peekn(1)) == [cx]
heap.remove(cx)
assert list(heap.peekn(-1)) == []
assert list(heap.peekn(0)) == []
assert list(heap.peekn(1)) == []
assert list(heap.peekn(2)) == []

# Test resilience to failure in key()
heap.add(cx)
bad_key = C("bad_key", 0)
del bad_key.i
with pytest.raises(AttributeError):
Expand Down
83 changes: 79 additions & 4 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -86,14 +87,14 @@ async def test_administration(s, a, b):

@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_respect_data_in_memory(c, s, a):
x = delayed(inc)(1)
y = delayed(inc)(x)
x = delayed(inc)(1, dask_key_name="x")
y = delayed(inc)(x, dask_key_name="y")
f = c.persist(y)
await wait([f])

assert s.tasks[y.key].who_has == {s.workers[a.address]}

z = delayed(operator.add)(x, y)
z = delayed(operator.add)(x, y, dask_key_name="z")
f2 = c.persist(z)
while f2.key not in s.tasks or not s.tasks[f2.key]:
assert s.tasks[y.key].who_has
Expand Down Expand Up @@ -371,6 +372,80 @@ def __del__(self):
assert max(Refcount.log) <= s.total_nthreads


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_forget_tasks_while_processing(c, s, a, b):
events = [Event() for _ in range(10)]

futures = c.map(Event.wait, events)
await events[0].set()
await futures[0]
await c.close()
assert not s.tasks


@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_restart_while_processing(c, s, a, b):
events = [Event() for _ in range(10)]

futures = c.map(Event.wait, events)
await events[0].set()
await futures[0]
# TODO slow because worker waits a while for the task to finish
await c.restart()
assert not s.tasks


@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
config={"distributed.scheduler.worker-saturation": 1.0},
)
async def test_queued_release_multiple_workers(c, s, *workers):
async with Client(s.address, asynchronous=True) as c2:
event = Event(client=c2)

rootish_threshold = s.total_nthreads * 2 + 1

first_batch = c.map(
lambda i: event.wait(),
range(rootish_threshold),
key=[f"first-{i}" for i in range(rootish_threshold)],
)
await async_wait_for(lambda: s.queued, 5)

second_batch = c2.map(
lambda i: event.wait(),
range(rootish_threshold),
key=[f"second-{i}" for i in range(rootish_threshold)],
fifo_timeout=0,
)
await async_wait_for(lambda: second_batch[0].key in s.tasks, 5)

# All of the second batch should be queued after the first batch
assert [ts.key for ts in s.queued.sorted()] == [
f.key
for f in itertools.chain(first_batch[s.total_nthreads :], second_batch)
]

# Cancel the first batch.
# Use `Client.close` instead of `del first_batch` because deleting futures sends cancellation
# messages one at a time. We're testing here that when multiple workers have open slots, we don't
# recommend the same queued tasks for every worker, so we need a bulk cancellation operation.
await c.close()
del c, first_batch

await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5)

# Second batch should move up the queue and start processing
assert len(s.queued) == len(second_batch) - s.total_nthreads, list(
s.queued.sorted()
)

await event.set()
await c2.gather(second_batch)
Comment on lines +399 to +446
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test feels a bit heavy weight considering that we have a couple of very simple reproducers, see #7396 (comment)

If you feel strongly about this test, fine, but please add the other two very simple reproducers as well. Regardless of all the intricate timings and queuing, etc. The reproducers there should be true regardless of what internals look like

Copy link
Collaborator Author

@gjoseph92 gjoseph92 Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the reproducer from #7396 (comment) as well (with an Event to avoid timing issues):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_forget_tasks_while_processing(c, s, a, b):
events = [Event() for _ in range(10)]
futures = c.map(Event.wait, events)
await events[0].set()
await futures[0]
await c.close()
assert not s.tasks

This test is for #7401, which is a different issue (and I think maybe the same as #7398). Client.close is the only codepath that can cause client_releases_keys to be called with multiple keys, which is the case that breaks.

EDIT: Client.restart is the other code path that calls client_releases_keys with multiple keys...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a simple test with restart as well (fails on main, reproducer for #7398):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_restart_while_processing(c, s, a, b):
events = [Event() for _ in range(10)]
futures = c.map(Event.wait, events)
await events[0].set()
await futures[0]
await c.restart()
assert not s.tasks



@gen_cluster(
client=True,
nthreads=[("", 2)] * 2,
Expand Down Expand Up @@ -4237,7 +4312,7 @@ def assert_rootish():
await asyncio.sleep(0.005)
assert_rootish()
if rootish:
assert all(s.tasks[k] in s.queued for k in keys)
assert all(s.tasks[k] in s.queued for k in keys), [s.tasks[k] for k in keys]
await block.set()
# At this point we need/want to wait for the task-finished message to
# arrive on the scheduler. There is no proper hook to wait, therefore we
Expand Down