Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support quantile, median, mode with method="blockwise". #269

Merged
merged 21 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ dependencies:
- pooch
- toolz
- numba
- scipy
7 changes: 5 additions & 2 deletions docs/source/aggregations.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ the `func` kwarg:
- `"std"`, `"nanstd"`
- `"argmin"`
- `"argmax"`
- `"first"`
- `"last"`
- `"first"`, `"nanfirst"`
- `"last"`, `"nanlast"`
- `"median"`, `"nanmedian"`
- `"mode"`, `"nanmode"`
- `"quantile"`, `"nanquantile"`

```{tip}
We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome!
Expand Down
9 changes: 7 additions & 2 deletions docs/source/user-stories/custom-aggregations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
">\n",
"> A = da.groupby(['lon_bins', 'lat_bins']).mode()\n",
"\n",
"This notebook will describe how to accomplish this using a custom `Aggregation`\n",
"since `mode` and `median` aren't supported by flox yet.\n"
"This notebook will describe how to accomplish this using a custom `Aggregation`.\n",
"\n",
"\n",
"```{tip}\n",
"flox now supports `mode`, `nanmode`, `quantile`, `nanquantile`, `median`, `nanmedian` using exactly the same \n",
"approach as shown below\n",
"```\n"
]
},
{
Expand Down
81 changes: 81 additions & 0 deletions flox/aggregate_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,84 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None,

len = partial(_len, func="len")
nanlen = partial(_len, func="nanlen")


def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=np.median,
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=np.nanmedian,
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(np.quantile, q=q),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(np.nanquantile, q=q),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def mode_(array, nan_policy, dtype):
from scipy.stats import mode

# npg splits `array` into object arrays for each group
# scipy.stats.mode does not like that
# here we cast back
return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1).mode


def mode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(mode_, nan_policy="propagate", dtype=array.dtype),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanmode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(mode_, nan_policy="omit", dtype=array.dtype),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)
48 changes: 36 additions & 12 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import TYPE_CHECKING, Any, Callable, TypedDict

import numpy as np
import numpy_groupies as npg
from numpy.typing import DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
Expand Down Expand Up @@ -35,6 +34,16 @@ class AggDtype(TypedDict):
intermediate: tuple[np.dtype | type[np.intp], ...]


def get_npg_aggregation(func, *, engine):
try:
method_ = getattr(aggregate_npg, func)
method = partial(method_, engine=engine)
except AttributeError:
aggregate = aggregate_npg._get_aggregate(engine).aggregate
method = partial(aggregate, func=func)
return method


def generic_aggregate(
group_idx,
array,
Expand All @@ -51,14 +60,11 @@ def generic_aggregate(
try:
method = getattr(aggregate_flox, func)
except AttributeError:
method = partial(npg.aggregate_numpy.aggregate, func=func)
method = get_npg_aggregation(func, engine="numpy")

elif engine in ["numpy", "numba"]:
try:
method_ = getattr(aggregate_npg, func)
method = partial(method_, engine=engine)
except AttributeError:
aggregate = aggregate_npg._get_aggregate(engine).aggregate
method = partial(aggregate, func=func)
method = get_npg_aggregation(func, engine=engine)

else:
raise ValueError(
f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
Expand Down Expand Up @@ -465,10 +471,22 @@ def _pick_second(*x):
final_dtype=bool,
)

# numpy_groupies does not support median
# And the dask version is really hard!
# median = Aggregation("median", chunk=None, combine=None, fill_value=None)
# nanmedian = Aggregation("nanmedian", chunk=None, combine=None, fill_value=None)
# Support statistical quantities only blockwise
# The parallel versions will be approximate and are hard to implement!
median = Aggregation(
name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
nanmedian = Aggregation(
name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
quantile = Aggregation(
name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
nanquantile = Aggregation(
name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)

aggregations = {
"any": any_,
Expand Down Expand Up @@ -496,6 +514,12 @@ def _pick_second(*x):
"nanfirst": nanfirst,
"last": last,
"nanlast": nanlast,
"median": median,
"nanmedian": nanmedian,
"quantile": quantile,
"nanquantile": nanquantile,
"mode": mode,
"nanmode": nanmode,
}


Expand Down
54 changes: 36 additions & 18 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,15 +1307,14 @@ def dask_groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

if method == "blockwise" and not isinstance(by, np.ndarray):
raise NotImplementedError

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis)

if expected_groups is None and reindex:
expected_groups = _get_expected_groups(by, sort=sort)
if method == "cohorts":
assert reindex is False

by_input = by

Expand Down Expand Up @@ -1349,7 +1348,6 @@ def dask_groupby_agg(
# b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction.
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)

do_simple_combine = not _is_arg_reduction(agg)

if method == "blockwise":
Expand All @@ -1375,7 +1373,7 @@ def dask_groupby_agg(
partial(
blockwise_method,
axis=axis,
expected_groups=None if method == "cohorts" else expected_groups,
expected_groups=expected_groups if reindex else None,
engine=engine,
sort=sort,
),
Expand Down Expand Up @@ -1468,14 +1466,24 @@ def dask_groupby_agg(

elif method == "blockwise":
reduced = intermediate
# Here one input chunk → one output chunks
# find number of groups in each chunk, this is needed for output chunks
# along the reduced axis
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
groups = (np.concatenate(groups_in_block),)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
group_chunks = (ngroups_per_block,)
if reindex:
if TYPE_CHECKING:
assert expected_groups is not None
# TODO: we could have `expected_groups` be a dask array with appropriate chunks
# for now, we have a numpy array that is interpreted as listing all group labels
# that are present in every chunk
groups = (expected_groups,)
group_chunks = ((len(expected_groups),),)
else:
# Here one input chunk → one output chunks
# find number of groups in each chunk, this is needed for output chunks
# along the reduced axis
# TODO: this logic is very specialized for the resampling case
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
groups = (np.concatenate(groups_in_block),)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
group_chunks = (ngroups_per_block,)
else:
raise ValueError(f"Unknown method={method}.")

Expand Down Expand Up @@ -1547,7 +1555,7 @@ def _validate_reindex(
if reindex is True and not all_numpy:
if _is_arg_reduction(func):
raise NotImplementedError
if method in ["blockwise", "cohorts"]:
if method == "cohorts" or (method == "blockwise" and not any_by_dask):
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)
Expand All @@ -1562,7 +1570,11 @@ def _validate_reindex(
# have to do the grouped_combine since there's no good fill_value
reindex = False

if method == "blockwise" or _is_arg_reduction(func):
if method == "blockwise":
# for grouping by dask arrays, we set reindex=True
reindex = any_by_dask

elif _is_arg_reduction(func):
reindex = False

elif method == "cohorts":
Expand Down Expand Up @@ -1767,7 +1779,10 @@ def groupby_reduce(
*by : ndarray or DaskArray
Array of labels to group over. Must be aligned with ``array`` so that
``array.shape[-by.ndim :] == by.shape``
func : str or Aggregation
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
"first", "nanfirst", "last", "nanlast"} or Aggregation
Single function name or an Aggregation instance
expected_groups : (optional) Sequence
Expected unique labels.
Expand Down Expand Up @@ -1835,7 +1850,7 @@ def groupby_reduce(
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions.
finalize_kwargs : dict, optional
Kwargs passed to finalize the reduction such as ``ddof`` for var, std.
Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile.

Returns
-------
Expand All @@ -1855,6 +1870,9 @@ def groupby_reduce(
"Try engine='numpy' or engine='numba' instead."
)

if func == "quantile" and (finalize_kwargs is None or "q" not in finalize_kwargs):
raise ValueError("Please pass `q` for quantile calculations.")

bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
Expand Down Expand Up @@ -2023,7 +2041,7 @@ def groupby_reduce(
result, groups = partial_agg(
array,
by_,
expected_groups=None if method == "blockwise" else expected_groups,
expected_groups=expected_groups,
agg=agg,
reindex=reindex,
method=method,
Expand Down
9 changes: 6 additions & 3 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def xarray_reduce(
Xarray object to reduce
*by : DataArray or iterable of str or iterable of DataArray
Variables with which to group by ``obj``
func : str or Aggregation
Reduction method
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
"first", "nanfirst", "last", "nanlast"} or Aggregation
Single function name or an Aggregation instance
expected_groups : str or sequence
expected group labels corresponding to each `by` variable
isbin : iterable of bool
Expand Down Expand Up @@ -164,7 +167,7 @@ def xarray_reduce(
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions.
**finalize_kwargs
kwargs passed to the finalize function, like ``ddof`` for var, std.
kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile.

Returns
-------
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ module=[
"matplotlib.*",
"pandas",
"setuptools",
"toolz"
"scipy.*",
"toolz",
]
ignore_missing_imports = true

Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def LooseVersion(vstring):

has_dask, requires_dask = _importorskip("dask")
has_numba, requires_numba = _importorskip("numba")
has_scipy, requires_scipy = _importorskip("scipy")
has_xarray, requires_xarray = _importorskip("xarray")


Expand Down
Loading
Loading