Skip to content

Commit

Permalink
Add DataArrayCoarsen.reduce and DatasetCoarsen.reduce methods (#4939)
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerkclark authored Feb 23, 2021
1 parent 200c2b2 commit f554d0a
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 2 deletions.
2 changes: 2 additions & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
core.rolling.DatasetCoarsen.median
core.rolling.DatasetCoarsen.min
core.rolling.DatasetCoarsen.prod
core.rolling.DatasetCoarsen.reduce
core.rolling.DatasetCoarsen.std
core.rolling.DatasetCoarsen.sum
core.rolling.DatasetCoarsen.var
Expand Down Expand Up @@ -190,6 +191,7 @@
core.rolling.DataArrayCoarsen.median
core.rolling.DataArrayCoarsen.min
core.rolling.DataArrayCoarsen.prod
core.rolling.DataArrayCoarsen.reduce
core.rolling.DataArrayCoarsen.std
core.rolling.DataArrayCoarsen.sum
core.rolling.DataArrayCoarsen.var
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ New Features
(including globs for the latter) for ``engine="zarr"``, and so allow reading from
many remote and other file systems (:pull:`4461`)
By `Martin Durant <https://github.com/martindurant>`_
- :py:class:`DataArrayCoarsen` and :py:class:`DatasetCoarsen` now implement a
``reduce`` method, enabling coarsening operations with custom reduction
functions (:issue:`3741`, :pull:`4939`). By `Spencer Clark
<https://github.com/spencerkclark>`_.

Bug fixes
~~~~~~~~~
Expand Down
62 changes: 60 additions & 2 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,9 @@ class DataArrayCoarsen(Coarsen):
_reduce_extra_args_docstring = """"""

@classmethod
def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool):
def _reduce_method(
cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False
):
"""
Return a wrapped function for injecting reduction methods.
see ops.inject_reduce_methods
Expand Down Expand Up @@ -871,14 +873,48 @@ def wrapped_func(self, **kwargs):

return wrapped_func

def reduce(self, func: Callable, **kwargs):
"""Reduce the items in this group by applying `func` along some
dimension(s).
Parameters
----------
func : callable
Function which can be called in the form `func(x, axis, **kwargs)`
to return the result of collapsing an np.ndarray over the coarsening
dimensions. It must be possible to provide the `axis` argument
with a tuple of integers.
**kwargs : dict
Additional keyword arguments passed on to `func`.
Returns
-------
reduced : DataArray
Array with summarized data.
Examples
--------
>>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b"))
>>> coarsen = da.coarsen(b=2)
>>> coarsen.reduce(np.sum)
<xarray.DataArray (a: 2, b: 2)>
array([[ 1, 5],
[ 9, 13]])
Dimensions without coordinates: a, b
"""
wrapped_func = self._reduce_method(func)
return wrapped_func(self, **kwargs)


class DatasetCoarsen(Coarsen):
__slots__ = ()

_reduce_extra_args_docstring = """"""

@classmethod
def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool):
def _reduce_method(
cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False
):
"""
Return a wrapped function for injecting reduction methods.
see ops.inject_reduce_methods
Expand Down Expand Up @@ -917,6 +953,28 @@ def wrapped_func(self, **kwargs):

return wrapped_func

def reduce(self, func: Callable, **kwargs):
"""Reduce the items in this group by applying `func` along some
dimension(s).
Parameters
----------
func : callable
Function which can be called in the form `func(x, axis, **kwargs)`
to return the result of collapsing an np.ndarray over the coarsening
dimensions. It must be possible to provide the `axis` argument with
a tuple of integers.
**kwargs : dict
Additional keyword arguments passed on to `func`.
Returns
-------
reduced : Dataset
Arrays with summarized data.
"""
wrapped_func = self._reduce_method(func)
return wrapped_func(self, **kwargs)


inject_reduce_methods(DataArrayCoarsen)
inject_reduce_methods(DatasetCoarsen)
16 changes: 16 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6382,6 +6382,22 @@ def test_coarsen_keep_attrs():
xr.testing.assert_identical(da, da2)


@pytest.mark.parametrize("da", (1, 2), indirect=True)
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
def test_coarsen_reduce(da, window, name):
if da.isnull().sum() > 1 and window == 1:
pytest.skip("These parameters lead to all-NaN slices")

# Use boundary="trim" to accomodate all window sizes used in tests
coarsen_obj = da.coarsen(time=window, boundary="trim")

# add nan prefix to numpy methods to get similar # behavior as bottleneck
actual = coarsen_obj.reduce(getattr(np, f"nan{name}"))
expected = getattr(coarsen_obj, name)()
assert_allclose(actual, expected)


@pytest.mark.parametrize("da", (1, 2), indirect=True)
def test_rolling_iter(da):
rolling_obj = da.rolling(time=7)
Expand Down
21 changes: 21 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6055,6 +6055,27 @@ def test_coarsen_keep_attrs():
xr.testing.assert_identical(ds, ds2)


@pytest.mark.slow
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median"))
def test_coarsen_reduce(ds, window, name):
# Use boundary="trim" to accomodate all window sizes used in tests
coarsen_obj = ds.coarsen(time=window, boundary="trim")

# add nan prefix to numpy methods to get similar behavior as bottleneck
actual = coarsen_obj.reduce(getattr(np, f"nan{name}"))
expected = getattr(coarsen_obj, name)()
assert_allclose(actual, expected)

# make sure the order of data_var are not changed.
assert list(ds.data_vars.keys()) == list(actual.data_vars.keys())

# Make sure the dimension order is restored
for key, src_var in ds.data_vars.items():
assert src_var.dims == actual[key].dims


@pytest.mark.parametrize(
"funcname, argument",
[
Expand Down

0 comments on commit f554d0a

Please sign in to comment.