Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keep_variables keyword to open_dataset() #8450

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def open_dataset(
concat_characters: bool | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -494,6 +495,10 @@ def open_dataset(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
keep_variables: str or iterable of str, optional
A variable or list of variables to load from the dataset. This is
useful if you don't need all the variables in the file and don't
want to spend time loading them. Default is to load all variables.
inline_array: bool, default: False
How to include the array in the dask task graph.
By default(``inline_array=False``) the array is included in a task by
Expand Down Expand Up @@ -572,6 +577,7 @@ def open_dataset(
backend_ds = backend.open_dataset(
filename_or_obj,
drop_variables=drop_variables,
keep_variables=keep_variables,
**decoders,
**kwargs,
)
Expand All @@ -586,6 +592,7 @@ def open_dataset(
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
keep_variables=keep_variables,
**decoders,
**kwargs,
)
Expand All @@ -606,6 +613,7 @@ def open_dataarray(
concat_characters: bool | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -699,6 +707,10 @@ def open_dataarray(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
keep_variables: str or iterable of str, optional
A variable or list of variables to load from the dataset. This is
useful if you don't need all the variables in the file and don't
want to spend time loading them. Default is to load all variables.
inline_array: bool, default: False
How to include the array in the dask task graph.
By default(``inline_array=False``) the array is included in a task by
Expand Down Expand Up @@ -756,6 +768,7 @@ def open_dataarray(
chunks=chunks,
cache=cache,
drop_variables=drop_variables,
keep_variables=keep_variables,
inline_array=inline_array,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ class BackendEntrypoint:

- ``open_dataset`` method: it shall implement reading from file, variables
decoding and it returns an instance of :py:class:`~xarray.Dataset`.
It shall take in input at least ``filename_or_obj`` argument and
``drop_variables`` keyword argument.
It shall take in input at least ``filename_or_obj`` argument,
``keep_variables`` argument, and ``drop_variables`` keyword argument.
For more details see :ref:`RST open_dataset`.
- ``guess_can_open`` method: it shall return ``True`` if the backend is able to open
``filename_or_obj``, ``False`` otherwise. The implementation of this
Expand Down Expand Up @@ -490,6 +490,7 @@ def open_dataset(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
**kwargs: Any,
) -> Dataset:
"""
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
format=None,
Expand Down Expand Up @@ -427,6 +428,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group=None,
Expand Down Expand Up @@ -621,6 +622,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
application=None,
Expand Down Expand Up @@ -207,6 +208,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
mode="r",
Expand All @@ -155,6 +156,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
mode="r",
Expand All @@ -319,6 +320,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
) -> Dataset:
Expand All @@ -51,6 +52,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
8 changes: 8 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def open_zarr(
concat_characters=True,
decode_coords=True,
drop_variables=None,
keep_variables=None,
consolidated=None,
overwrite_encoded_chunks=False,
chunk_store=None,
Expand Down Expand Up @@ -836,6 +837,10 @@ def open_zarr(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
keep_variables: str or iterable of str, optional
A variable or list of variables to load from the dataset. This is
useful if you don't need all the variables in the file and don't
want to spend time loading them. Default is to load all variables.
consolidated : bool, optional
Whether to open the store using zarr's consolidated metadata
capability. Only works for stores that have already been consolidated.
Expand Down Expand Up @@ -933,6 +938,7 @@ def open_zarr(
engine="zarr",
chunks=chunks,
drop_variables=drop_variables,
keep_variables=keep_variables,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
backend_kwargs=backend_kwargs,
Expand Down Expand Up @@ -977,6 +983,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
keep_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group=None,
Expand Down Expand Up @@ -1011,6 +1018,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
16 changes: 15 additions & 1 deletion xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
T_Variables = Mapping[Any, Variable]
T_Attrs = MutableMapping[Any, Any]
T_DropVariables = Union[str, Iterable[Hashable], None]
T_KeepVariables = Union[str, Iterable[Hashable], None]
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]


Expand Down Expand Up @@ -380,6 +381,7 @@ def decode_cf_variables(
decode_times: bool = True,
decode_coords: bool | Literal["coordinates", "all"] = True,
drop_variables: T_DropVariables = None,
keep_variables: T_KeepVariables = None,
use_cftime: bool | None = None,
decode_timedelta: bool | None = None,
) -> tuple[T_Variables, T_Attrs, set[Hashable]]:
Expand Down Expand Up @@ -410,13 +412,19 @@ def stackable(dim: Hashable) -> bool:
drop_variables = []
drop_variables = set(drop_variables)

if isinstance(keep_variables, str):
keep_variables = [keep_variables]
keep_variables = set(keep_variables)

# Time bounds coordinates might miss the decoding attributes
if decode_times:
_update_bounds_attributes(variables)

new_vars = {}
for k, v in variables.items():
if k in drop_variables:
if k in drop_variables or (
keep_variables is not None and k not in keep_variables
):
continue
stack_char_dim = (
concat_characters
Expand Down Expand Up @@ -496,6 +504,7 @@ def decode_cf(
decode_times: bool = True,
decode_coords: bool | Literal["coordinates", "all"] = True,
drop_variables: T_DropVariables = None,
keep_variables: T_KeepVariables = None,
use_cftime: bool | None = None,
decode_timedelta: bool | None = None,
) -> Dataset:
Expand Down Expand Up @@ -527,6 +536,10 @@ def decode_cf(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
keep_variables: str or iterable of str, optional
A variable or list of variables to load from the dataset. This is
useful if you don't need all the variables in the file and don't
want to spend time loading them. Default is to load all variables.
use_cftime : bool, optional
Only relevant if encoded dates come from a standard calendar
(e.g. "gregorian", "proleptic_gregorian", "standard", or not
Expand Down Expand Up @@ -574,6 +587,7 @@ def decode_cf(
decode_times,
decode_coords,
drop_variables=drop_variables,
keep_variables=keep_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_backends_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def open_dataset(
class PassThroughBackendEntrypoint(xr.backends.BackendEntrypoint):
"""Access an object passed to the `open_dataset` method."""

def open_dataset(self, dataset, *, drop_variables=None):
def open_dataset(self, dataset, *, drop_variables=None, keep_variables=None):
"""Return the first argument."""
return dataset

Expand Down
37 changes: 37 additions & 0 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"time": ("time", [0.0], {"units": "hours since 2017-01-01"}),
}
)
decoded = conventions.decode_cf(orig, decode_coords=True)

Check failure on line 95 in xarray/tests/test_conventions.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 bare-minimum

test_decode_cf_variable_with_mismatched_coordinates TypeError: 'NoneType' object is not iterable
assert decoded["foo"].encoding["coordinates"] == "XTIME XLONG XLAT"
assert list(decoded.coords.keys()) == ["XLONG", "XLAT", "time"]

Expand Down Expand Up @@ -327,6 +327,43 @@
assert_identical(expected, actual)
assert_identical(expected, actual2)

def test_decode_cf_with_keep_variables(self) -> None:
original = Dataset(
{
"t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}),
"x": ("x", [9, 8, 7], {"units": "km"}),
"foo": (
("t", "x"),
[[0, 0, 0], [1, 1, 1], [2, 2, 2]],
{"units": "bar"},
),
"y": ("t", [5, 10, -999], {"_FillValue": -999}),
}
)
expected = Dataset(
{
"t": pd.date_range("2000-01-01", periods=3),
"foo": (
("t", "x"),
[[0, 0, 0], [1, 1, 1], [2, 2, 2]],
{"units": "bar"},
),
"y": ("t", [5, 10, np.nan]),
}
)
expected2 = Dataset(
{
"t": pd.date_range("2000-01-01", periods=3),
}
)
expected3 = Dataset()
actual = conventions.decode_cf(original, keep_variables=("t", "foo", "y"))
actual2 = conventions.decode_cf(original, keep_variables="t")
actual3 = conventions.decode_cf(original, keep_variables=[])
assert_identical(expected, actual)
assert_identical(expected2, actual2)
assert_identical(expected3, actual3)

@pytest.mark.filterwarnings("ignore:Ambiguous reference date string")
def test_invalid_time_units_raises_eagerly(self) -> None:
ds = Dataset({"time": ("time", [0, 1], {"units": "foobar since 123"})})
Expand Down
Loading