From 766da3480f50d7672fe1a7c1cdf3aa32d8181fcf Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Mon, 18 Dec 2023 14:30:18 -0500 Subject: [PATCH 1/7] Generalize cumulative reduction (scan) to non-dask types (#8019) * add scan to ChunkManager ABC * implement scan for dask using cumreduction * generalize push to work for non-dask chunked arrays * whatsnew * fix importerror * Allow arbitrary kwargs Co-authored-by: Deepak Cherian * Type hint return value of T_ChunkedArray Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Type hint return value of Dask array * ffill -> bfill in doc/whats-new.rst Co-authored-by: Deepak Cherian * hopefully fix docs warning --------- Co-authored-by: Deepak Cherian Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- doc/whats-new.rst | 4 ++++ xarray/core/daskmanager.py | 22 +++++++++++++++++++++ xarray/core/parallelcompat.py | 37 +++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4188af98e3f..c0917b7443b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -589,6 +589,10 @@ Internal Changes - :py:func:`as_variable` now consistently includes the variable name in any exceptions raised. (:pull:`7995`). By `Peter Hill `_ +- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`, + potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to + use non-dask chunked array types. + (:pull:`8019`) By `Tom Nicholas `_. - :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to `coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`). `By Ian Carroll `_. diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py index 56d8dc9e23a..efa04bc3df2 100644 --- a/xarray/core/daskmanager.py +++ b/xarray/core/daskmanager.py @@ -97,6 +97,28 @@ def reduction( keepdims=keepdims, ) + def scan( + self, + func: Callable, + binop: Callable, + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: np.dtype | None = None, + **kwargs, + ) -> DaskArray: + from dask.array.reductions import cumreduction + + return cumreduction( + func, + binop, + ident, + arr, + axis=axis, + dtype=dtype, + **kwargs, + ) + def apply_gufunc( self, func: Callable, diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 333059e00ae..37542925dde 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -403,6 +403,43 @@ def reduction( """ raise NotImplementedError() + def scan( + self, + func: Callable, + binop: Callable, + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: np.dtype | None = None, + **kwargs, + ) -> T_ChunkedArray: + """ + General version of a 1D scan, also known as a cumulative array reduction. + + Used in ``ffill`` and ``bfill`` in xarray. + + Parameters + ---------- + func: callable + Cumulative function like np.cumsum or np.cumprod + binop: callable + Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` + ident: Number + Associated identity like ``np.cumsum->0`` or ``np.cumprod->1`` + arr: dask Array + axis: int, optional + dtype: dtype + + Returns + ------- + Chunked array + + See also + -------- + dask.array.cumreduction + """ + raise NotImplementedError() + @abstractmethod def apply_gufunc( self, From 219ef0ce5e5c38f6033b285c356085ea0cce61e5 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 18 Dec 2023 13:30:40 -0800 Subject: [PATCH 2/7] Offer a fixture for unifying DataArray & Dataset tests (#8533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Cumulative aggregation Offer a fixture for unifying `DataArray` & `Dataset` tests (stacked on #8512, worth reviewing after that's merged) Some tests are literally copy & pasted between DataArray & Dataset tests. This change allows them to use a single test. Not everything will work — sometimes we want to check specifics — but sometimes they will... --- xarray/tests/conftest.py | 43 +++++++++++++++++++++++ xarray/tests/test_rolling.py | 67 ++++++++++++++---------------------- 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 6a8cf008f9f..f153c2f4dc0 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pandas as pd import pytest @@ -77,3 +79,44 @@ def da(request, backend): return da else: raise ValueError + + +@pytest.fixture(params=[Dataset, DataArray]) +def type(request): + return request.param + + +@pytest.fixture(params=[1]) +def d(request, backend, type) -> DataArray | Dataset: + """ + For tests which can test either a DataArray or a Dataset. + """ + result: DataArray | Dataset + if request.param == 1: + ds = Dataset( + dict( + a=(["x", "z"], np.arange(24).reshape(2, 12)), + b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)), + ), + dict( + x=("x", np.linspace(0, 1.0, 2)), + y=range(3), + z=("z", pd.date_range("2000-01-01", periods=12)), + w=("x", ["a", "b"]), + ), + ) + if type == DataArray: + result = ds["a"].assign_coords(w=ds.coords["w"]) + elif type == Dataset: + result = ds + else: + raise ValueError + else: + raise ValueError + + if backend == "dask": + return result.chunk() + elif backend == "numpy": + return result + else: + raise ValueError diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 645ec1f85e6..7cb2cd70d29 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -36,6 +36,31 @@ def compute_backend(request): yield request.param +@pytest.mark.parametrize("func", ["mean", "sum"]) +@pytest.mark.parametrize("min_periods", [1, 10]) +def test_cumulative(d, func, min_periods) -> None: + # One dim + result = getattr(d.cumulative("z", min_periods=min_periods), func)() + expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)() + assert_identical(result, expected) + + # Multiple dim + result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)() + expected = getattr( + d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods), + func, + )() + assert_identical(result, expected) + + +def test_cumulative_vs_cum(d) -> None: + result = d.cumulative("z").sum() + expected = d.cumsum("z") + # cumsum drops the coord of the dimension; cumulative doesn't + expected = expected.assign_coords(z=result["z"]) + assert_identical(result, expected) + + class TestDataArrayRolling: @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("center", [True, False]) @@ -485,29 +510,6 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None: ): da.rolling_exp(time=10, keep_attrs=True) - @pytest.mark.parametrize("func", ["mean", "sum"]) - @pytest.mark.parametrize("min_periods", [1, 20]) - def test_cumulative(self, da, func, min_periods) -> None: - # One dim - result = getattr(da.cumulative("time", min_periods=min_periods), func)() - expected = getattr( - da.rolling(time=da.time.size, min_periods=min_periods), func - )() - assert_identical(result, expected) - - # Multiple dim - result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)() - expected = getattr( - da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods), - func, - )() - assert_identical(result, expected) - - def test_cumulative_vs_cum(self, da) -> None: - result = da.cumulative("time").sum() - expected = da.cumsum("time") - assert_identical(result, expected) - class TestDatasetRolling: @pytest.mark.parametrize( @@ -832,25 +834,6 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)() assert_allclose(actual, expected) - @pytest.mark.parametrize("func", ["mean", "sum"]) - @pytest.mark.parametrize("ds", (2,), indirect=True) - @pytest.mark.parametrize("min_periods", [1, 10]) - def test_cumulative(self, ds, func, min_periods) -> None: - # One dim - result = getattr(ds.cumulative("time", min_periods=min_periods), func)() - expected = getattr( - ds.rolling(time=ds.time.size, min_periods=min_periods), func - )() - assert_identical(result, expected) - - # Multiple dim - result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)() - expected = getattr( - ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods), - func, - )() - assert_identical(result, expected) - @requires_numbagg class TestDatasetRollingExp: From b3890a3859993dc53064ff14c2362bb0134b7c56 Mon Sep 17 00:00:00 2001 From: Niclas Rieger <45175997+nicrie@users.noreply.github.com> Date: Tue, 19 Dec 2023 15:39:37 +0100 Subject: [PATCH 3/7] add xeofs to ecosystem.rst (#8561) Suggestion to include [xeofs](https://github.com/nicrie/xeofs) in the xarray ecosystem documentation. xeofs enables fully multidimensional PCA / EOF analysis and related techniques with large datasets, thanks to the integration of xarray and dask. References: - [Github repository](https://github.com/nicrie/xeofs) - [Documentation](https://xeofs.readthedocs.io/en/latest/) - [JOSS review](https://github.com/openjournals/joss-reviews/issues/6060) --- doc/ecosystem.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index fc5ae963a1d..561e9cdb5b2 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -78,6 +78,7 @@ Extend xarray capabilities - `xarray-dataclasses `_: xarray extension for typed DataArray and Dataset creation. - `xarray_einstats `_: Statistics, linear algebra and einops for xarray - `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). +- `xeofs `_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data. - `xpublish `_: Publish Xarray Datasets via a Zarr compatible REST API. - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. From b4444388cb0647c4375d6a364290e4fa5e5f94ba Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 20 Dec 2023 10:11:16 -0700 Subject: [PATCH 4/7] Adapt map_blocks to use new Coordinates API (#8560) * Adapt map_blocks to use new Coordinates API * cleanup * typing fixes * optimize * small cleanups * Typing fixes --- xarray/core/coordinates.py | 2 +- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- xarray/core/parallel.py | 89 ++++++++++++++++++++++++-------------- xarray/tests/test_dask.py | 19 ++++++++ 5 files changed, 79 insertions(+), 35 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index cdf1d354be6..c59c5deba16 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates): :py:class:`~xarray.Coordinates` object is passed, its indexes will be added to the new created object. indexes: dict-like, optional - Mapping of where keys are coordinate names and values are + Mapping where keys are coordinate names and values are :py:class:`~xarray.indexes.Index` objects. If None (default), pandas indexes will be created for each dimension coordinate. Passing an empty dictionary will skip this default behavior. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0335ad3bdda..0f245ff464b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -80,7 +80,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None try: from dask.delayed import Delayed except ImportError: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9ec39e74ad1..a6fc0e2ca18 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -171,7 +171,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None # list of attributes of pd.DatetimeIndex that are ndarrays of time info diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f971556b3f7..ef505b55345 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -4,19 +4,29 @@ import itertools import operator from collections.abc import Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict import numpy as np from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.indexes import Index +from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection if TYPE_CHECKING: from xarray.core.types import T_Xarray +class ExpectedDict(TypedDict): + shapes: dict[Hashable, int] + coords: set[Hashable] + data_vars: set[Hashable] + indexes: dict[Hashable, Index] + + def unzip(iterable): return zip(*iterable) @@ -31,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset): def check_result_variables( - result: DataArray | Dataset, expected: Mapping[str, Any], kind: str + result: DataArray | Dataset, + expected: ExpectedDict, + kind: Literal["coords", "data_vars"], ): if kind == "coords": nice_str = "coordinate" @@ -254,7 +266,7 @@ def _wrapper( args: list, kwargs: dict, arg_is_array: Iterable[bool], - expected: dict, + expected: ExpectedDict, ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -345,33 +357,45 @@ def _wrapper( for arg in aligned ) + merged_coordinates = merge([arg.coords for arg in aligned]).coords + _, npargs = unzip( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg._indexes) + coordinates: Coordinates if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template._indexes) - preserved_indexes = template_indexes & set(input_indexes) - new_indexes = template_indexes - set(input_indexes) - indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template._indexes[k] for k in new_indexes}) + template_coords = set(template.coords) + preserved_coord_vars = template_coords & set(merged_coordinates) + new_coord_vars = template_coords - set(merged_coordinates) + + preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars] + # preserved_coords contains all coordinates bariables that share a dimension + # with any index variable in preserved_indexes + # Drop any unneeded vars in a second pass, this is required for e.g. + # if the mapped function were to drop a non-dimension coordinate variable. + preserved_coords = preserved_coords.drop_vars( + tuple(k for k in preserved_coords.variables if k not in template_coords) + ) + + coordinates = merge( + (preserved_coords, template.coords.to_dataset()[new_coord_vars]) + ).coords output_chunks: Mapping[Hashable, tuple[int, ...]] = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template._indexes) + coordinates = template.coords output_chunks = template.chunksizes if not output_chunks: raise ValueError( @@ -473,6 +497,9 @@ def subset_dataset_to_block( return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + # variable names that depend on the computation. Currently, indexes + # cannot be modified in the mapped function, so we exclude thos + computed_variables = set(template.variables) - set(coordinates.xindexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index @@ -485,19 +512,23 @@ def subset_dataset_to_block( for isxr, arg in zip(is_xarray, npargs) ] - # expected["shapes", "coords", "data_vars", "indexes"] are used to # raise nice error messages in _wrapper - expected = {} - # input chunk 0 along a dimension maps to output chunk 0 along the same dimension - # even if length of dimension is changed by the applied function - expected["shapes"] = { - k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks - } - expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] - expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - expected["indexes"] = { - dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in indexes + expected: ExpectedDict = { + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + "shapes": { + k: output_chunks[k][v] + for k, v in chunk_index.items() + if k in output_chunks + }, + "data_vars": set(template.data_vars.keys()), + "coords": set(template.coords.keys()), + "indexes": { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in coordinates.xindexes + }, } from_wrapper = (gname,) + chunk_tuple @@ -505,9 +536,8 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} - for name, variable in template.variables.items(): - if name in indexes: - continue + for name in computed_variables: + variable = template.variables[name] gname_l = f"{name}-{gname}" var_key_map[name] = gname_l @@ -543,12 +573,7 @@ def subset_dataset_to_block( }, ) - # TODO: benbovy - flexible indexes: make it work with custom indexes - # this will need to pass both indexes and coords to the Dataset constructor - result = Dataset( - coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, - attrs=template.attrs, - ) + result = Dataset(coords=coordinates, attrs=template.attrs) for index in result._indexes: result[index].attrs = template[index].attrs diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c2a77c97d85..137d6020829 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1367,6 +1367,25 @@ def test_map_blocks_da_ds_with_template(obj): assert_identical(actual, template) +def test_map_blocks_roundtrip_string_index(): + ds = xr.Dataset( + {"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]} + ).chunk(label=1) + assert ds.label.dtype == np.dtype(" Date: Thu, 21 Dec 2023 03:08:00 +0100 Subject: [PATCH 5/7] FIX: reverse index output of bottleneck move_argmax/move_argmin functions (#8552) * FIX: reverse index output of bottleneck move_argmax/move_argmin functions, add move_argmax/move_argmin to bottleneck tests * add whats-new.rst entry --- doc/whats-new.rst | 3 +++ xarray/core/rolling.py | 5 +++++ xarray/tests/test_rolling.py | 12 ++++++++++-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c0917b7443b..b7a31b1e7bd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`). + By `Kai Mühlbauer `_. + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 819c31642d0..2188599962a 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -596,6 +596,11 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): values = func( padded.data, window=self.window[0], min_count=min_count, axis=axis ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func in [bottleneck.move_argmin, bottleneck.move_argmax]: + values = self.window[0] - 1 - values if self.center[0]: values = values[valid] diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 7cb2cd70d29..6db4d38b53e 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -120,7 +120,9 @@ def test_rolling_properties(self, da) -> None: ): da.rolling(foo=2) - @pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median")) + @pytest.mark.parametrize( + "name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax") + ) @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) @pytest.mark.parametrize("backend", ["numpy"], indirect=True) @@ -133,9 +135,15 @@ def test_rolling_wrapped_bottleneck( func_name = f"move_{name}" actual = getattr(rolling_obj, name)() + window = 7 expected = getattr(bn, func_name)( - da.values, window=7, axis=1, min_count=min_periods + da.values, window=window, axis=1, min_count=min_periods ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func_name in ["move_argmin", "move_argmax"]: + expected = window - 1 - expected # Using assert_allclose because we get tiny (1e-17) differences in numbagg. np.testing.assert_allclose(actual.values, expected) From a04900d724f05b3ab8757144bcd23845ff5eee94 Mon Sep 17 00:00:00 2001 From: Markel Date: Thu, 21 Dec 2023 16:24:15 +0100 Subject: [PATCH 6/7] Support for the new compression arguments. (#7551) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support for the new compression arguments. Use a dict for the arguments and update it with the encoding, so all variables are passed. * significant_digit and other missing keys added Should close #7634 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test for the new compression argument * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move the new test to TestNetCDF4Data * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * simplify this line (code review) * Added entry to whats-new Also removed an unnecesary call to monkeypatch fixture. * bump netcdf4 to 1.6.2 in min-all-deps.yml * parametrize compression in test * Revert "bump netcdf4 to 1.6.2 in min-all-deps.yml" This reverts commit c2ce8d5c92e8f5823ad6bbafdd79b5c7f1148d5e. * check netCDF4 version and skip test if netcdf4 version <1.6.2 * fix typing * Larger chunks to avoid random blosc errors With smaller chunks it raises "Blosc_FIlter Error: blosc_filter: Buffer is uncompressible." one out of three times. * use decorator to skip old netCDF4 versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove stale version-property * fix whats-new.rst * fix requires-decorator * fix for asserts of other tests that use test data * Apply suggestions from code review * Update xarray/tests/__init__.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/tests/test_backends.py --------- Co-authored-by: garciam Co-authored-by: Deepak Cherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ryan Abernathey Co-authored-by: Kai Mühlbauer Co-authored-by: Kai Mühlbauer --- doc/whats-new.rst | 4 ++ xarray/backends/netCDF4_.py | 25 ++++++++---- xarray/tests/__init__.py | 32 +++++++++++---- xarray/tests/test_backends.py | 73 ++++++++++++++++++++++++++++++++++- 4 files changed, 117 insertions(+), 17 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b7a31b1e7bd..24268406406 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,10 @@ New Features - :py:meth:`xr.cov` and :py:meth:`xr.corr` now support using weights (:issue:`8527`, :pull:`7392`). By `Llorenç Lledó `_. +- Accept the compression arguments new in netCDF 1.6.0 in the netCDF4 backend. + See `netCDF4 documentation `_ for details. + By `Markel García-Díez `_. (:issue:`6929`, :pull:`7551`) Note that some + new compression filters needs plugins to be installed which may not be available in all netCDF distributions. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 1aee4c1c726..cf753828242 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -257,6 +257,12 @@ def _extract_nc4_variable_encoding( "_FillValue", "dtype", "compression", + "significant_digits", + "quantize_mode", + "blosc_shuffle", + "szip_coding", + "szip_pixels_per_block", + "endian", } if lsd_okay: valid_encodings.add("least_significant_digit") @@ -497,20 +503,23 @@ def prepare_variable( if name in self.ds.variables: nc4_var = self.ds.variables[name] else: - nc4_var = self.ds.createVariable( + default_args = dict( varname=name, datatype=datatype, dimensions=variable.dims, - zlib=encoding.get("zlib", False), - complevel=encoding.get("complevel", 4), - shuffle=encoding.get("shuffle", True), - fletcher32=encoding.get("fletcher32", False), - contiguous=encoding.get("contiguous", False), - chunksizes=encoding.get("chunksizes"), + zlib=False, + complevel=4, + shuffle=True, + fletcher32=False, + contiguous=False, + chunksizes=None, endian="native", - least_significant_digit=encoding.get("least_significant_digit"), + least_significant_digit=None, fill_value=fill_value, ) + default_args.update(encoding) + default_args.pop("_FillValue", None) + nc4_var = self.ds.createVariable(**default_args) nc4_var.setncatts(attrs) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index b3a31b28016..7e173528222 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -2,6 +2,7 @@ import importlib import platform +import string import warnings from contextlib import contextmanager, nullcontext from unittest import mock # noqa: F401 @@ -112,6 +113,10 @@ def _importorskip( not has_h5netcdf_ros3[0], reason="requires h5netcdf 1.3.0" ) +has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip( + "netCDF4", "1.6.2" +) + # change some global options for tests set_options(warn_for_unclosed_files=True) @@ -262,28 +267,41 @@ def assert_allclose(a, b, check_default_indexes=True, **kwargs): xarray.testing._assert_internal_invariants(b, check_default_indexes) -def create_test_data(seed: int | None = None, add_attrs: bool = True) -> Dataset: +_DEFAULT_TEST_DIM_SIZES = (8, 9, 10) + + +def create_test_data( + seed: int | None = None, + add_attrs: bool = True, + dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES, +) -> Dataset: rs = np.random.RandomState(seed) _vars = { "var1": ["dim1", "dim2"], "var2": ["dim1", "dim2"], "var3": ["dim3", "dim1"], } - _dims = {"dim1": 8, "dim2": 9, "dim3": 10} + _dims = {"dim1": dim_sizes[0], "dim2": dim_sizes[1], "dim3": dim_sizes[2]} obj = Dataset() obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"])) - obj["dim3"] = ("dim3", list("abcdefghij")) + if _dims["dim3"] > 26: + raise RuntimeError( + f'Not enough letters for filling this dimension size ({_dims["dim3"]})' + ) + obj["dim3"] = ("dim3", list(string.ascii_lowercase[0 : _dims["dim3"]])) obj["time"] = ("time", pd.date_range("2000-01-01", periods=20)) for v, dims in sorted(_vars.items()): data = rs.normal(size=tuple(_dims[d] for d in dims)) obj[v] = (dims, data) if add_attrs: obj[v].attrs = {"foo": "variable"} - obj.coords["numbers"] = ( - "dim3", - np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"), - ) + + if dim_sizes == _DEFAULT_TEST_DIM_SIZES: + numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") + else: + numbers_values = np.random.randint(0, 3, _dims["dim3"], dtype="int64") + obj.coords["numbers"] = ("dim3", numbers_values) obj.encoding = {"foo": "bar"} assert all(obj.data.flags.writeable for obj in obj.variables.values()) return obj diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 062f5de7d20..a8722d59659 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -72,6 +72,7 @@ requires_h5netcdf_ros3, requires_iris, requires_netCDF4, + requires_netCDF4_1_6_2_or_above, requires_pydap, requires_pynio, requires_scipy, @@ -1486,7 +1487,7 @@ def test_dump_and_open_encodings(self) -> None: assert ds.variables["time"].getncattr("units") == units assert_array_equal(ds.variables["time"], np.arange(10) + 4) - def test_compression_encoding(self) -> None: + def test_compression_encoding_legacy(self) -> None: data = create_test_data() data["var2"].encoding.update( { @@ -1767,6 +1768,74 @@ def test_setncattr_string(self) -> None: assert_array_equal(one_element_list_of_strings, totest.attrs["bar"]) assert one_string == totest.attrs["baz"] + @pytest.mark.parametrize( + "compression", + [ + None, + "zlib", + "szip", + "zstd", + "blosc_lz", + "blosc_lz4", + "blosc_lz4hc", + "blosc_zlib", + "blosc_zstd", + ], + ) + @requires_netCDF4_1_6_2_or_above + @pytest.mark.xfail(ON_WINDOWS, reason="new compression not yet implemented") + def test_compression_encoding(self, compression: str | None) -> None: + data = create_test_data(dim_sizes=(20, 80, 10)) + encoding_params: dict[str, Any] = dict(compression=compression, blosc_shuffle=1) + data["var2"].encoding.update(encoding_params) + data["var2"].encoding.update( + { + "chunksizes": (20, 40), + "original_shape": data.var2.shape, + "blosc_shuffle": 1, + "fletcher32": False, + } + ) + with self.roundtrip(data) as actual: + expected_encoding = data["var2"].encoding.copy() + # compression does not appear in the retrieved encoding, that differs + # from the input encoding. shuffle also chantges. Here we modify the + # expected encoding to account for this + compression = expected_encoding.pop("compression") + blosc_shuffle = expected_encoding.pop("blosc_shuffle") + if compression is not None: + if "blosc" in compression and blosc_shuffle: + expected_encoding["blosc"] = { + "compressor": compression, + "shuffle": blosc_shuffle, + } + expected_encoding["shuffle"] = False + elif compression == "szip": + expected_encoding["szip"] = { + "coding": "nn", + "pixels_per_block": 8, + } + expected_encoding["shuffle"] = False + else: + # This will set a key like zlib=true which is what appears in + # the encoding when we read it. + expected_encoding[compression] = True + if compression == "zstd": + expected_encoding["shuffle"] = False + else: + expected_encoding["shuffle"] = False + + actual_encoding = actual["var2"].encoding + assert expected_encoding.items() <= actual_encoding.items() + if ( + encoding_params["compression"] is not None + and "blosc" not in encoding_params["compression"] + ): + # regression test for #156 + expected = data.isel(dim1=0) + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) + @pytest.mark.skip(reason="https://github.com/Unidata/netcdf4-python/issues/1195") def test_refresh_from_disk(self) -> None: super().test_refresh_from_disk() @@ -4518,7 +4587,7 @@ def test_extract_nc4_variable_encoding(self) -> None: assert {} == encoding @requires_netCDF4 - def test_extract_nc4_variable_encoding_netcdf4(self, monkeypatch): + def test_extract_nc4_variable_encoding_netcdf4(self): # New netCDF4 1.6.0 compression argument. var = xr.Variable(("x",), [1, 2, 3], {}, {"compression": "szlib"}) _extract_nc4_variable_encoding(var, backend="netCDF4", raise_on_invalid=True) From 03ec3cb4d3fc97dd31b2269887ddc63a13ee518c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 21 Dec 2023 08:24:51 -0700 Subject: [PATCH 7/7] Fix mypy type ignore (#8564) * Fix mypy type ignore * Better ignore * more specific --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0f245ff464b..dcdc9edbd26 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -84,7 +84,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None # type: ignore + Delayed = None # type: ignore[misc,assignment] try: from iris.cube import Cube as iris_Cube except ImportError: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a6fc0e2ca18..b4460e956df 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -167,7 +167,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None # type: ignore + Delayed = None # type: ignore[misc,assignment] try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: