From 6e14df62f0b01d8ca5b04bd0ed2b5ee45444265d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 11 May 2021 10:21:25 +0200 Subject: [PATCH 1/7] Flexible indexes: add Index base class and xindexes properties (#5102) * add IndexAdapter class + move PandasIndexAdapter * wip: xarray_obj.indexes -> IndexAdapter objects * fix more broken tests * fix merge glitch * fix group bins tests * add xindexes property Use it internally instead of indexes * rename IndexAdapter -> Index * rename _to_index_adpater (typo) -> _to_xindex * add Index.to_pandas_index() method Also improve xarray_obj.indexes property implementation * rename PandasIndexAdpater -> PandasIndex * update index type in tests * ensure .indexes only returns pd.Index objects * PandasIndex: normalize other index in cmp funcs * fix merge lint errors * fix PandasIndex union/intersection * [skip-ci] add TODO comment about index sizes * address more PR comments * [skip-ci] update what's new * fix coord_names normalization * move what's new entry to unreleased section --- doc/whats-new.rst | 5 +- xarray/core/alignment.py | 39 ++-- xarray/core/combine.py | 14 +- xarray/core/common.py | 5 +- xarray/core/coordinates.py | 22 ++- xarray/core/dataarray.py | 35 +++- xarray/core/dataset.py | 134 ++++++++----- xarray/core/indexes.py | 226 ++++++++++++++++++++-- xarray/core/indexing.py | 113 +---------- xarray/core/merge.py | 34 ++-- xarray/core/missing.py | 8 +- xarray/core/parallel.py | 30 +-- xarray/core/variable.py | 26 +-- xarray/testing.py | 5 +- xarray/tests/test_backends.py | 4 +- xarray/tests/test_cftimeindex.py | 2 +- xarray/tests/test_cftimeindex_resample.py | 8 +- xarray/tests/test_concat.py | 2 +- xarray/tests/test_conventions.py | 2 +- xarray/tests/test_dataarray.py | 44 +++-- xarray/tests/test_dataset.py | 79 +++++--- xarray/tests/test_variable.py | 8 +- 22 files changed, 534 insertions(+), 311 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 26c975c859e..3f81678b8d5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,10 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Explicit indexes refactor: add an ``xarray.Index`` base class and + ``Dataset.xindexes`` / ``DataArray.xindexes`` properties. Also rename + ``PandasIndexAdapter`` to ``PandasIndex``, which now inherits from + ``xarray.Index`` (:pull:`5102`). By `Benoit Bovy `_. .. _whats-new.0.18.0: @@ -268,7 +272,6 @@ Internal Changes (:pull:`5188`), (:pull:`5191`). By `Maximilian Roos `_. - .. _whats-new.0.17.0: v0.17.0 (24 Feb 2021) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 98cbadcb25c..f6e026c0109 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -17,9 +17,10 @@ import numpy as np import pandas as pd -from . import dtypes, utils +from . import dtypes +from .indexes import Index, PandasIndex from .indexing import get_indexer_nd -from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str +from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index from .variable import IndexVariable, Variable if TYPE_CHECKING: @@ -30,11 +31,11 @@ DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) -def _get_joiner(join): +def _get_joiner(join, index_cls): if join == "outer": - return functools.partial(functools.reduce, pd.Index.union) + return functools.partial(functools.reduce, index_cls.union) elif join == "inner": - return functools.partial(functools.reduce, pd.Index.intersection) + return functools.partial(functools.reduce, index_cls.intersection) elif join == "left": return operator.itemgetter(0) elif join == "right": @@ -63,7 +64,7 @@ def _override_indexes(objects, all_indexes, exclude): objects = list(objects) for idx, obj in enumerate(objects[1:]): new_indexes = {} - for dim in obj.indexes: + for dim in obj.xindexes: if dim not in exclude: new_indexes[dim] = all_indexes[dim][0] objects[idx + 1] = obj._overwrite_indexes(new_indexes) @@ -284,7 +285,7 @@ def align( if dim not in exclude: all_coords[dim].append(obj.coords[dim]) try: - index = obj.indexes[dim] + index = obj.xindexes[dim] except KeyError: unlabeled_dim_sizes[dim].add(obj.sizes[dim]) else: @@ -298,16 +299,19 @@ def align( # - It ensures it's possible to do operations that don't require alignment # on indexes with duplicate values (which cannot be reindexed with # pandas). This is useful, e.g., for overwriting such duplicate indexes. - joiner = _get_joiner(join) joined_indexes = {} for dim, matching_indexes in all_indexes.items(): if dim in indexes: - index = utils.safe_cast_to_index(indexes[dim]) + # TODO: benbovy - flexible indexes. maybe move this logic in util func + if isinstance(indexes[dim], Index): + index = indexes[dim] + else: + index = PandasIndex(safe_cast_to_index(indexes[dim])) if ( any(not index.equals(other) for other in matching_indexes) or dim in unlabeled_dim_sizes ): - joined_indexes[dim] = indexes[dim] + joined_indexes[dim] = index else: if ( any( @@ -318,6 +322,7 @@ def align( ): if join == "exact": raise ValueError(f"indexes along dimension {dim!r} are not equal") + joiner = _get_joiner(join, type(matching_indexes[0])) index = joiner(matching_indexes) # make sure str coords are not cast to object index = maybe_coerce_to_str(index, all_coords[dim]) @@ -327,6 +332,9 @@ def align( if dim in unlabeled_dim_sizes: unlabeled_sizes = unlabeled_dim_sizes[dim] + # TODO: benbovy - flexible indexes: expose a size property for xarray.Index? + # Some indexes may not have a defined size (e.g., built from multiple coords of + # different sizes) labeled_size = index.size if len(unlabeled_sizes | {labeled_size}) > 1: raise ValueError( @@ -469,7 +477,7 @@ def reindex_like_indexers( ValueError If any dimensions without labels have different sizes. """ - indexers = {k: v for k, v in other.indexes.items() if k in target.dims} + indexers = {k: v for k, v in other.xindexes.items() if k in target.dims} for dim in other.dims: if dim not in indexers and dim in target.dims: @@ -487,14 +495,14 @@ def reindex_like_indexers( def reindex_variables( variables: Mapping[Any, Variable], sizes: Mapping[Any, int], - indexes: Mapping[Any, pd.Index], + indexes: Mapping[Any, Index], indexers: Mapping, method: Optional[str] = None, tolerance: Any = None, copy: bool = True, fill_value: Optional[Any] = dtypes.NA, sparse: bool = False, -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Conform a dictionary of aligned variables onto a new set of variables, filling in missing values with NaN. @@ -559,10 +567,11 @@ def reindex_variables( "from that to be indexed along {:s}".format(str(indexer.dims), dim) ) - target = new_indexes[dim] = utils.safe_cast_to_index(indexers[dim]) + target = new_indexes[dim] = PandasIndex(safe_cast_to_index(indexers[dim])) if dim in indexes: - index = indexes[dim] + # TODO (benbovy - flexible indexes): support other indexes than pd.Index? + index = indexes[dim].to_pandas_index() if not index.is_unique: raise ValueError( diff --git a/xarray/core/combine.py b/xarray/core/combine.py index e907fc32c07..105e0a5a66c 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -69,13 +69,17 @@ def _infer_concat_order_from_coords(datasets): if dim in ds0: # Need to read coordinate values to do ordering - indexes = [ds.indexes.get(dim) for ds in datasets] + indexes = [ds.xindexes.get(dim) for ds in datasets] if any(index is None for index in indexes): raise ValueError( "Every dimension needs a coordinate for " "inferring concatenation order" ) + # TODO (benbovy, flexible indexes): all indexes should be Pandas.Index + # get pd.Index objects from Index objects + indexes = [index.array for index in indexes] + # If dimension coordinate values are same on every dataset then # should be leaving this dimension alone (it's just a "bystander") if not all(index.equals(indexes[0]) for index in indexes[1:]): @@ -801,9 +805,13 @@ def combine_by_coords( ) # Check the overall coordinates are monotonically increasing + # TODO (benbovy - flexible indexes): only with pandas.Index? for dim in concat_dims: - indexes = concatenated.indexes.get(dim) - if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing): + indexes = concatenated.xindexes.get(dim) + if not ( + indexes.array.is_monotonic_increasing + or indexes.array.is_monotonic_decreasing + ): raise ValueError( "Resulting object does not have monotonic" " global indexes along dimension {}".format(dim) diff --git a/xarray/core/common.py b/xarray/core/common.py index c9386c4e15f..e4a5264d8e6 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -406,7 +406,7 @@ def get_index(self, key: Hashable) -> pd.Index: raise KeyError(key) try: - return self.indexes[key] + return self.xindexes[key].to_pandas_index() except KeyError: return pd.Index(range(self.sizes[key]), name=key) @@ -1162,7 +1162,8 @@ def resample( category=FutureWarning, ) - if isinstance(self.indexes[dim_name], CFTimeIndex): + # TODO (benbovy - flexible indexes): update when CFTimeIndex is an xarray Index subclass + if isinstance(self.xindexes[dim_name].to_pandas_index(), CFTimeIndex): from .resample_cftime import CFTimeGrouper grouper = CFTimeGrouper(freq, closed, label, base, loffset) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 16eecef6efe..50be8a7f677 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -17,7 +17,7 @@ import pandas as pd from . import formatting, indexing -from .indexes import Indexes +from .indexes import Index, Indexes from .merge import merge_coordinates_without_align, merge_coords from .utils import Frozen, ReprObject, either_dict_or_kwargs from .variable import Variable @@ -52,6 +52,10 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]: def indexes(self) -> Indexes: return self._data.indexes # type: ignore[attr-defined] + @property + def xindexes(self) -> Indexes: + return self._data.xindexes # type: ignore[attr-defined] + @property def variables(self): raise NotImplementedError() @@ -157,7 +161,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: def update(self, other: Mapping[Hashable, Any]) -> None: other_vars = getattr(other, "variables", other) coords, indexes = merge_coords( - [self.variables, other_vars], priority_arg=1, indexes=self.indexes + [self.variables, other_vars], priority_arg=1, indexes=self.xindexes ) self._update_coords(coords, indexes) @@ -165,7 +169,7 @@ def _merge_raw(self, other, reflexive): """For use with binary arithmetic.""" if other is None: variables = dict(self.variables) - indexes = dict(self.indexes) + indexes = dict(self.xindexes) else: coord_list = [self, other] if not reflexive else [other, self] variables, indexes = merge_coordinates_without_align(coord_list) @@ -180,7 +184,9 @@ def _merge_inplace(self, other): # don't include indexes in prioritized, because we didn't align # first and we want indexes to be checked prioritized = { - k: (v, None) for k, v in self.variables.items() if k not in self.indexes + k: (v, None) + for k, v in self.variables.items() + if k not in self.xindexes } variables, indexes = merge_coordinates_without_align( [self, other], prioritized @@ -265,7 +271,7 @@ def to_dataset(self) -> "Dataset": return self._data._copy_listed(names) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, pd.Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] ) -> None: from .dataset import calculate_dimensions @@ -285,7 +291,7 @@ def _update_coords( # TODO(shoyer): once ._indexes is always populated by a dict, modify # it to update inplace instead. - original_indexes = dict(self._data.indexes) + original_indexes = dict(self._data.xindexes) original_indexes.update(indexes) self._data._indexes = original_indexes @@ -328,7 +334,7 @@ def __getitem__(self, key: Hashable) -> "DataArray": return self._data._getitem_coord(key) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, pd.Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] ) -> None: from .dataset import calculate_dimensions @@ -343,7 +349,7 @@ def _update_coords( # TODO(shoyer): once ._indexes is always populated by a dict, modify # it to update inplace instead. - original_indexes = dict(self._data.indexes) + original_indexes = dict(self._data.xindexes) original_indexes.update(indexes) self._data._indexes = original_indexes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 740493b863c..21daed1cec1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -51,7 +51,7 @@ ) from .dataset import Dataset, split_indexes from .formatting import format_item -from .indexes import Indexes, default_indexes, propagate_indexes +from .indexes import Index, Indexes, PandasIndex, default_indexes, propagate_indexes from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords from .options import OPTIONS, _get_keep_attrs @@ -345,7 +345,7 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic): _cache: Dict[str, Any] _coords: Dict[Any, Variable] _close: Optional[Callable[[], None]] - _indexes: Optional[Dict[Hashable, pd.Index]] + _indexes: Optional[Dict[Hashable, Index]] _name: Optional[Hashable] _variable: Variable @@ -478,7 +478,9 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": # switch from dimension to level names, if necessary dim_names: Dict[Any, str] = {} for dim, idx in indexes.items(): - if not isinstance(idx, pd.MultiIndex) and idx.name != dim: + # TODO: benbovy - flexible indexes: update when MultiIndex has its own class + pd_idx = idx.array + if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim: dim_names[dim] = idx.name if dim_names: obj = obj.rename(dim_names) @@ -772,7 +774,21 @@ def encoding(self, value: Mapping[Hashable, Any]) -> None: @property def indexes(self) -> Indexes: - """Mapping of pandas.Index objects used for label based indexing""" + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Dataset has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + DataArray.xindexes + + """ + return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + + @property + def xindexes(self) -> Indexes: + """Mapping of xarray Index objects used for label based indexing.""" if self._indexes is None: self._indexes = default_indexes(self._coords, self.dims) return Indexes(self._indexes) @@ -990,7 +1006,12 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray": if self._indexes is None: indexes = self._indexes else: - indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()} + # TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index) + # xarray Index needs a copy method. + indexes = { + k: PandasIndex(v.to_pandas_index().copy(deep=deep)) + for k, v in self._indexes.items() + } return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> "DataArray": @@ -2169,7 +2190,9 @@ def to_unstacked_dataset(self, dim, level=0): Dataset.to_stacked_array """ - idx = self.indexes[dim] + # TODO: benbovy - flexible indexes: update when MultIndex has its own + # class inheriting from xarray.Index + idx = self.xindexes[dim].to_pandas_index() if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6b9f297dee1..706ccbde8c4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -61,7 +61,9 @@ ) from .duck_array_ops import datetime_to_numeric from .indexes import ( + Index, Indexes, + PandasIndex, default_indexes, isel_variable_and_index, propagate_indexes, @@ -692,7 +694,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping): _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] _close: Optional[Callable[[], None]] - _indexes: Optional[Dict[Hashable, pd.Index]] + _indexes: Optional[Dict[Hashable, Index]] _variables: Dict[Hashable, Variable] __slots__ = ( @@ -1087,7 +1089,7 @@ def _replace( coord_names: Set[Hashable] = None, dims: Dict[Any, int] = None, attrs: Union[Dict[Hashable, Any], None, Default] = _default, - indexes: Union[Dict[Any, pd.Index], None, Default] = _default, + indexes: Union[Dict[Any, Index], None, Default] = _default, encoding: Union[dict, None, Default] = _default, inplace: bool = False, ) -> "Dataset": @@ -1136,7 +1138,7 @@ def _replace_with_new_dims( variables: Dict[Hashable, Variable], coord_names: set = None, attrs: Union[Dict[Hashable, Any], None, Default] = _default, - indexes: Union[Dict[Hashable, pd.Index], None, Default] = _default, + indexes: Union[Dict[Hashable, Index], None, Default] = _default, inplace: bool = False, ) -> "Dataset": """Replace variables with recalculated dimensions.""" @@ -1164,12 +1166,12 @@ def _replace_vars_and_dims( variables, coord_names, dims, attrs, indexes=None, inplace=inplace ) - def _overwrite_indexes(self, indexes: Mapping[Any, pd.Index]) -> "Dataset": + def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset": if not indexes: return self variables = self._variables.copy() - new_indexes = dict(self.indexes) + new_indexes = dict(self.xindexes) for name, idx in indexes.items(): variables[name] = IndexVariable(name, idx) new_indexes[name] = idx @@ -1178,8 +1180,9 @@ def _overwrite_indexes(self, indexes: Mapping[Any, pd.Index]) -> "Dataset": # switch from dimension to level names, if necessary dim_names: Dict[Hashable, str] = {} for dim, idx in indexes.items(): - if not isinstance(idx, pd.MultiIndex) and idx.name != dim: - dim_names[dim] = idx.name + pd_idx = idx.to_pandas_index() + if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim: + dim_names[dim] = pd_idx.name if dim_names: obj = obj.rename(dim_names) return obj @@ -1315,9 +1318,11 @@ def _level_coords(self) -> Dict[str, Hashable]: coordinate name. """ level_coords: Dict[str, Hashable] = {} - for name, index in self.indexes.items(): - if isinstance(index, pd.MultiIndex): - level_names = index.names + for name, index in self.xindexes.items(): + # TODO: benbovy - flexible indexes: update when MultIndex has its own xarray class. + pd_index = index.to_pandas_index() + if isinstance(pd_index, pd.MultiIndex): + level_names = pd_index.names (dim,) = self.variables[name].dims level_coords.update({lname: dim for lname in level_names}) return level_coords @@ -1328,7 +1333,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": """ variables: Dict[Hashable, Variable] = {} coord_names = set() - indexes: Dict[Hashable, pd.Index] = {} + indexes: Dict[Hashable, Index] = {} for name in names: try: @@ -1341,7 +1346,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": if ref_name in self._coord_names or ref_name in self.dims: coord_names.add(var_name) if (var_name,) == var.dims: - indexes[var_name] = var.to_index() + indexes[var_name] = var._to_xindex() needed_dims: Set[Hashable] = set() for v in variables.values(): @@ -1357,8 +1362,8 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": if set(self.variables[k].dims) <= needed_dims: variables[k] = self._variables[k] coord_names.add(k) - if k in self.indexes: - indexes[k] = self.indexes[k] + if k in self.xindexes: + indexes[k] = self.xindexes[k] return self._replace(variables, coord_names, dims, indexes=indexes) @@ -1527,7 +1532,7 @@ def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" del self._variables[key] self._coord_names.discard(key) - if key in self.indexes: + if key in self.xindexes: assert self._indexes is not None del self._indexes[key] self._dims = calculate_dimensions(self._variables) @@ -1604,7 +1609,21 @@ def identical(self, other: "Dataset") -> bool: @property def indexes(self) -> Indexes: - """Mapping of pandas.Index objects used for label based indexing""" + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Dataset has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + Dataset.xindexes + + """ + return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + + @property + def xindexes(self) -> Indexes: + """Mapping of xarray Index objects used for label based indexing.""" if self._indexes is None: self._indexes = default_indexes(self._variables, self._dims) return Indexes(self._indexes) @@ -2069,7 +2088,9 @@ def _validate_indexers( v = np.asarray(v) if v.dtype.kind in "US": - index = self.indexes[k] + # TODO: benbovy - flexible indexes + # update when CFTimeIndex has its own xarray index class + index = self.xindexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): v = v.astype("datetime64[ns]") elif isinstance(index, xr.CFTimeIndex): @@ -2218,7 +2239,7 @@ def isel( continue if indexes and var_name in indexes: if var_value.ndim == 1: - indexes[var_name] = var_value.to_index() + indexes[var_name] = var_value._to_xindex() else: del indexes[var_name] variables[var_name] = var_value @@ -2246,16 +2267,16 @@ def _isel_fancy( indexers_list = list(self._validate_indexers(indexers, missing_dims)) variables: Dict[Hashable, Variable] = {} - indexes: Dict[Hashable, pd.Index] = {} + indexes: Dict[Hashable, Index] = {} for name, var in self.variables.items(): var_indexers = {k: v for k, v in indexers_list if k in var.dims} if drop and name in var_indexers: continue # drop this variable - if name in self.indexes: + if name in self.xindexes: new_var, new_index = isel_variable_and_index( - name, var, self.indexes[name], var_indexers + name, var, self.xindexes[name], var_indexers ) if new_index is not None: indexes[name] = new_index @@ -2814,7 +2835,7 @@ def _reindex( variables, indexes = alignment.reindex_variables( self.variables, self.sizes, - self.indexes, + self.xindexes, indexers, method, tolerance, @@ -3030,7 +3051,7 @@ def _validate_interp_indexer(x, new_x): variables[name] = var coord_names = obj._coord_names & variables.keys() - indexes = {k: v for k, v in obj.indexes.items() if k not in indexers} + indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} selected = self._replace_with_new_dims( variables.copy(), coord_names, indexes=indexes ) @@ -3040,7 +3061,7 @@ def _validate_interp_indexer(x, new_x): for k, v in indexers.items(): assert isinstance(v, Variable) if v.dims == (k,): - indexes[k] = v.to_index() + indexes[k] = v._to_xindex() # Extract coordinates from indexers coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) @@ -3136,16 +3157,18 @@ def _rename_indexes(self, name_dict, dims_set): if self._indexes is None: return None indexes = {} - for k, v in self.indexes.items(): + for k, v in self.xindexes.items(): + # TODO: benbovy - flexible indexes: make it compatible with any xarray Index + index = v.to_pandas_index() new_name = name_dict.get(k, k) if new_name not in dims_set: continue - if isinstance(v, pd.MultiIndex): - new_names = [name_dict.get(k, k) for k in v.names] - index = v.rename(names=new_names) + if isinstance(index, pd.MultiIndex): + new_names = [name_dict.get(k, k) for k in index.names] + new_index = index.rename(names=new_names) else: - index = v.rename(new_name) - indexes[new_name] = index + new_index = index.rename(new_name) + indexes[new_name] = PandasIndex(new_index) return indexes def _rename_all(self, name_dict, dims_dict): @@ -3362,19 +3385,19 @@ def swap_dims( coord_names.update({dim for dim in dims_dict.values() if dim in self.variables}) variables: Dict[Hashable, Variable] = {} - indexes: Dict[Hashable, pd.Index] = {} + indexes: Dict[Hashable, Index] = {} for k, v in self.variables.items(): dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) if k in result_dims: var = v.to_index_variable() - if k in self.indexes: - indexes[k] = self.indexes[k] + if k in self.xindexes: + indexes[k] = self.xindexes[k] else: new_index = var.to_index() if new_index.nlevels == 1: # make sure index name matches dimension name new_index = new_index.rename(k) - indexes[k] = new_index + indexes[k] = PandasIndex(new_index) else: var = v.to_base_variable() var.dims = dims @@ -3637,15 +3660,17 @@ def reorder_levels( """ dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() - indexes = dict(self.indexes) + indexes = dict(self.xindexes) for dim, order in dim_order.items(): coord = self._variables[dim] - index = self.indexes[dim] + # TODO: benbovy - flexible indexes: update when MultiIndex + # has its own class inherited from xarray.Index + index = self.xindexes[dim].to_pandas_index() if not isinstance(index, pd.MultiIndex): raise ValueError(f"coordinate {dim} has no MultiIndex") new_index = index.reorder_levels(order) variables[dim] = IndexVariable(coord.dims, new_index) - indexes[dim] = new_index + indexes[dim] = PandasIndex(new_index) return self._replace(variables, indexes=indexes) @@ -3672,8 +3697,8 @@ def _stack_once(self, dims, new_dim): coord_names = set(self._coord_names) - set(dims) | {new_dim} - indexes = {k: v for k, v in self.indexes.items() if k not in dims} - indexes[new_dim] = idx + indexes = {k: v for k, v in self.xindexes.items() if k not in dims} + indexes[new_dim] = PandasIndex(idx) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes @@ -3825,7 +3850,9 @@ def ensure_stackable(val): # coerce the levels of the MultiIndex to have the same type as the # input dimensions. This code is messy, so it might be better to just # input a dummy value for the singleton dimension. - idx = data_array.indexes[new_dim] + # TODO: benbovy - flexible indexes: update when MultIndex has its own + # class inheriting from xarray.Index + idx = data_array.xindexes[new_dim].to_pandas_index() levels = [idx.levels[0]] + [ level.astype(self[level.name].dtype) for level in idx.levels[1:] ] @@ -3842,7 +3869,7 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": index = remove_unused_levels_categories(index) variables: Dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.indexes.items() if k != dim} + indexes = {k: v for k, v in self.xindexes.items() if k != dim} for name, var in self.variables.items(): if name != dim: @@ -3860,7 +3887,7 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": for name, lev in zip(index.names, index.levels): variables[name] = IndexVariable(name, lev) - indexes[name] = lev + indexes[name] = PandasIndex(lev) coord_names = set(self._coord_names) - {dim} | set(index.names) @@ -3887,7 +3914,7 @@ def _unstack_full_reindex( new_dim_sizes = [lev.size for lev in index.levels] variables: Dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.indexes.items() if k != dim} + indexes = {k: v for k, v in self.xindexes.items() if k != dim} for name, var in obj.variables.items(): if name != dim: @@ -3899,7 +3926,7 @@ def _unstack_full_reindex( for name, lev in zip(new_dim_names, index.levels): variables[name] = IndexVariable(name, lev) - indexes[name] = lev + indexes[name] = PandasIndex(lev) coord_names = set(self._coord_names) - {dim} | set(new_dim_names) @@ -4161,7 +4188,7 @@ def drop_vars( variables = {k: v for k, v in self._variables.items() if k not in names} coord_names = {k for k in self._coord_names if k in variables} - indexes = {k: v for k, v in self.indexes.items() if k not in names} + indexes = {k: v for k, v in self.xindexes.items() if k not in names} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) @@ -4871,7 +4898,7 @@ def reduce( ) coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.indexes.items() if k in variables} + indexes = {k: v for k, v in self.xindexes.items() if k in variables} attrs = self.attrs if keep_attrs else None return self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes @@ -5660,9 +5687,12 @@ def diff(self, dim, n=1, label="upper"): else: variables[name] = var - indexes = dict(self.indexes) + indexes = dict(self.xindexes) if dim in indexes: - indexes[dim] = indexes[dim][kwargs_new[dim]] + # TODO: benbovy - flexible indexes: check slicing of xarray indexes? + # or only allow this for pandas indexes? + index = indexes[dim].to_pandas_index() + indexes[dim] = PandasIndex(index[kwargs_new[dim]]) difference = self._replace_with_new_dims(variables, indexes=indexes) @@ -5799,14 +5829,14 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): if roll_coords: indexes = {} - for k, v in self.indexes.items(): + for k, v in self.xindexes.items(): (dim,) = self.variables[k].dims if dim in shifts: indexes[k] = roll_index(v, shifts[dim]) else: indexes[k] = v else: - indexes = dict(self.indexes) + indexes = dict(self.xindexes) return self._replace(variables, indexes=indexes) @@ -5999,7 +6029,7 @@ def quantile( # construct the new dataset coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.indexes.items() if k in variables} + indexes = {k: v for k, v in self.xindexes.items() if k in variables} if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None @@ -6223,7 +6253,7 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): variables[k] = Variable(v_dims, integ) else: variables[k] = v - indexes = {k: v for k, v in self.indexes.items() if k in variables} + indexes = {k: v for k, v in self.xindexes.items() if k in variables} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b33d08985e4..be362e1c942 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,12 +1,199 @@ import collections.abc -from typing import Any, Dict, Hashable, Iterable, Mapping, Optional, Tuple, Union +from contextlib import suppress +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Hashable, + Iterable, + Mapping, + Optional, + Tuple, + Union, +) import numpy as np import pandas as pd -from . import formatting +from . import formatting, utils +from .indexing import ExplicitlyIndexedNDArrayMixin, NumpyIndexingAdapter +from .npcompat import DTypeLike from .utils import is_scalar -from .variable import Variable + +if TYPE_CHECKING: + from .variable import Variable + + +class Index: + """Base class inherited by all xarray-compatible indexes.""" + + __slots__ = ("coord_names",) + + def __init__(self, coord_names: Union[Hashable, Iterable[Hashable]]): + if isinstance(coord_names, Hashable): + coord_names = (coord_names,) + self.coord_names = tuple(coord_names) + + @classmethod + def from_variables( + cls, variables: Dict[Hashable, "Variable"], **kwargs + ): # pragma: no cover + raise NotImplementedError() + + def to_pandas_index(self) -> pd.Index: + """Cast this xarray index to a pandas.Index object or raise a TypeError + if this is not supported. + + This method is used by all xarray operations that expect/require a + pandas.Index object. + + """ + raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") + + def equals(self, other): # pragma: no cover + raise NotImplementedError() + + def union(self, other): # pragma: no cover + raise NotImplementedError() + + def intersection(self, other): # pragma: no cover + raise NotImplementedError() + + +class PandasIndex(Index, ExplicitlyIndexedNDArrayMixin): + """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" + + __slots__ = ("array", "_dtype") + + def __init__( + self, array: Any, dtype: DTypeLike = None, coord_name: Optional[Hashable] = None + ): + if coord_name is None: + coord_name = tuple() + super().__init__(coord_name) + + self.array = utils.safe_cast_to_index(array) + + if dtype is None: + if isinstance(array, pd.PeriodIndex): + dtype_ = np.dtype("O") + elif hasattr(array, "categories"): + # category isn't a real numpy dtype + dtype_ = array.categories.dtype + elif not utils.is_valid_numpy_dtype(array.dtype): + dtype_ = np.dtype("O") + else: + dtype_ = array.dtype + else: + dtype_ = np.dtype(dtype) + self._dtype = dtype_ + + @classmethod + def from_variables(cls, variables: Dict[Hashable, "Variable"], **kwargs): + if len(variables) > 1: + raise ValueError("Cannot set a pandas.Index from more than one variable") + + varname, var = list(variables.items())[0] + return cls(var.data, dtype=var.dtype, coord_name=varname) + + def to_pandas_index(self) -> pd.Index: + return self.array + + @property + def dtype(self) -> np.dtype: + return self._dtype + + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + if dtype is None: + dtype = self.dtype + array = self.array + if isinstance(array, pd.PeriodIndex): + with suppress(AttributeError): + # this might not be public API + array = array.astype("object") + return np.asarray(array.values, dtype=dtype) + + @property + def shape(self) -> Tuple[int]: + return (len(self.array),) + + def equals(self, other): + if isinstance(other, pd.Index): + other = PandasIndex(other) + return isinstance(other, PandasIndex) and self.array.equals(other.array) + + def union(self, other): + if isinstance(other, pd.Index): + other = PandasIndex(other) + return PandasIndex(self.array.union(other.array)) + + def intersection(self, other): + if isinstance(other, pd.Index): + other = PandasIndex(other) + return PandasIndex(self.array.intersection(other.array)) + + def __getitem__( + self, indexer + ) -> Union[ + "PandasIndex", + NumpyIndexingAdapter, + np.ndarray, + np.datetime64, + np.timedelta64, + ]: + key = indexer.tuple + if isinstance(key, tuple) and len(key) == 1: + # unpack key so it can index a pandas.Index object (pandas.Index + # objects don't like tuples) + (key,) = key + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + return NumpyIndexingAdapter(self.array.values)[indexer] + + result = self.array[key] + + if isinstance(result, pd.Index): + result = PandasIndex(result, dtype=self.dtype) + else: + # result is a scalar + if result is pd.NaT: + # work around the impossibility of casting NaT with asarray + # note: it probably would be better in general to return + # pd.Timestamp rather np.than datetime64 but this is easier + # (for now) + result = np.datetime64("NaT", "ns") + elif isinstance(result, timedelta): + result = np.timedelta64(getattr(result, "value", result), "ns") + elif isinstance(result, pd.Timestamp): + # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 + # numpy fails to convert pd.Timestamp to np.datetime64[ns] + result = np.asarray(result.to_datetime64()) + elif self.dtype != object: + result = np.asarray(result, dtype=self.dtype) + + # as for numpy.ndarray indexing, we always want the result to be + # a NumPy array. + result = utils.to_0d_array(result) + + return result + + def transpose(self, order) -> pd.Index: + return self.array # self.array should be always one-dimensional + + def __repr__(self) -> str: + return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" + + def copy(self, deep: bool = True) -> "PandasIndex": + # Not the same as just writing `self.array.copy(deep=deep)`, as + # shallow copies of the underlying numpy.ndarrays become deep ones + # upon pickling + # >>> len(pickle.dumps((self.array, self.array))) + # 4000281 + # >>> len(pickle.dumps((self.array, self.array.copy(deep=False)))) + # 8000341 + array = self.array.copy(deep=True) if deep else self.array + return PandasIndex(array, self._dtype) def remove_unused_levels_categories(index: pd.Index) -> pd.Index: @@ -68,8 +255,8 @@ def __repr__(self): def default_indexes( - coords: Mapping[Any, Variable], dims: Iterable -) -> Dict[Hashable, pd.Index]: + coords: Mapping[Any, "Variable"], dims: Iterable +) -> Dict[Hashable, Index]: """Default indexes for a Dataset/DataArray. Parameters @@ -84,16 +271,18 @@ def default_indexes( Mapping from indexing keys (levels/dimension names) to indexes used for indexing along that dimension. """ - return {key: coords[key].to_index() for key in dims if key in coords} + return {key: coords[key]._to_xindex() for key in dims if key in coords} def isel_variable_and_index( name: Hashable, - variable: Variable, - index: pd.Index, - indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]], -) -> Tuple[Variable, Optional[pd.Index]]: + variable: "Variable", + index: Index, + indexers: Mapping[Hashable, Union[int, slice, np.ndarray, "Variable"]], +) -> Tuple["Variable", Optional[Index]]: """Index a Variable and pandas.Index together.""" + from .variable import Variable + if not indexers: # nothing to index return variable.copy(deep=False), index @@ -114,22 +303,25 @@ def isel_variable_and_index( indexer = indexers[dim] if isinstance(indexer, Variable): indexer = indexer.data - new_index = index[indexer] + pd_index = index.to_pandas_index() + new_index = PandasIndex(pd_index[indexer]) return new_variable, new_index -def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index: +def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: """Roll an pandas.Index.""" - count %= index.shape[0] + pd_index = index.to_pandas_index() + count %= pd_index.shape[0] if count != 0: - return index[-count:].append(index[:-count]) + new_idx = pd_index[-count:].append(pd_index[:-count]) else: - return index[:] + new_idx = pd_index[:] + return PandasIndex(new_idx) def propagate_indexes( - indexes: Optional[Dict[Hashable, pd.Index]], exclude: Optional[Any] = None -) -> Optional[Dict[Hashable, pd.Index]]: + indexes: Optional[Dict[Hashable, Index]], exclude: Optional[Any] = None +) -> Optional[Dict[Hashable, Index]]: """Creates new indexes dict from existing dict optionally excluding some dimensions.""" if exclude is None: exclude = () diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 82e4530f428..76a0c6888b2 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2,8 +2,6 @@ import functools import operator from collections import defaultdict -from contextlib import suppress -from datetime import timedelta from distutils.version import LooseVersion from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union @@ -18,7 +16,6 @@ DASK_VERSION = LooseVersion("0") from . import duck_array_ops, nputils, utils -from .npcompat import DTypeLike from .pycompat import ( dask_array_type, integer_types, @@ -119,6 +116,8 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No dimension. If `index` is a pandas.MultiIndex and depending on `label`, return a new pandas.Index or pandas.MultiIndex (otherwise return None). """ + from .indexes import PandasIndex + new_index = None if isinstance(label, slice): @@ -208,6 +207,10 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No indexer = get_indexer_nd(index, label, method, tolerance) if np.any(indexer < 0): raise KeyError(f"not all values found in index {index_name!r}") + + if new_index is not None: + new_index = PandasIndex(new_index) + return indexer, new_index @@ -262,7 +265,7 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): dim_indexers = get_dim_indexers(data_obj, indexers) for dim, label in dim_indexers.items(): try: - index = data_obj.indexes[dim] + index = data_obj.xindexes[dim].to_pandas_index() except KeyError: # no index for this dimension: reuse the provided labels if method is not None or tolerance is not None: @@ -726,7 +729,9 @@ def as_indexable(array): if isinstance(array, np.ndarray): return NumpyIndexingAdapter(array) if isinstance(array, pd.Index): - return PandasIndexAdapter(array) + from .indexes import PandasIndex + + return PandasIndex(array) if isinstance(array, dask_array_type): return DaskIndexingAdapter(array) if hasattr(array, "__array_function__"): @@ -1414,101 +1419,3 @@ def __setitem__(self, key, value): def transpose(self, order): return self.array.transpose(order) - - -class PandasIndexAdapter(ExplicitlyIndexedNDArrayMixin): - """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" - - __slots__ = ("array", "_dtype") - - def __init__(self, array: Any, dtype: DTypeLike = None): - self.array = utils.safe_cast_to_index(array) - if dtype is None: - if isinstance(array, pd.PeriodIndex): - dtype_ = np.dtype("O") - elif hasattr(array, "categories"): - # category isn't a real numpy dtype - dtype_ = array.categories.dtype - elif not utils.is_valid_numpy_dtype(array.dtype): - dtype_ = np.dtype("O") - else: - dtype_ = array.dtype - else: - dtype_ = np.dtype(dtype) - self._dtype = dtype_ - - @property - def dtype(self) -> np.dtype: - return self._dtype - - def __array__(self, dtype: DTypeLike = None) -> np.ndarray: - if dtype is None: - dtype = self.dtype - array = self.array - if isinstance(array, pd.PeriodIndex): - with suppress(AttributeError): - # this might not be public API - array = array.astype("object") - return np.asarray(array.values, dtype=dtype) - - @property - def shape(self) -> Tuple[int]: - return (len(self.array),) - - def __getitem__( - self, indexer - ) -> Union[NumpyIndexingAdapter, np.ndarray, np.datetime64, np.timedelta64]: - key = indexer.tuple - if isinstance(key, tuple) and len(key) == 1: - # unpack key so it can index a pandas.Index object (pandas.Index - # objects don't like tuples) - (key,) = key - - if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional - return NumpyIndexingAdapter(self.array.values)[indexer] - - result = self.array[key] - - if isinstance(result, pd.Index): - result = PandasIndexAdapter(result, dtype=self.dtype) - else: - # result is a scalar - if result is pd.NaT: - # work around the impossibility of casting NaT with asarray - # note: it probably would be better in general to return - # pd.Timestamp rather np.than datetime64 but this is easier - # (for now) - result = np.datetime64("NaT", "ns") - elif isinstance(result, timedelta): - result = np.timedelta64(getattr(result, "value", result), "ns") - elif isinstance(result, pd.Timestamp): - # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 - # numpy fails to convert pd.Timestamp to np.datetime64[ns] - result = np.asarray(result.to_datetime64()) - elif self.dtype != object: - result = np.asarray(result, dtype=self.dtype) - - # as for numpy.ndarray indexing, we always want the result to be - # a NumPy array. - result = utils.to_0d_array(result) - - return result - - def transpose(self, order) -> pd.Index: - return self.array # self.array should be always one-dimensional - - def __repr__(self) -> str: - return "{}(array={!r}, dtype={!r})".format( - type(self).__name__, self.array, self.dtype - ) - - def copy(self, deep: bool = True) -> "PandasIndexAdapter": - # Not the same as just writing `self.array.copy(deep=deep)`, as - # shallow copies of the underlying numpy.ndarrays become deep ones - # upon pickling - # >>> len(pickle.dumps((self.array, self.array))) - # 4000281 - # >>> len(pickle.dumps((self.array, self.array.copy(deep=False)))) - # 8000341 - array = self.array.copy(deep=True) if deep else self.array - return PandasIndexAdapter(array, self._dtype) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 4d83855a15d..6747957ca75 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -20,6 +20,7 @@ from . import dtypes, pdcompat from .alignment import deep_align from .duck_array_ops import lazy_array_equiv +from .indexes import Index, PandasIndex from .utils import Frozen, compat_dict_union, dict_equiv, equivalent from .variable import Variable, as_variable, assert_unique_multiindex_level_names @@ -157,7 +158,7 @@ def _assert_compat_valid(compat): ) -MergeElement = Tuple[Variable, Optional[pd.Index]] +MergeElement = Tuple[Variable, Optional[Index]] def merge_collected( @@ -165,7 +166,7 @@ def merge_collected( prioritized: Mapping[Hashable, MergeElement] = None, compat: str = "minimal", combine_attrs="override", -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. Parameters @@ -187,7 +188,7 @@ def merge_collected( _assert_compat_valid(compat) merged_vars: Dict[Hashable, Variable] = {} - merged_indexes: Dict[Hashable, pd.Index] = {} + merged_indexes: Dict[Hashable, Index] = {} for name, elements_list in grouped.items(): if name in prioritized: @@ -261,7 +262,7 @@ def collect_variables_and_indexes( from .dataarray import DataArray from .dataset import Dataset - grouped: Dict[Hashable, List[Tuple[Variable, pd.Index]]] = {} + grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = {} def append(name, variable, index): values = grouped.setdefault(name, []) @@ -273,13 +274,13 @@ def append_all(variables, indexes): for mapping in list_of_mappings: if isinstance(mapping, Dataset): - append_all(mapping.variables, mapping.indexes) + append_all(mapping.variables, mapping.xindexes) continue for name, variable in mapping.items(): if isinstance(variable, DataArray): coords = variable._coords.copy() # use private API for speed - indexes = dict(variable.indexes) + indexes = dict(variable.xindexes) # explicitly overwritten variables should take precedence coords.pop(name, None) indexes.pop(name, None) @@ -288,7 +289,7 @@ def append_all(variables, indexes): variable = as_variable(variable, name=name) if variable.dims == (name,): variable = variable.to_index_variable() - index = variable.to_index() + index = variable._to_xindex() else: index = None append(name, variable, index) @@ -300,11 +301,11 @@ def collect_from_coordinates( list_of_coords: "List[Coordinates]", ) -> Dict[Hashable, List[MergeElement]]: """Collect variables and indexes to be merged from Coordinate objects.""" - grouped: Dict[Hashable, List[Tuple[Variable, pd.Index]]] = {} + grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = {} for coords in list_of_coords: variables = coords.variables - indexes = coords.indexes + indexes = coords.xindexes for name, variable in variables.items(): value = grouped.setdefault(name, []) value.append((variable, indexes.get(name))) @@ -315,7 +316,7 @@ def merge_coordinates_without_align( objects: "List[Coordinates]", prioritized: Mapping[Hashable, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge variables/indexes from coordinates without automatic alignments. This function is used for merging coordinate from pre-existing xarray @@ -448,9 +449,9 @@ def merge_coords( compat: str = "minimal", join: str = "outer", priority_arg: Optional[int] = None, - indexes: Optional[Mapping[Hashable, pd.Index]] = None, + indexes: Optional[Mapping[Hashable, Index]] = None, fill_value: object = dtypes.NA, -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge coordinate variables. See merge_core below for argument descriptions. This works similarly to @@ -484,7 +485,7 @@ def _extract_indexes_from_coords(coords): for name, variable in coords.items(): variable = as_variable(variable, name=name) if variable.dims == (name,): - yield name, variable.to_index() + yield name, variable._to_xindex() def assert_valid_explicit_coords(variables, dims, explicit_coords): @@ -569,7 +570,7 @@ def merge_core( combine_attrs: Optional[str] = "override", priority_arg: Optional[int] = None, explicit_coords: Optional[Sequence] = None, - indexes: Optional[Mapping[Hashable, pd.Index]] = None, + indexes: Optional[Mapping[Hashable, Index]] = None, fill_value: object = dtypes.NA, ) -> _MergeResult: """Core logic for merging labeled objects. @@ -970,10 +971,11 @@ def dataset_update_method( other[key] = value.drop_vars(coord_names) # use ds.coords and not ds.indexes, else str coords are cast to object - indexes = {key: dataset.coords[key] for key in dataset.indexes.keys()} + # TODO: benbovy - flexible indexes: fix this (it only works with pandas indexes) + indexes = {key: PandasIndex(dataset.coords[key]) for key in dataset.xindexes.keys()} return merge_core( [dataset, other], priority_arg=1, - indexes=indexes, + indexes=indexes, # type: ignore combine_attrs="override", ) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index d12ccc65ca6..41205242cce 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -317,9 +317,13 @@ def interp_na( if not is_scalar(max_gap): raise ValueError("max_gap must be a scalar.") + # TODO: benbovy - flexible indexes: update when CFTimeIndex (and DatetimeIndex?) + # has its own class inheriting from xarray.Index if ( - dim in self.indexes - and isinstance(self.indexes[dim], (pd.DatetimeIndex, CFTimeIndex)) + dim in self.xindexes + and isinstance( + self.xindexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) + ) and use_coordinate ): # Convert to float diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 895e939c505..e1d32b7de43 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -27,6 +27,8 @@ import numpy as np +from xarray.core.indexes import PandasIndex + from .alignment import align from .dataarray import DataArray from .dataset import Dataset @@ -291,7 +293,7 @@ def _wrapper( ) # check that index lengths and values are as expected - for name, index in result.indexes.items(): + for name, index in result.xindexes.items(): if name in expected["shapes"]: if len(index) != expected["shapes"][name]: raise ValueError( @@ -357,27 +359,27 @@ def _wrapper( # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0].indexes) + input_indexes = dict(npargs[0].xindexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg.indexes) + input_indexes.update(arg.xindexes) if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template.indexes) + template_indexes = set(template.xindexes) preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.indexes[k] for k in new_indexes}) + indexes.update({k: template.xindexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template.indexes) + indexes = dict(template.xindexes) if isinstance(template, DataArray): output_chunks = dict( zip(template.dims, template.chunks) # type: ignore[arg-type] @@ -501,10 +503,16 @@ def subset_dataset_to_block( } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - expected["indexes"] = { - dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in indexes - } + # TODO: benbovy - flexible indexes: clean this up + # for now assumes pandas index (thus can be indexed) but it won't be the case for + # all indexes + expected_indexes = {} + for dim in indexes: + idx = indexes[dim].to_pandas_index()[ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + expected_indexes[dim] = PandasIndex(idx) + expected["indexes"] = expected_indexes from_wrapper = (gname,) + chunk_tuple graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) @@ -550,7 +558,7 @@ def subset_dataset_to_block( ) result = Dataset(coords=indexes, attrs=template.attrs) - for index in result.indexes: + for index in result.xindexes: result[index].attrs = template[index].attrs result[index].encoding = template[index].encoding diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6f828a5128c..cffaf2c3146 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -26,13 +26,8 @@ from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from .arithmetic import VariableArithmetic from .common import AbstractArray -from .indexing import ( - BasicIndexer, - OuterIndexer, - PandasIndexAdapter, - VectorizedIndexer, - as_indexable, -) +from .indexes import PandasIndex +from .indexing import BasicIndexer, OuterIndexer, VectorizedIndexer, as_indexable from .options import _get_keep_attrs from .pycompat import ( cupy_array_type, @@ -180,11 +175,11 @@ def _maybe_wrap_data(data): Put pandas.Index and numpy.ndarray arguments in adapter objects to ensure they can be indexed properly. - NumpyArrayAdapter, PandasIndexAdapter and LazilyIndexedArray should + NumpyArrayAdapter, PandasIndex and LazilyIndexedArray should all pass through unmodified. """ if isinstance(data, pd.Index): - return PandasIndexAdapter(data) + return PandasIndex(data) return data @@ -351,7 +346,7 @@ def nbytes(self): @property def _in_memory(self): - return isinstance(self._data, (np.ndarray, np.number, PandasIndexAdapter)) or ( + return isinstance(self._data, (np.ndarray, np.number, PandasIndex)) or ( isinstance(self._data, indexing.MemoryCachedArray) and isinstance(self._data.array, indexing.NumpyIndexingAdapter) ) @@ -556,6 +551,11 @@ def to_index_variable(self): to_coord = utils.alias(to_index_variable, "to_coord") + def _to_xindex(self): + # temporary function used internally as a replacement of to_index() + # returns an xarray Index instance instead of a pd.Index instance + return PandasIndex(self.to_index()) + def to_index(self): """Convert this variable to a pandas.Index""" return self.to_index_variable().to_index() @@ -2553,8 +2553,8 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): raise ValueError("%s objects must be 1-dimensional" % type(self).__name__) # Unlike in Variable, always eagerly load values into memory - if not isinstance(self._data, PandasIndexAdapter): - self._data = PandasIndexAdapter(self._data) + if not isinstance(self._data, PandasIndex): + self._data = PandasIndex(self._data) def __dask_tokenize__(self): from dask.base import normalize_token @@ -2890,7 +2890,7 @@ def assert_unique_multiindex_level_names(variables): level_names = defaultdict(list) all_level_names = set() for var_name, var in variables.items(): - if isinstance(var._data, PandasIndexAdapter): + if isinstance(var._data, PandasIndex): idx_level_names = var.to_index_variable().level_names if idx_level_names is not None: for n in idx_level_names: diff --git a/xarray/testing.py b/xarray/testing.py index 365b81edc40..40ca12852b9 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -4,12 +4,11 @@ from typing import Hashable, Set, Union import numpy as np -import pandas as pd from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.indexes import default_indexes +from xarray.core.indexes import Index, default_indexes from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -254,7 +253,7 @@ def assert_chunks_equal(a, b): def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): assert isinstance(indexes, dict), indexes - assert all(isinstance(v, pd.Index) for v in indexes.values()), { + assert all(isinstance(v, Index) for v in indexes.values()), { k: type(v) for k, v in indexes.items() } diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index feab45b1f00..3e3d6e8b8d0 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -35,7 +35,7 @@ from xarray.backends.pydap_ import PydapDataStore from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates -from xarray.core import indexing +from xarray.core import indexes, indexing from xarray.core.options import set_options from xarray.core.pycompat import dask_array_type from xarray.tests import LooseVersion, mock @@ -735,7 +735,7 @@ def find_and_validate_array(obj): elif isinstance(obj.array, dask_array_type): assert isinstance(obj, indexing.DaskIndexingAdapter) elif isinstance(obj.array, pd.Index): - assert isinstance(obj, indexing.PandasIndexAdapter) + assert isinstance(obj, indexes.PandasIndex) else: raise TypeError( "{} is wrapped by {}".format(type(obj.array), type(obj)) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 8dee364a08a..725b5efee75 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -696,7 +696,7 @@ def test_concat_cftimeindex(date_type): ) da = xr.concat([da1, da2], dim="time") - assert isinstance(da.indexes["time"], CFTimeIndex) + assert isinstance(da.xindexes["time"].to_pandas_index(), CFTimeIndex) @requires_cftime diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index c4f32795b59..526f3fc30c1 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -99,7 +99,10 @@ def test_resample(freqs, closed, label, base): ) .mean() ) - da_cftime["time"] = da_cftime.indexes["time"].to_datetimeindex() + # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass + da_cftime["time"] = ( + da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() + ) xr.testing.assert_identical(da_cftime, da_datetime) @@ -145,5 +148,6 @@ def test_calendars(calendar): .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) .mean() ) - da_cftime["time"] = da_cftime.indexes["time"].to_datetimeindex() + # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass + da_cftime["time"] = da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() xr.testing.assert_identical(da_cftime, da_datetime) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 9cfc134e4fe..42232f7df57 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -521,7 +521,7 @@ def test_concat(self): stacked = concat(grouped, ds["x"]) assert_identical(foo, stacked) # with an index as the 'dim' argument - stacked = concat(grouped, ds.indexes["x"]) + stacked = concat(grouped, pd.Index(ds["x"], name="x")) assert_identical(foo, stacked) actual = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 3608a53f747..cd8e3419231 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -280,7 +280,7 @@ def test_decode_cf_with_dask(self): assert all( isinstance(var.data, da.Array) for name, var in decoded.variables.items() - if name not in decoded.indexes + if name not in decoded.xindexes ) assert_identical(decoded, conventions.decode_cf(original).compute()) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index ff098ced161..e6c479896e9 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -24,7 +24,7 @@ from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.indexes import propagate_indexes +from xarray.core.indexes import Index, PandasIndex, propagate_indexes from xarray.core.utils import is_scalar from xarray.tests import ( LooseVersion, @@ -147,10 +147,15 @@ def test_data_property(self): def test_indexes(self): array = DataArray(np.zeros((2, 3)), [("x", [0, 1]), ("y", ["a", "b", "c"])]) - expected = {"x": pd.Index([0, 1]), "y": pd.Index(["a", "b", "c"])} - assert array.indexes.keys() == expected.keys() - for k in expected: - assert array.indexes[k].equals(expected[k]) + expected_indexes = {"x": pd.Index([0, 1]), "y": pd.Index(["a", "b", "c"])} + expected_xindexes = {k: PandasIndex(idx) for k, idx in expected_indexes.items()} + assert array.xindexes.keys() == expected_xindexes.keys() + assert array.indexes.keys() == expected_indexes.keys() + assert all([isinstance(idx, pd.Index) for idx in array.indexes.values()]) + assert all([isinstance(idx, Index) for idx in array.xindexes.values()]) + for k in expected_indexes: + assert array.xindexes[k].equals(expected_xindexes[k]) + assert array.indexes[k].equals(expected_indexes[k]) def test_get_index(self): array = DataArray(np.zeros((2, 3)), coords={"x": ["a", "b"]}, dims=["x", "y"]) @@ -1459,7 +1464,7 @@ def test_coords_alignment(self): def test_set_coords_update_index(self): actual = DataArray([1, 2, 3], [("x", [1, 2, 3])]) actual.coords["x"] = ["a", "b", "c"] - assert actual.indexes["x"].equals(pd.Index(["a", "b", "c"])) + assert actual.xindexes["x"].equals(pd.Index(["a", "b", "c"])) def test_coords_replacement_alignment(self): # regression test for GH725 @@ -1479,7 +1484,7 @@ def test_coords_delitem_delete_indexes(self): # regression test for GH3746 arr = DataArray(np.ones((2,)), dims="x", coords={"x": [0, 1]}) del arr.coords["x"] - assert "x" not in arr.indexes + assert "x" not in arr.xindexes def test_broadcast_like(self): arr1 = DataArray( @@ -1627,18 +1632,19 @@ def test_swap_dims(self): expected = DataArray(array.values, {"y": list("abc")}, dims="y") actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] + expected.xindexes[dim_name].array, actual.xindexes[dim_name].array ) array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] + expected.xindexes[dim_name].to_pandas_index(), + actual.xindexes[dim_name].to_pandas_index(), ) # as kwargs @@ -1646,9 +1652,10 @@ def test_swap_dims(self): expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") actual = array.swap_dims(x="y") assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] + expected.xindexes[dim_name].to_pandas_index(), + actual.xindexes[dim_name].to_pandas_index(), ) # multiindex case @@ -1657,9 +1664,10 @@ def test_swap_dims(self): expected = DataArray(array.values, {"y": idx}, "y") actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] + expected.xindexes[dim_name].to_pandas_index(), + actual.xindexes[dim_name].to_pandas_index(), ) def test_expand_dims_error(self): @@ -4334,12 +4342,12 @@ def test_matmul_align_coords(self): def test_binary_op_propagate_indexes(self): # regression test for GH2227 self.dv["x"] = np.arange(self.dv.sizes["x"]) - expected = self.dv.indexes["x"] + expected = self.dv.xindexes["x"] - actual = (self.dv * 10).indexes["x"] + actual = (self.dv * 10).xindexes["x"] assert expected is actual - actual = (self.dv > 10).indexes["x"] + actual = (self.dv > 10).xindexes["x"] assert expected is actual def test_binary_op_join_setting(self): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 33b5d16fbac..b8e1cd4b03b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -28,6 +28,7 @@ from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import dtypes, indexing, utils from xarray.core.common import duck_array_ops, full_like +from xarray.core.indexes import Index from xarray.core.pycompat import integer_types from xarray.core.utils import is_scalar @@ -582,9 +583,15 @@ def test_properties(self): assert "numbers" not in ds.data_vars assert len(ds.data_vars) == 3 + assert set(ds.xindexes) == {"dim2", "dim3", "time"} + assert len(ds.xindexes) == 3 + assert "dim2" in repr(ds.xindexes) + assert all([isinstance(idx, Index) for idx in ds.xindexes.values()]) + assert set(ds.indexes) == {"dim2", "dim3", "time"} assert len(ds.indexes) == 3 assert "dim2" in repr(ds.indexes) + assert all([isinstance(idx, pd.Index) for idx in ds.indexes.values()]) assert list(ds.coords) == ["time", "dim2", "dim3", "numbers"] assert "dim2" in ds.coords @@ -747,12 +754,12 @@ def test_coords_modify(self): # regression test for GH3746 del actual.coords["x"] - assert "x" not in actual.indexes + assert "x" not in actual.xindexes def test_update_index(self): actual = Dataset(coords={"x": [1, 2, 3]}) actual["x"] = ["a", "b", "c"] - assert actual.indexes["x"].equals(pd.Index(["a", "b", "c"])) + assert actual.xindexes["x"].equals(pd.Index(["a", "b", "c"])) def test_coords_setitem_with_new_dimension(self): actual = Dataset() @@ -1044,19 +1051,19 @@ def test_isel(self): assert {"time": 20, "dim2": 9, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) - assert set(data.indexes) == set(ret.indexes) + assert set(data.xindexes) == set(ret.xindexes) ret = data.isel(time=slice(2), dim1=0, dim2=slice(5)) assert {"time": 2, "dim2": 5, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) - assert set(data.indexes) == set(ret.indexes) + assert set(data.xindexes) == set(ret.xindexes) ret = data.isel(time=0, dim1=0, dim2=slice(5)) assert {"dim2": 5, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) - assert set(data.indexes) == set(list(ret.indexes) + ["time"]) + assert set(data.xindexes) == set(list(ret.xindexes) + ["time"]) def test_isel_fancy(self): # isel with fancy indexing. @@ -1392,13 +1399,13 @@ def test_sel_dataarray_mindex(self): ) actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="x")) - actual_sel = mds.sel(x=DataArray(mds.indexes["x"][:3], dims="x")) + actual_sel = mds.sel(x=DataArray(midx[:3], dims="x")) assert actual_isel["x"].dims == ("x",) assert actual_sel["x"].dims == ("x",) assert_identical(actual_isel, actual_sel) actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="z")) - actual_sel = mds.sel(x=Variable("z", mds.indexes["x"][:3])) + actual_sel = mds.sel(x=Variable("z", midx[:3])) assert actual_isel["x"].dims == ("z",) assert actual_sel["x"].dims == ("z",) assert_identical(actual_isel, actual_sel) @@ -1408,7 +1415,7 @@ def test_sel_dataarray_mindex(self): x=xr.DataArray(np.arange(3), dims="z", coords={"z": [0, 1, 2]}) ) actual_sel = mds.sel( - x=xr.DataArray(mds.indexes["x"][:3], dims="z", coords={"z": [0, 1, 2]}) + x=xr.DataArray(midx[:3], dims="z", coords={"z": [0, 1, 2]}) ) assert actual_isel["x"].dims == ("z",) assert actual_sel["x"].dims == ("z",) @@ -2421,7 +2428,7 @@ def test_drop_labels_by_keyword(self): with pytest.warns(FutureWarning): data.drop(arr.coords) with pytest.warns(FutureWarning): - data.drop(arr.indexes) + data.drop(arr.xindexes) assert_array_equal(ds1.coords["x"], ["b"]) assert_array_equal(ds2.coords["x"], ["b"]) @@ -2711,21 +2718,23 @@ def test_rename_does_not_change_CFTimeIndex_type(self): orig = Dataset(coords={"time": time}) renamed = orig.rename(time="time_new") - assert "time_new" in renamed.indexes - assert isinstance(renamed.indexes["time_new"], CFTimeIndex) - assert renamed.indexes["time_new"].name == "time_new" + assert "time_new" in renamed.xindexes + # TODO: benbovy - flexible indexes: update when CFTimeIndex + # inherits from xarray.Index + assert isinstance(renamed.xindexes["time_new"].to_pandas_index(), CFTimeIndex) + assert renamed.xindexes["time_new"].to_pandas_index().name == "time_new" # check original has not changed - assert "time" in orig.indexes - assert isinstance(orig.indexes["time"], CFTimeIndex) - assert orig.indexes["time"].name == "time" + assert "time" in orig.xindexes + assert isinstance(orig.xindexes["time"].to_pandas_index(), CFTimeIndex) + assert orig.xindexes["time"].to_pandas_index().name == "time" # note: rename_dims(time="time_new") drops "ds.indexes" renamed = orig.rename_dims() - assert isinstance(renamed.indexes["time"], CFTimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), CFTimeIndex) renamed = orig.rename_vars() - assert isinstance(renamed.indexes["time"], CFTimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), CFTimeIndex) def test_rename_does_not_change_DatetimeIndex_type(self): # make sure DatetimeIndex is conderved on rename @@ -2734,21 +2743,23 @@ def test_rename_does_not_change_DatetimeIndex_type(self): orig = Dataset(coords={"time": time}) renamed = orig.rename(time="time_new") - assert "time_new" in renamed.indexes - assert isinstance(renamed.indexes["time_new"], DatetimeIndex) - assert renamed.indexes["time_new"].name == "time_new" + assert "time_new" in renamed.xindexes + # TODO: benbovy - flexible indexes: update when DatetimeIndex + # inherits from xarray.Index? + assert isinstance(renamed.xindexes["time_new"].to_pandas_index(), DatetimeIndex) + assert renamed.xindexes["time_new"].to_pandas_index().name == "time_new" # check original has not changed - assert "time" in orig.indexes - assert isinstance(orig.indexes["time"], DatetimeIndex) - assert orig.indexes["time"].name == "time" + assert "time" in orig.xindexes + assert isinstance(orig.xindexes["time"].to_pandas_index(), DatetimeIndex) + assert orig.xindexes["time"].to_pandas_index().name == "time" # note: rename_dims(time="time_new") drops "ds.indexes" renamed = orig.rename_dims() - assert isinstance(renamed.indexes["time"], DatetimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), DatetimeIndex) renamed = orig.rename_vars() - assert isinstance(renamed.indexes["time"], DatetimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), DatetimeIndex) def test_swap_dims(self): original = Dataset({"x": [1, 2, 3], "y": ("x", list("abc")), "z": 42}) @@ -2757,7 +2768,10 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal(actual.indexes["y"], expected.indexes["y"]) + pd.testing.assert_index_equal( + actual.xindexes["y"].to_pandas_index(), + expected.xindexes["y"].to_pandas_index(), + ) roundtripped = actual.swap_dims({"y": "x"}) assert_identical(original.set_coords("y"), roundtripped) @@ -2788,7 +2802,10 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal(actual.indexes["y"], expected.indexes["y"]) + pd.testing.assert_index_equal( + actual.xindexes["y"].to_pandas_index(), + expected.xindexes["y"].to_pandas_index(), + ) def test_expand_dims_error(self): original = Dataset( @@ -3165,7 +3182,9 @@ def test_to_stacked_array_dtype_dims(self): D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - assert y.indexes["features"].levels[1].dtype == D.y.dtype + # TODO: benbovy - flexible indexes: update when MultiIndex has its own class + # inherited from xarray.Index + assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") def test_to_stacked_array_to_unstacked_dataset(self): @@ -5565,8 +5584,8 @@ def test_binary_op_propagate_indexes(self): ds = Dataset( {"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})} ) - expected = ds.indexes["x"] - actual = (ds * 2).indexes["x"] + expected = ds.xindexes["x"] + actual = (ds * 2).xindexes["x"] assert expected is actual def test_binary_op_join_setting(self): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index b3334e92c4a..1e0dff45dd2 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -11,6 +11,7 @@ from xarray import Coordinate, DataArray, Dataset, IndexVariable, Variable, set_options from xarray.core import dtypes, duck_array_ops, indexing from xarray.core.common import full_like, ones_like, zeros_like +from xarray.core.indexes import PandasIndex from xarray.core.indexing import ( BasicIndexer, CopyOnWriteArray, @@ -19,7 +20,6 @@ MemoryCachedArray, NumpyIndexingAdapter, OuterIndexer, - PandasIndexAdapter, VectorizedIndexer, ) from xarray.core.pycompat import dask_array_type @@ -535,7 +535,7 @@ def test_copy_index(self): v = self.cls("x", midx) for deep in [True, False]: w = v.copy(deep=deep) - assert isinstance(w._data, PandasIndexAdapter) + assert isinstance(w._data, PandasIndex) assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) @@ -2145,7 +2145,7 @@ def test_multiindex_default_level_names(self): def test_data(self): x = IndexVariable("x", np.arange(3.0)) - assert isinstance(x._data, PandasIndexAdapter) + assert isinstance(x._data, PandasIndex) assert isinstance(x.data, np.ndarray) assert float == x.dtype assert_array_equal(np.arange(3), x) @@ -2287,7 +2287,7 @@ def test_coarsen_2d(self): class TestAsCompatibleData: def test_unchanged_types(self): - types = (np.asarray, PandasIndexAdapter, LazilyIndexedArray) + types = (np.asarray, PandasIndex, LazilyIndexedArray) for t in types: for data in [ np.arange(3), From 1fa7b9ba8b695820474b0d045059995a34a1c684 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 13 May 2021 17:28:15 +0200 Subject: [PATCH 2/7] Allow dataset interpolation with different datatypes (#5008) Co-authored-by: Deepak Cherian --- xarray/core/dataset.py | 48 +++++++++++++++++++++++++++++++++---- xarray/tests/test_interp.py | 22 ++++++++++------- 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 706ccbde8c4..19af7f6c3cd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2853,6 +2853,7 @@ def interp( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, + method_non_numeric: str = "nearest", **coords_kwargs: Any, ) -> "Dataset": """Multidimensional interpolation of Dataset. @@ -2877,6 +2878,9 @@ def interp( Additional keyword arguments passed to scipy's interpolator. Valid options and their behavior depend on if 1-dimensional or multi-dimensional interpolation is used. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. **coords_kwargs : {dim: coordinate, ...}, optional The keyword arguments form of ``coords``. One of coords or coords_kwargs must be provided. @@ -3034,6 +3038,7 @@ def _validate_interp_indexer(x, new_x): } variables: Dict[Hashable, Variable] = {} + to_reindex: Dict[Hashable, Variable] = {} for name, var in obj._variables.items(): if name in indexers: continue @@ -3043,20 +3048,45 @@ def _validate_interp_indexer(x, new_x): else: use_indexers = validated_indexers - if var.dtype.kind in "uifc": + dtype_kind = var.dtype.kind + if dtype_kind in "uifc": + # For normal number types do the interpolation: var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims} variables[name] = missing.interp(var, var_indexers, method, **kwargs) + elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): + # For types that we do not understand do stepwise + # interpolation to avoid modifying the elements. + # Use reindex_variables instead because it supports + # booleans and objects and retains the dtype but inside + # this loop there might be some duplicate code that slows it + # down, therefore collect these signals and run it later: + to_reindex[name] = var elif all(d not in indexers for d in var.dims): - # keep unrelated object array + # For anything else we can only keep variables if they + # are not dependent on any coords that are being + # interpolated along: variables[name] = var + if to_reindex: + # Reindex variables: + variables_reindex = alignment.reindex_variables( + variables=to_reindex, + sizes=obj.sizes, + indexes=obj.xindexes, + indexers={k: v[-1] for k, v in validated_indexers.items()}, + method=method_non_numeric, + )[0] + variables.update(variables_reindex) + + # Get the coords that also exist in the variables: coord_names = obj._coord_names & variables.keys() + # Get the indexes that are not being interpolated along: indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} selected = self._replace_with_new_dims( variables.copy(), coord_names, indexes=indexes ) - # attach indexer as coordinate + # Attach indexer as coordinate variables.update(indexers) for k, v in indexers.items(): assert isinstance(v, Variable) @@ -3077,6 +3107,7 @@ def interp_like( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, + method_non_numeric: str = "nearest", ) -> "Dataset": """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -3098,6 +3129,9 @@ def interp_like( values. kwargs : dict, optional Additional keyword passed to scipy's interpolator. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. Returns ------- @@ -3133,7 +3167,13 @@ def interp_like( # We do not support interpolation along object coordinate. # reindex instead. ds = self.reindex(object_coords) - return ds.interp(numeric_coords, method, assume_sorted, kwargs) + return ds.interp( + coords=numeric_coords, + method=method, + assume_sorted=assume_sorted, + kwargs=kwargs, + method_non_numeric=method_non_numeric, + ) # Helper methods for rename() def _rename_vars(self, name_dict, dims_dict): diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 97a6d236f0a..ab023dc1558 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -416,15 +416,19 @@ def test_errors(use_dask): @requires_scipy def test_dtype(): - ds = xr.Dataset( - {"var1": ("x", [0, 1, 2]), "var2": ("x", ["a", "b", "c"])}, - coords={"x": [0.1, 0.2, 0.3], "z": ("x", ["a", "b", "c"])}, - ) - actual = ds.interp(x=[0.15, 0.25]) - assert "var1" in actual - assert "var2" not in actual - # object array should be dropped - assert "z" not in actual.coords + data_vars = dict( + a=("time", np.array([1, 1.25, 2])), + b=("time", np.array([True, True, False], dtype=bool)), + c=("time", np.array(["start", "start", "end"], dtype=str)), + ) + time = np.array([0, 0.25, 1], dtype=float) + expected = xr.Dataset(data_vars, coords=dict(time=time)) + actual = xr.Dataset( + {k: (dim, arr[[0, -1]]) for k, (dim, arr) in data_vars.items()}, + coords=dict(time=time[[0, -1]]), + ) + actual = actual.interp(time=time, method="linear") + assert_identical(expected, actual) @requires_scipy From 504caeafc5830937c5c315e552bd0486522c848c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 13 May 2021 18:30:20 +0200 Subject: [PATCH 3/7] Add whats new for dataset interpolation with non-numerics (#5297) --- doc/whats-new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3f81678b8d5..9024f0efe37 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,7 +21,9 @@ v0.18.1 (unreleased) New Features ~~~~~~~~~~~~ - +- :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes, + such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`). + By `Jimmy Westling `_. Breaking changes ~~~~~~~~~~~~~~~~ From 4067c0141262e472a74da3fa6c22b61f7b02df9e Mon Sep 17 00:00:00 2001 From: Zachary Moon Date: Thu, 13 May 2021 12:36:17 -0400 Subject: [PATCH 4/7] FacetGrid docstrings (#5293) Co-authored-by: keewis --- doc/api.rst | 1 + xarray/plot/facetgrid.py | 54 +++++++++++++++++++++------------------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 7c01a8af0f1..1bd4eee9b12 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -846,6 +846,7 @@ Faceting plot.FacetGrid plot.FacetGrid.add_colorbar plot.FacetGrid.add_legend + plot.FacetGrid.add_quiverkey plot.FacetGrid.map plot.FacetGrid.map_dataarray plot.FacetGrid.map_dataarray_line diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index ab6d524aee4..28dd82e76f5 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -35,13 +35,13 @@ def _nicetitle(coord, value, maxchar, template): class FacetGrid: """ - Initialize the matplotlib figure and FacetGrid object. + Initialize the Matplotlib figure and FacetGrid object. The :class:`FacetGrid` is an object that links a xarray DataArray to - a matplotlib figure with a particular structure. + a Matplotlib figure with a particular structure. In particular, :class:`FacetGrid` is used to draw plots with multiple - Axes where each Axes shows the same relationship conditioned on + axes, where each axes shows the same relationship conditioned on different levels of some dimension. It's possible to condition on up to two variables by assigning variables to the rows and columns of the grid. @@ -59,19 +59,19 @@ class FacetGrid: Attributes ---------- - axes : numpy object array - Contains axes in corresponding position, as returned from - plt.subplots - col_labels : list - list of :class:`matplotlib.text.Text` instances corresponding to column titles. - row_labels : list - list of :class:`matplotlib.text.Text` instances corresponding to row titles. - fig : matplotlib.Figure - The figure containing all the axes - name_dicts : numpy object array - Contains dictionaries mapping coordinate names to values. None is - used as a sentinel value for axes which should remain empty, ie. - sometimes the bottom right grid + axes : ndarray of matplotlib.axes.Axes + Array containing axes in corresponding position, as returned from + :py:func:`matplotlib.pyplot.subplots`. + col_labels : list of matplotlib.text.Text + Column titles. + row_labels : list of matplotlib.text.Text + Row titles. + fig : matplotlib.figure.Figure + The figure containing all the axes. + name_dicts : ndarray of dict + Array containing dictionaries mapping coordinate names to values. ``None`` is + used as a sentinel value for axes that should remain empty, i.e., + sometimes the rightmost grid positions in the bottom row. """ def __init__( @@ -91,26 +91,28 @@ def __init__( Parameters ---------- data : DataArray - xarray DataArray to be plotted - row, col : strings + xarray DataArray to be plotted. + row, col : str Dimesion names that define subsets of the data, which will be drawn on separate facets in the grid. col_wrap : int, optional - "Wrap" the column variable at this width, so that the column facets + "Wrap" the grid the for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. sharex : bool, optional - If true, the facets will share x axes + If true, the facets will share *x* axes. sharey : bool, optional - If true, the facets will share y axes + If true, the facets will share *y* axes. figsize : tuple, optional A tuple (width, height) of the figure in inches. If set, overrides ``size`` and ``aspect``. aspect : scalar, optional Aspect ratio of each facet, so that ``aspect * size`` gives the - width of each facet in inches + width of each facet in inches. size : scalar, optional - Height (in inches) of each facet. See also: ``aspect`` + Height (in inches) of each facet. See also: ``aspect``. subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots + Dictionary of keyword arguments for Matplotlib subplots + (:py:func:`matplotlib.pyplot.subplots`). """ @@ -431,7 +433,7 @@ def add_legend(self, **kwargs): self._adjust_fig_for_guide(self.figlegend) def add_colorbar(self, **kwargs): - """Draw a colorbar""" + """Draw a colorbar.""" kwargs = kwargs.copy() if self._cmap_extend is not None: kwargs.setdefault("extend", self._cmap_extend) @@ -564,7 +566,7 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, fontsize=_FONTSIZE): """ - Set and control tick behavior + Set and control tick behavior. Parameters ---------- From 1a7b285be676d5404a4140fc86e8756de75ee7ac Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Thu, 13 May 2021 09:37:15 -0700 Subject: [PATCH 5/7] Code cleanup (#5234) Co-authored-by: keewis Co-authored-by: Mathias Hauser Co-authored-by: Stephan Hoyer --- xarray/backends/api.py | 45 ++++++++----------- xarray/backends/cfgrib_.py | 3 +- xarray/backends/common.py | 7 ++- xarray/backends/h5netcdf_.py | 15 +++---- xarray/backends/locks.py | 2 +- xarray/backends/netCDF4_.py | 63 ++++++++++++-------------- xarray/backends/netcdf3.py | 2 - xarray/backends/plugins.py | 5 +-- xarray/backends/pydap_.py | 2 +- xarray/backends/rasterio_.py | 2 +- xarray/backends/scipy_.py | 24 +++++----- xarray/backends/zarr.py | 20 ++++----- xarray/coding/cftime_offsets.py | 14 ++---- xarray/coding/cftimeindex.py | 26 +++++------ xarray/coding/frequencies.py | 5 +-- xarray/coding/strings.py | 5 +-- xarray/coding/times.py | 7 ++- xarray/coding/variables.py | 19 ++++---- xarray/conventions.py | 4 +- xarray/core/accessor_str.py | 4 +- xarray/core/alignment.py | 53 ++++++++++------------ xarray/core/combine.py | 5 +-- xarray/core/common.py | 17 +++---- xarray/core/computation.py | 10 ++--- xarray/core/concat.py | 26 +++++------ xarray/core/coordinates.py | 36 +++++++-------- xarray/core/dataarray.py | 55 +++++++++++------------ xarray/core/dataset.py | 80 ++++++++++++++++----------------- xarray/core/dtypes.py | 5 +-- xarray/core/duck_array_ops.py | 6 +-- xarray/core/extensions.py | 7 ++- xarray/core/formatting_html.py | 8 ++-- xarray/core/groupby.py | 35 ++++++--------- xarray/core/merge.py | 33 ++++++-------- xarray/core/missing.py | 2 +- xarray/core/npcompat.py | 7 ++- xarray/core/ops.py | 4 +- xarray/core/options.py | 3 +- xarray/core/parallel.py | 2 +- xarray/core/resample_cftime.py | 20 ++------- xarray/core/rolling.py | 41 +++++++---------- xarray/core/utils.py | 34 ++++++-------- xarray/core/variable.py | 69 +++++++++++----------------- xarray/plot/utils.py | 26 +++++------ xarray/tests/test_groupby.py | 4 +- xarray/tutorial.py | 4 +- xarray/ufuncs.py | 7 ++- xarray/util/print_versions.py | 18 ++++---- 48 files changed, 377 insertions(+), 514 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 29ce46c8c68..e950baed5e0 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -102,12 +102,11 @@ def _get_default_engine_netcdf(): def _get_default_engine(path: str, allow_remote: bool = False): if allow_remote and is_remote_uri(path): - engine = _get_default_engine_remote_uri() + return _get_default_engine_remote_uri() elif path.endswith(".gz"): - engine = _get_default_engine_gz() + return _get_default_engine_gz() else: - engine = _get_default_engine_netcdf() - return engine + return _get_default_engine_netcdf() def _validate_dataset_names(dataset): @@ -282,7 +281,7 @@ def _chunk_ds( mtime = _get_mtime(filename_or_obj) token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) - name_prefix = "open_dataset-%s" % token + name_prefix = f"open_dataset-{token}" variables = {} for name, var in backend_ds.variables.items(): @@ -295,8 +294,7 @@ def _chunk_ds( name_prefix=name_prefix, token=token, ) - ds = backend_ds._replace(variables) - return ds + return backend_ds._replace(variables) def _dataset_from_backend_dataset( @@ -308,12 +306,10 @@ def _dataset_from_backend_dataset( overwrite_encoded_chunks, **extra_tokens, ): - if not (isinstance(chunks, (int, dict)) or chunks is None): - if chunks != "auto": - raise ValueError( - "chunks must be an int, dict, 'auto', or None. " - "Instead found %s. " % chunks - ) + if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: + raise ValueError( + f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}." + ) _protect_dataset_variables_inplace(backend_ds, cache) if chunks is None: @@ -331,9 +327,8 @@ def _dataset_from_backend_dataset( ds.set_close(backend_ds._close) # Ensure source filename always stored in dataset object (GH issue #2550) - if "source" not in ds.encoding: - if isinstance(filename_or_obj, str): - ds.encoding["source"] = filename_or_obj + if "source" not in ds.encoding and isinstance(filename_or_obj, str): + ds.encoding["source"] = filename_or_obj return ds @@ -515,7 +510,6 @@ def open_dataset( **decoders, **kwargs, ) - return ds @@ -1015,8 +1009,8 @@ def to_netcdf( elif engine != "scipy": raise ValueError( "invalid engine for creating bytes with " - "to_netcdf: %r. Only the default engine " - "or engine='scipy' is supported" % engine + f"to_netcdf: {engine!r}. Only the default engine " + "or engine='scipy' is supported" ) if not compute: raise NotImplementedError( @@ -1037,7 +1031,7 @@ def to_netcdf( try: store_open = WRITEABLE_STORES[engine] except KeyError: - raise ValueError("unrecognized engine for to_netcdf: %r" % engine) + raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") if format is not None: format = format.upper() @@ -1049,9 +1043,8 @@ def to_netcdf( autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] if autoclose and engine == "scipy": raise NotImplementedError( - "Writing netCDF files with the %s backend " - "is not currently supported with dask's %s " - "scheduler" % (engine, scheduler) + f"Writing netCDF files with the {engine} backend " + f"is not currently supported with dask's {scheduler} scheduler" ) target = path_or_file if path_or_file is not None else BytesIO() @@ -1061,7 +1054,7 @@ def to_netcdf( kwargs["invalid_netcdf"] = invalid_netcdf else: raise ValueError( - "unrecognized option 'invalid_netcdf' for engine %s" % engine + f"unrecognized option 'invalid_netcdf' for engine {engine}" ) store = store_open(target, mode, format, group, **kwargs) @@ -1203,7 +1196,7 @@ def save_mfdataset( Data variables: a (time) float64 0.0 0.02128 0.04255 0.06383 ... 0.9574 0.9787 1.0 >>> years, datasets = zip(*ds.groupby("time.year")) - >>> paths = ["%s.nc" % y for y in years] + >>> paths = [f"{y}.nc" for y in years] >>> xr.save_mfdataset(datasets, paths) """ if mode == "w" and len(set(paths)) < len(paths): @@ -1215,7 +1208,7 @@ def save_mfdataset( if not isinstance(obj, Dataset): raise TypeError( "save_mfdataset only supports writing Dataset " - "objects, received type %s" % type(obj) + f"objects, received type {type(obj)}" ) if groups is None: diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 24a075aa811..9e5546f052a 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -90,8 +90,7 @@ def get_dimensions(self): def get_encoding(self): dims = self.get_dimensions() - encoding = {"unlimited_dims": {k for k, v in dims.items() if v is None}} - return encoding + return {"unlimited_dims": {k for k, v in dims.items() if v is None}} class CfgribfBackendEntrypoint(BackendEntrypoint): diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 026c7e5c7db..64a245ddead 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -69,9 +69,8 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 base_delay = initial_delay * 2 ** n next_delay = base_delay + np.random.randint(base_delay) msg = ( - "getitem failed, waiting %s ms before trying again " - "(%s tries remaining). Full traceback: %s" - % (next_delay, max_retries - n, traceback.format_exc()) + f"getitem failed, waiting {next_delay} ms before trying again " + f"({max_retries - n} tries remaining). Full traceback: {traceback.format_exc()}" ) logger.debug(msg) time.sleep(1e-3 * next_delay) @@ -336,7 +335,7 @@ def set_dimensions(self, variables, unlimited_dims=None): if dim in existing_dims and length != existing_dims[dim]: raise ValueError( "Unable to update size for existing dimension" - "%r (%d != %d)" % (dim, length, existing_dims[dim]) + f"{dim!r} ({length} != {existing_dims[dim]})" ) elif dim not in existing_dims: is_unlimited = dim in unlimited_dims diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 84e89f80dae..9f744d0c1ef 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -37,8 +37,7 @@ class H5NetCDFArrayWrapper(BaseNetCDF4Array): def get_array(self, needs_lock=True): ds = self.datastore._acquire(needs_lock) - variable = ds.variables[self.variable_name] - return variable + return ds.variables[self.variable_name] def __getitem__(self, key): return indexing.explicit_indexing_adapter( @@ -102,7 +101,7 @@ def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=Fal if group is None: root, group = find_root_and_group(manager) else: - if not type(manager) is h5netcdf.File: + if type(manager) is not h5netcdf.File: raise ValueError( "must supply a h5netcdf.File if the group " "argument is provided" @@ -233,11 +232,9 @@ def get_dimensions(self): return self.ds.dimensions def get_encoding(self): - encoding = {} - encoding["unlimited_dims"] = { - k for k, v in self.ds.dimensions.items() if v is None + return { + "unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None} } - return encoding def set_dimension(self, name, length, is_unlimited=False): if is_unlimited: @@ -266,9 +263,9 @@ def prepare_variable( "h5netcdf does not yet support setting a fill value for " "variable-length strings " "(https://github.com/shoyer/h5netcdf/issues/37). " - "Either remove '_FillValue' from encoding on variable %r " + f"Either remove '_FillValue' from encoding on variable {name!r} " "or set {'dtype': 'S1'} in encoding to use the fixed width " - "NC_CHAR type." % name + "NC_CHAR type." ) if dtype is str: diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 5303ea49381..59417336f5f 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -167,7 +167,7 @@ def locked(self): return any(lock.locked for lock in self.locks) def __repr__(self): - return "CombinedLock(%r)" % list(self.locks) + return f"CombinedLock({list(self.locks)!r})" class DummyLock: diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a60c940c3c4..694b0d2fdd2 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -122,25 +122,23 @@ def _encode_nc4_variable(var): def _check_encoding_dtype_is_vlen_string(dtype): if dtype is not str: raise AssertionError( # pragma: no cover - "unexpected dtype encoding %r. This shouldn't happen: please " - "file a bug report at github.com/pydata/xarray" % dtype + f"unexpected dtype encoding {dtype!r}. This shouldn't happen: please " + "file a bug report at github.com/pydata/xarray" ) def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False): if nc_format == "NETCDF4": - datatype = _nc4_dtype(var) - else: - if "dtype" in var.encoding: - encoded_dtype = var.encoding["dtype"] - _check_encoding_dtype_is_vlen_string(encoded_dtype) - if raise_on_invalid_encoding: - raise ValueError( - "encoding dtype=str for vlen strings is only supported " - "with format='NETCDF4'." - ) - datatype = var.dtype - return datatype + return _nc4_dtype(var) + if "dtype" in var.encoding: + encoded_dtype = var.encoding["dtype"] + _check_encoding_dtype_is_vlen_string(encoded_dtype) + if raise_on_invalid_encoding: + raise ValueError( + "encoding dtype=str for vlen strings is only supported " + "with format='NETCDF4'." + ) + return var.dtype def _nc4_dtype(var): @@ -178,7 +176,7 @@ def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): ds = create_group(ds, key) else: # wrap error to provide slightly more helpful message - raise OSError("group not found: %s" % key, e) + raise OSError(f"group not found: {key}", e) return ds @@ -203,7 +201,7 @@ def _force_native_endianness(var): # if endian exists, remove it from the encoding. var.encoding.pop("endian", None) # check to see if encoding has a value for endian its 'native' - if not var.encoding.get("endian", "native") == "native": + if var.encoding.get("endian", "native") != "native": raise NotImplementedError( "Attempt to write non-native endian type, " "this is not supported by the netCDF4 " @@ -270,8 +268,8 @@ def _extract_nc4_variable_encoding( invalid = [k for k in encoding if k not in valid_encodings] if invalid: raise ValueError( - "unexpected encoding parameters for %r backend: %r. Valid " - "encodings are: %r" % (backend, invalid, valid_encodings) + f"unexpected encoding parameters for {backend!r} backend: {invalid!r}. Valid " + f"encodings are: {valid_encodings!r}" ) else: for k in list(encoding): @@ -282,10 +280,8 @@ def _extract_nc4_variable_encoding( def _is_list_of_strings(value): - if np.asarray(value).dtype.kind in ["U", "S"] and np.asarray(value).size > 1: - return True - else: - return False + arr = np.asarray(value) + return arr.dtype.kind in ["U", "S"] and arr.size > 1 class NetCDF4DataStore(WritableCFDataStore): @@ -313,7 +309,7 @@ def __init__( if group is None: root, group = find_root_and_group(manager) else: - if not type(manager) is netCDF4.Dataset: + if type(manager) is not netCDF4.Dataset: raise ValueError( "must supply a root netCDF4.Dataset if the group " "argument is provided" @@ -417,25 +413,22 @@ def open_store_variable(self, name, var): return Variable(dimensions, data, attributes, encoding) def get_variables(self): - dsvars = FrozenDict( + return FrozenDict( (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() ) - return dsvars def get_attrs(self): - attrs = FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs()) - return attrs + return FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs()) def get_dimensions(self): - dims = FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) - return dims + return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) def get_encoding(self): - encoding = {} - encoding["unlimited_dims"] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited() + return { + "unlimited_dims": { + k for k, v in self.ds.dimensions.items() if v.isunlimited() + } } - return encoding def set_dimension(self, name, length, is_unlimited=False): dim_length = length if not is_unlimited else None @@ -473,9 +466,9 @@ def prepare_variable( "netCDF4 does not yet support setting a fill value for " "variable-length strings " "(https://github.com/Unidata/netcdf4-python/issues/730). " - "Either remove '_FillValue' from encoding on variable %r " + f"Either remove '_FillValue' from encoding on variable {name!r} " "or set {'dtype': 'S1'} in encoding to use the fixed width " - "NC_CHAR type." % name + "NC_CHAR type." ) encoding = _extract_nc4_variable_encoding( diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index 001af0bf8e1..5fdd0534d57 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -125,8 +125,6 @@ def is_valid_nc3_name(s): """ if not isinstance(s, str): return False - if not isinstance(s, str): - s = s.decode("utf-8") num_bytes = len(s.encode("utf-8")) return ( (unicodedata.normalize("NFC", s) == s) diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 23e83b0021e..d892e8761a7 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -87,10 +87,7 @@ def build_engines(pkg_entrypoints): backend_entrypoints.update(external_backend_entrypoints) backend_entrypoints = sort_backends(backend_entrypoints) set_missing_parameters(backend_entrypoints) - engines = {} - for name, backend in backend_entrypoints.items(): - engines[name] = backend() - return engines + return {name: backend() for name, backend in backend_entrypoints.items()} @functools.lru_cache(maxsize=1) diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 2372468d934..25d2df9d76a 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -47,7 +47,7 @@ def _getitem(self, key): result = robust_getitem(array, key, catch=ValueError) # in some cases, pydap doesn't squeeze axes automatically like numpy axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types)) - if result.ndim + len(axis) != array.ndim and len(axis) > 0: + if result.ndim + len(axis) != array.ndim and axis: result = np.squeeze(result, axis) return result diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index f5d9b7bf900..49a5a9ec7ae 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -389,7 +389,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc # the filename is probably an s3 bucket rather than a regular file mtime = None token = tokenize(filename, mtime, chunks) - name_prefix = "open_rasterio-%s" % token + name_prefix = f"open_rasterio-{token}" result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index c27716ea44d..9c33b172639 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -128,7 +128,7 @@ def __init__( elif format == "NETCDF3_CLASSIC": version = 1 else: - raise ValueError("invalid format for scipy.io.netcdf backend: %r" % format) + raise ValueError(f"invalid format for scipy.io.netcdf backend: {format!r}") if lock is None and mode != "r" and isinstance(filename_or_obj, str): lock = get_write_lock(filename_or_obj) @@ -174,16 +174,14 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) def get_encoding(self): - encoding = {} - encoding["unlimited_dims"] = { - k for k, v in self.ds.dimensions.items() if v is None + return { + "unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None} } - return encoding def set_dimension(self, name, length, is_unlimited=False): if name in self.ds.dimensions: raise ValueError( - "%s does not support modifying dimensions" % type(self).__name__ + f"{type(self).__name__} does not support modifying dimensions" ) dim_length = length if not is_unlimited else None self.ds.createDimension(name, dim_length) @@ -204,12 +202,14 @@ def encode_variable(self, variable): def prepare_variable( self, name, variable, check_encoding=False, unlimited_dims=None ): - if check_encoding and variable.encoding: - if variable.encoding != {"_FillValue": None}: - raise ValueError( - "unexpected encoding for scipy backend: %r" - % list(variable.encoding) - ) + if ( + check_encoding + and variable.encoding + and variable.encoding != {"_FillValue": None} + ): + raise ValueError( + f"unexpected encoding for scipy backend: {list(variable.encoding)}" + ) data = variable.data # nb. this still creates a numpy array in all memory, even though we diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index fef7d739d25..72c4e99265d 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -205,8 +205,8 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): dimensions = zarr_obj.attrs[dimension_key] except KeyError: raise KeyError( - "Zarr object is missing the attribute `%s`, which is " - "required for xarray to determine variable dimensions." % (dimension_key) + f"Zarr object is missing the attribute `{dimension_key}`, which is " + "required for xarray to determine variable dimensions." ) attributes = HiddenKeyDict(zarr_obj.attrs, [dimension_key]) return dimensions, attributes @@ -236,7 +236,7 @@ def extract_zarr_variable_encoding( invalid = [k for k in encoding if k not in valid_encodings] if invalid: raise ValueError( - "unexpected encoding parameters for zarr backend: %r" % invalid + f"unexpected encoding parameters for zarr backend: {invalid!r}" ) else: for k in list(encoding): @@ -380,8 +380,7 @@ def get_variables(self): ) def get_attrs(self): - attributes = dict(self.ds.attrs.asdict()) - return attributes + return dict(self.ds.attrs.asdict()) def get_dimensions(self): dimensions = {} @@ -390,16 +389,16 @@ def get_dimensions(self): for d, s in zip(v.attrs[DIMENSION_KEY], v.shape): if d in dimensions and dimensions[d] != s: raise ValueError( - "found conflicting lengths for dimension %s " - "(%d != %d)" % (d, s, dimensions[d]) + f"found conflicting lengths for dimension {d} " + f"({s} != {dimensions[d]})" ) dimensions[d] = s except KeyError: raise KeyError( - "Zarr object is missing the attribute `%s`, " + f"Zarr object is missing the attribute `{DIMENSION_KEY}`, " "which is required for xarray to determine " - "variable dimensions." % (DIMENSION_KEY) + "variable dimensions." ) return dimensions @@ -459,7 +458,7 @@ def store( variables_without_encoding, attributes ) - if len(existing_variables) > 0: + if existing_variables: # there are variables to append # their encoding must be the same as in the store ds = open_zarr(self.ds.store, group=self.ds.path, chunks=None) @@ -700,7 +699,6 @@ def open_zarr( decode_timedelta=decode_timedelta, use_cftime=use_cftime, ) - return ds diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index c25d5296c41..c031bffb2cd 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -178,8 +178,7 @@ def _get_day_of_month(other, day_option): if day_option == "start": return 1 elif day_option == "end": - days_in_month = _days_in_month(other) - return days_in_month + return _days_in_month(other) elif day_option is None: # Note: unlike `_shift_month`, _get_day_of_month does not # allow day_option = None @@ -291,10 +290,7 @@ def roll_qtrday(other, n, month, day_option, modby=3): def _validate_month(month, default_month): - if month is None: - result_month = default_month - else: - result_month = month + result_month = default_month if month is None else month if not isinstance(result_month, int): raise TypeError( "'self.month' must be an integer value between 1 " @@ -687,11 +683,7 @@ def to_offset(freq): freq = freq_data["freq"] multiples = freq_data["multiple"] - if multiples is None: - multiples = 1 - else: - multiples = int(multiples) - + multiples = 1 if multiples is None else int(multiples) return _FREQUENCIES[freq](n=multiples) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 15f75955e00..a43724a6f31 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -255,7 +255,7 @@ def format_times( indent = first_row_offset if row == 0 else offset row_end = last_row_end if row == n_rows - 1 else intermediate_row_end times_for_row = index[row * n_per_row : (row + 1) * n_per_row] - representation = representation + format_row( + representation += format_row( times_for_row, indent=indent, separator=separator, row_end=row_end ) @@ -268,8 +268,9 @@ def format_attrs(index, separator=", "): "dtype": f"'{index.dtype}'", "length": f"{len(index)}", "calendar": f"'{index.calendar}'", + "freq": f"'{index.freq}'" if len(index) >= 3 else None, } - attrs["freq"] = f"'{index.freq}'" if len(index) >= 3 else None + attrs_str = [f"{k}={v}" for k, v in attrs.items()] attrs_str = f"{separator}".join(attrs_str) return attrs_str @@ -350,14 +351,13 @@ def __repr__(self): attrs_str = format_attrs(self) # oneliner only if smaller than display_width full_repr_str = f"{klass_name}([{datastr}], {attrs_str})" - if len(full_repr_str) <= display_width: - return full_repr_str - else: + if len(full_repr_str) > display_width: # if attrs_str too long, one per line if len(attrs_str) >= display_width - offset: attrs_str = attrs_str.replace(",", f",\n{' '*(offset-2)}") full_repr_str = f"{klass_name}([{datastr}],\n{' '*(offset-1)}{attrs_str})" - return full_repr_str + + return full_repr_str def _partial_date_slice(self, resolution, parsed): """Adapted from @@ -470,15 +470,15 @@ def get_loc(self, key, method=None, tolerance=None): def _maybe_cast_slice_bound(self, label, side, kind): """Adapted from pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound""" - if isinstance(label, str): - parsed, resolution = _parse_iso8601_with_reso(self.date_type, label) - start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) - if self.is_monotonic_decreasing and len(self) > 1: - return end if side == "left" else start - return start if side == "left" else end - else: + if not isinstance(label, str): return label + parsed, resolution = _parse_iso8601_with_reso(self.date_type, label) + start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) + if self.is_monotonic_decreasing and len(self) > 1: + return end if side == "left" else start + return start if side == "left" else end + # TODO: Add ability to use integer range outside of iloc? # e.g. series[1:5]. def get_value(self, series, key): diff --git a/xarray/coding/frequencies.py b/xarray/coding/frequencies.py index c83c766f071..e9efef8eb7a 100644 --- a/xarray/coding/frequencies.py +++ b/xarray/coding/frequencies.py @@ -187,7 +187,7 @@ def _get_quartely_rule(self): if len(self.month_deltas) > 1: return None - if not self.month_deltas[0] % 3 == 0: + if self.month_deltas[0] % 3 != 0: return None return {"cs": "QS", "ce": "Q"}.get(month_anchor_check(self.index)) @@ -259,8 +259,7 @@ def month_anchor_check(dates): if calendar_end: cal = date.day == date.daysinmonth - if calendar_end: - calendar_end &= cal + calendar_end &= cal elif not calendar_start: break diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index e16e983fd8a..c217cb0c865 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -111,7 +111,7 @@ def encode(self, variable, name=None): if "char_dim_name" in encoding.keys(): char_dim_name = encoding.pop("char_dim_name") else: - char_dim_name = "string%s" % data.shape[-1] + char_dim_name = f"string{data.shape[-1]}" dims = dims + (char_dim_name,) return Variable(dims, data, attrs, encoding) @@ -140,8 +140,7 @@ def bytes_to_char(arr): chunks=arr.chunks + ((arr.dtype.itemsize,)), new_axis=[arr.ndim], ) - else: - return _numpy_bytes_to_char(arr) + return _numpy_bytes_to_char(arr) def _numpy_bytes_to_char(arr): diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 54400414ebc..9f5d1f87aee 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -75,7 +75,7 @@ def _is_standard_calendar(calendar): def _netcdf_to_numpy_timeunit(units): units = units.lower() if not units.endswith("s"): - units = "%ss" % units + units = f"{units}s" return { "nanoseconds": "ns", "microseconds": "us", @@ -147,7 +147,7 @@ def _decode_cf_datetime_dtype(data, units, calendar, use_cftime): result = decode_cf_datetime(example_value, units, calendar, use_cftime) except Exception: calendar_msg = ( - "the default calendar" if calendar is None else "calendar %r" % calendar + "the default calendar" if calendar is None else f"calendar {calendar!r}" ) msg = ( f"unable to decode time units {units!r} with {calendar_msg!r}. Try " @@ -370,8 +370,7 @@ def infer_timedelta_units(deltas): """ deltas = to_timedelta_unboxed(np.asarray(deltas).ravel()) unique_timedeltas = np.unique(deltas[pd.notnull(deltas)]) - units = _infer_time_units_from_diff(unique_timedeltas) - return units + return _infer_time_units_from_diff(unique_timedeltas) def cftime_to_nptime(times): diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 938752c4efc..1ebaab1be02 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -77,7 +77,6 @@ def __repr__(self): def lazy_elemwise_func(array, func, dtype): """Lazily apply an element-wise function to an array. - Parameters ---------- array : any valid value of Variable._data @@ -255,10 +254,10 @@ def encode(self, variable, name=None): if "scale_factor" in encoding or "add_offset" in encoding: dtype = _choose_float_dtype(data.dtype, "add_offset" in encoding) data = data.astype(dtype=dtype, copy=True) - if "add_offset" in encoding: - data -= pop_to(encoding, attrs, "add_offset", name=name) - if "scale_factor" in encoding: - data /= pop_to(encoding, attrs, "scale_factor", name=name) + if "add_offset" in encoding: + data -= pop_to(encoding, attrs, "add_offset", name=name) + if "scale_factor" in encoding: + data /= pop_to(encoding, attrs, "scale_factor", name=name) return Variable(dims, data, attrs, encoding) @@ -294,7 +293,7 @@ def encode(self, variable, name=None): # integer data should be treated as unsigned" if encoding.get("_Unsigned", "false") == "true": pop_to(encoding, attrs, "_Unsigned") - signed_dtype = np.dtype("i%s" % data.dtype.itemsize) + signed_dtype = np.dtype(f"i{data.dtype.itemsize}") if "_FillValue" in attrs: new_fill = signed_dtype.type(attrs["_FillValue"]) attrs["_FillValue"] = new_fill @@ -310,7 +309,7 @@ def decode(self, variable, name=None): if data.dtype.kind == "i": if unsigned == "true": - unsigned_dtype = np.dtype("u%s" % data.dtype.itemsize) + unsigned_dtype = np.dtype(f"u{data.dtype.itemsize}") transform = partial(np.asarray, dtype=unsigned_dtype) data = lazy_elemwise_func(data, transform, unsigned_dtype) if "_FillValue" in attrs: @@ -318,7 +317,7 @@ def decode(self, variable, name=None): attrs["_FillValue"] = new_fill elif data.dtype.kind == "u": if unsigned == "false": - signed_dtype = np.dtype("i%s" % data.dtype.itemsize) + signed_dtype = np.dtype(f"i{data.dtype.itemsize}") transform = partial(np.asarray, dtype=signed_dtype) data = lazy_elemwise_func(data, transform, signed_dtype) if "_FillValue" in attrs: @@ -326,8 +325,8 @@ def decode(self, variable, name=None): attrs["_FillValue"] = new_fill else: warnings.warn( - "variable %r has _Unsigned attribute but is not " - "of integer type. Ignoring attribute." % name, + f"variable {name!r} has _Unsigned attribute but is not " + "of integer type. Ignoring attribute.", SerializationWarning, stacklevel=3, ) diff --git a/xarray/conventions.py b/xarray/conventions.py index aece572fda3..901d19bd99b 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -110,9 +110,9 @@ def maybe_encode_nonstring_dtype(var, name=None): and "missing_value" not in var.attrs ): warnings.warn( - "saving variable %s with floating " + f"saving variable {name} with floating " "point data as an integer dtype without " - "any _FillValue to use for NaNs" % name, + "any _FillValue to use for NaNs", SerializationWarning, stacklevel=10, ) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index f0e416b52e6..d50163c435b 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -81,7 +81,7 @@ def _contains_obj_type(*, pat: Any, checker: Any) -> bool: return True # If it is not an object array it can't contain compiled re - if not getattr(pat, "dtype", "no") == np.object_: + if getattr(pat, "dtype", "no") != np.object_: return False return _apply_str_ufunc(func=checker, obj=pat).all() @@ -95,7 +95,7 @@ def _contains_str_like(pat: Any) -> bool: if not hasattr(pat, "dtype"): return False - return pat.dtype.kind == "U" or pat.dtype.kind == "S" + return pat.dtype.kind in ["U", "S"] def _contains_compiled_re(pat: Any) -> bool: diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index f6e026c0109..a4794dd28f5 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -48,7 +48,7 @@ def _get_joiner(join, index_cls): # We rewrite all indexes and then use join='left' return operator.itemgetter(0) else: - raise ValueError("invalid value for join: %s" % join) + raise ValueError(f"invalid value for join: {join}") def _override_indexes(objects, all_indexes, exclude): @@ -57,16 +57,16 @@ def _override_indexes(objects, all_indexes, exclude): lengths = {index.size for index in dim_indexes} if len(lengths) != 1: raise ValueError( - "Indexes along dimension %r don't have the same length." - " Cannot use join='override'." % dim + f"Indexes along dimension {dim!r} don't have the same length." + " Cannot use join='override'." ) objects = list(objects) for idx, obj in enumerate(objects[1:]): - new_indexes = {} - for dim in obj.xindexes: - if dim not in exclude: - new_indexes[dim] = all_indexes[dim][0] + new_indexes = { + dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude + } + objects[idx + 1] = obj._overwrite_indexes(new_indexes) return objects @@ -338,21 +338,17 @@ def align( labeled_size = index.size if len(unlabeled_sizes | {labeled_size}) > 1: raise ValueError( - "arguments without labels along dimension %r cannot be " - "aligned because they have different dimension size(s) %r " - "than the size of the aligned dimension labels: %r" - % (dim, unlabeled_sizes, labeled_size) + f"arguments without labels along dimension {dim!r} cannot be " + f"aligned because they have different dimension size(s) {unlabeled_sizes!r} " + f"than the size of the aligned dimension labels: {labeled_size!r}" ) - for dim in unlabeled_dim_sizes: - if dim not in all_indexes: - sizes = unlabeled_dim_sizes[dim] - if len(sizes) > 1: - raise ValueError( - "arguments without labels along dimension %r cannot be " - "aligned because they have different dimension sizes: %r" - % (dim, sizes) - ) + for dim, sizes in unlabeled_dim_sizes.items(): + if dim not in all_indexes and len(sizes) > 1: + raise ValueError( + f"arguments without labels along dimension {dim!r} cannot be " + f"aligned because they have different dimension sizes: {sizes!r}" + ) result = [] for obj in objects: @@ -486,8 +482,7 @@ def reindex_like_indexers( if other_size != target_size: raise ValueError( "different size for unlabeled " - "dimension on argument %r: %r vs %r" - % (dim, other_size, target_size) + f"dimension on argument {dim!r}: {other_size!r} vs {target_size!r}" ) return indexers @@ -575,8 +570,8 @@ def reindex_variables( if not index.is_unique: raise ValueError( - "cannot reindex or align along dimension %r because the " - "index has duplicate values" % dim + f"cannot reindex or align along dimension {dim!r} because the " + "index has duplicate values" ) int_indexer = get_indexer_nd(index, target, method, tolerance) @@ -603,9 +598,9 @@ def reindex_variables( new_size = indexers[dim].size if existing_size != new_size: raise ValueError( - "cannot reindex or align along dimension %r without an " - "index because its size %r is different from the size of " - "the new index %r" % (dim, existing_size, new_size) + f"cannot reindex or align along dimension {dim!r} without an " + f"index because its size {existing_size!r} is different from the size of " + f"the new index {new_size!r}" ) for name, var in variables.items(): @@ -756,8 +751,6 @@ def broadcast(*args, exclude=None): args = align(*args, join="outer", copy=False, exclude=exclude) dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - result = [] - for arg in args: - result.append(_broadcast_helper(arg, exclude, dims_map, common_coords)) + result = [_broadcast_helper(arg, exclude, dims_map, common_coords) for arg in args] return tuple(result) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 105e0a5a66c..7d9273670a3 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -11,8 +11,7 @@ def _infer_concat_order_from_positions(datasets): - combined_ids = dict(_infer_tile_ids_from_nested_list(datasets, ())) - return combined_ids + return dict(_infer_tile_ids_from_nested_list(datasets, ())) def _infer_tile_ids_from_nested_list(entry, current_pos): @@ -144,7 +143,7 @@ def _check_dimension_depth_tile_ids(combined_tile_ids): nesting_depths = [len(tile_id) for tile_id in tile_ids] if not nesting_depths: nesting_depths = [0] - if not set(nesting_depths) == {nesting_depths[0]}: + if set(nesting_depths) != {nesting_depths[0]}: raise ValueError( "The supplied objects do not form a hypercube because" " sub-lists do not have consistent depths" diff --git a/xarray/core/common.py b/xarray/core/common.py index e4a5264d8e6..320c94972d2 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -209,11 +209,11 @@ def __init_subclass__(cls, **kwargs): if not hasattr(object.__new__(cls), "__dict__"): pass elif cls.__module__.startswith("xarray."): - raise AttributeError("%s must explicitly define __slots__" % cls.__name__) + raise AttributeError(f"{cls.__name__} must explicitly define __slots__") else: cls.__setattr__ = cls._setattr_dict warnings.warn( - "xarray subclass %s should explicitly define __slots__" % cls.__name__, + f"xarray subclass {cls.__name__} should explicitly define __slots__", FutureWarning, stacklevel=2, ) @@ -251,10 +251,9 @@ def _setattr_dict(self, name: str, value: Any) -> None: if name in self.__dict__: # Custom, non-slotted attr, or improperly assigned variable? warnings.warn( - "Setting attribute %r on a %r object. Explicitly define __slots__ " + f"Setting attribute {name!r} on a {type(self).__name__!r} object. Explicitly define __slots__ " "to suppress this warning for legitimate custom attributes and " - "raise an error when attempting variables assignments." - % (name, type(self).__name__), + "raise an error when attempting variables assignments.", FutureWarning, stacklevel=2, ) @@ -274,9 +273,8 @@ def __setattr__(self, name: str, value: Any) -> None: ): raise raise AttributeError( - "cannot set attribute %r on a %r object. Use __setitem__ style" + f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style" "assignment (e.g., `ds['name'] = ...`) instead of assigning variables." - % (name, type(self).__name__) ) from e def __dir__(self) -> List[str]: @@ -655,7 +653,7 @@ def pipe( func, target = func if target in kwargs: raise ValueError( - "%s is both the pipe target and a keyword argument" % target + f"{target} is both the pipe target and a keyword argument" ) kwargs[target] = self return func(*args, **kwargs) @@ -1273,8 +1271,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): if not isinstance(cond, (Dataset, DataArray)): raise TypeError( - "cond argument is %r but must be a %r or %r" - % (cond, Dataset, DataArray) + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" ) # align so we can use integer indexing diff --git a/xarray/core/computation.py b/xarray/core/computation.py index e12938d6965..2bc3ba921a7 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -307,7 +307,7 @@ def assert_and_return_exact_match(all_keys): if keys != first_keys: raise ValueError( "exact match required for all data variable names, " - "but %r != %r" % (keys, first_keys) + f"but {keys!r} != {first_keys!r}" ) return first_keys @@ -516,7 +516,7 @@ def unified_dim_sizes( if len(set(var.dims)) < len(var.dims): raise ValueError( "broadcasting cannot handle duplicate " - "dimensions on a variable: %r" % list(var.dims) + f"dimensions on a variable: {list(var.dims)}" ) for dim, size in zip(var.dims, var.shape): if dim not in exclude_dims: @@ -526,7 +526,7 @@ def unified_dim_sizes( raise ValueError( "operands cannot be broadcast together " "with mismatched lengths for dimension " - "%r: %s vs %s" % (dim, dim_sizes[dim], size) + f"{dim}: {dim_sizes[dim]} vs {size}" ) return dim_sizes @@ -563,8 +563,8 @@ def broadcast_compat_data( if unexpected_dims: raise ValueError( "operand to apply_ufunc encountered unexpected " - "dimensions %r on an input variable: these are core " - "dimensions on other input or output variables" % unexpected_dims + f"dimensions {unexpected_dims!r} on an input variable: these are core " + "dimensions on other input or output variables" ) # for consistency with numpy, keep broadcast dimensions to the left diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 9eca99918d4..368f8992607 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -223,8 +223,7 @@ def concat( if compat not in _VALID_COMPAT: raise ValueError( - "compat=%r invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" - % compat + f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" ) if isinstance(first_obj, DataArray): @@ -234,7 +233,7 @@ def concat( else: raise TypeError( "can only concatenate xarray Dataset and DataArray " - "objects, got %s" % type(first_obj) + f"objects, got {type(first_obj)}" ) return f( objs, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs @@ -293,18 +292,16 @@ def process_subset_opt(opt, subset): if opt == "different": if compat == "override": raise ValueError( - "Cannot specify both %s='different' and compat='override'." - % subset + f"Cannot specify both {subset}='different' and compat='override'." ) # all nonindexes that are not the same in each dataset for k in getattr(datasets[0], subset): if k not in concat_over: equals[k] = None - variables = [] - for ds in datasets: - if k in ds.variables: - variables.append(ds.variables[k]) + variables = [ + ds.variables[k] for ds in datasets if k in ds.variables + ] if len(variables) == 1: # coords="different" doesn't make sense when only one object @@ -367,12 +364,12 @@ def process_subset_opt(opt, subset): if subset == "coords": raise ValueError( "some variables in coords are not coordinates on " - "the first dataset: %s" % (invalid_vars,) + f"the first dataset: {invalid_vars}" ) else: raise ValueError( "some variables in data_vars are not data variables " - "on the first dataset: %s" % (invalid_vars,) + f"on the first dataset: {invalid_vars}" ) concat_over.update(opt) @@ -439,7 +436,7 @@ def _dataset_concat( both_data_and_coords = coord_names & data_names if both_data_and_coords: raise ValueError( - "%r is a coordinate in some datasets but not others." % both_data_and_coords + f"{both_data_and_coords!r} is a coordinate in some datasets but not others." ) # we don't want the concat dimension in the result dataset yet dim_coords.pop(dim, None) @@ -507,7 +504,7 @@ def ensure_common_dims(vars): try: vars = ensure_common_dims([ds[k].variable for ds in datasets]) except KeyError: - raise ValueError("%r is not present in all datasets." % k) + raise ValueError(f"{k!r} is not present in all datasets.") combined = concat_vars(vars, dim, positions, combine_attrs=combine_attrs) assert isinstance(combined, Variable) result_vars[k] = combined @@ -519,8 +516,7 @@ def ensure_common_dims(vars): absent_coord_names = coord_names - set(result.variables) if absent_coord_names: raise ValueError( - "Variables %r are coordinates in some datasets but not others." - % absent_coord_names + f"Variables {absent_coord_names!r} are coordinates in some datasets but not others." ) result = result.set_coords(coord_names) result.encoding = result_encoding diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 50be8a7f677..767b76d0d12 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -122,13 +122,13 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: ) cumprod_lengths = np.cumproduct(index_lengths) - if cumprod_lengths[-1] != 0: - # sizes of the repeats - repeat_counts = cumprod_lengths[-1] / cumprod_lengths - else: + if cumprod_lengths[-1] == 0: # if any factor is empty, the cartesian product is empty repeat_counts = np.zeros_like(cumprod_lengths) + else: + # sizes of the repeats + repeat_counts = cumprod_lengths[-1] / cumprod_lengths # sizes of the tiles tile_counts = np.roll(cumprod_lengths, 1) tile_counts[0] = 1 @@ -156,7 +156,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: level_list += levels names += index.names - return pd.MultiIndex(level_list, code_list, names=names) + return pd.MultiIndex(level_list, code_list, names=names) def update(self, other: Mapping[Hashable, Any]) -> None: other_vars = getattr(other, "variables", other) @@ -226,10 +226,9 @@ def merge(self, other: "Coordinates") -> "Dataset": coords, indexes = merge_coordinates_without_align([self, other]) coord_names = set(coords) - merged = Dataset._construct_direct( + return Dataset._construct_direct( variables=coords, coord_names=coord_names, indexes=indexes ) - return merged class DatasetCoordinates(Coordinates): @@ -364,13 +363,13 @@ def to_dataset(self) -> "Dataset": return Dataset._construct_direct(coords, set(coords)) def __delitem__(self, key: Hashable) -> None: - if key in self: - del self._data._coords[key] - if self._data._indexes is not None and key in self._data._indexes: - del self._data._indexes[key] - else: + if key not in self: raise KeyError(f"{key!r} is not a coordinate variable.") + del self._data._coords[key] + if self._data._indexes is not None and key in self._data._indexes: + del self._data._indexes[key] + def _ipython_key_completions_(self): """Provide method for the key-autocompletions in IPython.""" return self._data._ipython_key_completions_() @@ -386,14 +385,11 @@ def assert_coordinate_consistent( """ for k in obj.dims: # make sure there are no conflict in dimension coordinates - if k in coords and k in obj.coords: - if not coords[k].equals(obj[k].variable): - raise IndexError( - "dimension coordinate {!r} conflicts between " - "indexed and indexing objects:\n{}\nvs.\n{}".format( - k, obj[k], coords[k] - ) - ) + if k in coords and k in obj.coords and not coords[k].equals(obj[k].variable): + raise IndexError( + f"dimension coordinate {k!r} conflicts between " + f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" + ) def remap_label_indexers( diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 21daed1cec1..831d0d24ccb 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -98,16 +98,16 @@ def _infer_coords_and_dims( and len(coords) != len(shape) ): raise ValueError( - "coords is not dict-like, but it has %s items, " - "which does not match the %s dimensions of the " - "data" % (len(coords), len(shape)) + f"coords is not dict-like, but it has {len(coords)} items, " + f"which does not match the {len(shape)} dimensions of the " + "data" ) if isinstance(dims, str): dims = (dims,) if dims is None: - dims = ["dim_%s" % n for n in range(len(shape))] + dims = [f"dim_{n}" for n in range(len(shape))] if coords is not None and len(coords) == len(shape): # try to infer dimensions from coords if utils.is_dict_like(coords): @@ -125,12 +125,12 @@ def _infer_coords_and_dims( elif len(dims) != len(shape): raise ValueError( "different number of dimensions on data " - "and dims: %s vs %s" % (len(shape), len(dims)) + f"and dims: {len(shape)} vs {len(dims)}" ) else: for d in dims: if not isinstance(d, str): - raise TypeError("dimension %s is not a string" % d) + raise TypeError(f"dimension {d} is not a string") new_coords: Dict[Any, Variable] = {} @@ -147,24 +147,24 @@ def _infer_coords_and_dims( for k, v in new_coords.items(): if any(d not in dims for d in v.dims): raise ValueError( - "coordinate %s has dimensions %s, but these " + f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " - "dimensions %s" % (k, v.dims, dims) + f"dimensions {dims}" ) for d, s in zip(v.dims, v.shape): if s != sizes[d]: raise ValueError( - "conflicting sizes for dimension %r: " - "length %s on the data but length %s on " - "coordinate %r" % (d, sizes[d], s, k) + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" ) if k in sizes and v.shape != (sizes[k],): raise ValueError( - "coordinate %r is a DataArray dimension, but " - "it has shape %r rather than expected shape %r " - "matching the dimension size" % (k, v.shape, (sizes[k],)) + f"coordinate {k!r} is a DataArray dimension, but " + f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " + "matching the dimension size" ) assert_unique_multiindex_level_names(new_coords) @@ -539,8 +539,7 @@ def _to_dataset_whole( indexes = self._indexes coord_names = set(self._coords) - dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes) - return dataset + return Dataset._construct_direct(variables, coord_names, indexes=indexes) def to_dataset( self, @@ -669,9 +668,8 @@ def dims(self, value): def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: if utils.is_dict_like(key): return key - else: - key = indexing.expanded_indexer(key, self.ndim) - return dict(zip(self.dims, key)) + key = indexing.expanded_indexer(key, self.ndim) + return dict(zip(self.dims, key)) @property def _level_coords(self) -> Dict[Hashable, Hashable]: @@ -823,13 +821,12 @@ def reset_coords( dataset = self.coords.to_dataset().reset_coords(names, drop) if drop: return self._replace(coords=dataset._variables) - else: - if self.name is None: - raise ValueError( - "cannot reset_coords with drop=False on an unnamed DataArrray" - ) - dataset[self.name] = self.variable - return dataset + if self.name is None: + raise ValueError( + "cannot reset_coords with drop=False on an unnamed DataArrray" + ) + dataset[self.name] = self.variable + return dataset def __dask_tokenize__(self): from dask.base import normalize_token @@ -2012,7 +2009,7 @@ def reorder_levels( coord = self._coords[dim] index = coord.to_index() if not isinstance(index, pd.MultiIndex): - raise ValueError("coordinate %r has no MultiIndex" % dim) + raise ValueError(f"coordinate {dim!r} has no MultiIndex") replace_coords[dim] = IndexVariable(coord.dims, index.reorder_levels(order)) coords = self._coords.copy() coords.update(replace_coords) @@ -2658,8 +2655,8 @@ def to_pandas(self) -> Union["DataArray", pd.Series, pd.DataFrame]: constructor = constructors[self.ndim] except KeyError: raise ValueError( - "cannot convert arrays with %s dimensions into " - "pandas objects" % self.ndim + f"cannot convert arrays with {self.ndim} dimensions into " + "pandas objects" ) indexes = [self.get_index(dim) for dim in self.dims] return constructor(self.values, *indexes) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 19af7f6c3cd..1974301ddd3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -196,16 +196,15 @@ def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashabl for dim, size in zip(var.dims, var.shape): if dim in scalar_vars: raise ValueError( - "dimension %r already exists as a scalar variable" % dim + f"dimension {dim!r} already exists as a scalar variable" ) if dim not in dims: dims[dim] = size last_used[dim] = k elif dims[dim] != size: raise ValueError( - "conflicting sizes for dimension %r: " - "length %s on %r and length %s on %r" - % (dim, size, k, dims[dim], last_used[dim]) + f"conflicting sizes for dimension {dim!r}: " + f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" ) return dims @@ -245,8 +244,7 @@ def merge_indexes( and var.dims != current_index_variable.dims ): raise ValueError( - "dimension mismatch between %r %s and %r %s" - % (dim, current_index_variable.dims, n, var.dims) + f"dimension mismatch between {dim!r} {current_index_variable.dims} and {n!r} {var.dims}" ) if current_index_variable is not None and append: @@ -256,7 +254,7 @@ def merge_indexes( codes.extend(current_index.codes) levels.extend(current_index.levels) else: - names.append("%s_level_0" % dim) + names.append(f"{dim}_level_0") cat = pd.Categorical(current_index.values, ordered=True) codes.append(cat.codes) levels.append(cat.categories) @@ -733,8 +731,7 @@ def __init__( both_data_and_coords = set(data_vars) & set(coords) if both_data_and_coords: raise ValueError( - "variables %r are found in both data_vars and coords" - % both_data_and_coords + f"variables {both_data_and_coords!r} are found in both data_vars and coords" ) if isinstance(coords, Dataset): @@ -1700,7 +1697,7 @@ def reset_coords( bad_coords = set(names) & set(self.dims) if bad_coords: raise ValueError( - "cannot remove index coordinates with reset_coords: %s" % bad_coords + f"cannot remove index coordinates with reset_coords: {bad_coords}" ) obj = self.copy() obj._coord_names.difference_update(names) @@ -2050,7 +2047,7 @@ def chunk( bad_dims = chunks.keys() - self.dims.keys() if bad_dims: raise ValueError( - "some chunks keys are not dimensions on this " "object: %s" % bad_dims + f"some chunks keys are not dimensions on this object: {bad_dims}" ) variables = { @@ -2408,12 +2405,12 @@ def head( if not isinstance(v, int): raise TypeError( "expected integer type indexer for " - "dimension %r, found %r" % (k, type(v)) + f"dimension {k!r}, found {type(v)!r}" ) elif v < 0: raise ValueError( "expected positive integer as indexer " - "for dimension %r, found %s" % (k, v) + f"for dimension {k!r}, found {v}" ) indexers_slices = {k: slice(val) for k, val in indexers.items()} return self.isel(indexers_slices) @@ -2454,12 +2451,12 @@ def tail( if not isinstance(v, int): raise TypeError( "expected integer type indexer for " - "dimension %r, found %r" % (k, type(v)) + f"dimension {k!r}, found {type(v)!r}" ) elif v < 0: raise ValueError( "expected positive integer as indexer " - "for dimension %r, found %s" % (k, v) + f"for dimension {k!r}, found {v}" ) indexers_slices = { k: slice(-val, None) if val != 0 else slice(val) @@ -2504,12 +2501,12 @@ def thin( if not isinstance(v, int): raise TypeError( "expected integer type indexer for " - "dimension %r, found %r" % (k, type(v)) + f"dimension {k!r}, found {type(v)!r}" ) elif v < 0: raise ValueError( "expected positive integer as indexer " - "for dimension %r, found %s" % (k, v) + f"for dimension {k!r}, found {v}" ) elif v == 0: raise ValueError("step cannot be zero") @@ -2830,7 +2827,7 @@ def _reindex( bad_dims = [d for d in indexers if d not in self.dims] if bad_dims: - raise ValueError("invalid reindex dimensions: %s" % bad_dims) + raise ValueError(f"invalid reindex dimensions: {bad_dims}") variables, indexes = alignment.reindex_variables( self.variables, @@ -3249,8 +3246,8 @@ def rename( for k in name_dict.keys(): if k not in self and k not in self.dims: raise ValueError( - "cannot rename %r because it is not a " - "variable or dimension in this dataset" % k + f"cannot rename {k!r} because it is not a " + "variable or dimension in this dataset" ) variables, coord_names, dims, indexes = self._rename_all( @@ -3290,8 +3287,8 @@ def rename_dims( for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( - "cannot rename %r because it is not a " - "dimension in this dataset" % k + f"cannot rename {k!r} because it is not a " + "dimension in this dataset" ) if v in self.dims or v in self: raise ValueError( @@ -3334,8 +3331,8 @@ def rename_vars( for k in name_dict: if k not in self: raise ValueError( - "cannot rename %r because it is not a " - "variable or coordinate in this dataset" % k + f"cannot rename {k!r} because it is not a " + "variable or coordinate in this dataset" ) variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict={} @@ -3410,13 +3407,13 @@ def swap_dims( for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( - "cannot swap from dimension %r because it is " - "not an existing dimension" % k + f"cannot swap from dimension {k!r} because it is " + "not an existing dimension" ) if v in self.variables and self.variables[v].dims != (k,): raise ValueError( - "replacement dimension %r is not a 1D " - "variable along the old dimension %r" % (v, k) + f"replacement dimension {v!r} is not a 1D " + f"variable along the old dimension {k!r}" ) result_dims = {dims_dict.get(dim, dim) for dim in self.dims} @@ -4020,7 +4017,7 @@ def unstack( missing_dims = [d for d in dims if d not in self.dims] if missing_dims: raise ValueError( - "Dataset does not contain the dimensions: %s" % missing_dims + f"Dataset does not contain the dimensions: {missing_dims}" ) non_multi_dims = [ @@ -4029,7 +4026,7 @@ def unstack( if non_multi_dims: raise ValueError( "cannot unstack dimensions that do not " - "have a MultiIndex: %s" % non_multi_dims + f"have a MultiIndex: {non_multi_dims}" ) result = self.copy(deep=False) @@ -4346,7 +4343,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): try: index = self.get_index(dim) except KeyError: - raise ValueError("dimension %r does not have coordinate labels" % dim) + raise ValueError(f"dimension {dim!r} does not have coordinate labels") new_index = index.drop(labels_for_dim, errors=errors) ds = ds.loc[{dim: new_index}] return ds @@ -4453,7 +4450,7 @@ def drop_dims( missing_dims = drop_dims - set(self.dims) if missing_dims: raise ValueError( - "Dataset does not contain the dimensions: %s" % missing_dims + f"Dataset does not contain the dimensions: {missing_dims}" ) drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims} @@ -4491,8 +4488,8 @@ def transpose(self, *dims: Hashable) -> "Dataset": if dims: if set(dims) ^ set(self.dims) and ... not in dims: raise ValueError( - "arguments to transpose (%s) must be " - "permuted dataset dimensions (%s)" % (dims, tuple(self.dims)) + f"arguments to transpose ({dims}) must be " + f"permuted dataset dimensions ({tuple(self.dims)})" ) ds = self.copy() for name, var in self._variables.items(): @@ -4533,7 +4530,7 @@ def dropna( # depending on the order of the supplied axes. if dim not in self.dims: - raise ValueError("%s must be a single dataset dimension" % dim) + raise ValueError(f"{dim} must be a single dataset dimension") if subset is None: subset = iter(self.data_vars) @@ -4555,7 +4552,7 @@ def dropna( elif how == "all": mask = count > 0 elif how is not None: - raise ValueError("invalid how option: %s" % how) + raise ValueError(f"invalid how option: {how}") else: raise TypeError("must specify how or thresh") @@ -4902,7 +4899,7 @@ def reduce( missing_dimensions = [d for d in dims if d not in self.dims] if missing_dimensions: raise ValueError( - "Dataset does not contain the dimensions: %s" % missing_dimensions + f"Dataset does not contain the dimensions: {missing_dimensions}" ) if keep_attrs is None: @@ -5610,8 +5607,7 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): if inplace and set(lhs_data_vars) != set(rhs_data_vars): raise ValueError( "datasets must have the same data variables " - "for in-place arithmetic operations: %s, %s" - % (list(lhs_data_vars), list(rhs_data_vars)) + f"for in-place arithmetic operations: {list(lhs_data_vars)}, {list(rhs_data_vars)}" ) dest_vars = {} @@ -5783,7 +5779,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") invalid = [k for k in shifts if k not in self.dims] if invalid: - raise ValueError("dimensions %r do not exist" % invalid) + raise ValueError(f"dimensions {invalid!r} do not exist") variables = {} for name, var in self.variables.items(): @@ -5845,7 +5841,7 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll") invalid = [k for k in shifts if k not in self.dims] if invalid: - raise ValueError("dimensions %r do not exist" % invalid) + raise ValueError(f"dimensions {invalid!r} do not exist") if roll_coords is None: warnings.warn( @@ -6107,7 +6103,7 @@ def rank(self, dim, pct=False, keep_attrs=None): Variables that do not depend on `dim` are dropped. """ if dim not in self.dims: - raise ValueError("Dataset does not contain the dimension: %s" % dim) + raise ValueError(f"Dataset does not contain the dimension: {dim}") variables = {} for name, var in self.variables.items(): diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 51499c3a687..5f9349051b7 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -63,10 +63,7 @@ def maybe_promote(dtype): # Check np.timedelta64 before np.integer fill_value = np.timedelta64("NaT") elif np.issubdtype(dtype, np.integer): - if dtype.itemsize <= 2: - dtype = np.float32 - else: - dtype = np.float64 + dtype = np.float32 if dtype.itemsize <= 2 else np.float64 fill_value = np.nan elif np.issubdtype(dtype, np.complexfloating): fill_value = np.nan + np.nan * 1j diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cb8a5f9946f..e32fd4be376 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -204,7 +204,7 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - if any([isinstance(x, cupy_array_type) for x in scalars_or_arrays]): + if any(isinstance(x, cupy_array_type) for x in scalars_or_arrays): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] @@ -427,9 +427,7 @@ def _datetime_nanmin(array): def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): """Convert an array containing datetime-like data to numerical values. - Convert the datetime array to a timedelta relative to an offset. - Parameters ---------- array : array-like @@ -442,12 +440,10 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): conversions are not allowed due to non-linear relationships between units. dtype : dtype Output dtype. - Returns ------- array Numerical representation of datetime object relative to an offset. - Notes ----- Some datetime unit conversions won't work, for example from days to years, even diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py index 9b7b060107b..3debefe2e0d 100644 --- a/xarray/core/extensions.py +++ b/xarray/core/extensions.py @@ -38,7 +38,7 @@ def __get__(self, obj, cls): # __getattr__ on data object will swallow any AttributeErrors # raised when initializing the accessor, so we need to raise as # something else (GH933): - raise RuntimeError("error initializing %r accessor." % self._name) + raise RuntimeError(f"error initializing {self._name!r} accessor.") cache[self._name] = accessor_obj return accessor_obj @@ -48,9 +48,8 @@ def _register_accessor(name, cls): def decorator(accessor): if hasattr(cls, name): warnings.warn( - "registration of accessor %r under name %r for type %r is " - "overriding a preexisting attribute with the same name." - % (accessor, name, cls), + f"registration of accessor {accessor!r} under name {name!r} for type {cls!r} is " + "overriding a preexisting attribute with the same name.", AccessorRegistrationWarning, stacklevel=2, ) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 5c2d2210ebd..2a480427d4e 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -25,9 +25,8 @@ def short_data_repr_html(array): internal_data = getattr(array, "variable", array)._data if hasattr(internal_data, "_repr_html_"): return internal_data._repr_html_() - else: - text = escape(short_data_repr(array)) - return f"
{text}
" + text = escape(short_data_repr(array)) + return f"
{text}
" def format_dims(dims, coord_names): @@ -77,8 +76,7 @@ def summarize_coord(name, var): if is_index: coord = var.variable.to_index_variable() if coord.level_names is not None: - coords = {} - coords[name] = _summarize_coord_multiindex(name, coord) + coords = {name: _summarize_coord_multiindex(name, coord)} for lname in coord.level_names: var = coord.get_level_variable(lname) coords[lname] = summarize_variable(lname, var) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e2678896c0e..c73ef738a29 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -29,8 +29,8 @@ def check_reduce_dims(reduce_dims, dimensions): reduce_dims = [reduce_dims] if any(dim not in dimensions for dim in reduce_dims): raise ValueError( - "cannot reduce over dimensions %r. expected either '...' to reduce over all dimensions or one or more of %r." - % (reduce_dims, dimensions) + f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' " + f"to reduce over all dimensions or one or more of {dimensions!r}." ) @@ -105,7 +105,7 @@ def _consolidate_slices(slices): last_slice = slice(None) for slice_ in slices: if not isinstance(slice_, slice): - raise ValueError("list element is not a slice: %r" % slice_) + raise ValueError(f"list element is not a slice: {slice_!r}") if ( result and last_slice.stop == slice_.start @@ -141,8 +141,7 @@ def _inverse_permutation_indices(positions): return None positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] - indices = nputils.inverse_permutation(np.concatenate(positions)) - return indices + return nputils.inverse_permutation(np.concatenate(positions)) class _DummyGroup: @@ -200,9 +199,8 @@ def _ensure_1d(group, obj): def _unique_and_monotonic(group): if isinstance(group, _DummyGroup): return True - else: - index = safe_cast_to_index(group) - return index.is_unique and index.is_monotonic + index = safe_cast_to_index(group) + return index.is_unique and index.is_monotonic def _apply_loffset(grouper, result): @@ -380,7 +378,7 @@ def __init__( if len(group_indices) == 0: if bins is not None: raise ValueError( - "None of the data falls within bins with edges %r" % bins + f"None of the data falls within bins with edges {bins!r}" ) else: raise ValueError( @@ -475,8 +473,7 @@ def _infer_concat_args(self, applied_example): def _binary_op(self, other, f, reflexive=False): g = f if not reflexive else lambda x, y: f(y, x) applied = self._yield_binary_applied(g, other) - combined = self._combine(applied) - return combined + return self._combine(applied) def _yield_binary_applied(self, func, other): dummy = None @@ -494,8 +491,8 @@ def _yield_binary_applied(self, func, other): if self._group.name not in other.dims: raise ValueError( "incompatible dimensions for a grouped " - "binary operation: the group variable %r " - "is not a dimension on the other argument" % self._group.name + f"binary operation: the group variable {self._group.name!r} " + "is not a dimension on the other argument" ) if dummy is None: dummy = _dummy_copy(other) @@ -548,8 +545,7 @@ def fillna(self, value): Dataset.fillna DataArray.fillna """ - out = ops.fillna(self, value) - return out + return ops.fillna(self, value) def quantile( self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True @@ -655,7 +651,6 @@ def quantile( keep_attrs=keep_attrs, skipna=skipna, ) - return out def where(self, cond, other=dtypes.NA): @@ -737,8 +732,7 @@ def _concat_shortcut(self, applied, dim, positions=None): # compiled language) stacked = Variable.concat(applied, dim, shortcut=True) reordered = _maybe_reorder(stacked, dim, positions) - result = self._obj._replace_maybe_drop_dims(reordered) - return result + return self._obj._replace_maybe_drop_dims(reordered) def _restore_dim_order(self, stacked): def lookup_order(dimension): @@ -795,10 +789,7 @@ def map(self, func, shortcut=False, args=(), **kwargs): applied : DataArray or DataArray The result of splitting, applying and combining this array. """ - if shortcut: - grouped = self._iter_grouped_shortcut() - else: - grouped = self._iter_grouped() + grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) return self._combine(applied, shortcut=shortcut) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 6747957ca75..06da058eb1f 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -66,7 +66,7 @@ def broadcast_dimension_size(variables: List[Variable]) -> Dict[Hashable, int]: for var in variables: for dim, size in zip(var.dims, var.shape): if dim in dims and size != dims[dim]: - raise ValueError("index %r not aligned" % dim) + raise ValueError(f"index {dim!r} not aligned") dims[dim] = size return dims @@ -211,17 +211,15 @@ def merge_collected( for _, other_index in indexed_elements[1:]: if not index.equals(other_index): raise MergeError( - "conflicting values for index %r on objects to be " - "combined:\nfirst value: %r\nsecond value: %r" - % (name, index, other_index) + f"conflicting values for index {name!r} on objects to be " + f"combined:\nfirst value: {index!r}\nsecond value: {other_index!r}" ) if compat == "identical": for other_variable, _ in indexed_elements[1:]: if not dict_equiv(variable.attrs, other_variable.attrs): raise MergeError( "conflicting attribute values on combined " - "variable %r:\nfirst value: %r\nsecond value: %r" - % (name, variable.attrs, other_variable.attrs) + f"variable {name!r}:\nfirst value: {variable.attrs!r}\nsecond value: {other_variable.attrs!r}" ) merged_vars[name] = variable merged_vars[name].attrs = merge_attrs( @@ -497,9 +495,9 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords): for coord_name in explicit_coords: if coord_name in dims and variables[coord_name].dims != (coord_name,): raise MergeError( - "coordinate %s shares a name with a dataset dimension, but is " + f"coordinate {coord_name} shares a name with a dataset dimension, but is " "not a 1D variable along that dimension. This is disallowed " - "by the xarray data model." % coord_name + "by the xarray data model." ) @@ -521,7 +519,7 @@ def merge_attrs(variable_attrs, combine_attrs): except ValueError as e: raise MergeError( "combine_attrs='no_conflicts', but some values are not " - "the same. Merging %s with %s" % (str(result), str(attrs)) + f"the same. Merging {str(result)} with {str(attrs)}" ) from e return result elif combine_attrs == "drop_conflicts": @@ -547,12 +545,12 @@ def merge_attrs(variable_attrs, combine_attrs): for attrs in variable_attrs[1:]: if not dict_equiv(result, attrs): raise MergeError( - "combine_attrs='identical', but attrs differ. First is %s " - ", other is %s." % (str(result), str(attrs)) + f"combine_attrs='identical', but attrs differ. First is {str(result)} " + f", other is {str(attrs)}." ) return result else: - raise ValueError("Unrecognised value for combine_attrs=%s" % combine_attrs) + raise ValueError(f"Unrecognised value for combine_attrs={combine_attrs}") class _MergeResult(NamedTuple): @@ -642,15 +640,11 @@ def merge_core( if ambiguous_coords: raise MergeError( "unable to determine if these variables should be " - "coordinates or not in the merged result: %s" % ambiguous_coords + f"coordinates or not in the merged result: {ambiguous_coords}" ) attrs = merge_attrs( - [ - var.attrs - for var in coerced - if isinstance(var, Dataset) or isinstance(var, DataArray) - ], + [var.attrs for var in coerced if isinstance(var, (Dataset, DataArray))], combine_attrs, ) @@ -895,8 +889,7 @@ def merge( combine_attrs=combine_attrs, fill_value=fill_value, ) - merged = Dataset._construct_direct(**merge_result._asdict()) - return merged + return Dataset._construct_direct(**merge_result._asdict()) def dataset_merge_method( diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 41205242cce..c576e0718c6 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -93,7 +93,7 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None): self._left = fill_value self._right = fill_value else: - raise ValueError("%s is not a valid fill_value" % fill_value) + raise ValueError(f"{fill_value} is not a valid fill_value") def __call__(self, x): return self.f( diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 803c7c3ccfe..35bac982d4c 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -46,9 +46,9 @@ def _validate_axis(axis, ndim, argname): axis = list(axis) axis = [a + ndim if a < 0 else a for a in axis] if not builtins.all(0 <= a < ndim for a in axis): - raise ValueError("invalid axis for this array in `%s` argument" % argname) + raise ValueError(f"invalid axis for this array in {argname} argument") if len(set(axis)) != len(axis): - raise ValueError("repeated axis in `%s` argument" % argname) + raise ValueError(f"repeated axis in {argname} argument") return axis @@ -73,8 +73,7 @@ def moveaxis(a, source, destination): for dest, src in sorted(zip(destination, source)): order.insert(dest, src) - result = transpose(order) - return result + return transpose(order) # Type annotations stubs diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 27740d53d45..8265035a25c 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -236,7 +236,7 @@ def func(self, *args, **kwargs): def inject_reduce_methods(cls): methods = ( [ - (name, getattr(duck_array_ops, "array_%s" % name), False) + (name, getattr(duck_array_ops, f"array_{name}"), False) for name in REDUCE_METHODS ] + [(name, getattr(duck_array_ops, name), True) for name in NAN_REDUCE_METHODS] @@ -275,7 +275,7 @@ def inject_cum_methods(cls): def op_str(name): - return "__%s__" % name + return f"__{name}__" def get_op(name): diff --git a/xarray/core/options.py b/xarray/core/options.py index d53c9d5d7d9..45f45c0dcc5 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -166,8 +166,7 @@ def __init__(self, **kwargs): for k, v in kwargs.items(): if k not in OPTIONS: raise ValueError( - "argument name %r is not in the set of valid options %r" - % (k, set(OPTIONS)) + f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}" ) if k in _VALIDATORS and not _VALIDATORS[k](v): if k == ARITHMETIC_JOIN: diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index e1d32b7de43..795d30af28f 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -75,7 +75,7 @@ def check_result_variables( def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): - raise TypeError("Expected Dataset, got %s" % type(obj)) + raise TypeError(f"Expected Dataset, got {type(obj)}") if len(obj.data_vars) > 1: raise TypeError( diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 882664cbb60..4a413902b90 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -146,7 +146,7 @@ def _get_time_bins(index, freq, closed, label, base): if not isinstance(index, CFTimeIndex): raise TypeError( "index must be a CFTimeIndex, but got " - "an instance of %r" % type(index).__name__ + f"an instance of {type(index).__name__!r}" ) if len(index) == 0: datetime_bins = labels = CFTimeIndex(data=[], name=index.name) @@ -163,11 +163,7 @@ def _get_time_bins(index, freq, closed, label, base): datetime_bins, freq, closed, index, labels ) - if label == "right": - labels = labels[1:] - else: - labels = labels[:-1] - + labels = labels[1:] if label == "right" else labels[:-1] # TODO: when CFTimeIndex supports missing values, if the reference index # contains missing values, insert the appropriate NaN value at the # beginning of the datetime_bins and labels indexes. @@ -262,11 +258,7 @@ def _get_range_edges(first, last, offset, closed="left", base=0): first = normalize_date(first) last = normalize_date(last) - if closed == "left": - first = offset.rollback(first) - else: - first = first - offset - + first = offset.rollback(first) if closed == "left" else first - offset last = last + offset return first, last @@ -321,11 +313,7 @@ def _adjust_dates_anchored(first, last, offset, closed="right", base=0): else: lresult = last else: - if foffset.total_seconds() > 0: - fresult = first - foffset - else: - fresult = first - + fresult = first - foffset if foffset.total_seconds() > 0 else first if loffset.total_seconds() > 0: lresult = last + (offset.as_timedelta() - loffset) else: diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index f99a7568282..870df122aa9 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -6,7 +6,6 @@ from . import dtypes, duck_array_ops, utils from .arithmetic import CoarsenArithmetic -from .dask_array_ops import dask_rolling_wrapper from .options import _get_keep_attrs from .pycompat import is_duck_dask_array @@ -173,11 +172,10 @@ def _mapping_to_list( if utils.is_dict_like(arg): if allow_default: return [arg.get(d, default) for d in self.dim] - else: - for d in self.dim: - if d not in arg: - raise KeyError(f"argument has no key {d}.") - return [arg[d] for d in self.dim] + for d in self.dim: + if d not in arg: + raise KeyError(f"argument has no key {d}.") + return [arg[d] for d in self.dim] elif allow_allsame: # for single argument return [arg] * len(self.dim) elif len(self.dim) == 1: @@ -439,7 +437,6 @@ def reduce(self, func, keep_attrs=None, **kwargs): obj = self.obj.fillna(fillna) else: obj = self.obj - windows = self._construct( obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna ) @@ -504,9 +501,6 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): if is_duck_dask_array(padded.data): raise AssertionError("should not be reachable") - values = dask_rolling_wrapper( - func, padded.data, window=self.window[0], min_count=min_count, axis=axis - ) else: values = func( padded.data, window=self.window[0], min_count=min_count, axis=axis @@ -549,20 +543,17 @@ def _numpy_or_bottleneck_reduce( return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs ) - else: - if rolling_agg_func: - return rolling_agg_func( - self, keep_attrs=self._get_keep_attrs(keep_attrs) - ) - if fillna is not None: - if fillna is dtypes.INF: - fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True) - elif fillna is dtypes.NINF: - fillna = dtypes.get_neg_infinity(self.obj.dtype, min_for_int=True) - kwargs.setdefault("skipna", False) - kwargs.setdefault("fillna", fillna) + if rolling_agg_func: + return rolling_agg_func(self, keep_attrs=self._get_keep_attrs(keep_attrs)) + if fillna is not None: + if fillna is dtypes.INF: + fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True) + elif fillna is dtypes.NINF: + fillna = dtypes.get_neg_infinity(self.obj.dtype, min_for_int=True) + kwargs.setdefault("skipna", False) + kwargs.setdefault("fillna", fillna) - return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs) + return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs) class DatasetRolling(Rolling): @@ -612,7 +603,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None dims.append(d) center[d] = self.center[i] - if len(dims) > 0: + if dims: w = {d: windows[d] for d in dims} self.rollings[key] = DataArrayRolling(da, w, min_periods, center) @@ -735,7 +726,7 @@ def construct( for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on self.dim dims = [d for d in self.dim if d in da.dims] - if len(dims) > 0: + if dims: wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims} st = {d: stride[i] for i, d in enumerate(self.dim) if d in da.dims} diff --git a/xarray/core/utils.py b/xarray/core/utils.py index d3b4cd39c53..31ac43ed214 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -218,7 +218,7 @@ def update_safety_check( if k in first_dict and not compat(v, first_dict[k]): raise ValueError( "unsafe to merge dictionaries without " - "overriding values; conflicting key %r" % k + f"overriding values; conflicting key {k!r}" ) @@ -254,7 +254,7 @@ def is_full_slice(value: Any) -> bool: def is_list_like(value: Any) -> bool: - return isinstance(value, list) or isinstance(value, tuple) + return isinstance(value, (list, tuple)) def is_duck_array(value: Any) -> bool: @@ -274,22 +274,19 @@ def either_dict_or_kwargs( kw_kwargs: Mapping[str, T], func_name: str, ) -> Mapping[Hashable, T]: - if pos_kwargs is not None: - if not is_dict_like(pos_kwargs): - raise ValueError( - "the first argument to .%s must be a dictionary" % func_name - ) - if kw_kwargs: - raise ValueError( - "cannot specify both keyword and positional " - "arguments to .%s" % func_name - ) - return pos_kwargs - else: + if pos_kwargs is None: # Need an explicit cast to appease mypy due to invariance; see # https://github.com/python/mypy/issues/6228 return cast(Mapping[Hashable, T], kw_kwargs) + if not is_dict_like(pos_kwargs): + raise ValueError(f"the first argument to .{func_name} must be a dictionary") + if kw_kwargs: + raise ValueError( + f"cannot specify both keyword and positional arguments to .{func_name}" + ) + return pos_kwargs + def is_scalar(value: Any, include_0d: bool = True) -> bool: """Whether to treat a value as a scalar. @@ -358,10 +355,7 @@ def dict_equiv( for k in first: if k not in second or not compat(first[k], second[k]): return False - for k in second: - if k not in first: - return False - return True + return all(k in first for k in second) def compat_dict_intersection( @@ -730,7 +724,7 @@ def __init__(self, data: MutableMapping[K, V], hidden_keys: Iterable[K]): def _raise_if_hidden(self, key: K) -> None: if key in self._hidden_keys: - raise KeyError("Key `%r` is hidden." % key) + raise KeyError(f"Key `{key!r}` is hidden.") # The next five methods are requirements of the ABC. def __setitem__(self, key: K, value: V) -> None: @@ -863,7 +857,7 @@ def drop_missing_dims( """ if missing_dims == "raise": - supplied_dims_set = set(val for val in supplied_dims if val is not ...) + supplied_dims_set = {val for val in supplied_dims if val is not ...} invalid = supplied_dims_set - set(dims) if invalid: raise ValueError( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index cffaf2c3146..7122346baca 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -145,25 +145,24 @@ def as_variable(obj, name=None) -> "Union[Variable, IndexVariable]": data = as_compatible_data(obj) if data.ndim != 1: raise MissingDimensionsError( - "cannot set variable %r with %r-dimensional data " + f"cannot set variable {name!r} with {data.ndim!r}-dimensional data " "without explicit dimension names. Pass a tuple of " - "(dims, data) instead." % (name, data.ndim) + "(dims, data) instead." ) obj = Variable(name, data, fastpath=True) else: raise TypeError( "unable to convert object into a variable without an " - "explicit list of dimensions: %r" % obj + f"explicit list of dimensions: {obj!r}" ) if name is not None and name in obj.dims: # convert the Variable into an Index if obj.ndim != 1: raise MissingDimensionsError( - "%r has more than 1-dimension and the same name as one of its " - "dimensions %r. xarray disallows such variables because they " - "conflict with the coordinates used to label " - "dimensions." % (name, obj.dims) + f"{name!r} has more than 1-dimension and the same name as one of its " + f"dimensions {obj.dims!r}. xarray disallows such variables because they " + "conflict with the coordinates used to label dimensions." ) obj = obj.to_index_variable() @@ -236,21 +235,14 @@ def as_compatible_data(data, fastpath=False): else: data = np.asarray(data) - if not isinstance(data, np.ndarray): - if hasattr(data, "__array_function__"): - return data + if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"): + return data # validate whether the data is valid data types. data = np.asarray(data) - if isinstance(data, np.ndarray): - if data.dtype.kind == "O": - data = _possibly_convert_objects(data) - elif data.dtype.kind == "M": - data = _possibly_convert_objects(data) - elif data.dtype.kind == "m": - data = _possibly_convert_objects(data) - + if isinstance(data, np.ndarray) and data.dtype.kind in "OMm": + data = _possibly_convert_objects(data) return _maybe_wrap_data(data) @@ -268,10 +260,7 @@ def _as_array_or_item(data): TODO: remove this (replace with np.asarray) once these issues are fixed """ - if isinstance(data, cupy_array_type): - data = data.get() - else: - data = np.asarray(data) + data = data.get() if isinstance(data, cupy_array_type) else np.asarray(data) if data.ndim == 0: if data.dtype.kind == "M": data = np.datetime64(data, "ns") @@ -584,8 +573,8 @@ def _parse_dimensions(self, dims): dims = tuple(dims) if len(dims) != self.ndim: raise ValueError( - "dimensions %s must have the same length as the " - "number of data dimensions, ndim=%s" % (dims, self.ndim) + f"dimensions {dims} must have the same length as the " + f"number of data dimensions, ndim={self.ndim}" ) return dims @@ -662,9 +651,7 @@ def _broadcast_indexes_basic(self, key): def _validate_indexers(self, key): """Make sanity checks""" for dim, k in zip(self.dims, key): - if isinstance(k, BASIC_INDEXING_TYPES): - pass - else: + if not isinstance(k, BASIC_INDEXING_TYPES): if not isinstance(k, Variable): k = np.asarray(k) if k.ndim > 1: @@ -852,9 +839,8 @@ def __setitem__(self, key, value): value = as_compatible_data(value) if value.ndim > len(dims): raise ValueError( - "shape mismatch: value array of shape %s could not be " - "broadcast to indexing result with %s dimensions" - % (value.shape, len(dims)) + f"shape mismatch: value array of shape {value.shape} could not be " + f"broadcast to indexing result with {len(dims)} dimensions" ) if value.ndim == 0: value = Variable((), value) @@ -1462,8 +1448,8 @@ def set_dims(self, dims, shape=None): missing_dims = set(self.dims) - set(dims) if missing_dims: raise ValueError( - "new dimensions %r must be a superset of " - "existing dimensions %r" % (dims, self.dims) + f"new dimensions {dims!r} must be a superset of " + f"existing dimensions {self.dims!r}" ) self_dims = set(self.dims) @@ -1487,7 +1473,7 @@ def set_dims(self, dims, shape=None): def _stack_once(self, dims: List[Hashable], new_dim: Hashable): if not set(dims) <= set(self.dims): - raise ValueError("invalid existing dimensions: %s" % dims) + raise ValueError(f"invalid existing dimensions: {dims}") if new_dim in self.dims: raise ValueError( @@ -1554,7 +1540,7 @@ def _unstack_once_full( new_dim_sizes = tuple(dims.values()) if old_dim not in self.dims: - raise ValueError("invalid existing dimension: %s" % old_dim) + raise ValueError(f"invalid existing dimension: {old_dim}") if set(new_dim_names).intersection(self.dims): raise ValueError( @@ -2550,7 +2536,7 @@ class IndexVariable(Variable): def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): super().__init__(dims, data, attrs, encoding, fastpath) if self.ndim != 1: - raise ValueError("%s objects must be 1-dimensional" % type(self).__name__) + raise ValueError(f"{type(self).__name__} objects must be 1-dimensional") # Unlike in Variable, always eagerly load values into memory if not isinstance(self._data, PandasIndex): @@ -2601,7 +2587,7 @@ def _finalize_indexing_result(self, dims, data): return self._replace(dims=dims, data=data) def __setitem__(self, key, value): - raise TypeError("%s values cannot be modified" % type(self).__name__) + raise TypeError(f"{type(self).__name__} values cannot be modified") @classmethod def concat( @@ -2744,7 +2730,7 @@ def level_names(self): def get_level_variable(self, level): """Return a new IndexVariable from a given MultiIndex level.""" if self.level_names is None: - raise ValueError("IndexVariable %r has no MultiIndex" % self.name) + raise ValueError(f"IndexVariable {self.name!r} has no MultiIndex") index = self.to_index() return type(self)(self.dims, index.get_level_values(level)) @@ -2769,7 +2755,7 @@ def _unified_dims(variables): if len(set(var_dims)) < len(var_dims): raise ValueError( "broadcasting cannot handle duplicate " - "dimensions: %r" % list(var_dims) + f"dimensions: {list(var_dims)!r}" ) for d, s in zip(var_dims, var.shape): if d not in all_dims: @@ -2777,8 +2763,7 @@ def _unified_dims(variables): elif all_dims[d] != s: raise ValueError( "operands cannot be broadcast together " - "with mismatched lengths for dimension %r: %s" - % (d, (all_dims[d], s)) + f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}" ) return all_dims @@ -2900,12 +2885,12 @@ def assert_unique_multiindex_level_names(variables): for k, v in level_names.items(): if k in variables: - v.append("(%s)" % k) + v.append(f"({k})") duplicate_names = [v for v in level_names.values() if len(v) > 1] if duplicate_names: conflict_str = "\n".join(", ".join(v) for v in duplicate_names) - raise ValueError("conflicting MultiIndex level name(s):\n%s" % conflict_str) + raise ValueError(f"conflicting MultiIndex level name(s):\n{conflict_str}") # Check confliction between level names and dimensions GH:2299 for k, v in variables.items(): for d in v.dims: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 325ea799f28..858695ec538 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -44,14 +44,13 @@ def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax if extend_min and extend_max: - extend = "both" + return "both" elif extend_min: - extend = "min" + return "min" elif extend_max: - extend = "max" + return "max" else: - extend = "neither" - return extend + return "neither" def _build_discrete_cmap(cmap, levels, extend, filled): @@ -320,7 +319,7 @@ def _infer_xy_labels_3d(darray, x, y, rgb): if len(set(not_none)) < len(not_none): raise ValueError( "Dimension names must be None or unique strings, but imshow was " - "passed x=%r, y=%r, and rgb=%r." % (x, y, rgb) + f"passed x={x!r}, y={y!r}, and rgb={rgb!r}." ) for label in not_none: if label not in darray.dims: @@ -342,8 +341,7 @@ def _infer_xy_labels_3d(darray, x, y, rgb): rgb = could_be_color[0] if rgb is not None and darray[rgb].size not in (3, 4): raise ValueError( - "Cannot interpret dim %r of size %s as RGB or RGBA." - % (rgb, darray[rgb].size) + f"Cannot interpret dim {rgb!r} of size {darray[rgb].size} as RGB or RGBA." ) # If rgb dimension is still unknown, there must be two or three dimensions @@ -353,9 +351,9 @@ def _infer_xy_labels_3d(darray, x, y, rgb): rgb = could_be_color[-1] warnings.warn( "Several dimensions of this array could be colors. Xarray " - "will use the last possible dimension (%r) to match " + f"will use the last possible dimension ({rgb!r}) to match " "matplotlib.pyplot.imshow. You can pass names of x, y, " - "and/or rgb dimensions to override this guess." % rgb + "and/or rgb dimensions to override this guess." ) assert rgb is not None @@ -662,15 +660,15 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1 if vmax < vmin: raise ValueError( - "vmin=%r is less than the default vmax (%r) - you must supply " - "a vmax > vmin in this case." % (vmin, vmax) + f"vmin={vmin!r} is less than the default vmax ({vmax!r}) - you must supply " + "a vmax > vmin in this case." ) elif vmin is None: vmin = 0 if vmin > vmax: raise ValueError( - "vmax=%r is less than the default vmin (0) - you must supply " - "a vmin < vmax in this case." % vmax + f"vmax={vmax!r} is less than the default vmin (0) - you must supply " + "a vmin < vmax in this case." ) # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float # to avoid precision loss, integer over/underflow, etc with extreme inputs. diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index db102eefdc1..355c5dbed32 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -387,7 +387,7 @@ def test_da_groupby_assign_coords(): @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) def test_groupby_repr(obj, dim): actual = repr(obj.groupby(dim)) - expected = "%sGroupBy" % obj.__class__.__name__ + expected = f"{obj.__class__.__name__}GroupBy" expected += ", grouped over %r" % dim expected += "\n%r groups with labels " % (len(np.unique(obj[dim]))) if dim == "x": @@ -404,7 +404,7 @@ def test_groupby_repr(obj, dim): @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) def test_groupby_repr_datetime(obj): actual = repr(obj.groupby("t.month")) - expected = "%sGroupBy" % obj.__class__.__name__ + expected = f"{obj.__class__.__name__}GroupBy" expected += ", grouped over 'month'" expected += "\n%r groups with labels " % (len(np.unique(obj.t.dt.month))) expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12." diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 80c5e22513d..62762d29216 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -87,10 +87,10 @@ def open_dataset( if name in external_urls: url = external_urls[name] else: - # process the name - default_extension = ".nc" path = pathlib.Path(name) if not path.suffix: + # process the name + default_extension = ".nc" path = path.with_suffix(default_extension) url = f"{base_url}/raw/{version}/{path.name}" diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index ce01936b9dd..bf80dcf68cd 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -77,8 +77,7 @@ def __call__(self, *args, **kwargs): res = f(*new_args, **kwargs) if res is NotImplemented: raise TypeError( - "%r not implemented for types (%r, %r)" - % (self._name, type(args[0]), type(args[1])) + f"{self._name!r} not implemented for types ({type(args[0])!r}, {type(args[1])!r})" ) return res @@ -127,11 +126,11 @@ def _create_op(name): doc = _remove_unused_reference_labels(_skip_signature(_dedent(doc), name)) func.__doc__ = ( - "xarray specific variant of numpy.%s. Handles " + f"xarray specific variant of numpy.{name}. Handles " "xarray.Dataset, xarray.DataArray, xarray.Variable, " "numpy.ndarray and dask.array.Array objects with " "automatic dispatching.\n\n" - "Documentation from numpy:\n\n%s" % (name, doc) + f"Documentation from numpy:\n\n{doc}" ) return func diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index d643d768093..cd5d425efe2 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -42,15 +42,15 @@ def get_sys_info(): [ ("python", sys.version), ("python-bits", struct.calcsize("P") * 8), - ("OS", "%s" % (sysname)), - ("OS-release", "%s" % (release)), - # ("Version", "%s" % (version)), - ("machine", "%s" % (machine)), - ("processor", "%s" % (processor)), - ("byteorder", "%s" % sys.byteorder), - ("LC_ALL", "%s" % os.environ.get("LC_ALL", "None")), - ("LANG", "%s" % os.environ.get("LANG", "None")), - ("LOCALE", "%s.%s" % locale.getlocale()), + ("OS", f"{sysname}"), + ("OS-release", f"{release}"), + # ("Version", f"{version}"), + ("machine", f"{machine}"), + ("processor", f"{processor}"), + ("byteorder", f"{sys.byteorder}"), + ("LC_ALL", f'{os.environ.get("LC_ALL", "None")}'), + ("LANG", f'{os.environ.get("LANG", "None")}'), + ("LOCALE", f"{locale.getlocale()}"), ] ) except Exception: From 1f52ae0e841d75e3f331f16c0f31cd06a1675e23 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Thu, 13 May 2021 12:38:18 -0400 Subject: [PATCH 6/7] Explained what a deprecation cycle is (#5289) Co-authored-by: Anderson Banihirwe Co-authored-by: Mathias Hauser --- doc/contributing.rst | 30 +++++++++++++++++++++++++++--- doc/whats-new.rst | 4 ++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/doc/contributing.rst b/doc/contributing.rst index e10ceacd59f..f43fc3e312c 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -379,10 +379,34 @@ with ``git commit --no-verify``. Backwards Compatibility ~~~~~~~~~~~~~~~~~~~~~~~ -Please try to maintain backward compatibility. *xarray* has growing number of users with +Please try to maintain backwards compatibility. *xarray* has a growing number of users with lots of existing code, so don't break it if at all possible. If you think breakage is -required, clearly state why as part of the pull request. Also, be careful when changing -method signatures and add deprecation warnings where needed. +required, clearly state why as part of the pull request. + +Be especially careful when changing function and method signatures, because any change +may require a deprecation warning. For example, if your pull request means that the +argument ``old_arg`` to ``func`` is no longer valid, instead of simply raising an error if +a user passes ``old_arg``, we would instead catch it: + +.. code-block:: python + + def func(new_arg, old_arg=None): + if old_arg is not None: + from warnings import warn + + warn( + "`old_arg` has been deprecated, and in the future will raise an error." + "Please use `new_arg` from now on.", + DeprecationWarning, + ) + + # Still do what the user intended here + +This temporary check would then be removed in a subsequent version of xarray. +This process of first warning users before actually breaking their code is known as a +"deprecation cycle", and makes changes significantly easier to handle both for users +of xarray, and for developers of other libraries that depend on xarray. + .. _contributing.ci: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9024f0efe37..62da8bf1ea2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,10 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Explanation of deprecation cycles and how to implement them added to contributors + guide. (:pull:`5289`) + By `Tom Nicholas `_. + Internal Changes ~~~~~~~~~~~~~~~~ From 751f76ac95761e18d2bf2b5c7ac3c84bd2ee69ea Mon Sep 17 00:00:00 2001 From: keewis Date: Thu, 13 May 2021 19:25:52 +0200 Subject: [PATCH 7/7] combine keep_attrs and combine_attrs in apply_ufunc (#5041) --- doc/whats-new.rst | 3 + xarray/core/computation.py | 92 ++++--- xarray/core/merge.py | 3 +- xarray/tests/test_computation.py | 396 +++++++++++++++++++++++++++++++ 4 files changed, 460 insertions(+), 34 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 62da8bf1ea2..d07a2741eef 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v0.18.1 (unreleased) New Features ~~~~~~~~~~~~ +- allow passing ``combine_attrs`` strategy names to the ``keep_attrs`` parameter of + :py:func:`apply_ufunc` (:pull:`5041`) + By `Justus Magin `_. - :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes, such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`). By `Jimmy Westling `_. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2bc3ba921a7..12dded3e158 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -27,8 +27,8 @@ from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align -from .merge import merge_coordinates_without_align -from .options import OPTIONS +from .merge import merge_attrs, merge_coordinates_without_align +from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array from .utils import is_dict_like from .variable import Variable @@ -50,6 +50,11 @@ def _first_of_type(args, kind): raise ValueError("This should be unreachable.") +def _all_of_type(args, kind): + """Return all objects of type 'kind'""" + return [arg for arg in args if isinstance(arg, kind)] + + class _UFuncSignature: """Core dimensions signature for a given function. @@ -202,7 +207,10 @@ def _get_coords_list(args) -> List["Coordinates"]: def build_output_coords( - args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset() + args: list, + signature: _UFuncSignature, + exclude_dims: AbstractSet = frozenset(), + combine_attrs: str = "override", ) -> "List[Dict[Any, Variable]]": """Build output coordinates for an operation. @@ -230,7 +238,7 @@ def build_output_coords( else: # TODO: save these merged indexes, instead of re-computing them later merged_vars, unused_indexes = merge_coordinates_without_align( - coords_list, exclude_dims=exclude_dims + coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs ) output_coords = [] @@ -248,7 +256,12 @@ def build_output_coords( def apply_dataarray_vfunc( - func, *args, signature, join="inner", exclude_dims=frozenset(), keep_attrs=False + func, + *args, + signature, + join="inner", + exclude_dims=frozenset(), + keep_attrs="override", ): """Apply a variable level function over DataArray, Variable and/or ndarray objects. @@ -260,12 +273,16 @@ def apply_dataarray_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - if keep_attrs: + objs = _all_of_type(args, DataArray) + + if keep_attrs == "drop": + name = result_name(args) + else: first_obj = _first_of_type(args, DataArray) name = first_obj.name - else: - name = result_name(args) - result_coords = build_output_coords(args, signature, exclude_dims) + result_coords = build_output_coords( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) data_vars = [getattr(a, "variable", a) for a in args] result_var = func(*data_vars) @@ -279,13 +296,12 @@ def apply_dataarray_vfunc( (coords,) = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) - if keep_attrs: - if isinstance(out, tuple): - for da in out: - # This is adding attrs in place - da._copy_attrs_from(first_obj) - else: - out._copy_attrs_from(first_obj) + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for da in out: + da.attrs = attrs + else: + out.attrs = attrs return out @@ -400,7 +416,7 @@ def apply_dataset_vfunc( dataset_join="exact", fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), - keep_attrs=False, + keep_attrs="override", ): """Apply a variable level function over Dataset, dict of DataArray, DataArray, Variable and/or ndarray objects. @@ -414,15 +430,16 @@ def apply_dataset_vfunc( "dataset_fill_value argument." ) - if keep_attrs: - first_obj = _first_of_type(args, Dataset) + objs = _all_of_type(args, Dataset) if len(args) > 1: args = deep_align( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - list_of_coords = build_output_coords(args, signature, exclude_dims) + list_of_coords = build_output_coords( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) args = [getattr(arg, "data_vars", arg) for arg in args] result_vars = apply_dict_of_variables_vfunc( @@ -435,13 +452,13 @@ def apply_dataset_vfunc( (coord_vars,) = list_of_coords out = _fast_dataset(result_vars, coord_vars) - if keep_attrs: - if isinstance(out, tuple): - for ds in out: - # This is adding attrs in place - ds._copy_attrs_from(first_obj) - else: - out._copy_attrs_from(first_obj) + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for ds in out: + ds.attrs = attrs + else: + out.attrs = attrs + return out @@ -609,14 +626,12 @@ def apply_variable_ufunc( dask="forbidden", output_dtypes=None, vectorize=False, - keep_attrs=False, + keep_attrs="override", dask_gufunc_kwargs=None, ): """Apply a ndarray level function over Variable and/or ndarray objects.""" from .variable import Variable, as_compatible_data - first_obj = _first_of_type(args, Variable) - dim_sizes = unified_dim_sizes( (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims ) @@ -736,6 +751,12 @@ def func(*arrays): ) ) + objs = _all_of_type(args, Variable) + attrs = merge_attrs( + [obj.attrs for obj in objs], + combine_attrs=keep_attrs, + ) + output = [] for dims, data in zip(output_dims, result_data): data = as_compatible_data(data) @@ -758,8 +779,7 @@ def func(*arrays): ) ) - if keep_attrs: - var.attrs.update(first_obj.attrs) + var.attrs = attrs output.append(var) if signature.num_outputs == 1: @@ -801,7 +821,7 @@ def apply_ufunc( join: str = "exact", dataset_join: str = "exact", dataset_fill_value: object = _NO_FILL_VALUE, - keep_attrs: bool = False, + keep_attrs: Union[bool, str] = None, kwargs: Mapping = None, dask: str = "forbidden", output_dtypes: Sequence = None, @@ -1098,6 +1118,12 @@ def apply_ufunc( if kwargs: func = functools.partial(func, **kwargs) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + if isinstance(keep_attrs, bool): + keep_attrs = "override" if keep_attrs else "drop" + variables_vfunc = functools.partial( apply_variable_ufunc, func, diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 06da058eb1f..4901c5bf312 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -314,6 +314,7 @@ def merge_coordinates_without_align( objects: "List[Coordinates]", prioritized: Mapping[Hashable, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), + combine_attrs: str = "override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge variables/indexes from coordinates without automatic alignments. @@ -335,7 +336,7 @@ def merge_coordinates_without_align( else: filtered = collected - return merge_collected(filtered, prioritized) + return merge_collected(filtered, prioritized, combine_attrs=combine_attrs) def determine_coords( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c76633de831..cbfa61a4482 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -29,6 +29,7 @@ def assert_identical(a, b): """A version of this function which accepts numpy arrays""" + __tracebackhide__ = True from xarray.testing import assert_identical as assert_identical_ if hasattr(a, "identical"): @@ -563,6 +564,401 @@ def add(a, b, keep_attrs): assert_identical(actual.x.attrs, a.x.attrs) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_variable(strategy, attrs, expected, error): + a = xr.Variable("x", [0, 1], attrs=attrs[0]) + b = xr.Variable("x", [0, 1], attrs=attrs[1]) + c = xr.Variable("x", [0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Variable("x", [0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray(strategy, attrs, expected, error): + a = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[0]) + b = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[1]) + c = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.DataArray(dims="x", data=[0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "dim": lambda attrs, default: (attrs, default), + "coord": lambda attrs, default: (default, attrs), + }.get(variant) + + dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[0]), "u": ("x", [0, 1], coord_attrs[0])}, + ) + b = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[1]), "u": ("x", [0, 1], coord_attrs[1])}, + ) + c = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[2]), "u": ("x", [0, 1], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.DataArray( + dims="x", + data=[0, 3], + coords={"x": ("x", [0, 1], dim_attrs), "u": ("x", [0, 1], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset(strategy, attrs, expected, error): + a = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[0]) + b = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[1]) + c = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Dataset({"a": ("x", [0, 3])}, attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("data", "dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "data": lambda attrs, default: (attrs, default, default), + "dim": lambda attrs, default: (default, attrs, default), + "coord": lambda attrs, default: (default, default, attrs), + }.get(variant) + data_attrs, dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.Dataset( + {"a": ("x", [], data_attrs[0])}, + coords={"x": ("x", [], dim_attrs[0]), "u": ("x", [], coord_attrs[0])}, + ) + b = xr.Dataset( + {"a": ("x", [], data_attrs[1])}, + coords={"x": ("x", [], dim_attrs[1]), "u": ("x", [], coord_attrs[1])}, + ) + c = xr.Dataset( + {"a": ("x", [], data_attrs[2])}, + coords={"x": ("x", [], dim_attrs[2]), "u": ("x", [], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + data_attrs, dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.Dataset( + {"a": ("x", [], data_attrs)}, + coords={"x": ("x", [], dim_attrs), "u": ("x", [], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + def test_dataset_join(): ds0 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]})