Skip to content

Commit

Permalink
[WIP] Fix blockwise
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 3, 2023
1 parent f6cb33f commit 796dcd2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
25 changes: 14 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,42 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:

@memoize
def _get_optimal_chunks_for_groups(chunks, labels):
chunkidx = np.cumsum(chunks) - 1
chunks_array = np.asarray(chunks, like=labels)
chunkidx = np.cumsum(chunks_array) - 1
# what are the groups at chunk boundaries
labels_at_chunk_bounds = _unique(labels[chunkidx])
# what's the last index of all groups
last_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="last")
last_indexes = npg.aggregate_numpy.aggregate(
labels, np.arange(len(labels), like=labels), func="last"
)
# what's the last index of groups at the chunk boundaries.
lastidx = last_indexes[labels_at_chunk_bounds]

if len(chunkidx) == len(lastidx) and (chunkidx == lastidx).all():
return chunks

first_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="first")
first_indexes = npg.aggregate_numpy.aggregate(
labels, np.arange(len(labels), like=labels), func="first"
)
firstidx = first_indexes[labels_at_chunk_bounds]

newchunkidx = [0]
newchunkidx = np.array([0], like=labels)
for c, f, l in zip(chunkidx, firstidx, lastidx): # noqa
Δf = abs(c - f)
Δl = abs(c - l)
if c == 0 or newchunkidx[-1] > l:
continue
if Δf < Δl and f > newchunkidx[-1]:
newchunkidx.append(f)
newchunkidx = np.append(newchunkidx, f)
else:
newchunkidx.append(l + 1)
newchunkidx = np.append(newchunkidx, l + 1)
if newchunkidx[-1] != chunkidx[-1] + 1:
newchunkidx.append(chunkidx[-1] + 1)
newchunkidx = np.append(newchunkidx, chunkidx[-1] + 1)
newchunks = np.diff(newchunkidx)

assert sum(newchunks) == sum(chunks)
return tuple(newchunks)
# workaround cupy bug with tuple(array)
return tuple(newchunks.tolist())


def _unique(a: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -1317,9 +1323,6 @@ def dask_groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

if method == "blockwise" and not isinstance(by, np.ndarray):
raise NotImplementedError

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,8 +807,8 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
[(10,), (10,)],
],
)
def test_rechunk_for_blockwise(inchunks, expected):
labels = np.array([1, 1, 1, 2, 2, 3, 3, 5, 5, 5])
def test_rechunk_for_blockwise(inchunks, expected, array_module):
labels = array_module.array([1, 1, 1, 2, 2, 3, 3, 5, 5, 5])
assert _get_optimal_chunks_for_groups(inchunks, labels) == expected


Expand Down

0 comments on commit 796dcd2

Please sign in to comment.