Skip to content

Commit

Permalink
Delete resample_reduce (#246)
Browse files Browse the repository at this point in the history
Closes #245
  • Loading branch information
dcherian authored Jun 14, 2023
1 parent 0d353ec commit 5d223a7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 100 deletions.
48 changes: 5 additions & 43 deletions flox/xarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence, Union

import numpy as np
Expand All @@ -21,7 +20,6 @@
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric

if TYPE_CHECKING:
from xarray.core.resample import Resample
from xarray.core.types import T_DataArray, T_Dataset

Dims = Union[str, Iterable[Hashable], None]
Expand Down Expand Up @@ -450,7 +448,11 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
actual = actual.drop_vars(name)
# When grouping by MultiIndex, expect is an pd.Index wrapping
# an object array of tuples
if name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex):
if (
name in ds_broad.indexes
and isinstance(ds_broad.indexes[name], pd.MultiIndex)
and not isinstance(expect, pd.RangeIndex)
):
levelnames = ds_broad.indexes[name].names
expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames)
actual[name] = expect
Expand Down Expand Up @@ -582,43 +584,3 @@ def _rechunk(func, obj, dim, labels, **kwargs):
)

return obj


def resample_reduce(
resampler: Resample,
func: str | Aggregation,
keep_attrs: bool = True,
**kwargs,
):
warnings.warn(
"flox.xarray.resample_reduce is now deprecated. Please use Xarray's resample method directly.",
DeprecationWarning,
)

obj = resampler._obj
dim = resampler._group_dim

# this creates a label DataArray since resample doesn't do that somehow
tostack = []
for idx, slicer in enumerate(resampler._group_indices):
if slicer.stop is None:
stop = resampler._obj.sizes[dim]
else:
stop = slicer.stop
tostack.append(idx * np.ones((stop - slicer.start,), dtype=np.int32))
by = xr.DataArray(np.hstack(tostack), dims=(dim,), name="__resample_dim__")

result = (
xarray_reduce(
obj,
by,
func=func,
method="blockwise",
keep_attrs=keep_attrs,
**kwargs,
)
.rename({"__resample_dim__": dim})
.transpose(dim, ...)
)
result[dim] = resampler._unique_coord.data
return result
58 changes: 1 addition & 57 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
xr = pytest.importorskip("xarray")
# isort: on

from flox.xarray import rechunk_for_blockwise, resample_reduce, xarray_reduce
from flox.xarray import rechunk_for_blockwise, xarray_reduce

from . import assert_equal, has_dask, raise_if_dask_computes, requires_dask

Expand Down Expand Up @@ -245,48 +245,6 @@ def test_xarray_reduce_errors():
xarray_reduce(da, by.chunk(), func="mean")


@pytest.mark.parametrize("isdask", [True, False])
@pytest.mark.parametrize("dataarray", [True, False])
@pytest.mark.parametrize("chunklen", [27, 4 * 31 + 1, 4 * 31 + 20])
def test_xarray_resample(chunklen, isdask, dataarray, engine):
if isdask:
if not has_dask:
pytest.skip()
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": chunklen})
else:
ds = xr.tutorial.open_dataset("air_temperature")

if dataarray:
ds = ds.air

resampler = ds.resample(time="M")
with pytest.warns(DeprecationWarning):
actual = resample_reduce(resampler, "mean", engine=engine)
expected = resampler.mean()
xr.testing.assert_allclose(actual, expected)

with xr.set_options(use_flox=True):
actual = resampler.mean()
xr.testing.assert_allclose(actual, expected)


@requires_dask
def test_xarray_resample_dataset_multiple_arrays(engine):
# regression test for #35
times = pd.date_range("2000", periods=5)
foo = xr.DataArray(range(5), dims=["time"], coords=[times], name="foo")
bar = xr.DataArray(range(1, 6), dims=["time"], coords=[times], name="bar")
ds = xr.merge([foo, bar]).chunk({"time": 4})

resampler = ds.resample(time="4D")
# The separate computes are necessary here to force xarray
# to compute all variables in result at the same time.
expected = resampler.mean().compute()
with pytest.warns(DeprecationWarning):
result = resample_reduce(resampler, "mean", engine=engine).compute()
xr.testing.assert_allclose(expected, result)


@requires_dask
@pytest.mark.parametrize(
"inchunks, expected",
Expand Down Expand Up @@ -439,20 +397,6 @@ def test_cache():
assert len(cache.data) == 2


@pytest.mark.parametrize("use_cftime", [True, False])
@pytest.mark.parametrize("func", ["count", "mean"])
def test_datetime_array_reduce(use_cftime, func, engine):
time = xr.DataArray(
xr.date_range("2009-01-01", "2012-12-31", use_cftime=use_cftime),
dims=("time",),
name="time",
)
expected = getattr(time.resample(time="YS"), func)()
with pytest.warns(DeprecationWarning):
actual = resample_reduce(time.resample(time="YS"), func=func, engine=engine)
assert_equal(expected, actual)


@requires_dask
@pytest.mark.parametrize("method", ["cohorts", "map-reduce"])
def test_groupby_bins_indexed_coordinate(method):
Expand Down

0 comments on commit 5d223a7

Please sign in to comment.