diff --git a/intake_esm/core.py b/intake_esm/core.py index 85a40880..3b994716 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -7,6 +7,7 @@ import pandas as pd import pydantic import xarray as xr +import xcollection as xc from fastprogress.fastprogress import progress_bar from intake.catalog import Catalog @@ -247,6 +248,7 @@ def __dir__(self) -> typing.List[str]: rv = [ 'df', 'to_dataset_dict', + 'to_collection', 'to_dask', 'keys', 'serialize', @@ -573,6 +575,88 @@ def to_dataset_dict( self.datasets = self._create_derived_variables(datasets, skip_on_error) return self.datasets + @pydantic.validate_arguments + def to_collection( + self, + xarray_open_kwargs: typing.Dict[str, typing.Any] = None, + xarray_combine_by_coords_kwargs: typing.Dict[str, typing.Any] = None, + preprocess: typing.Callable = None, + storage_options: typing.Dict[pydantic.StrictStr, typing.Any] = None, + progressbar: pydantic.StrictBool = None, + aggregate: pydantic.StrictBool = None, + skip_on_error: pydantic.StrictBool = False, + **kwargs, + ) -> xc.Collection: + """ + Load catalog entries into a Collection of xarray datasets. + + Parameters + ---------- + xarray_open_kwargs : dict + Keyword arguments to pass to :py:func:`~xarray.open_dataset` function + xarray_combine_by_coords_kwargs: : dict + Keyword arguments to pass to :py:func:`~xarray.combine_by_coords` function. + preprocess : callable, optional + If provided, call this function on each dataset prior to aggregation. + storage_options : dict, optional + Parameters passed to the backend file-system such as Google Cloud Storage, + Amazon Web Service S3. + progressbar : bool + If True, will print a progress bar to standard error (stderr) + when loading assets into :py:class:`~xarray.Dataset`. + aggregate : bool, optional + If False, no aggregation will be done. + skip_on_error : bool, optional + If True, skip datasets that cannot be loaded and/or variables we are unable to derive. + + Returns + ------- + dsets : Collection + A Collection of xarray :py:class:`~xarray.Dataset`. + + Examples + -------- + >>> import intake + >>> col = intake.open_esm_datastore("glade-cmip6.json") + >>> cat = col.search( + ... source_id=["BCC-CSM2-MR", "CNRM-CM6-1", "CNRM-ESM2-1"], + ... experiment_id=["historical", "ssp585"], + ... variable_id="pr", + ... table_id="Amon", + ... grid_label="gn", + ... ) + >>> dsets = cat.to_collection() + >>> dsets.keys() + dict_keys(['CMIP.BCC.BCC-CSM2-MR.historical.Amon.gn', 'ScenarioMIP.BCC.BCC-CSM2-MR.ssp585.Amon.gn']) + >>> dsets["CMIP.BCC.BCC-CSM2-MR.historical.Amon.gn"] + + Dimensions: (bnds: 2, lat: 160, lon: 320, member_id: 3, time: 1980) + Coordinates: + * lon (lon) float64 0.0 1.125 2.25 3.375 ... 355.5 356.6 357.8 358.9 + * lat (lat) float64 -89.14 -88.03 -86.91 -85.79 ... 86.91 88.03 89.14 + * time (time) object 1850-01-16 12:00:00 ... 2014-12-16 12:00:00 + * member_id (member_id) + lon_bnds (lon, bnds) float64 dask.array + time_bnds (time, bnds) object dask.array + pr (member_id, time, lat, lon) float32 dask.array + """ + + self.datasets = self.to_dataset_dict( + xarray_open_kwargs=xarray_open_kwargs, + xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs, + preprocess=preprocess, + storage_options=storage_options, + progressbar=progressbar, + aggregate=aggregate, + skip_on_error=skip_on_error, + **kwargs, + ) + self.datasets = xc.Collection(self.datasets) + return self.datasets + def to_dask(self, **kwargs) -> xr.Dataset: """ Convert result to dataset. diff --git a/requirements.txt b/requirements.txt index 39868d6f..630f89e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ requests>=2.24.0 xarray>=0.19,!=0.20.0,!=0.20.1 zarr>=2.5 pydantic>=1.8.2 +xcollection diff --git a/tests/test_core.py b/tests/test_core.py index 1e7d7319..8fb2fdac 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,6 +5,7 @@ import pydantic import pytest import xarray as xr +import xcollection as xc import intake_esm @@ -241,6 +242,39 @@ def test_to_dataset_dict(path, query, xarray_open_kwargs): assert ds.time.encoding +@pytest.mark.parametrize( + 'path, query, xarray_open_kwargs', + [ + ( + zarr_col_pangeo_cmip6, + dict( + variable_id=['pr'], + experiment_id='ssp370', + activity_id='AerChemMIP', + source_id='BCC-ESM1', + table_id='Amon', + grid_label='gn', + ), + {'consolidated': True, 'backend_kwargs': {'storage_options': {'token': 'anon'}}}, + ), + ( + cdf_col_sample_cmip6, + dict(source_id=['CNRM-ESM2-1', 'CNRM-CM6-1', 'BCC-ESM1'], variable_id=['tasmax']), + {'chunks': {'time': 1}}, + ), + ], +) +def test_to_collection(path, query, xarray_open_kwargs): + cat = intake.open_esm_datastore(path) + cat_sub = cat.search(**query) + coll = cat_sub.to_collection(xarray_open_kwargs=xarray_open_kwargs) + _, ds = coll.popitem() + assert 'member_id' in ds.dims + assert len(ds.__dask_keys__()) > 0 + assert ds.time.encoding + assert isinstance(coll, xc.Collection) + + @pytest.mark.parametrize( 'path, query, xarray_open_kwargs', [