Skip to content

Commit

Permalink
Concatenate using task-based shuffling before P2P
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Aug 14, 2024
1 parent d8a9c8d commit f709b3c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
75 changes: 67 additions & 8 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@

from __future__ import annotations

import math
import mmap
import os
from collections import defaultdict
Expand All @@ -111,7 +112,7 @@
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from itertools import product
from itertools import chain, product
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, cast

Expand All @@ -124,6 +125,7 @@
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.typing import Key
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
from distributed.metrics import context_meter
Expand Down Expand Up @@ -221,7 +223,7 @@ def rechunk_p2p(
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)
from dask.array.core import new_da_object

prechunked = _prechunk_for_partials(x.chunks, chunks)
prechunked = _prechunk_for_partials(x.chunks, chunks, x.dtype)
if prechunked != x.chunks:
x = cast(
"da.Array",
Expand Down Expand Up @@ -434,8 +436,10 @@ def _construct_graph(self) -> _T_LowLevelGraph:


def _prechunk_for_partials(
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes, dtype: np.dtype
) -> ChunkedAxes:
import numpy as np

from dask.array.rechunk import old_to_new

_old_to_new = old_to_new(old_chunks, new_chunks)
Expand All @@ -449,6 +453,7 @@ def _prechunk_for_partials(
old_axis = old_chunks[axis_index]
split_axis = []
for slice_ in slices:
partial_chunks = []
first_new_chunk = slice_.start
first_old_chunk, first_old_slice = old_to_new_axis[first_new_chunk][0]
last_new_chunk = slice_.stop - 1
Expand All @@ -474,22 +479,76 @@ def _prechunk_for_partials(
assert first_old_slice.start == 0
chunk_size = last_old_slice.stop

split_axis.append(chunk_size)
split_axis.append([chunk_size])
continue

split_axis.append(first_chunk_size - first_old_slice.start)
partial_chunks.append(first_chunk_size - first_old_slice.start)

split_axis.extend(old_axis[first_old_chunk + 1 : last_old_chunk])
partial_chunks.extend(old_axis[first_old_chunk + 1 : last_old_chunk])

if last_old_slice.stop is not None:
chunk_size = last_old_slice.stop
else:
chunk_size = last_chunk_size

split_axis.append(chunk_size)
partial_chunks.append(chunk_size)
split_axis.append(partial_chunks)

split_axes.append(split_axis)
return tuple(tuple(axis) for axis in split_axes)

has_nans = (any(math.isnan(y) for y in x) for x in old_chunks)

if len(new_chunks) <= 1 or not all(new_chunks) or any(has_nans):
return tuple(tuple(chain(*axis)) for axis in split_axes)

if dtype is None or dtype.hasobject or dtype.itemsize == 0:
return tuple(tuple(chain(*axis)) for axis in split_axes)

# TODO: Incorporate block_size_limit from dask.array.rechunk(..., block_size_limit=...)
block_size_limit = dask.config.get("array.chunk-size")
if isinstance(block_size_limit, str):
block_size_limit = parse_bytes(block_size_limit)

# Make it a number of elements
block_size_limit //= dtype.itemsize

# We verified earlier that we do not have any NaNs
largest_old_block = _largest_block_size(old_chunks) # type: ignore[arg-type]
largest_new_block = _largest_block_size(new_chunks) # type: ignore[arg-type]
block_size_limit = max([block_size_limit, largest_old_block, largest_new_block])

max_chunk_sizes = tuple(max(chain(*axis)) for axis in split_axes)

block_reduction_ratio = tuple(
sum(len(partial) for partial in split_axis) / len(new_axis)
for split_axis, new_axis in zip(split_axes, new_chunks)
)

ascending = np.argsort(block_reduction_ratio)

concatenated_axes: list[list[float]] = [[] for _ in ascending]

for axis_index in ascending:
concatenated_axis = concatenated_axes[axis_index]
multiplier = math.prod(
max_chunk_sizes[:axis_index] + max_chunk_sizes[axis_index + 1 :]
)
axis_limit = block_size_limit // multiplier

for partial in split_axes[axis_index]:
current = partial[0]
for chunk in partial[1:]:
if (current + chunk) > axis_limit:
concatenated_axis.append(current)
current = chunk
else:
current += chunk
concatenated_axis.append(current)
return tuple(tuple(axis) for axis in concatenated_axes)


def _largest_block_size(chunks: tuple[tuple[int, ...], ...]) -> int:
return math.prod(map(max, chunks))


def _split_partials(
Expand Down
7 changes: 4 additions & 3 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,8 @@ async def test_rechunk_avoid_needless_chunking(c, s, *ws):
x = da.ones(16, chunks=2)
y = x.rechunk(8, method="p2p")
dsk = y.__dask_graph__()
assert len(dsk) <= 8 + 2
# 8 inputs, 2 concatenations of small inputs, 2 outputs
assert len(dsk) <= 8 + 2 + 2


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1365,7 +1366,7 @@ async def test_partial_rechunk_taskgroups(c, s):
],
)
def test_prechunk_for_partials_1d(old, new, expected):
actual = _prechunk_for_partials(old, new)
actual = _prechunk_for_partials(old, new, np.dtype(np.int16))
assert actual == expected


Expand All @@ -1385,5 +1386,5 @@ def test_prechunk_for_partials_1d(old, new, expected):
],
)
def test_prechunk_for_partials_2d(old, new, expected):
actual = _prechunk_for_partials(old, new)
actual = _prechunk_for_partials(old, new, np.dtype(np.int16))
assert actual == expected

0 comments on commit f709b3c

Please sign in to comment.