diff --git a/asv_bench/benchmarks/reduce.py b/asv_bench/benchmarks/reduce.py index c475a8bc7..aa3edd611 100644 --- a/asv_bench/benchmarks/reduce.py +++ b/asv_bench/benchmarks/reduce.py @@ -110,7 +110,7 @@ def setup(self, *args, **kwargs): class ChunkReduce2DAllAxes(ChunkReduce): def setup(self, *args, **kwargs): self.array = np.ones((N, N)) - self.labels = np.repeat(np.arange(N // 5), repeats=5) + self.labels = np.repeat(np.arange(N // 5), repeats=5)[np.newaxis, :] self.axis = None setup_jit() diff --git a/flox/core.py b/flox/core.py index 8041ee475..edf41c727 100644 --- a/flox/core.py +++ b/flox/core.py @@ -782,16 +782,25 @@ def chunk_reduce( ) groups = grps[0] + order = "C" if nax > 1: needs_broadcast = any( group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1 for ax in range(-nax, 0) ) if needs_broadcast: + # This is the dim=... case, it's a lot faster to ravel group_idx + # in fortran order since group_idx is then sorted + # I'm seeing 400ms -> 23ms for engine="flox" + # Of course we are slower to ravel `array` but we avoid argsorting + # both `array` *and* `group_idx` in _prepare_for_flox group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :]) + if engine == "flox": + group_idx = group_idx.reshape(-1, order="F") + order = "F" # always reshape to 1D along group dimensions newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),) - array = array.reshape(newshape) + array = array.reshape(newshape, order=order) # type: ignore[call-overload] group_idx = group_idx.reshape(-1) assert group_idx.ndim == 1 @@ -1814,7 +1823,8 @@ def groupby_reduce( Array to be reduced, possibly nD *by : ndarray or DaskArray Array of labels to group over. Must be aligned with ``array`` so that - ``array.shape[-by.ndim :] == by.shape`` + ``array.shape[-by.ndim :] == by.shape`` or any disagreements in that + equality check are for dimensions of size 1 in `by`. func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ diff --git a/tests/test_core.py b/tests/test_core.py index 99f181255..ba8ad9340 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -203,7 +203,7 @@ def test_groupby_reduce( def gen_array_by(size, func): by = np.ones(size[-1]) rng = np.random.default_rng(12345) - array = rng.random(size) + array = rng.random(tuple(6 if s == 1 else s for s in size)) if "nan" in func and "nanarg" not in func: array[[1, 4, 5], ...] = np.nan elif "nanarg" in func and len(size) > 1: @@ -222,8 +222,8 @@ def gen_array_by(size, func): pytest.param(4, marks=requires_dask), ], ) +@pytest.mark.parametrize("size", ((1, 12), (12,), (12, 9))) @pytest.mark.parametrize("nby", [1, 2, 3]) -@pytest.mark.parametrize("size", ((12,), (12, 9))) @pytest.mark.parametrize("add_nan_by", [True, False]) @pytest.mark.parametrize("func", ALL_FUNCS) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):