From bacac078090f34a683acf302d50d72376b15fc57 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 14:39:28 -0600 Subject: [PATCH 01/20] Support quantile, median with method="blockwise". We allow method="blockwise" when grouping by a dask array. This can only work if we have expected_groups, and set reindex=True. --- flox/aggregate_npg.py | 48 +++++++++++++++++++++++++++++++++++++++++++ flox/aggregations.py | 10 +++++++-- flox/core.py | 28 +++++++++++++++---------- flox/xarray.py | 2 +- 4 files changed, 74 insertions(+), 14 deletions(-) diff --git a/flox/aggregate_npg.py b/flox/aggregate_npg.py index 30e0eb257..af4f09000 100644 --- a/flox/aggregate_npg.py +++ b/flox/aggregate_npg.py @@ -100,3 +100,51 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None, len = partial(_len, func="len") nanlen = partial(_len, func="nanlen") + + +def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=np.median, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=np.nanmedian, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(np.quantile, q=q), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(np.nanquantile, q=q), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) diff --git a/flox/aggregations.py b/flox/aggregations.py index e5013032a..ee8ce5dc2 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -467,8 +467,10 @@ def _pick_second(*x): # numpy_groupies does not support median # And the dask version is really hard! -# median = Aggregation("median", chunk=None, combine=None, fill_value=None) -# nanmedian = Aggregation("nanmedian", chunk=None, combine=None, fill_value=None) +median = Aggregation(name="median", fill_value=-1, chunk=None, combine=None) +nanmedian = Aggregation(name="nanmedian", fill_value=-1, chunk=None, combine=None) +quantile = Aggregation(name="quantile", fill_value=-1, chunk=None, combine=None) +nanquantile = Aggregation(name="nanquantile", fill_value=-1, chunk=None, combine=None) aggregations = { "any": any_, @@ -496,6 +498,10 @@ def _pick_second(*x): "nanfirst": nanfirst, "last": last, "nanlast": nanlast, + "median": median, + "nanmedian": nanmedian, + "quantile": quantile, + "nanquantile": nanquantile, } diff --git a/flox/core.py b/flox/core.py index 98c37bc6f..c637fba2f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1307,9 +1307,6 @@ def dask_groupby_agg( assert isinstance(axis, Sequence) assert all(ax >= 0 for ax in axis) - if method == "blockwise" and not isinstance(by, np.ndarray): - raise NotImplementedError - inds = tuple(range(array.ndim)) name = f"groupby_{agg.name}" token = dask.base.tokenize(array, by, agg, expected_groups, axis) @@ -1471,11 +1468,15 @@ def dask_groupby_agg( # Here one input chunk → one output chunks # find number of groups in each chunk, this is needed for output chunks # along the reduced axis - slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) - groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) - groups = (np.concatenate(groups_in_block),) - ngroups_per_block = tuple(len(grp) for grp in groups_in_block) - group_chunks = (ngroups_per_block,) + if isinstance(by, dask.array.Array): + groups = (expected_groups,) + group_chunks = ((len(expected_groups),),) + else: + slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) + groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) + groups = (np.concatenate(groups_in_block),) + ngroups_per_block = tuple(len(grp) for grp in groups_in_block) + group_chunks = (ngroups_per_block,) else: raise ValueError(f"Unknown method={method}.") @@ -1544,6 +1545,7 @@ def _validate_reindex( is_dask_array: bool, ) -> bool: all_numpy = not is_dask_array and not any_by_dask + if reindex is True and not all_numpy: if _is_arg_reduction(func): raise NotImplementedError @@ -1562,7 +1564,11 @@ def _validate_reindex( # have to do the grouped_combine since there's no good fill_value reindex = False - if method == "blockwise" or _is_arg_reduction(func): + if method == "blockwise": + # for grouping by dask arrays, we set reindex=True + reindex = any_by_dask + + elif _is_arg_reduction(func): reindex = False elif method == "cohorts": @@ -1835,7 +1841,7 @@ def groupby_reduce( boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions. finalize_kwargs : dict, optional - Kwargs passed to finalize the reduction such as ``ddof`` for var, std. + Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile. Returns ------- @@ -2023,7 +2029,7 @@ def groupby_reduce( result, groups = partial_agg( array, by_, - expected_groups=None if method == "blockwise" else expected_groups, + expected_groups=expected_groups, agg=agg, reindex=reindex, method=method, diff --git a/flox/xarray.py b/flox/xarray.py index c85ad7113..dfa30cc53 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -164,7 +164,7 @@ def xarray_reduce( boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions. **finalize_kwargs - kwargs passed to the finalize function, like ``ddof`` for var, std. + kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile. Returns ------- From 3857afd750d112107a23d1628222b218674627ac Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 19:12:08 -0600 Subject: [PATCH 02/20] Update flox/core.py --- flox/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index c637fba2f..07597cae6 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1545,7 +1545,6 @@ def _validate_reindex( is_dask_array: bool, ) -> bool: all_numpy = not is_dask_array and not any_by_dask - if reindex is True and not all_numpy: if _is_arg_reduction(func): raise NotImplementedError From c0be4c7a7edb843afc1653980b047f9942beb181 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 19:25:19 -0600 Subject: [PATCH 03/20] fix comment --- flox/core.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index 07597cae6..ac645e5dd 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1465,13 +1465,17 @@ def dask_groupby_agg( elif method == "blockwise": reduced = intermediate - # Here one input chunk → one output chunks - # find number of groups in each chunk, this is needed for output chunks - # along the reduced axis if isinstance(by, dask.array.Array): + # TODO: we could have `expected_groups` be a dask array with appropriate chunks + # for now, we have a numpy array that is interpreted as listing all group labels + # that are present in every chunk groups = (expected_groups,) group_chunks = ((len(expected_groups),),) else: + # Here one input chunk → one output chunks + # find number of groups in each chunk, this is needed for output chunks + # along the reduced axis + # TODO: this logic is very specialized for the resampling case slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) groups = (np.concatenate(groups_in_block),) From 32028a98131822aa0600dda87edd57dee7cbc5b2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 19:25:25 -0600 Subject: [PATCH 04/20] Update validate_reindex test --- flox/core.py | 2 +- tests/test_core.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index ac645e5dd..85e5971f3 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1552,7 +1552,7 @@ def _validate_reindex( if reindex is True and not all_numpy: if _is_arg_reduction(func): raise NotImplementedError - if method in ["blockwise", "cohorts"]: + if method == "cohorts" or (method == "blockwise" and not any_by_dask): raise ValueError( "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." ) diff --git a/tests/test_core.py b/tests/test_core.py index 453958b2d..5763c6893 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1387,6 +1387,33 @@ def test_validate_reindex() -> None: ) assert actual is False + with pytest.raises(ValueError): + _validate_reindex( + True, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=False, + is_dask_array=True, + ) + + assert _validate_reindex( + True, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=True, + is_dask_array=True, + ) + assert _validate_reindex( + None, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=True, + is_dask_array=True, + ) + @requires_dask def test_1d_blockwise_sort_optimization(): From effc7619817c45761cc4143d059726df8107ab43 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 19:34:55 -0600 Subject: [PATCH 05/20] Fix --- flox/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index 85e5971f3..4aa28cc6f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1346,7 +1346,6 @@ def dask_groupby_agg( # b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction. # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) - do_simple_combine = not _is_arg_reduction(agg) if method == "blockwise": @@ -1372,7 +1371,7 @@ def dask_groupby_agg( partial( blockwise_method, axis=axis, - expected_groups=None if method == "cohorts" else expected_groups, + expected_groups=expected_groups if reindex else None, engine=engine, sort=sort, ), @@ -1465,7 +1464,7 @@ def dask_groupby_agg( elif method == "blockwise": reduced = intermediate - if isinstance(by, dask.array.Array): + if reindex: # TODO: we could have `expected_groups` be a dask array with appropriate chunks # for now, we have a numpy array that is interpreted as listing all group labels # that are present in every chunk From 77a784a17ad6906078759a7b645fbf5622d410d2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 19:34:55 -0600 Subject: [PATCH 06/20] Fix --- flox/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flox/core.py b/flox/core.py index 4aa28cc6f..195e6de30 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1313,6 +1313,8 @@ def dask_groupby_agg( if expected_groups is None and reindex: expected_groups = _get_expected_groups(by, sort=sort) + if method == "cohorts": + assert reindex is False by_input = by From 460a167206af6de9e3a2d0d686ea32f8ba91c016 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 20:00:42 -0600 Subject: [PATCH 07/20] Raise early if `q` is not provided for quantile --- flox/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flox/core.py b/flox/core.py index 195e6de30..945b99925 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1865,6 +1865,9 @@ def groupby_reduce( "Try engine='numpy' or engine='numba' instead." ) + if func == "quantile" and "q" not in finalize_kwargs: + raise ValueError("Please pass `q` for quantile calculations.") + bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) by_is_dask = tuple(is_duck_dask_array(b) for b in bys) From 2a371299e06ce8011dab8c31979f48b306579892 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 20:35:39 -0600 Subject: [PATCH 08/20] WIP test --- tests/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 5763c6893..aa8dc53e8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -75,6 +75,8 @@ def dask_array_ones(*args): "nanlast", pytest.param("median", marks=(pytest.mark.skip,)), pytest.param("nanmedian", marks=(pytest.mark.skip,)), + pytest.param("quantile", marks=(pytest.mark.skip,)), + pytest.param("nanquantile", marks=(pytest.mark.skip,)), ) if TYPE_CHECKING: From 01d1ddcebb71e5a9988f2ff725d9a84c0318b969 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 21:16:52 -0600 Subject: [PATCH 09/20] narrow type --- flox/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flox/core.py b/flox/core.py index 945b99925..fc8ace94c 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1467,6 +1467,8 @@ def dask_groupby_agg( elif method == "blockwise": reduced = intermediate if reindex: + if TYPE_CHECKING: + assert expected_groups is not None # TODO: we could have `expected_groups` be a dask array with appropriate chunks # for now, we have a numpy array that is interpreted as listing all group labels # that are present in every chunk From d88835d331e48b8f27001bdf574ff207d3b5f51b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Oct 2023 21:34:08 -0600 Subject: [PATCH 10/20] fix type --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index fc8ace94c..576634b81 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1867,7 +1867,7 @@ def groupby_reduce( "Try engine='numpy' or engine='numba' instead." ) - if func == "quantile" and "q" not in finalize_kwargs: + if func == "quantile" and (finalize_kwargs is None or "q" not in finalize_kwargs): raise ValueError("Please pass `q` for quantile calculations.") bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) From 6f00e2d22a25546b9f31dc197fb596057d0b373e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 08:27:45 -0600 Subject: [PATCH 11/20] Mode + Tests --- ci/environment.yml | 1 + flox/aggregate_npg.py | 33 ++++++++++++++++++ flox/aggregations.py | 16 +++++---- tests/__init__.py | 1 + tests/test_core.py | 78 ++++++++++++++++++++++++++++++++----------- 5 files changed, 104 insertions(+), 25 deletions(-) diff --git a/ci/environment.yml b/ci/environment.yml index 93f0e891f..9d5aa6d01 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -22,3 +22,4 @@ dependencies: - pooch - toolz - numba + - scipy diff --git a/flox/aggregate_npg.py b/flox/aggregate_npg.py index af4f09000..966bd43b8 100644 --- a/flox/aggregate_npg.py +++ b/flox/aggregate_npg.py @@ -148,3 +148,36 @@ def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=N fill_value=fill_value, dtype=dtype, ) + + +def mode_(array, nan_policy, dtype): + from scipy.stats import mode + + # npg splits `array` into object arrays for each group + # scipy.stats.mode does not like that + # here we cast back + return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1).mode + + +def mode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(mode_, nan_policy="propagate", dtype=array.dtype), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanmode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(mode_, nan_policy="omit", dtype=array.dtype), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) diff --git a/flox/aggregations.py b/flox/aggregations.py index ee8ce5dc2..3cd339525 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -465,12 +465,14 @@ def _pick_second(*x): final_dtype=bool, ) -# numpy_groupies does not support median -# And the dask version is really hard! -median = Aggregation(name="median", fill_value=-1, chunk=None, combine=None) -nanmedian = Aggregation(name="nanmedian", fill_value=-1, chunk=None, combine=None) -quantile = Aggregation(name="quantile", fill_value=-1, chunk=None, combine=None) -nanquantile = Aggregation(name="nanquantile", fill_value=-1, chunk=None, combine=None) +# Support statistical quantities only blockwise +# The parallel versions will be approximate and are hard to implement! +median = Aggregation(name="median", fill_value=dtypes.NA, chunk=None, combine=None) +nanmedian = Aggregation(name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None) +quantile = Aggregation(name="quantile", fill_value=dtypes.NA, chunk=None, combine=None) +nanquantile = Aggregation(name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None) +mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) +nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) aggregations = { "any": any_, @@ -502,6 +504,8 @@ def _pick_second(*x): "nanmedian": nanmedian, "quantile": quantile, "nanquantile": nanquantile, + "mode": mode, + "nanmode": nanmode, } diff --git a/tests/__init__.py b/tests/__init__.py index 3e43e94d6..f1c8ec6bf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -47,6 +47,7 @@ def LooseVersion(vstring): has_dask, requires_dask = _importorskip("dask") has_numba, requires_numba = _importorskip("numba") +has_scipy, requires_scipy = _importorskip("scipy") has_xarray, requires_xarray = _importorskip("xarray") diff --git a/tests/test_core.py b/tests/test_core.py index aa8dc53e8..03beb7013 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -31,6 +31,7 @@ has_dask, raise_if_dask_computes, requires_dask, + requires_scipy, ) labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) @@ -77,7 +78,15 @@ def dask_array_ones(*args): pytest.param("nanmedian", marks=(pytest.mark.skip,)), pytest.param("quantile", marks=(pytest.mark.skip,)), pytest.param("nanquantile", marks=(pytest.mark.skip,)), + "median", + "nanmedian", + "quantile", + "nanquantile", + pytest.param("mode", marks=requires_scipy), + pytest.param("nanmode", marks=requires_scipy), ) +BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile", "mode", "nanmode") +SCIPY_STATS_FUNCS = ("mode", "nanmode") if TYPE_CHECKING: from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method @@ -92,6 +101,19 @@ def npfunc(x): elif func in ["nanfirst", "nanlast"]: npfunc = getattr(xrutils, func) + elif func in SCIPY_STATS_FUNCS: + import scipy.stats + + if "nan" in func: + func = func[3:] + nan_policy = "omit" + else: + nan_policy = "propagate" + + def npfunc(x, **kwargs): + spfunc = partial(getattr(scipy.stats, func), nan_policy=nan_policy) + return getattr(spfunc(x, **kwargs), func) + else: npfunc = getattr(np, func) @@ -209,6 +231,9 @@ def gen_array_by(size, func): def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if "arg" in func and engine == "flox": pytest.skip() + if func in BLOCKWISE_FUNCS: + if chunks != -1: + pytest.skip() array, by = gen_array_by(size, func) if chunks: @@ -226,6 +251,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}] fill_value = np.nan tolerance = {"rtol": 1e-14, "atol": 1e-16} + elif "quantile" in func: + finalize_kwargs = [{"q": 0.5}] + fill_value = None + tolerance = None else: fill_value = None tolerance = None @@ -248,15 +277,16 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): func_ = f"nan{func}" if "nan" not in func else func array_[..., nanmask] = np.nan expected = getattr(np, func_)(array_, axis=-1, **kwargs) - # elif func in ["first", "last"]: - # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs) - elif func in ["nanfirst", "nanlast"]: - expected = getattr(xrutils, func)(array_[..., ~nanmask], axis=-1, **kwargs) else: - expected = getattr(np, func)(array_[..., ~nanmask], axis=-1, **kwargs) + array_func = _get_array_func(func) + expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs) for _ in range(nby): expected = np.expand_dims(expected, -1) + if func in BLOCKWISE_FUNCS: + assert chunks == -1 + flox_kwargs["method"] = "blockwise" + actual, *groups = groupby_reduce(array, *by, **flox_kwargs) assert actual.ndim == (array.ndim + nby - 1) assert expected.ndim == (array.ndim + nby - 1) @@ -267,7 +297,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert actual.dtype.kind == "i" assert_equal(actual, expected, tolerance) - if not has_dask or chunks is None: + if not has_dask or chunks is None or func in BLOCKWISE_FUNCS: continue params = list(itertools.product(["map-reduce"], [True, False, None])) @@ -398,7 +428,7 @@ def test_numpy_reduce_nd_md(): def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtype, engine, reindex): """Tests groupby_reduce with dask arrays against groupby_reduce with numpy arrays""" - if func in ["first", "last"]: + if func in ["first", "last"] or func in BLOCKWISE_FUNCS: pytest.skip() if "arg" in func and (engine == "flox" or reindex): @@ -553,7 +583,7 @@ def test_first_last_disallowed_dask(func): "axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)] ) def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): - if "arg" in func and engine == "flox": + if ("arg" in func and engine == "flox") or func in BLOCKWISE_FUNCS: pytest.skip() if not isinstance(axis, int): @@ -849,7 +879,7 @@ def test_rechunk_for_cohorts(chunk_at, expected): def test_fill_value_behaviour(func, chunks, fill_value, engine): # fill_value = np.nan tests promotion of int counts to float # This is used by xarray - if func in ["all", "any"] or "arg" in func: + if (func in ["all", "any"] or "arg" in func) or func in BLOCKWISE_FUNCS: pytest.skip() npfunc = _get_array_func(func) @@ -905,8 +935,17 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype): @requires_dask @pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize("axis", (-1, None)) -@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce", "split-reduce"]) +@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce"]) def test_cohorts_nd_by(func, method, axis, engine): + if ( + ("arg" in func and (axis is None or engine == "flox")) + or (method != "blockwise" and func in BLOCKWISE_FUNCS) + or (axis is None and ("first" in func or "last" in func)) + ): + pytest.skip() + if axis is not None and method != "map-reduce": + pytest.xfail() + o = dask.array.ones((3,), chunks=-1) o2 = dask.array.ones((2, 3), chunks=-1) @@ -917,20 +956,14 @@ def test_cohorts_nd_by(func, method, axis, engine): by[0, 4] = 31 array = np.broadcast_to(array, (2, 3) + array.shape) - if "arg" in func and (axis is None or engine == "flox"): - pytest.skip() - if func in ["any", "all"]: fill_value = False else: fill_value = -123 - if axis is not None and method != "map-reduce": - pytest.xfail() - if axis is None and ("first" in func or "last" in func): - pytest.skip() - kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value) + if "quantile" in func: + kwargs["finalize_kwargs"] = {"q": 0.9} actual, groups = groupby_reduce(array, by, **kwargs) expected, sorted_groups = groupby_reduce(array.compute(), by, **kwargs) assert_equal(groups, sorted_groups) @@ -992,6 +1025,8 @@ def test_datetime_binning(): def test_bool_reductions(func, engine): if "arg" in func and engine == "flox": pytest.skip() + if "quantile" in func or "mode" in func: + pytest.skip() groups = np.array([1, 1, 1]) data = np.array([True, True, False]) npfunc = _get_array_func(func) @@ -1244,9 +1279,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty def test_dtype(func, dtype, engine): if "arg" in func or func in ["any", "all"]: pytest.skip() + + finalize_kwargs = {"q": 0.6} if "quantile" in func else {} + arr = np.ones((4, 12), dtype=dtype) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) - actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64, engine=engine) + actual, _ = groupby_reduce( + arr, labels, func=func, dtype=np.float64, engine=engine, finalize_kwargs=finalize_kwargs + ) assert actual.dtype == np.dtype("float64") From d46c3aef3b9c5d9dadd6555c0e7b1b628c0633e6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 08:28:05 -0600 Subject: [PATCH 12/20] limit tests --- tests/conftest.py | 9 ++++++--- tests/test_core.py | 48 +++++++++++++++++++++------------------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index eb5971784..604e6b7a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ import pytest -from . import requires_numba - @pytest.fixture( - scope="module", params=["flox", "numpy", pytest.param("numba", marks=requires_numba)] + scope="module", + params=[ + # "flox", + "numpy", + # pytest.param("numba", marks=requires_numba) + ], ) def engine(request): return request.param diff --git a/tests/test_core.py b/tests/test_core.py index 03beb7013..f889389d9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -52,32 +52,28 @@ def dask_array_ones(*args): ALL_FUNCS = ( - "sum", - "nansum", - "argmax", - "nanfirst", - "nanargmax", - "prod", - "nanprod", - "mean", - "nanmean", - "var", - "nanvar", - "std", - "nanstd", - "max", - "nanmax", - "min", - "nanmin", - "argmin", - "nanargmin", - "any", - "all", - "nanlast", - pytest.param("median", marks=(pytest.mark.skip,)), - pytest.param("nanmedian", marks=(pytest.mark.skip,)), - pytest.param("quantile", marks=(pytest.mark.skip,)), - pytest.param("nanquantile", marks=(pytest.mark.skip,)), + # "sum", + # "nansum", + # "argmax", + # "nanfirst", + # "nanargmax", + # "prod", + # "nanprod", + # "mean", + # "nanmean", + # "var", + # "nanvar", + # "std", + # "nanstd", + # "max", + # "nanmax", + # "min", + # "nanmin", + # "argmin", + # "nanargmin", + # "any", + # "all", + # "nanlast", "median", "nanmedian", "quantile", From 78c94a990cef17860b49d9b6ae30a18116317637 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 08:32:26 -0600 Subject: [PATCH 13/20] cleanup tests --- tests/test_core.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index f889389d9..4bfdfd484 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -51,6 +51,8 @@ def dask_array_ones(*args): return None +SCIPY_STATS_FUNCS = ("mode", "nanmode") +BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS ALL_FUNCS = ( # "sum", # "nansum", @@ -78,11 +80,7 @@ def dask_array_ones(*args): "nanmedian", "quantile", "nanquantile", - pytest.param("mode", marks=requires_scipy), - pytest.param("nanmode", marks=requires_scipy), -) -BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile", "mode", "nanmode") -SCIPY_STATS_FUNCS = ("mode", "nanmode") +) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS) if TYPE_CHECKING: from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method @@ -225,11 +223,8 @@ def gen_array_by(size, func): @pytest.mark.parametrize("add_nan_by", [True, False]) @pytest.mark.parametrize("func", ALL_FUNCS) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): - if "arg" in func and engine == "flox": + if ("arg" in func and engine == "flox") or (func in BLOCKWISE_FUNCS and chunks != -1): pytest.skip() - if func in BLOCKWISE_FUNCS: - if chunks != -1: - pytest.skip() array, by = gen_array_by(size, func) if chunks: From f6107c36a164729136c94c2dd37f0bfe87fe3076 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 12:27:56 -0600 Subject: [PATCH 14/20] fix bool tests --- flox/aggregations.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 3cd339525..34053cc21 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -467,10 +467,18 @@ def _pick_second(*x): # Support statistical quantities only blockwise # The parallel versions will be approximate and are hard to implement! -median = Aggregation(name="median", fill_value=dtypes.NA, chunk=None, combine=None) -nanmedian = Aggregation(name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None) -quantile = Aggregation(name="quantile", fill_value=dtypes.NA, chunk=None, combine=None) -nanquantile = Aggregation(name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None) +median = Aggregation( + name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +nanmedian = Aggregation( + name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +quantile = Aggregation( + name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +nanquantile = Aggregation( + name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) From 6e5505f2c7ec1c7ce991deef3a6d475fbd15996f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 12:28:17 -0600 Subject: [PATCH 15/20] Revert "limit tests" This reverts commit d46c3aef3b9c5d9dadd6555c0e7b1b628c0633e6. --- tests/conftest.py | 9 +++------ tests/test_core.py | 48 +++++++++++++++++++++++++--------------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 604e6b7a3..eb5971784 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,10 @@ import pytest +from . import requires_numba + @pytest.fixture( - scope="module", - params=[ - # "flox", - "numpy", - # pytest.param("numba", marks=requires_numba) - ], + scope="module", params=["flox", "numpy", pytest.param("numba", marks=requires_numba)] ) def engine(request): return request.param diff --git a/tests/test_core.py b/tests/test_core.py index 4bfdfd484..35ce8b74e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -54,28 +54,32 @@ def dask_array_ones(*args): SCIPY_STATS_FUNCS = ("mode", "nanmode") BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS ALL_FUNCS = ( - # "sum", - # "nansum", - # "argmax", - # "nanfirst", - # "nanargmax", - # "prod", - # "nanprod", - # "mean", - # "nanmean", - # "var", - # "nanvar", - # "std", - # "nanstd", - # "max", - # "nanmax", - # "min", - # "nanmin", - # "argmin", - # "nanargmin", - # "any", - # "all", - # "nanlast", + "sum", + "nansum", + "argmax", + "nanfirst", + "nanargmax", + "prod", + "nanprod", + "mean", + "nanmean", + "var", + "nanvar", + "std", + "nanstd", + "max", + "nanmax", + "min", + "nanmin", + "argmin", + "nanargmin", + "any", + "all", + "nanlast", + pytest.param("median", marks=(pytest.mark.skip,)), + pytest.param("nanmedian", marks=(pytest.mark.skip,)), + pytest.param("quantile", marks=(pytest.mark.skip,)), + pytest.param("nanquantile", marks=(pytest.mark.skip,)), "median", "nanmedian", "quantile", From 73a80c1502e82a876b3ab6c571809875a5da50fb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 12:30:11 -0600 Subject: [PATCH 16/20] Small cleanup --- tests/test_core.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 35ce8b74e..de26b9eb8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -51,6 +51,7 @@ def dask_array_ones(*args): return None +DEFAULT_QUANTILE = 0.9 SCIPY_STATS_FUNCS = ("mode", "nanmode") BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS ALL_FUNCS = ( @@ -93,12 +94,13 @@ def dask_array_ones(*args): def _get_array_func(func: str) -> Callable: if func == "count": - def npfunc(x): + def npfunc(x, **kwargs): x = np.asarray(x) return (~np.isnan(x)).sum() elif func in ["nanfirst", "nanlast"]: npfunc = getattr(xrutils, func) + elif func in SCIPY_STATS_FUNCS: import scipy.stats @@ -247,7 +249,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): fill_value = np.nan tolerance = {"rtol": 1e-14, "atol": 1e-16} elif "quantile" in func: - finalize_kwargs = [{"q": 0.5}] + finalize_kwargs = [{"q": DEFAULT_QUANTILE}] fill_value = None tolerance = None else: @@ -958,7 +960,7 @@ def test_cohorts_nd_by(func, method, axis, engine): kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value) if "quantile" in func: - kwargs["finalize_kwargs"] = {"q": 0.9} + kwargs["finalize_kwargs"] = {"q": DEFAULT_QUANTILE} actual, groups = groupby_reduce(array, by, **kwargs) expected, sorted_groups = groupby_reduce(array.compute(), by, **kwargs) assert_equal(groups, sorted_groups) @@ -1275,7 +1277,7 @@ def test_dtype(func, dtype, engine): if "arg" in func or func in ["any", "all"]: pytest.skip() - finalize_kwargs = {"q": 0.6} if "quantile" in func else {} + finalize_kwargs = {"q": DEFAULT_QUANTILE} if "quantile" in func else {} arr = np.ones((4, 12), dtype=dtype) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) From a5b9c9c282c6f58bf74d504bef096b8d9c414b99 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 12:43:22 -0600 Subject: [PATCH 17/20] more cleanup --- tests/test_core.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index de26b9eb8..440431304 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -77,10 +77,6 @@ def dask_array_ones(*args): "any", "all", "nanlast", - pytest.param("median", marks=(pytest.mark.skip,)), - pytest.param("nanmedian", marks=(pytest.mark.skip,)), - pytest.param("quantile", marks=(pytest.mark.skip,)), - pytest.param("nanquantile", marks=(pytest.mark.skip,)), "median", "nanmedian", "quantile", From e0b33722be9f060f5a068aa776f246f6bf05b961 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 12:46:07 -0600 Subject: [PATCH 18/20] fix test --- flox/aggregations.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 34053cc21..c06ef3509 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, TypedDict import numpy as np -import numpy_groupies as npg from numpy.typing import DTypeLike from . import aggregate_flox, aggregate_npg, xrutils @@ -35,6 +34,16 @@ class AggDtype(TypedDict): intermediate: tuple[np.dtype | type[np.intp], ...] +def get_npg_aggregation(func, *, engine): + try: + method_ = getattr(aggregate_npg, func) + method = partial(method_, engine=engine) + except AttributeError: + aggregate = aggregate_npg._get_aggregate(engine).aggregate + method = partial(aggregate, func=func) + return method + + def generic_aggregate( group_idx, array, @@ -51,14 +60,11 @@ def generic_aggregate( try: method = getattr(aggregate_flox, func) except AttributeError: - method = partial(npg.aggregate_numpy.aggregate, func=func) + method = get_npg_aggregation(func, engine="numpy") + elif engine in ["numpy", "numba"]: - try: - method_ = getattr(aggregate_npg, func) - method = partial(method_, engine=engine) - except AttributeError: - aggregate = aggregate_npg._get_aggregate(engine).aggregate - method = partial(aggregate, func=func) + method = get_npg_aggregation(func, engine=engine) + else: raise ValueError( f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead." From 45d7dcb2d87b4d854fe0803c4276d70e12640a91 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 14:09:35 -0600 Subject: [PATCH 19/20] ignore scipy typing --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e41fb17eb..e387657d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,8 @@ module=[ "matplotlib.*", "pandas", "setuptools", - "toolz" + "scipy.*", + "toolz", ] ignore_missing_imports = true From b3aa7f53a9563320bc197b93154cf87e8d7373c0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 5 Oct 2023 16:17:28 -0600 Subject: [PATCH 20/20] update docs --- docs/source/aggregations.md | 7 +++++-- docs/source/user-stories/custom-aggregations.ipynb | 9 +++++++-- flox/core.py | 5 ++++- flox/xarray.py | 7 +++++-- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/docs/source/aggregations.md b/docs/source/aggregations.md index e6c10e4ba..d3591d2dc 100644 --- a/docs/source/aggregations.md +++ b/docs/source/aggregations.md @@ -11,8 +11,11 @@ the `func` kwarg: - `"std"`, `"nanstd"` - `"argmin"` - `"argmax"` -- `"first"` -- `"last"` +- `"first"`, `"nanfirst"` +- `"last"`, `"nanlast"` +- `"median"`, `"nanmedian"` +- `"mode"`, `"nanmode"` +- `"quantile"`, `"nanquantile"` ```{tip} We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome! diff --git a/docs/source/user-stories/custom-aggregations.ipynb b/docs/source/user-stories/custom-aggregations.ipynb index f191c77e0..8b9be09e9 100644 --- a/docs/source/user-stories/custom-aggregations.ipynb +++ b/docs/source/user-stories/custom-aggregations.ipynb @@ -15,8 +15,13 @@ ">\n", "> A = da.groupby(['lon_bins', 'lat_bins']).mode()\n", "\n", - "This notebook will describe how to accomplish this using a custom `Aggregation`\n", - "since `mode` and `median` aren't supported by flox yet.\n" + "This notebook will describe how to accomplish this using a custom `Aggregation`.\n", + "\n", + "\n", + "```{tip}\n", + "flox now supports `mode`, `nanmode`, `quantile`, `nanquantile`, `median`, `nanmedian` using exactly the same \n", + "approach as shown below\n", + "```\n" ] }, { diff --git a/flox/core.py b/flox/core.py index 576634b81..484303295 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1779,7 +1779,10 @@ def groupby_reduce( *by : ndarray or DaskArray Array of labels to group over. Must be aligned with ``array`` so that ``array.shape[-by.ndim :] == by.shape`` - func : str or Aggregation + func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ + "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ + "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : (optional) Sequence Expected unique labels. diff --git a/flox/xarray.py b/flox/xarray.py index dfa30cc53..eb35da387 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -88,8 +88,11 @@ def xarray_reduce( Xarray object to reduce *by : DataArray or iterable of str or iterable of DataArray Variables with which to group by ``obj`` - func : str or Aggregation - Reduction method + func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ + "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ + "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "first", "nanfirst", "last", "nanlast"} or Aggregation + Single function name or an Aggregation instance expected_groups : str or sequence expected group labels corresponding to each `by` variable isbin : iterable of bool