diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 48b08aacebf..e955a2de669 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -70,8 +70,11 @@ ErrorOptions, ErrorOptionsWithWarn, NetcdfWriteModes, + T_ChunkDimFreq, + T_ChunksFreq, ZarrWriteModes, ) + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint # """ # DEVELOPERS' NOTE @@ -1997,3 +2000,82 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: f"/{path}" if path != "." else "/": get_chunksizes(node.variables.values()) for path, node in self.subtree_with_keys } + + def chunk( + self, + chunks: T_ChunksFreq = {}, # noqa: B006 # {} even though it's technically unsafe, is being used intentionally here (#4667) + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + **chunks_kwargs: T_ChunkDimFreq, + ) -> Self: + """Coerce all arrays in all groups in this tree into dask arrays with the given + chunks. + + Non-dask arrays in this tree will be converted to dask arrays. Dask + arrays will be rechunked to the given chunk sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or + ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``. + name_prefix : str, default: "xarray-" + Prefix for the name of any new dask arrays. + token : str, optional + Token uniquely identifying this dataset. + lock : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + inline_array: bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided + + Returns + ------- + chunked : xarray.DataTree + + See Also + -------- + Dataset.chunk + Dataset.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + return DataTree.from_dict( + { + path: node.dataset.chunk( + chunks, + name_prefix=name_prefix, + token=token, + lock=lock, + inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + **chunks_kwargs, + ) + for path, node in self.subtree_with_keys + }, + name=self.name, + ) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 43a45389b17..e17f7ecee25 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2274,3 +2274,26 @@ def test_compute(self): assert actual.chunksizes == expected_chunksizes, "mismatching chunksizes" assert tree.chunksizes == original_chunksizes, "original tree was modified" + + def test_chunk(self): + ds1 = xr.Dataset({"a": ("x", np.arange(10))}) + ds2 = xr.Dataset({"b": ("y", np.arange(5))}) + ds3 = xr.Dataset({"c": ("z", np.arange(4))}) + ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) + + expected = xr.DataTree.from_dict( + { + "/": ds1.chunk({"x": 5}), + "/group1": ds2.chunk({"y": 3}), + "/group2": ds3.chunk({"z": 2}), + "/group1/subgroup1": ds4.chunk({"x": 5}), + } + ) + + tree = xr.DataTree.from_dict( + {"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4} + ) + actual = tree.chunk({"x": 5, "y": 3, "z": 2}) + + assert_identical(actual, expected) + assert actual.chunksizes == expected.chunksizes