Skip to content

Commit

Permalink
Fix numbagg aggregations (#282)
Browse files Browse the repository at this point in the history
* Fix numbagg version check

Closes #281

* Enable numbagg for count

* Better numbagg special-casing

* Fixes.

* A bunch of typing

* Handle fill_value in core numbagg reduction.

* Update flox/aggregate_numbagg.py

* cleanup

* [WIP] test hacky fix

* [wip]

* Cleanup functions

* Fix casting

* Fix fill_value masking

* optimize

* Update flox/aggregations.py

* Small cleanup

* Fix.

* Fix typing

* Another bugfix

* Optimize seen_groups

* Be careful about raveling

* Fix benchmark skipping for numbagg

* add test
  • Loading branch information
dcherian authored Nov 9, 2023
1 parent 273d319 commit 19db5b3
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 74 deletions.
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
numbagg_skip = []
for name in expected_names:
numbagg_skip.extend(
list((func, expected_names[0], "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
)


Expand Down
112 changes: 63 additions & 49 deletions flox/aggregate_numbagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,75 @@
import numbagg.grouped
import numpy as np

DEFAULT_FILL_VALUE = {
"nansum": 0,
"nanmean": np.nan,
"nanvar": np.nan,
"nanstd": np.nan,
"nanmin": np.nan,
"nanmax": np.nan,
"nanany": False,
"nanall": False,
"nansum_of_squares": 0,
"nanprod": 1,
"nancount": 0,
"nanargmax": np.nan,
"nanargmin": np.nan,
"nanfirst": np.nan,
"nanlast": np.nan,
}

CAST_TO = {
# "nansum": {np.bool_: np.int64},
"nanmean": {np.int_: np.float64},
"nanvar": {np.int_: np.float64},
"nanstd": {np.int_: np.float64},
}


FILLNA = {"nansum": 0, "nanprod": 1}


def _numbagg_wrapper(
group_idx,
array,
*,
func,
axis=-1,
func="sum",
size=None,
fill_value=None,
dtype=None,
numbagg_func=None,
):
return numbagg_func(
array,
group_idx,
axis=axis,
num_labels=size,
# The following are unsupported
# fill_value=fill_value,
# dtype=dtype,
)
cast_to = CAST_TO.get(func, None)
if cast_to:
for from_, to_ in cast_to.items():
if np.issubdtype(array.dtype, from_):
array = array.astype(to_)

func_ = getattr(numbagg.grouped, f"group_{func}")

def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
if np.issubdtype(array.dtype, np.bool_):
array = array.astype(np.in64)
return numbagg.grouped.group_nansum(
result = func_(
array,
group_idx,
axis=axis,
num_labels=size,
# The following are unsupported
# fill_value=fill_value,
# dtype=dtype,
)

).astype(dtype, copy=False)

def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
if np.issubdtype(array.dtype, np.int_):
array = array.astype(np.float64)
return numbagg.grouped.group_nanmean(
array,
group_idx,
axis=axis,
num_labels=size,
# fill_value=fill_value,
# dtype=dtype,
)
return result


def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
assert ddof != 0
if np.issubdtype(array.dtype, np.int_):
array = array.astype(np.float64)
return numbagg.grouped.group_nanvar(
array,

return _numbagg_wrapper(
group_idx,
array,
axis=axis,
num_labels=size,
size=size,
func="nanvar",
# ddof=0,
# fill_value=fill_value,
# dtype=dtype,
Expand All @@ -70,30 +81,33 @@ def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None,

def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
assert ddof != 0
if np.issubdtype(array.dtype, np.int_):
array = array.astype(np.float64)
return numbagg.grouped.group_nanstd(
array,

return _numbagg_wrapper(
group_idx,
array,
axis=axis,
num_labels=size,
size=size,
func="nanstd"
# ddof=0,
# fill_value=fill_value,
# dtype=dtype,
)


nansum_of_squares = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nansum_of_squares)
nanlen = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nancount)
nanprod = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanprod)
nanfirst = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanfirst)
nanlast = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanlast)
# nanargmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmax)
# nanargmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmin)
nanmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmax)
nanmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmin)
any = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanany)
all = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanall)
nansum = partial(_numbagg_wrapper, func="nansum")
nanmean = partial(_numbagg_wrapper, func="nanmean")
nanprod = partial(_numbagg_wrapper, func="nanprod")
nansum_of_squares = partial(_numbagg_wrapper, func="nansum_of_squares")
nanlen = partial(_numbagg_wrapper, func="nancount")
nanprod = partial(_numbagg_wrapper, func="nanprod")
nanfirst = partial(_numbagg_wrapper, func="nanfirst")
nanlast = partial(_numbagg_wrapper, func="nanlast")
# nanargmax = partial(_numbagg_wrapper, func="nanargmax)
# nanargmin = partial(_numbagg_wrapper, func="nanargmin)
nanmax = partial(_numbagg_wrapper, func="nanmax")
nanmin = partial(_numbagg_wrapper, func="nanmin")
any = partial(_numbagg_wrapper, func="nanany")
all = partial(_numbagg_wrapper, func="nanall")

# sum = nansum
# mean = nanmean
Expand Down
22 changes: 9 additions & 13 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, TypedDict
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np
from numpy.typing import DTypeLike
Expand All @@ -13,6 +13,7 @@

if TYPE_CHECKING:
FuncTuple = tuple[Callable | str, ...]
OptionalFuncTuple = tuple[Callable | str | None, ...]


def _is_arg_reduction(func: str | Aggregation) -> bool:
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
final_fill_value=dtypes.NA,
dtypes=None,
final_dtype: DTypeLike | None = None,
reduction_type="reduce",
reduction_type: Literal["reduce", "argreduce"] = "reduce",
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -203,11 +204,11 @@ def __init__(
self.reduction_type = reduction_type
self.numpy: FuncTuple = (numpy,) if numpy else (self.name,)
# initialize blockwise reduction
self.chunk: FuncTuple = _atleast_1d(chunk)
self.chunk: OptionalFuncTuple = _atleast_1d(chunk)
# how to aggregate results after first round of reduction
self.combine: FuncTuple = _atleast_1d(combine)
self.combine: OptionalFuncTuple = _atleast_1d(combine)
# simpler reductions used with the "simple combine" algorithm
self.simple_combine: tuple[Callable, ...] = ()
self.simple_combine: OptionalFuncTuple = ()
# finalize results (see mean)
self.finalize: Callable | None = finalize

Expand Down Expand Up @@ -279,13 +280,7 @@ def __repr__(self) -> str:
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0)
prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1)
nanprod = Aggregation(
"nanprod",
chunk="nanprod",
combine="prod",
fill_value=1,
final_fill_value=dtypes.NA,
)
nanprod = Aggregation("nanprod", chunk="nanprod", combine="prod", fill_value=1)


def _mean_finalize(sum_, count):
Expand Down Expand Up @@ -579,6 +574,7 @@ def _initialize_aggregation(
}

# Replace sentinel fill values according to dtype
agg.fill_value["user"] = fill_value
agg.fill_value["intermediate"] = tuple(
_get_fill_value(dt, fv)
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
Expand Down Expand Up @@ -613,7 +609,7 @@ def _initialize_aggregation(
else:
agg.min_count = 0

simple_combine: list[Callable] = []
simple_combine: list[Callable | None] = []
for combine in agg.combine:
if isinstance(combine, str):
if combine in ["nanfirst", "nanlast"]:
Expand Down
62 changes: 52 additions & 10 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@
DUMMY_AXIS = -2


def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
"""Account for numbagg not providing a fill_value kwarg."""
from .aggregate_numbagg import DEFAULT_FILL_VALUE

if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
return result
# The condition needs to be
# len(found_groups) < size; if so we mask with fill_value (?)
default_fv = DEFAULT_FILL_VALUE[func]
needs_masking = fill_value is not None and not np.array_equal(
fill_value, default_fv, equal_nan=True
)
groups = np.arange(size)
if needs_masking:
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
if mask.any():
result[..., groups[mask]] = fill_value
return result


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())

Expand Down Expand Up @@ -780,7 +800,11 @@ def chunk_reduce(
group_idx, grps, found_groups_shape, _, size, props = factorize_(
(by,), axes, expected_groups=(expected_groups,), reindex=reindex, sort=sort
)
groups = grps[0]
(groups,) = grps

# do this *before* possible broadcasting below.
# factorize_ has already taken care of offsetting
seen_groups = _unique(group_idx)

order = "C"
if nax > 1:
Expand Down Expand Up @@ -850,6 +874,16 @@ def chunk_reduce(
result = generic_aggregate(
group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func
).astype(dt, copy=False)
if engine == "numbagg":
result = _postprocess_numbagg(
result,
func=reduction,
size=size,
fill_value=fv,
# Unfortunately, we cannot reuse found_groups, it has not
# been "offset" and is really expected_groups in nearly all cases
seen_groups=seen_groups,
)
if np.any(props.nanmask):
# remove NaN group label which should be last
result = result[..., :-1]
Expand Down Expand Up @@ -1053,6 +1087,8 @@ def _grouped_combine(
"""Combine intermediates step of tree reduction."""
from dask.utils import deepmap

combine = agg.combine

if isinstance(x_chunk, dict):
# Only one block at final step; skip one extra groupby
return x_chunk
Expand Down Expand Up @@ -1093,7 +1129,8 @@ def _grouped_combine(
results = chunk_argreduce(
array_idx,
groups,
func=agg.combine[slicer], # count gets treated specially next
# count gets treated specially next
func=combine[slicer], # type: ignore[arg-type]
axis=axis,
expected_groups=None,
fill_value=agg.fill_value["intermediate"][slicer],
Expand Down Expand Up @@ -1127,9 +1164,10 @@ def _grouped_combine(
elif agg.reduction_type == "reduce":
# Here we reduce the intermediates individually
results = {"groups": None, "intermediates": []}
for idx, (combine, fv, dtype) in enumerate(
zip(agg.combine, agg.fill_value["intermediate"], agg.dtype["intermediate"])
for idx, (combine_, fv, dtype) in enumerate(
zip(combine, agg.fill_value["intermediate"], agg.dtype["intermediate"])
):
assert combine_ is not None
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis)
if array.shape[-1] == 0:
# all empty when combined
Expand All @@ -1143,7 +1181,7 @@ def _grouped_combine(
_results = chunk_reduce(
array,
groups,
func=combine,
func=combine_,
axis=axis,
expected_groups=None,
fill_value=(fv,),
Expand Down Expand Up @@ -1788,8 +1826,13 @@ def _choose_engine(by, agg: Aggregation):

# numbagg only supports nan-skipping reductions
# without dtype specified
if HAS_NUMBAGG and "nan" in agg.name:
if not_arg_reduce and dtype is None:
has_blockwise_nan_skipping = (agg.chunk[0] is None and "nan" in agg.name) or any(
(isinstance(func, str) and "nan" in func) for func in agg.chunk
)
if HAS_NUMBAGG:
if agg.name in ["all", "any"] or (
not_arg_reduce and has_blockwise_nan_skipping and dtype is None
):
return "numbagg"

if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)):
Expand Down Expand Up @@ -2050,7 +2093,7 @@ def groupby_reduce(
nax = len(axis_)

# When axis is a subset of possible values; then npg will
# apply it to groups that don't exist along a particular axis (for e.g.)
# apply the fill_value to groups that don't exist along a particular axis (for e.g.)
# since these count as a group that is absent. thoo!
# fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
# The only way to do this consistently is mask out using min_count
Expand Down Expand Up @@ -2090,8 +2133,7 @@ def groupby_reduce(
# TODO: How else to narrow that array.chunks is there?
assert isinstance(array, DaskArray)

# TODO: fix typing of FuncTuple in Aggregation
if agg.chunk[0] is None and method != "blockwise": # type: ignore[unreachable]
if agg.chunk[0] is None and method != "blockwise":
raise NotImplementedError(
f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'."
f"Received method={method!r}"
Expand Down
2 changes: 1 addition & 1 deletion flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,6 @@ def module_available(module: str, minversion: Optional[str] = None) -> bool:
has = importlib.util.find_spec(module) is not None
if has:
mod = importlib.import_module(module)
return Version(mod.__version__) < Version(minversion) if minversion is not None else True
return Version(mod.__version__) >= Version(minversion) if minversion is not None else True
else:
return False
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ show_error_codes = true
warn_unused_ignores = true
warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude=["asv_bench/pkgs"]

[[tool.mypy.overrides]]
module=[
Expand Down
Loading

0 comments on commit 19db5b3

Please sign in to comment.