Skip to content

Commit

Permalink
Set order='F' when raveling group_idx after broadcast
Browse files Browse the repository at this point in the history
This majorly improves the dim=... case for engine="flox" at least.
xref #281

I'm not sure if it is a regression for engine="numpy"

We trade off a single bad reshape for array against argsorting both
array and group_idx for a ~10-20x speedup

```
ds = xr.tutorial.load_dataset('air_temperature')
ds.groupby('lon').count(..., engine="flox")
```
  • Loading branch information
dcherian committed Nov 8, 2023
1 parent 0cea972 commit 746b8fd
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,16 +782,25 @@ def chunk_reduce(
)
groups = grps[0]

order = "C"
if nax > 1:
needs_broadcast = any(
group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
for ax in range(-nax, 0)
)
if needs_broadcast:
# This is the dim=... case, it's a lot faster to ravel group_idx
# in fortran order since group_idx is then sorted
# I'm seeing 400ms -> 23ms for engine="flox"
# Of course we are slower to ravel `array` but we avoid argsorting
# both `array` *and* `group_idx` in _prepare_for_flox
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
# if engine == "flox":
group_idx = group_idx.reshape(-1, order="F")
order = "F"
# always reshape to 1D along group dimensions
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
array = array.reshape(newshape)
array = array.reshape(newshape, order=order)
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1
Expand Down

0 comments on commit 746b8fd

Please sign in to comment.