Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] P2P shuffle skeleton - scheduler plugin #5524

Closed
8 changes: 6 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

from distributed.utils import recursive_to_dict

from . import preloading, profile
from . import preloading, profile, shuffle
from . import versions as version_module
from .active_memory_manager import ActiveMemoryManagerExtension
from .batched import BatchedSend
Expand Down Expand Up @@ -188,6 +188,10 @@ def nogil(func):
ActiveMemoryManagerExtension,
MemorySamplerExtension,
]
DEFAULT_PLUGINS: "tuple[SchedulerPlugin, ...]" = (
(shuffle.ShuffleSchedulerPlugin(),) if shuffle.SHUFFLE_AVAILABLE else ()
)
# ^ TODO this assumes one Scheduler per process; probably a bad idea.

ALL_TASK_STATES = declare(
set, {"released", "waiting", "no-worker", "processing", "erred", "memory"}
Expand Down Expand Up @@ -3623,7 +3627,7 @@ def __init__(
http_prefix="/",
preload=None,
preload_argv=(),
plugins=(),
plugins=DEFAULT_PLUGINS,
**kwargs,
):
self._setup_logging(logger)
Expand Down
20 changes: 20 additions & 0 deletions distributed/shuffle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
try:
import pandas
except ImportError:
SHUFFLE_AVAILABLE = False
else:
del pandas
SHUFFLE_AVAILABLE = True

from .common import ShuffleId
from .graph import rearrange_by_column_p2p
from .shuffle_scheduler import ShuffleSchedulerPlugin
from .shuffle_worker import ShuffleWorkerExtension

__all__ = [
"SHUFFLE_AVAILABLE",
"ShuffleId",
"rearrange_by_column_p2p",
"ShuffleWorkerExtension",
"ShuffleSchedulerPlugin",
]
34 changes: 34 additions & 0 deletions distributed/shuffle/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import math
from typing import NewType

ShuffleId = NewType("ShuffleId", str)


def worker_for(output_partition: int, npartitions: int, workers: list[str]) -> str:
"Get the address of the worker which should hold this output partition number"
if output_partition < 0:
raise IndexError(f"Negative output partition: {output_partition}")
if output_partition >= npartitions:
raise IndexError(
f"Output partition {output_partition} does not exist in a shuffle producing {npartitions} partitions"
)
i = len(workers) * output_partition // npartitions
return workers[i]


def partition_range(
worker: str, npartitions: int, workers: list[str]
) -> tuple[int, int]:
"Get the output partition numbers (inclusive) that a worker will hold"
i = workers.index(worker)
first = math.ceil(npartitions * i / len(workers))
last = math.ceil(npartitions * (i + 1) / len(workers)) - 1
return first, last


def npartitions_for(worker: str, npartitions: int, workers: list[str]) -> int:
"Get the number of output partitions a worker will hold"
first, last = partition_range(worker, npartitions, workers)
return last - first + 1
117 changes: 117 additions & 0 deletions distributed/shuffle/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from dask.base import tokenize
from dask.blockwise import BlockwiseDepDict, blockwise
from dask.dataframe import DataFrame
from dask.dataframe.core import partitionwise_graph
from dask.highlevelgraph import HighLevelGraph

from .common import ShuffleId
from .shuffle_worker import ShuffleWorkerExtension

if TYPE_CHECKING:
import pandas as pd


def get_shuffle_extension() -> ShuffleWorkerExtension:
from distributed import get_worker

try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle")
if not extension:
raise RuntimeError(
f"The worker {worker.address} does not have a ShuffleExtension. "
"Is pandas installed on the worker?"
)
return extension


def shuffle_transfer(
data: pd.DataFrame, id: ShuffleId, npartitions: int, column: str
) -> None:
ext = get_shuffle_extension()
ext.sync(ext.add_partition(data, id, npartitions, column))


def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None:
ext = get_shuffle_extension()
ext.sync(ext.barrier(id))


def shuffle_unpack(
id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None
) -> pd.DataFrame:
return get_shuffle_extension().get_output_partition(id, i, empty)


def rearrange_by_column_p2p(
df: DataFrame,
column: str,
npartitions: int | None = None,
):
npartitions = npartitions or df.npartitions
token = tokenize(df, column, npartitions)

# We use `partitionwise_graph` instead of `map_partitions` so we can pass in our own key.
# The scheduler needs the task key to contain the shuffle ID; it's the only way it knows
# what shuffle a task belongs to.
# (Yes, this is rather brittle.)
transfer_name = "shuffle-transfer-" + token
transfer_dsk = partitionwise_graph(
shuffle_transfer, transfer_name, df, token, npartitions, column
)

barrier_name = "shuffle-barrier-" + token
barrier_dsk = {
barrier_name: (
shuffle_barrier,
token,
[(transfer_name, i) for i in range(df.npartitions)],
)
}

unpack_name = "shuffle-unpack-" + token
unpack_dsk = blockwise(
shuffle_unpack,
unpack_name,
"i",
token,
None,
BlockwiseDepDict({(i,): i for i in range(npartitions)}),
"i",
df._meta,
None,
barrier_name,
None,
numblocks={},
)

hlg = HighLevelGraph(
{
transfer_name: transfer_dsk,
barrier_name: barrier_dsk,
unpack_name: unpack_dsk,
**df.dask.layers,
},
{
transfer_name: set(df.__dask_layers__()),
barrier_name: {transfer_name},
unpack_name: {barrier_name},
**df.dask.dependencies,
},
)

return DataFrame(
hlg,
unpack_name,
df._meta,
[None] * (npartitions + 1),
)
Loading