From 19db5b3b456974230932b4119c4735216813f5f7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 8 Nov 2023 21:25:08 -0700 Subject: [PATCH] Fix numbagg aggregations (#282) * Fix numbagg version check Closes #281 * Enable numbagg for count * Better numbagg special-casing * Fixes. * A bunch of typing * Handle fill_value in core numbagg reduction. * Update flox/aggregate_numbagg.py * cleanup * [WIP] test hacky fix * [wip] * Cleanup functions * Fix casting * Fix fill_value masking * optimize * Update flox/aggregations.py * Small cleanup * Fix. * Fix typing * Another bugfix * Optimize seen_groups * Be careful about raveling * Fix benchmark skipping for numbagg * add test --- asv_bench/benchmarks/reduce.py | 2 +- flox/aggregate_numbagg.py | 112 ++++++++++++++++++--------------- flox/aggregations.py | 22 +++---- flox/core.py | 62 +++++++++++++++--- flox/xrutils.py | 2 +- pyproject.toml | 1 + tests/test_core.py | 27 ++++++++ tests/test_xarray.py | 21 +++++++ 8 files changed, 175 insertions(+), 74 deletions(-) diff --git a/asv_bench/benchmarks/reduce.py b/asv_bench/benchmarks/reduce.py index aa3edd611..d7a4b6695 100644 --- a/asv_bench/benchmarks/reduce.py +++ b/asv_bench/benchmarks/reduce.py @@ -18,7 +18,7 @@ numbagg_skip = [] for name in expected_names: numbagg_skip.extend( - list((func, expected_names[0], "numbagg") for func in funcs if func not in NUMBAGG_FUNCS) + list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS) ) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index b0b06d86e..a5f12d7e0 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -4,64 +4,75 @@ import numbagg.grouped import numpy as np +DEFAULT_FILL_VALUE = { + "nansum": 0, + "nanmean": np.nan, + "nanvar": np.nan, + "nanstd": np.nan, + "nanmin": np.nan, + "nanmax": np.nan, + "nanany": False, + "nanall": False, + "nansum_of_squares": 0, + "nanprod": 1, + "nancount": 0, + "nanargmax": np.nan, + "nanargmin": np.nan, + "nanfirst": np.nan, + "nanlast": np.nan, +} + +CAST_TO = { + # "nansum": {np.bool_: np.int64}, + "nanmean": {np.int_: np.float64}, + "nanvar": {np.int_: np.float64}, + "nanstd": {np.int_: np.float64}, +} + + +FILLNA = {"nansum": 0, "nanprod": 1} + def _numbagg_wrapper( group_idx, array, *, + func, axis=-1, - func="sum", size=None, fill_value=None, dtype=None, - numbagg_func=None, ): - return numbagg_func( - array, - group_idx, - axis=axis, - num_labels=size, - # The following are unsupported - # fill_value=fill_value, - # dtype=dtype, - ) + cast_to = CAST_TO.get(func, None) + if cast_to: + for from_, to_ in cast_to.items(): + if np.issubdtype(array.dtype, from_): + array = array.astype(to_) + func_ = getattr(numbagg.grouped, f"group_{func}") -def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): - if np.issubdtype(array.dtype, np.bool_): - array = array.astype(np.in64) - return numbagg.grouped.group_nansum( + result = func_( array, group_idx, axis=axis, num_labels=size, + # The following are unsupported # fill_value=fill_value, # dtype=dtype, - ) - + ).astype(dtype, copy=False) -def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): - if np.issubdtype(array.dtype, np.int_): - array = array.astype(np.float64) - return numbagg.grouped.group_nanmean( - array, - group_idx, - axis=axis, - num_labels=size, - # fill_value=fill_value, - # dtype=dtype, - ) + return result def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0): assert ddof != 0 - if np.issubdtype(array.dtype, np.int_): - array = array.astype(np.float64) - return numbagg.grouped.group_nanvar( - array, + + return _numbagg_wrapper( group_idx, + array, axis=axis, - num_labels=size, + size=size, + func="nanvar", # ddof=0, # fill_value=fill_value, # dtype=dtype, @@ -70,30 +81,33 @@ def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0): assert ddof != 0 - if np.issubdtype(array.dtype, np.int_): - array = array.astype(np.float64) - return numbagg.grouped.group_nanstd( - array, + + return _numbagg_wrapper( group_idx, + array, axis=axis, - num_labels=size, + size=size, + func="nanstd" # ddof=0, # fill_value=fill_value, # dtype=dtype, ) -nansum_of_squares = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nansum_of_squares) -nanlen = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nancount) -nanprod = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanprod) -nanfirst = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanfirst) -nanlast = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanlast) -# nanargmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmax) -# nanargmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmin) -nanmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmax) -nanmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmin) -any = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanany) -all = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanall) +nansum = partial(_numbagg_wrapper, func="nansum") +nanmean = partial(_numbagg_wrapper, func="nanmean") +nanprod = partial(_numbagg_wrapper, func="nanprod") +nansum_of_squares = partial(_numbagg_wrapper, func="nansum_of_squares") +nanlen = partial(_numbagg_wrapper, func="nancount") +nanprod = partial(_numbagg_wrapper, func="nanprod") +nanfirst = partial(_numbagg_wrapper, func="nanfirst") +nanlast = partial(_numbagg_wrapper, func="nanlast") +# nanargmax = partial(_numbagg_wrapper, func="nanargmax) +# nanargmin = partial(_numbagg_wrapper, func="nanargmin) +nanmax = partial(_numbagg_wrapper, func="nanmax") +nanmin = partial(_numbagg_wrapper, func="nanmin") +any = partial(_numbagg_wrapper, func="nanany") +all = partial(_numbagg_wrapper, func="nanall") # sum = nansum # mean = nanmean diff --git a/flox/aggregations.py b/flox/aggregations.py index 131c49cbd..b91d191b2 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -3,7 +3,7 @@ import copy import warnings from functools import partial -from typing import TYPE_CHECKING, Any, Callable, TypedDict +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict import numpy as np from numpy.typing import DTypeLike @@ -13,6 +13,7 @@ if TYPE_CHECKING: FuncTuple = tuple[Callable | str, ...] + OptionalFuncTuple = tuple[Callable | str | None, ...] def _is_arg_reduction(func: str | Aggregation) -> bool: @@ -152,7 +153,7 @@ def __init__( final_fill_value=dtypes.NA, dtypes=None, final_dtype: DTypeLike | None = None, - reduction_type="reduce", + reduction_type: Literal["reduce", "argreduce"] = "reduce", ): """ Blueprint for computing grouped aggregations. @@ -203,11 +204,11 @@ def __init__( self.reduction_type = reduction_type self.numpy: FuncTuple = (numpy,) if numpy else (self.name,) # initialize blockwise reduction - self.chunk: FuncTuple = _atleast_1d(chunk) + self.chunk: OptionalFuncTuple = _atleast_1d(chunk) # how to aggregate results after first round of reduction - self.combine: FuncTuple = _atleast_1d(combine) + self.combine: OptionalFuncTuple = _atleast_1d(combine) # simpler reductions used with the "simple combine" algorithm - self.simple_combine: tuple[Callable, ...] = () + self.simple_combine: OptionalFuncTuple = () # finalize results (see mean) self.finalize: Callable | None = finalize @@ -279,13 +280,7 @@ def __repr__(self) -> str: sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0) nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0) prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1) -nanprod = Aggregation( - "nanprod", - chunk="nanprod", - combine="prod", - fill_value=1, - final_fill_value=dtypes.NA, -) +nanprod = Aggregation("nanprod", chunk="nanprod", combine="prod", fill_value=1) def _mean_finalize(sum_, count): @@ -579,6 +574,7 @@ def _initialize_aggregation( } # Replace sentinel fill values according to dtype + agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( _get_fill_value(dt, fv) for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) @@ -613,7 +609,7 @@ def _initialize_aggregation( else: agg.min_count = 0 - simple_combine: list[Callable] = [] + simple_combine: list[Callable | None] = [] for combine in agg.combine: if isinstance(combine, str): if combine in ["nanfirst", "nanlast"]: diff --git a/flox/core.py b/flox/core.py index edf41c727..2f83e014c 100644 --- a/flox/core.py +++ b/flox/core.py @@ -86,6 +86,26 @@ DUMMY_AXIS = -2 +def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups): + """Account for numbagg not providing a fill_value kwarg.""" + from .aggregate_numbagg import DEFAULT_FILL_VALUE + + if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE: + return result + # The condition needs to be + # len(found_groups) < size; if so we mask with fill_value (?) + default_fv = DEFAULT_FILL_VALUE[func] + needs_masking = fill_value is not None and not np.array_equal( + fill_value, default_fv, equal_nan=True + ) + groups = np.arange(size) + if needs_masking: + mask = np.isin(groups, seen_groups, assume_unique=True, invert=True) + if mask.any(): + result[..., groups[mask]] = fill_value + return result + + def _issorted(arr: np.ndarray) -> bool: return bool((arr[:-1] <= arr[1:]).all()) @@ -780,7 +800,11 @@ def chunk_reduce( group_idx, grps, found_groups_shape, _, size, props = factorize_( (by,), axes, expected_groups=(expected_groups,), reindex=reindex, sort=sort ) - groups = grps[0] + (groups,) = grps + + # do this *before* possible broadcasting below. + # factorize_ has already taken care of offsetting + seen_groups = _unique(group_idx) order = "C" if nax > 1: @@ -850,6 +874,16 @@ def chunk_reduce( result = generic_aggregate( group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func ).astype(dt, copy=False) + if engine == "numbagg": + result = _postprocess_numbagg( + result, + func=reduction, + size=size, + fill_value=fv, + # Unfortunately, we cannot reuse found_groups, it has not + # been "offset" and is really expected_groups in nearly all cases + seen_groups=seen_groups, + ) if np.any(props.nanmask): # remove NaN group label which should be last result = result[..., :-1] @@ -1053,6 +1087,8 @@ def _grouped_combine( """Combine intermediates step of tree reduction.""" from dask.utils import deepmap + combine = agg.combine + if isinstance(x_chunk, dict): # Only one block at final step; skip one extra groupby return x_chunk @@ -1093,7 +1129,8 @@ def _grouped_combine( results = chunk_argreduce( array_idx, groups, - func=agg.combine[slicer], # count gets treated specially next + # count gets treated specially next + func=combine[slicer], # type: ignore[arg-type] axis=axis, expected_groups=None, fill_value=agg.fill_value["intermediate"][slicer], @@ -1127,9 +1164,10 @@ def _grouped_combine( elif agg.reduction_type == "reduce": # Here we reduce the intermediates individually results = {"groups": None, "intermediates": []} - for idx, (combine, fv, dtype) in enumerate( - zip(agg.combine, agg.fill_value["intermediate"], agg.dtype["intermediate"]) + for idx, (combine_, fv, dtype) in enumerate( + zip(combine, agg.fill_value["intermediate"], agg.dtype["intermediate"]) ): + assert combine_ is not None array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) if array.shape[-1] == 0: # all empty when combined @@ -1143,7 +1181,7 @@ def _grouped_combine( _results = chunk_reduce( array, groups, - func=combine, + func=combine_, axis=axis, expected_groups=None, fill_value=(fv,), @@ -1788,8 +1826,13 @@ def _choose_engine(by, agg: Aggregation): # numbagg only supports nan-skipping reductions # without dtype specified - if HAS_NUMBAGG and "nan" in agg.name: - if not_arg_reduce and dtype is None: + has_blockwise_nan_skipping = (agg.chunk[0] is None and "nan" in agg.name) or any( + (isinstance(func, str) and "nan" in func) for func in agg.chunk + ) + if HAS_NUMBAGG: + if agg.name in ["all", "any"] or ( + not_arg_reduce and has_blockwise_nan_skipping and dtype is None + ): return "numbagg" if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)): @@ -2050,7 +2093,7 @@ def groupby_reduce( nax = len(axis_) # When axis is a subset of possible values; then npg will - # apply it to groups that don't exist along a particular axis (for e.g.) + # apply the fill_value to groups that don't exist along a particular axis (for e.g.) # since these count as a group that is absent. thoo! # fill_value applies to all-NaN groups as well as labels in expected_groups that are not found. # The only way to do this consistently is mask out using min_count @@ -2090,8 +2133,7 @@ def groupby_reduce( # TODO: How else to narrow that array.chunks is there? assert isinstance(array, DaskArray) - # TODO: fix typing of FuncTuple in Aggregation - if agg.chunk[0] is None and method != "blockwise": # type: ignore[unreachable] + if agg.chunk[0] is None and method != "blockwise": raise NotImplementedError( f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'." f"Received method={method!r}" diff --git a/flox/xrutils.py b/flox/xrutils.py index 497cd7b24..0ced6fbb6 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -339,6 +339,6 @@ def module_available(module: str, minversion: Optional[str] = None) -> bool: has = importlib.util.find_spec(module) is not None if has: mod = importlib.import_module(module) - return Version(mod.__version__) < Version(minversion) if minversion is not None else True + return Version(mod.__version__) >= Version(minversion) if minversion is not None else True else: return False diff --git a/pyproject.toml b/pyproject.toml index 9cf0b4ca4..c507a5222 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ show_error_codes = true warn_unused_ignores = true warn_unreachable = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +exclude=["asv_bench/pkgs"] [[tool.mypy.overrides]] module=[ diff --git a/tests/test_core.py b/tests/test_core.py index ba8ad9340..0d953b3c4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1566,6 +1566,19 @@ def test_choose_engine(dtype): finalize_kwargs=None, ) + # count_engine + for method in ["all", "any", "count"]: + agg = _initialize_aggregation( + method, + dtype=None, + array_dtype=dtype, + fill_value=0, + min_count=0, + finalize_kwargs=None, + ) + engine = _choose_engine(np.array([1, 1, 2, 2]), agg=agg) + assert engine == ("numbagg" if HAS_NUMBAGG else "flox") + # sorted by -> flox sorted_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=mean) assert sorted_engine == ("numbagg" if numbagg_possible else "flox") @@ -1573,3 +1586,17 @@ def test_choose_engine(dtype): assert _choose_engine(np.array([3, 1, 1]), agg=mean) == default # argmax does not give engine="flox" assert _choose_engine(np.array([1, 1, 2, 2]), agg=argmax) == "numpy" + + +def test_xarray_fill_value_behaviour(): + bar = np.array([1, 2, 3, np.nan, np.nan, np.nan, 4, 5, np.nan, np.nan]) + times = np.arange(0, 20, 2) + actual, _ = groupby_reduce(bar, times, func="nansum", expected_groups=(np.arange(19),)) + nan = np.nan + # fmt: off + expected = np.array( + [ 1., nan, 2., nan, 3., nan, 0., nan, 0., + nan, 0., nan, 4., nan, 5., nan, 0., nan, 0.] + ) + # fmt: on + assert_equal(expected, actual) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 116937052..6028a1139 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -561,3 +561,24 @@ def test_preserve_multiindex(): ) assert "region" in hist.coords + + +def test_fill_value_xarray_behaviour(): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = xr.Dataset( + { + "bar": ( + "time", + [1, 2, 3, np.nan, np.nan, np.nan, 4, 5, np.nan, np.nan], + {"meta": "data"}, + ), + "time": times, + } + ) + + expected_time = pd.date_range("2000-01-01", freq="3H", periods=19) + expected = ds.reindex(time=expected_time) + expected = ds.resample(time="3H").sum() + with xr.set_options(use_flox=True): + actual = ds.resample(time="3H").sum() + xr.testing.assert_identical(expected, actual)