"
- f"{ICONS_SVG}"
+ f"{icons_svg}"
f"
{escape(repr(obj))}
"
"
"
f"{header}"
diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py
index 04c0fabae6a..5087390ecc0 100644
--- a/xarray/core/groupby.py
+++ b/xarray/core/groupby.py
@@ -64,8 +64,8 @@ def unique_value_groups(ar, sort=True):
def _dummy_copy(xarray_obj):
- from .dataset import Dataset
from .dataarray import DataArray
+ from .dataset import Dataset
if isinstance(xarray_obj, Dataset):
res = Dataset(
@@ -310,7 +310,8 @@ def __init__(
if not hashable(group):
raise TypeError(
"`group` must be an xarray.DataArray or the "
- "name of an xarray variable or dimension"
+ "name of an xarray variable or dimension."
+ f"Received {group!r} instead."
)
group = obj[group]
if len(group) == 0:
diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py
index ab049a0a4b4..28ed2cfb16f 100644
--- a/xarray/core/indexing.py
+++ b/xarray/core/indexing.py
@@ -50,8 +50,8 @@ def _expand_slice(slice_, size):
def _sanitize_slice_element(x):
- from .variable import Variable
from .dataarray import DataArray
+ from .variable import Variable
if isinstance(x, (Variable, DataArray)):
x = x.values
diff --git a/xarray/core/missing.py b/xarray/core/missing.py
index 59d4f777c73..a6bed408164 100644
--- a/xarray/core/missing.py
+++ b/xarray/core/missing.py
@@ -544,15 +544,6 @@ def _get_valid_fill_mask(arr, dim, limit):
) <= limit
-def _assert_single_chunk(var, axes):
- for axis in axes:
- if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]:
- raise NotImplementedError(
- "Chunking along the dimension to be interpolated "
- "({}) is not yet supported.".format(axis)
- )
-
-
def _localize(var, indexes_coords):
""" Speed up for linear and nearest neighbor method.
Only consider a subspace that is needed for the interpolation
@@ -617,49 +608,42 @@ def interp(var, indexes_coords, method, **kwargs):
if not indexes_coords:
return var.copy()
- # simple speed up for the local interpolation
- if method in ["linear", "nearest"]:
- var, indexes_coords = _localize(var, indexes_coords)
-
# default behavior
kwargs["bounds_error"] = kwargs.get("bounds_error", False)
- # check if the interpolation can be done in orthogonal manner
- if (
- len(indexes_coords) > 1
- and method in ["linear", "nearest"]
- and all(dest[1].ndim == 1 for dest in indexes_coords.values())
- and len(set([d[1].dims[0] for d in indexes_coords.values()]))
- == len(indexes_coords)
- ):
- # interpolate sequentially
- for dim, dest in indexes_coords.items():
- var = interp(var, {dim: dest}, method, **kwargs)
- return var
-
- # target dimensions
- dims = list(indexes_coords)
- x, new_x = zip(*[indexes_coords[d] for d in dims])
- destination = broadcast_variables(*new_x)
-
- # transpose to make the interpolated axis to the last position
- broadcast_dims = [d for d in var.dims if d not in dims]
- original_dims = broadcast_dims + dims
- new_dims = broadcast_dims + list(destination[0].dims)
- interped = interp_func(
- var.transpose(*original_dims).data, x, destination, method, kwargs
- )
+ result = var
+ # decompose the interpolation into a succession of independant interpolation
+ for indexes_coords in decompose_interp(indexes_coords):
+ var = result
+
+ # simple speed up for the local interpolation
+ if method in ["linear", "nearest"]:
+ var, indexes_coords = _localize(var, indexes_coords)
+
+ # target dimensions
+ dims = list(indexes_coords)
+ x, new_x = zip(*[indexes_coords[d] for d in dims])
+ destination = broadcast_variables(*new_x)
+
+ # transpose to make the interpolated axis to the last position
+ broadcast_dims = [d for d in var.dims if d not in dims]
+ original_dims = broadcast_dims + dims
+ new_dims = broadcast_dims + list(destination[0].dims)
+ interped = interp_func(
+ var.transpose(*original_dims).data, x, destination, method, kwargs
+ )
- result = Variable(new_dims, interped, attrs=var.attrs)
+ result = Variable(new_dims, interped, attrs=var.attrs)
- # dimension of the output array
- out_dims = OrderedSet()
- for d in var.dims:
- if d in dims:
- out_dims.update(indexes_coords[d][1].dims)
- else:
- out_dims.add(d)
- return result.transpose(*tuple(out_dims))
+ # dimension of the output array
+ out_dims = OrderedSet()
+ for d in var.dims:
+ if d in dims:
+ out_dims.update(indexes_coords[d][1].dims)
+ else:
+ out_dims.add(d)
+ result = result.transpose(*tuple(out_dims))
+ return result
def interp_func(var, x, new_x, method, kwargs):
@@ -706,21 +690,51 @@ def interp_func(var, x, new_x, method, kwargs):
if isinstance(var, dask_array_type):
import dask.array as da
- _assert_single_chunk(var, range(var.ndim - len(x), var.ndim))
- chunks = var.chunks[: -len(x)] + new_x[0].shape
- drop_axis = range(var.ndim - len(x), var.ndim)
- new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim)
- return da.map_blocks(
- _interpnd,
+ nconst = var.ndim - len(x)
+
+ out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim))
+
+ # blockwise args format
+ x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)]
+ x_arginds = [item for pair in x_arginds for item in pair]
+ new_x_arginds = [
+ [_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x
+ ]
+ new_x_arginds = [item for pair in new_x_arginds for item in pair]
+
+ args = (
var,
- x,
- new_x,
- func,
- kwargs,
+ range(var.ndim),
+ *x_arginds,
+ *new_x_arginds,
+ )
+
+ _, rechunked = da.unify_chunks(*args)
+
+ args = tuple([elem for pair in zip(rechunked, args[1::2]) for elem in pair])
+
+ new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]
+
+ new_axes = {
+ var.ndim + i: new_x[0].chunks[i]
+ if new_x[0].chunks is not None
+ else new_x[0].shape[i]
+ for i in range(new_x[0].ndim)
+ }
+
+ # if usefull, re-use localize for each chunk of new_x
+ localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None)
+
+ return da.blockwise(
+ _dask_aware_interpnd,
+ out_ind,
+ *args,
+ interp_func=func,
+ interp_kwargs=kwargs,
+ localize=localize,
+ concatenate=True,
dtype=var.dtype,
- chunks=chunks,
- new_axis=new_axis,
- drop_axis=drop_axis,
+ new_axes=new_axes,
)
return _interpnd(var, x, new_x, func, kwargs)
@@ -751,3 +765,67 @@ def _interpnd(var, x, new_x, func, kwargs):
# move back the interpolation axes to the last position
rslt = rslt.transpose(range(-rslt.ndim + 1, 1))
return rslt.reshape(rslt.shape[:-1] + new_x[0].shape)
+
+
+def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):
+ """Wrapper for `_interpnd` through `blockwise`
+
+ The first half arrays in `coords` are original coordinates,
+ the other half are destination coordinates
+ """
+ n_x = len(coords) // 2
+ nconst = len(var.shape) - n_x
+
+ # _interpnd expect coords to be Variables
+ x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
+ new_x = [
+ Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x)
+ for _x in coords[n_x:]
+ ]
+
+ if localize:
+ # _localize expect var to be a Variable
+ var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var)
+
+ indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)}
+
+ # simple speed up for the local interpolation
+ var, indexes_coords = _localize(var, indexes_coords)
+ x, new_x = zip(*[indexes_coords[d] for d in indexes_coords])
+
+ # put var back as a ndarray
+ var = var.data
+
+ return _interpnd(var, x, new_x, interp_func, interp_kwargs)
+
+
+def decompose_interp(indexes_coords):
+ """Decompose the interpolation into a succession of independant interpolation keeping the order"""
+
+ dest_dims = [
+ dest[1].dims if dest[1].ndim > 0 else [dim]
+ for dim, dest in indexes_coords.items()
+ ]
+ partial_dest_dims = []
+ partial_indexes_coords = {}
+ for i, index_coords in enumerate(indexes_coords.items()):
+ partial_indexes_coords.update([index_coords])
+
+ if i == len(dest_dims) - 1:
+ break
+
+ partial_dest_dims += [dest_dims[i]]
+ other_dims = dest_dims[i + 1 :]
+
+ s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims}
+ s_other_dims = {dim for dims in other_dims for dim in dims}
+
+ if not s_partial_dest_dims.intersection(s_other_dims):
+ # this interpolation is orthogonal to the rest
+
+ yield partial_indexes_coords
+
+ partial_dest_dims = []
+ partial_indexes_coords = {}
+
+ yield partial_indexes_coords
diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py
index f9989c2c8c9..41c8d258d7a 100644
--- a/xarray/core/nanops.py
+++ b/xarray/core/nanops.py
@@ -6,6 +6,7 @@
try:
import dask.array as dask_array
+
from . import dask_array_compat
except ImportError:
dask_array = None
@@ -118,7 +119,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None):
def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
""" In house nanmean. ddof argument will be used in _nanvar method """
- from .duck_array_ops import count, fillna, _dask_or_eager_func, where_method
+ from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method
valid_count = count(value, axis=axis)
value = fillna(value, 0)
diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py
index fa6df63e0ea..4f592eb3c5c 100644
--- a/xarray/core/nputils.py
+++ b/xarray/core/nputils.py
@@ -135,14 +135,22 @@ def __setitem__(self, key, value):
def rolling_window(a, axis, window, center, fill_value):
""" rolling window with padding. """
pads = [(0, 0) for s in a.shape]
- if center:
- start = int(window / 2) # 10 -> 5, 9 -> 4
- end = window - 1 - start
- pads[axis] = (start, end)
- else:
- pads[axis] = (window - 1, 0)
+ if not hasattr(axis, "__len__"):
+ axis = [axis]
+ window = [window]
+ center = [center]
+
+ for ax, win, cent in zip(axis, window, center):
+ if cent:
+ start = int(win / 2) # 10 -> 5, 9 -> 4
+ end = win - 1 - start
+ pads[ax] = (start, end)
+ else:
+ pads[ax] = (win - 1, 0)
a = np.pad(a, pads, mode="constant", constant_values=fill_value)
- return _rolling_window(a, window, axis)
+ for ax, win in zip(axis, window):
+ a = _rolling_window(a, win, ax)
+ return a
def _rolling_window(a, window, axis=-1):
diff --git a/xarray/core/ops.py b/xarray/core/ops.py
index d4aeea37aad..3675317977f 100644
--- a/xarray/core/ops.py
+++ b/xarray/core/ops.py
@@ -90,12 +90,7 @@
Parameters
----------
-{extra_args}
-skipna : bool, optional
- If True, skip missing values (as marked by NaN). By default, only
- skips missing values for float dtypes; other dtypes either do not
- have a sentinel missing value (int) or skipna=True has not been
- implemented (object, datetime64 or timedelta64).{min_count_docs}
+{extra_args}{skip_na_docs}{min_count_docs}
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
@@ -111,6 +106,13 @@
indicated dimension(s) removed.
"""
+_SKIPNA_DOCSTRING = """
+skipna : bool, optional
+ If True, skip missing values (as marked by NaN). By default, only
+ skips missing values for float dtypes; other dtypes either do not
+ have a sentinel missing value (int) or skipna=True has not been
+ implemented (object, datetime64 or timedelta64)."""
+
_MINCOUNT_DOCSTRING = """
min_count : int, default None
The required number of valid values to perform the operation.
@@ -260,6 +262,7 @@ def inject_reduce_methods(cls):
for name, f, include_skipna in methods:
numeric_only = getattr(f, "numeric_only", False)
available_min_count = getattr(f, "available_min_count", False)
+ skip_na_docs = _SKIPNA_DOCSTRING if include_skipna else ""
min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else ""
func = cls._reduce_method(f, include_skipna, numeric_only)
@@ -268,6 +271,7 @@ def inject_reduce_methods(cls):
name=name,
cls=cls.__name__,
extra_args=cls._reduce_extra_args_docstring.format(name=name),
+ skip_na_docs=skip_na_docs,
min_count_docs=min_count_docs,
)
setattr(cls, name, func)
diff --git a/xarray/core/options.py b/xarray/core/options.py
index 5d81ca40a6e..bb1b1c47840 100644
--- a/xarray/core/options.py
+++ b/xarray/core/options.py
@@ -132,7 +132,15 @@ def __init__(self, **kwargs):
% (k, set(OPTIONS))
)
if k in _VALIDATORS and not _VALIDATORS[k](v):
- raise ValueError(f"option {k!r} given an invalid value: {v!r}")
+ if k == ARITHMETIC_JOIN:
+ expected = f"Expected one of {_JOIN_OPTIONS!r}"
+ elif k == DISPLAY_STYLE:
+ expected = f"Expected one of {_DISPLAY_OPTIONS!r}"
+ else:
+ expected = ""
+ raise ValueError(
+ f"option {k!r} given an invalid value: {v!r}. " + expected
+ )
self.old[k] = OPTIONS[k]
self._apply_update(kwargs)
diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py
index 86044e72dd2..6d5456f77f7 100644
--- a/xarray/core/parallel.py
+++ b/xarray/core/parallel.py
@@ -2,6 +2,7 @@
import dask
import dask.array
from dask.highlevelgraph import HighLevelGraph
+
from .dask_array_compat import meta_from_array
except ImportError:
@@ -234,11 +235,14 @@ def map_blocks(
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
+ >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
>>> np.random.seed(123)
>>> array = xr.DataArray(
- ... np.random.rand(len(time)), dims="time", coords=[time]
+ ... np.random.rand(len(time)),
+ ... dims=["time"],
+ ... coords={"time": time, "month": month},
... ).chunk()
- >>> xr.map_blocks(calculate_anomaly, array, template=array).compute()
+ >>> array.map_blocks(calculate_anomaly, template=array).compute()
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
@@ -247,25 +251,20 @@ def map_blocks(
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
+ month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:
- >>> xr.map_blocks(
- ... calculate_anomaly,
- ... array,
- ... kwargs={"groupby_type": "time.year"},
- ... template=array,
- ... )
+ >>> array.map_blocks(
+ ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array,
+ ... ) # doctest: +ELLIPSIS
- array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
- -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
- -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
- 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
- 0.14482397, 0.35985481, 0.23487834, 0.12144652])
+ dask.array
Coordinates:
- * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
- """
+ * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
+ month (time) int64 dask.array
+ """
def _wrapper(
func: Callable,
diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py
index aaf52b9f295..dcb78d17cf8 100644
--- a/xarray/core/pycompat.py
+++ b/xarray/core/pycompat.py
@@ -17,3 +17,11 @@
sparse_array_type = (sparse.SparseArray,)
except ImportError: # pragma: no cover
sparse_array_type = ()
+
+try:
+ # solely for isinstance checks
+ import cupy
+
+ cupy_array_type = (cupy.ndarray,)
+except ImportError: # pragma: no cover
+ cupy_array_type = ()
diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py
index ecba5307680..fb38c0c7fe6 100644
--- a/xarray/core/rolling.py
+++ b/xarray/core/rolling.py
@@ -75,40 +75,32 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
-------
rolling : type of input argument
"""
- if len(windows) != 1:
- raise ValueError("exactly one dim/window should be provided")
-
- dim, window = next(iter(windows.items()))
-
- if window <= 0:
- raise ValueError("window must be > 0")
-
+ self.dim, self.window = [], []
+ for d, w in windows.items():
+ self.dim.append(d)
+ if w <= 0:
+ raise ValueError("window must be > 0")
+ self.window.append(w)
+
+ self.center = self._mapping_to_list(center, default=False)
self.obj = obj
# attributes
- self.window = window
if min_periods is not None and min_periods <= 0:
raise ValueError("min_periods must be greater than zero or None")
- self.min_periods = min_periods
- self.center = center
- self.dim = dim
+ self.min_periods = np.prod(self.window) if min_periods is None else min_periods
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
self.keep_attrs = keep_attrs
- @property
- def _min_periods(self):
- return self.min_periods if self.min_periods is not None else self.window
-
def __repr__(self):
"""provide a nice str repr of our rolling object"""
attrs = [
- "{k}->{v}".format(k=k, v=getattr(self, k))
- for k in self._attributes
- if getattr(self, k, None) is not None
+ "{k}->{v}{c}".format(k=k, v=w, c="(center)" if c else "")
+ for k, w, c in zip(self.dim, self.window, self.center)
]
return "{klass} [{attrs}]".format(
klass=self.__class__.__name__, attrs=",".join(attrs)
@@ -143,11 +135,31 @@ def method(self, **kwargs):
def count(self):
rolling_count = self._counts()
- enough_periods = rolling_count >= self._min_periods
+ enough_periods = rolling_count >= self.min_periods
return rolling_count.where(enough_periods)
count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count")
+ def _mapping_to_list(
+ self, arg, default=None, allow_default=True, allow_allsame=True
+ ):
+ if utils.is_dict_like(arg):
+ if allow_default:
+ return [arg.get(d, default) for d in self.dim]
+ else:
+ for d in self.dim:
+ if d not in arg:
+ raise KeyError("argument has no key {}.".format(d))
+ return [arg[d] for d in self.dim]
+ elif allow_allsame: # for single argument
+ return [arg] * len(self.dim)
+ elif len(self.dim) == 1:
+ return [arg]
+ else:
+ raise ValueError(
+ "Mapping argument is necessary for {}d-rolling.".format(len(self.dim))
+ )
+
class DataArrayRolling(Rolling):
__slots__ = ("window_labels",)
@@ -196,33 +208,41 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs
)
- self.window_labels = self.obj[self.dim]
+ # TODO legacy attribute
+ self.window_labels = self.obj[self.dim[0]]
def __iter__(self):
+ if len(self.dim) > 1:
+ raise ValueError("__iter__ is only supported for 1d-rolling")
stops = np.arange(1, len(self.window_labels) + 1)
- starts = stops - int(self.window)
- starts[: int(self.window)] = 0
+ starts = stops - int(self.window[0])
+ starts[: int(self.window[0])] = 0
for (label, start, stop) in zip(self.window_labels, starts, stops):
- window = self.obj.isel(**{self.dim: slice(start, stop)})
+ window = self.obj.isel(**{self.dim[0]: slice(start, stop)})
- counts = window.count(dim=self.dim)
- window = window.where(counts >= self._min_periods)
+ counts = window.count(dim=self.dim[0])
+ window = window.where(counts >= self.min_periods)
yield (label, window)
- def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
+ def construct(
+ self, window_dim=None, stride=1, fill_value=dtypes.NA, **window_dim_kwargs
+ ):
"""
Convert this rolling object to xr.DataArray,
where the window dimension is stacked as a new dimension
Parameters
----------
- window_dim: str
- New name of the window dimension.
- stride: integer, optional
+ window_dim: str or a mapping, optional
+ A mapping from dimension name to the new window dimension names.
+ Just a string can be used for 1d-rolling.
+ stride: integer or a mapping, optional
Size of stride for the rolling window.
fill_value: optional. Default dtypes.NA
Filling value to match the dimension size.
+ **window_dim_kwargs : {dim: new_name, ...}, optional
+ The keyword arguments form of ``window_dim``.
Returns
-------
@@ -251,13 +271,27 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
from .dataarray import DataArray
+ if window_dim is None:
+ if len(window_dim_kwargs) == 0:
+ raise ValueError(
+ "Either window_dim or window_dim_kwargs need to be specified."
+ )
+ window_dim = {d: window_dim_kwargs[d] for d in self.dim}
+
+ window_dim = self._mapping_to_list(
+ window_dim, allow_default=False, allow_allsame=False
+ )
+ stride = self._mapping_to_list(stride, default=1)
+
window = self.obj.variable.rolling_window(
self.dim, self.window, window_dim, self.center, fill_value=fill_value
)
result = DataArray(
- window, dims=self.obj.dims + (window_dim,), coords=self.obj.coords
+ window, dims=self.obj.dims + tuple(window_dim), coords=self.obj.coords
+ )
+ return result.isel(
+ **{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
)
- return result.isel(**{self.dim: slice(None, None, stride)})
def reduce(self, func, **kwargs):
"""Reduce the items in this group by applying `func` along some
@@ -300,27 +334,36 @@ def reduce(self, func, **kwargs):
[ 4., 9., 15., 18.]])
"""
- rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim")
+ rolling_dim = {
+ d: utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d))
+ for d in self.dim
+ }
windows = self.construct(rolling_dim)
- result = windows.reduce(func, dim=rolling_dim, **kwargs)
+ result = windows.reduce(func, dim=list(rolling_dim.values()), **kwargs)
# Find valid windows based on count.
counts = self._counts()
- return result.where(counts >= self._min_periods)
+ return result.where(counts >= self.min_periods)
def _counts(self):
""" Number of non-nan entries in each rolling window. """
- rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim")
+ rolling_dim = {
+ d: utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d))
+ for d in self.dim
+ }
# We use False as the fill_value instead of np.nan, since boolean
# array is faster to be reduced than object array.
# The use of skipna==False is also faster since it does not need to
# copy the strided array.
counts = (
self.obj.notnull()
- .rolling(center=self.center, **{self.dim: self.window})
+ .rolling(
+ center={d: self.center[i] for i, d in enumerate(self.dim)},
+ **{d: w for d, w in zip(self.dim, self.window)},
+ )
.construct(rolling_dim, fill_value=False)
- .sum(dim=rolling_dim, skipna=False)
+ .sum(dim=list(rolling_dim.values()), skipna=False)
)
return counts
@@ -329,39 +372,40 @@ def _bottleneck_reduce(self, func, **kwargs):
# bottleneck doesn't allow min_count to be 0, although it should
# work the same as if min_count = 1
+ # Note bottleneck only works with 1d-rolling.
if self.min_periods is not None and self.min_periods == 0:
min_count = 1
else:
min_count = self.min_periods
- axis = self.obj.get_axis_num(self.dim)
+ axis = self.obj.get_axis_num(self.dim[0])
padded = self.obj.variable
- if self.center:
+ if self.center[0]:
if isinstance(padded.data, dask_array_type):
# Workaround to make the padded chunk size is larger than
# self.window-1
- shift = -(self.window + 1) // 2
- offset = (self.window - 1) // 2
+ shift = -(self.window[0] + 1) // 2
+ offset = (self.window[0] - 1) // 2
valid = (slice(None),) * axis + (
slice(offset, offset + self.obj.shape[axis]),
)
else:
- shift = (-self.window // 2) + 1
+ shift = (-self.window[0] // 2) + 1
valid = (slice(None),) * axis + (slice(-shift, None),)
- padded = padded.pad({self.dim: (0, -shift)}, mode="constant")
+ padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")
if isinstance(padded.data, dask_array_type):
raise AssertionError("should not be reachable")
values = dask_rolling_wrapper(
- func, padded.data, window=self.window, min_count=min_count, axis=axis
+ func, padded.data, window=self.window[0], min_count=min_count, axis=axis
)
else:
values = func(
- padded.data, window=self.window, min_count=min_count, axis=axis
+ padded.data, window=self.window[0], min_count=min_count, axis=axis
)
- if self.center:
+ if self.center[0]:
values = values[valid]
result = DataArray(values, self.obj.coords)
@@ -378,8 +422,10 @@ def _numpy_or_bottleneck_reduce(
)
del kwargs["dim"]
- if bottleneck_move_func is not None and not isinstance(
- self.obj.data, dask_array_type
+ if (
+ bottleneck_move_func is not None
+ and not isinstance(self.obj.data, dask_array_type)
+ and len(self.dim) == 1
):
# TODO: renable bottleneck with dask after the issues
# underlying https://github.com/pydata/xarray/issues/2940 are
@@ -412,7 +458,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
- center : boolean, default False
+ center : boolean, or a mapping from dimension name to boolean, default False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
@@ -431,15 +477,22 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
DataArray.groupby
"""
super().__init__(obj, windows, min_periods, center, keep_attrs)
- if self.dim not in self.obj.dims:
+ if any(d not in self.obj.dims for d in self.dim):
raise KeyError(self.dim)
# Keep each Rolling object as a dictionary
self.rollings = {}
for key, da in self.obj.data_vars.items():
# keeps rollings only for the dataset depending on slf.dim
- if self.dim in da.dims:
+ dims, center = [], {}
+ for i, d in enumerate(self.dim):
+ if d in da.dims:
+ dims.append(d)
+ center[d] = self.center[i]
+
+ if len(dims) > 0:
+ w = {d: windows[d] for d in dims}
self.rollings[key] = DataArrayRolling(
- da, windows, min_periods, center, keep_attrs
+ da, w, min_periods, center, keep_attrs
)
def _dataset_implementation(self, func, **kwargs):
@@ -447,7 +500,7 @@ def _dataset_implementation(self, func, **kwargs):
reduced = {}
for key, da in self.obj.data_vars.items():
- if self.dim in da.dims:
+ if any(d in da.dims for d in self.dim):
reduced[key] = func(self.rollings[key], **kwargs)
else:
reduced[key] = self.obj[key]
@@ -491,19 +544,29 @@ def _numpy_or_bottleneck_reduce(
**kwargs,
)
- def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None):
+ def construct(
+ self,
+ window_dim=None,
+ stride=1,
+ fill_value=dtypes.NA,
+ keep_attrs=None,
+ **window_dim_kwargs,
+ ):
"""
Convert this rolling object to xr.Dataset,
where the window dimension is stacked as a new dimension
Parameters
----------
- window_dim: str
- New name of the window dimension.
+ window_dim: str or a mapping, optional
+ A mapping from dimension name to the new window dimension names.
+ Just a string can be used for 1d-rolling.
stride: integer, optional
size of stride for the rolling window.
fill_value: optional. Default dtypes.NA
Filling value to match the dimension size.
+ **window_dim_kwargs : {dim: new_name, ...}, optional
+ The keyword arguments form of ``window_dim``.
Returns
-------
@@ -512,19 +575,35 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None)
from .dataset import Dataset
+ if window_dim is None:
+ if len(window_dim_kwargs) == 0:
+ raise ValueError(
+ "Either window_dim or window_dim_kwargs need to be specified."
+ )
+ window_dim = {d: window_dim_kwargs[d] for d in self.dim}
+
+ window_dim = self._mapping_to_list(
+ window_dim, allow_default=False, allow_allsame=False
+ )
+ stride = self._mapping_to_list(stride, default=1)
+
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
dataset = {}
for key, da in self.obj.data_vars.items():
- if self.dim in da.dims:
+ # keeps rollings only for the dataset depending on slf.dim
+ dims = [d for d in self.dim if d in da.dims]
+ if len(dims) > 0:
+ wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims}
+ st = {d: stride[i] for i, d in enumerate(self.dim) if d in da.dims}
dataset[key] = self.rollings[key].construct(
- window_dim, fill_value=fill_value
+ window_dim=wi, fill_value=fill_value, stride=st
)
else:
dataset[key] = da
return Dataset(dataset, coords=self.obj.coords).isel(
- **{self.dim: slice(None, None, stride)}
+ **{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
)
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index c505c749557..1f86a40348c 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -33,7 +33,7 @@
)
from .npcompat import IS_NEP18_ACTIVE
from .options import _get_keep_attrs
-from .pycompat import dask_array_type, integer_types
+from .pycompat import cupy_array_type, dask_array_type, integer_types
from .utils import (
OrderedSet,
_default,
@@ -45,9 +45,8 @@
)
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
- indexing.ExplicitlyIndexed,
- pd.Index,
-) + dask_array_type
+ (indexing.ExplicitlyIndexed, pd.Index,) + dask_array_type + cupy_array_type
+)
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore
@@ -257,7 +256,10 @@ def _as_array_or_item(data):
TODO: remove this (replace with np.asarray) once these issues are fixed
"""
- data = np.asarray(data)
+ if isinstance(data, cupy_array_type):
+ data = data.get()
+ else:
+ data = np.asarray(data)
if data.ndim == 0:
if data.dtype.kind == "M":
data = np.datetime64(data, "ns")
@@ -1054,7 +1056,7 @@ def isel(
missing_dims : {"raise", "warn", "ignore"}, default "raise"
What to do if dimensions that should be selected from are not present in the
DataArray:
- - "exception": raise an exception
+ - "raise": raise an exception
- "warning": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions
@@ -1881,11 +1883,14 @@ def rolling_window(
Parameters
----------
dim: str
- Dimension over which to compute rolling_window
+ Dimension over which to compute rolling_window.
+ For nd-rolling, should be list of dimensions.
window: int
Window size of the rolling
+ For nd-rolling, should be list of integers.
window_dim: str
New name of the window dimension.
+ For nd-rolling, should be list of integers.
center: boolean. default False.
If True, pad fill_value for both ends. Otherwise, pad in the head
of the axis.
@@ -1919,15 +1924,21 @@ def rolling_window(
dtype = self.dtype
array = self.data
- new_dims = self.dims + (window_dim,)
+ if isinstance(dim, list):
+ assert len(dim) == len(window)
+ assert len(dim) == len(window_dim)
+ assert len(dim) == len(center)
+ else:
+ dim = [dim]
+ window = [window]
+ window_dim = [window_dim]
+ center = [center]
+ axis = [self.get_axis_num(d) for d in dim]
+ new_dims = self.dims + tuple(window_dim)
return Variable(
new_dims,
duck_array_ops.rolling_window(
- array,
- axis=self.get_axis_num(dim),
- window=window,
- center=center,
- fill_value=fill_value,
+ array, axis=axis, window=window, center=center, fill_value=fill_value
),
)
diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py
index a987d302202..e0cbdb7377a 100644
--- a/xarray/tests/test_accessor_str.py
+++ b/xarray/tests/test_accessor_str.py
@@ -596,7 +596,7 @@ def test_wrap():
)
# expected values
- xp = xr.DataArray(
+ expected = xr.DataArray(
[
"hello world",
"hello world!",
@@ -610,15 +610,29 @@ def test_wrap():
]
)
- rs = values.str.wrap(12, break_long_words=True)
- assert_equal(rs, xp)
+ result = values.str.wrap(12, break_long_words=True)
+ assert_equal(result, expected)
# test with pre and post whitespace (non-unicode), NaN, and non-ascii
# Unicode
values = xr.DataArray([" pre ", "\xac\u20ac\U00008000 abadcafe"])
- xp = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"])
- rs = values.str.wrap(6)
- assert_equal(rs, xp)
+ expected = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"])
+ result = values.str.wrap(6)
+ assert_equal(result, expected)
+
+
+def test_wrap_kwargs_passed():
+ # GH4334
+
+ values = xr.DataArray(" hello world ")
+
+ result = values.str.wrap(7)
+ expected = xr.DataArray(" hello\nworld")
+ assert_equal(result, expected)
+
+ result = values.str.wrap(7, drop_whitespace=False)
+ expected = xr.DataArray(" hello\n world\n ")
+ assert_equal(result, expected)
def test_get(dtype):
@@ -642,6 +656,15 @@ def test_get(dtype):
assert_equal(result, expected)
+def test_get_default(dtype):
+ # GH4334
+ values = xr.DataArray(["a_b", "c", ""]).astype(dtype)
+
+ result = values.str.get(2, "default")
+ expected = xr.DataArray(["b", "default", "default"]).astype(dtype)
+ assert_equal(result, expected)
+
+
def test_encode_decode():
data = xr.DataArray(["a", "b", "a\xe4"])
encoded = data.str.encode("utf-8")
diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py
index 6a840e6303e..9f987e7fdf2 100644
--- a/xarray/tests/test_backends.py
+++ b/xarray/tests/test_backends.py
@@ -3207,7 +3207,7 @@ def test_load_dataarray(self):
@pytest.mark.filterwarnings("ignore:The binary mode of fromstring is deprecated")
class TestPydap:
def convert_to_pydap_dataset(self, original):
- from pydap.model import GridType, BaseType, DatasetType
+ from pydap.model import BaseType, DatasetType, GridType
ds = DatasetType("bears", **original.attrs)
for key, var in original.data_vars.items():
@@ -3747,9 +3747,10 @@ def test_platecarree(self):
def test_notransform(self):
# regression test for https://github.com/pydata/xarray/issues/1686
- import rasterio
import warnings
+ import rasterio
+
# Create a geotiff file
with warnings.catch_warnings():
# rasterio throws a NotGeoreferencedWarning here, which is
@@ -4097,8 +4098,8 @@ def test_rasterio_vrt_with_transform_and_size(self):
# Test open_rasterio() support of WarpedVRT with transform, width and
# height (issue #2864)
import rasterio
- from rasterio.warp import calculate_default_transform
from affine import Affine
+ from rasterio.warp import calculate_default_transform
with create_tmp_geotiff() as (tmp_file, expected):
with rasterio.open(tmp_file) as src:
diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py
index 745ae341370..642609ba059 100644
--- a/xarray/tests/test_cftimeindex.py
+++ b/xarray/tests/test_cftimeindex.py
@@ -1,4 +1,5 @@
from datetime import timedelta
+from textwrap import dedent
import numpy as np
import pandas as pd
@@ -884,6 +885,120 @@ def test_cftimeindex_shift_invalid_freq():
index.shift(1, 1)
+@requires_cftime
+@pytest.mark.parametrize(
+ ("calendar", "expected"),
+ [
+ ("noleap", "noleap"),
+ ("365_day", "noleap"),
+ ("360_day", "360_day"),
+ ("julian", "julian"),
+ ("gregorian", "gregorian"),
+ ("proleptic_gregorian", "proleptic_gregorian"),
+ ],
+)
+def test_cftimeindex_calendar_property(calendar, expected):
+ index = xr.cftime_range(start="2000", periods=3, calendar=calendar)
+ assert index.calendar == expected
+
+
+@requires_cftime
+@pytest.mark.parametrize(
+ ("calendar", "expected"),
+ [
+ ("noleap", "noleap"),
+ ("365_day", "noleap"),
+ ("360_day", "360_day"),
+ ("julian", "julian"),
+ ("gregorian", "gregorian"),
+ ("proleptic_gregorian", "proleptic_gregorian"),
+ ],
+)
+def test_cftimeindex_calendar_repr(calendar, expected):
+ """Test that cftimeindex has calendar property in repr."""
+ index = xr.cftime_range(start="2000", periods=3, calendar=calendar)
+ repr_str = index.__repr__()
+ assert f" calendar='{expected}'" in repr_str
+ assert "2000-01-01 00:00:00, 2000-01-02 00:00:00" in repr_str
+
+
+@requires_cftime
+@pytest.mark.parametrize("periods", [2, 40])
+def test_cftimeindex_periods_repr(periods):
+ """Test that cftimeindex has periods property in repr."""
+ index = xr.cftime_range(start="2000", periods=periods)
+ repr_str = index.__repr__()
+ assert f" length={periods}" in repr_str
+
+
+@requires_cftime
+@pytest.mark.parametrize(
+ "periods,expected",
+ [
+ (
+ 2,
+ """\
+CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00],
+ dtype='object', length=2, calendar='gregorian')""",
+ ),
+ (
+ 4,
+ """\
+CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00, 2000-01-03 00:00:00,
+ 2000-01-04 00:00:00],
+ dtype='object', length=4, calendar='gregorian')""",
+ ),
+ (
+ 101,
+ """\
+CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00, 2000-01-03 00:00:00,
+ 2000-01-04 00:00:00, 2000-01-05 00:00:00, 2000-01-06 00:00:00,
+ 2000-01-07 00:00:00, 2000-01-08 00:00:00, 2000-01-09 00:00:00,
+ 2000-01-10 00:00:00,
+ ...
+ 2000-04-01 00:00:00, 2000-04-02 00:00:00, 2000-04-03 00:00:00,
+ 2000-04-04 00:00:00, 2000-04-05 00:00:00, 2000-04-06 00:00:00,
+ 2000-04-07 00:00:00, 2000-04-08 00:00:00, 2000-04-09 00:00:00,
+ 2000-04-10 00:00:00],
+ dtype='object', length=101, calendar='gregorian')""",
+ ),
+ ],
+)
+def test_cftimeindex_repr_formatting(periods, expected):
+ """Test that cftimeindex.__repr__ is formatted similar to pd.Index.__repr__."""
+ index = xr.cftime_range(start="2000", periods=periods)
+ expected = dedent(expected)
+ assert expected == repr(index)
+
+
+@requires_cftime
+@pytest.mark.parametrize("display_width", [40, 80, 100])
+@pytest.mark.parametrize("periods", [2, 3, 4, 100, 101])
+def test_cftimeindex_repr_formatting_width(periods, display_width):
+ """Test that cftimeindex is sensitive to OPTIONS['display_width']."""
+ index = xr.cftime_range(start="2000", periods=periods)
+ len_intro_str = len("CFTimeIndex(")
+ with xr.set_options(display_width=display_width):
+ repr_str = index.__repr__()
+ splitted = repr_str.split("\n")
+ for i, s in enumerate(splitted):
+ # check that lines not longer than OPTIONS['display_width']
+ assert len(s) <= display_width, f"{len(s)} {s} {display_width}"
+ if i > 0:
+ # check for initial spaces
+ assert s[:len_intro_str] == " " * len_intro_str
+
+
+@requires_cftime
+@pytest.mark.parametrize("periods", [22, 50, 100])
+def test_cftimeindex_repr_101_shorter(periods):
+ index_101 = xr.cftime_range(start="2000", periods=101)
+ index_periods = xr.cftime_range(start="2000", periods=periods)
+ index_101_repr_str = index_101.__repr__()
+ index_periods_repr_str = index_periods.__repr__()
+ assert len(index_101_repr_str) < len(index_periods_repr_str)
+
+
@requires_cftime
def test_parse_array_of_cftime_strings():
from cftime import DatetimeNoLeap
diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py
index 1efd4b02bf8..457e68f5593 100644
--- a/xarray/tests/test_coding_times.py
+++ b/xarray/tests/test_coding_times.py
@@ -222,9 +222,10 @@ def test_decode_non_standard_calendar_inside_timestamp_range(calendar):
@requires_cftime
@pytest.mark.parametrize("calendar", _ALL_CALENDARS)
def test_decode_dates_outside_timestamp_range(calendar):
- import cftime
from datetime import datetime
+ import cftime
+
units = "days since 0001-01-01"
times = [datetime(1, 4, 1, h) for h in range(1, 5)]
time = cftime.date2num(times, units, calendar=calendar)
@@ -358,9 +359,10 @@ def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range(calend
@requires_cftime
@pytest.mark.parametrize("calendar", _ALL_CALENDARS)
def test_decode_multidim_time_outside_timestamp_range(calendar):
- import cftime
from datetime import datetime
+ import cftime
+
units = "days since 0001-01-01"
times1 = [datetime(1, 4, day) for day in range(1, 6)]
times2 = [datetime(1, 5, day) for day in range(1, 6)]
diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py
index 624e78d9271..0276b8ebc08 100644
--- a/xarray/tests/test_cupy.py
+++ b/xarray/tests/test_cupy.py
@@ -48,3 +48,13 @@ def test_check_data_stays_on_gpu(toy_weather_data):
"""Perform some operations and check the data stays on the GPU."""
freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time")
assert isinstance(freeze.data, cp.core.core.ndarray)
+
+
+def test_where():
+ from xarray.core.duck_array_ops import where
+
+ data = cp.zeros(10)
+
+ output = where(data < 1, 1, data).all()
+ assert output
+ assert isinstance(output, cp.ndarray)
diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py
index 793090cc122..84455d320cb 100644
--- a/xarray/tests/test_dataarray.py
+++ b/xarray/tests/test_dataarray.py
@@ -3147,7 +3147,8 @@ def test_upsample_interpolate_regression_1605(self):
@requires_dask
@requires_scipy
- def test_upsample_interpolate_dask(self):
+ @pytest.mark.parametrize("chunked_time", [True, False])
+ def test_upsample_interpolate_dask(self, chunked_time):
from scipy.interpolate import interp1d
xs = np.arange(6)
@@ -3158,6 +3159,8 @@ def test_upsample_interpolate_dask(self):
data = np.tile(z, (6, 3, 1))
array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time"))
chunks = {"x": 2, "y": 1}
+ if chunked_time:
+ chunks["time"] = 3
expected_times = times.to_series().resample("1H").asfreq().index
# Split the times into equal sub-intervals to simulate the 6 hour
@@ -3185,13 +3188,6 @@ def test_upsample_interpolate_dask(self):
# done here due to floating point arithmetic
assert_allclose(expected, actual, rtol=1e-16)
- # Check that an error is raised if an attempt is made to interpolate
- # over a chunked dimension
- with raises_regex(
- NotImplementedError, "Chunking along the dimension to be interpolated"
- ):
- array.chunk({"time": 1}).resample(time="1H").interpolate("linear")
-
def test_align(self):
array = DataArray(
np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"]
@@ -3467,15 +3463,18 @@ def test_to_pandas(self):
def test_to_dataframe(self):
# regression test for #260
- arr = DataArray(
- np.random.randn(3, 4), [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo"
- )
+ arr_np = np.random.randn(3, 4)
+
+ arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo")
expected = arr.to_series()
actual = arr.to_dataframe()["foo"]
assert_array_equal(expected.values, actual.values)
assert_array_equal(expected.name, actual.name)
assert_array_equal(expected.index.values, actual.index.values)
+ actual = arr.to_dataframe(dim_order=["A", "B"])["foo"]
+ assert_array_equal(arr_np.transpose().reshape(-1), actual.values)
+
# regression test for coords with different dimensions
arr.coords["C"] = ("B", [-1, -2, -3])
expected = arr.to_series().to_frame()
@@ -3486,6 +3485,9 @@ def test_to_dataframe(self):
assert_array_equal(expected.columns.values, actual.columns.values)
assert_array_equal(expected.index.values, actual.index.values)
+ with pytest.raises(ValueError, match="does not match the set of dimensions"):
+ arr.to_dataframe(dim_order=["B", "A", "C"])
+
arr.name = None # unnamed
with raises_regex(ValueError, "unnamed"):
arr.to_dataframe()
@@ -6180,6 +6182,16 @@ def test_rolling_iter(da):
)
+@pytest.mark.parametrize("da", (1,), indirect=True)
+def test_rolling_repr(da):
+ rolling_obj = da.rolling(time=7)
+ assert repr(rolling_obj) == "DataArrayRolling [time->7]"
+ rolling_obj = da.rolling(time=7, center=True)
+ assert repr(rolling_obj) == "DataArrayRolling [time->7(center)]"
+ rolling_obj = da.rolling(time=7, x=3, center=True)
+ assert repr(rolling_obj) == "DataArrayRolling [time->7(center),x->3(center)]"
+
+
def test_rolling_doc(da):
rolling_obj = da.rolling(time=7)
@@ -6193,8 +6205,6 @@ def test_rolling_properties(da):
assert rolling_obj.obj.get_axis_num("time") == 1
# catching invalid args
- with pytest.raises(ValueError, match="exactly one dim/window should"):
- da.rolling(time=7, x=2)
with pytest.raises(ValueError, match="window must be > 0"):
da.rolling(time=-2)
with pytest.raises(ValueError, match="min_periods must be greater than zero"):
@@ -6399,6 +6409,47 @@ def test_rolling_count_correct():
assert_equal(result, expected)
+@pytest.mark.parametrize("da", (1,), indirect=True)
+@pytest.mark.parametrize("center", (True, False))
+@pytest.mark.parametrize("min_periods", (None, 1))
+@pytest.mark.parametrize("name", ("sum", "mean", "max"))
+def test_ndrolling_reduce(da, center, min_periods, name):
+ rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods)
+
+ actual = getattr(rolling_obj, name)()
+ expected = getattr(
+ getattr(
+ da.rolling(time=3, center=center, min_periods=min_periods), name
+ )().rolling(x=2, center=center, min_periods=min_periods),
+ name,
+ )()
+
+ assert_allclose(actual, expected)
+ assert actual.dims == expected.dims
+
+
+@pytest.mark.parametrize("center", (True, False, (True, False)))
+@pytest.mark.parametrize("fill_value", (np.nan, 0.0))
+def test_ndrolling_construct(center, fill_value):
+ da = DataArray(
+ np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float),
+ dims=["x", "y", "z"],
+ coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)},
+ )
+ actual = da.rolling(x=3, z=2, center=center).construct(
+ x="x1", z="z1", fill_value=fill_value
+ )
+ if not isinstance(center, tuple):
+ center = (center, center)
+ expected = (
+ da.rolling(x=3, center=center[0])
+ .construct(x="x1", fill_value=fill_value)
+ .rolling(z=2, center=center[1])
+ .construct(z="z1", fill_value=fill_value)
+ )
+ assert_allclose(actual, expected)
+
+
def test_raise_no_warning_for_nan_in_binary_ops():
with pytest.warns(None) as record:
xr.DataArray([1, 2, np.NaN]) > 0
@@ -6417,8 +6468,8 @@ def test_name_in_masking():
class TestIrisConversion:
@requires_iris
def test_to_and_from_iris(self):
- import iris
import cf_units # iris requirement
+ import iris
# to iris
coord_dict = {}
@@ -6488,9 +6539,9 @@ def test_to_and_from_iris(self):
@requires_iris
@requires_dask
def test_to_and_from_iris_dask(self):
+ import cf_units # iris requirement
import dask.array as da
import iris
- import cf_units # iris requirement
coord_dict = {}
coord_dict["distance"] = ("distance", [-2, 2], {"units": "meters"})
@@ -6623,8 +6674,8 @@ def test_da_name_from_cube(self, std_name, long_name, var_name, name, attrs):
],
)
def test_da_coord_name_from_cube(self, std_name, long_name, var_name, name, attrs):
- from iris.cube import Cube
from iris.coords import DimCoord
+ from iris.cube import Cube
latitude = DimCoord(
[-90, 0, 90], standard_name=std_name, var_name=var_name, long_name=long_name
@@ -6637,8 +6688,8 @@ def test_da_coord_name_from_cube(self, std_name, long_name, var_name, name, attr
@requires_iris
def test_prevent_duplicate_coord_names(self):
- from iris.cube import Cube
from iris.coords import DimCoord
+ from iris.cube import Cube
# Iris enforces unique coordinate names. Because we use a different
# name resolution order a valid iris Cube with coords that have the
@@ -6659,8 +6710,8 @@ def test_prevent_duplicate_coord_names(self):
[["IA", "IL", "IN"], [0, 2, 1]], # non-numeric values # non-monotonic values
)
def test_fallback_to_iris_AuxCoord(self, coord_values):
- from iris.cube import Cube
from iris.coords import AuxCoord
+ from iris.cube import Cube
data = [0, 0, 0]
da = xr.DataArray(data, coords=[coord_values], dims=["space"])
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index 3e8c60c045e..3a61bc29896 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -3954,6 +3954,33 @@ def test_to_and_from_dataframe(self):
# check roundtrip
assert_identical(ds.assign_coords(x=[0, 1]), Dataset.from_dataframe(actual))
+ # Check multiindex reordering
+ new_order = ["x", "y"]
+ actual = ds.to_dataframe(dim_order=new_order)
+ assert expected.equals(actual)
+
+ new_order = ["y", "x"]
+ exp_index = pd.MultiIndex.from_arrays(
+ [["a", "a", "b", "b", "c", "c"], [0, 1, 0, 1, 0, 1]], names=["y", "x"]
+ )
+ expected = pd.DataFrame(
+ w.transpose().reshape(-1), columns=["w"], index=exp_index
+ )
+ actual = ds.to_dataframe(dim_order=new_order)
+ assert expected.equals(actual)
+
+ invalid_order = ["x"]
+ with pytest.raises(
+ ValueError, match="does not match the set of dimensions of this"
+ ):
+ ds.to_dataframe(dim_order=invalid_order)
+
+ invalid_order = ["x", "z"]
+ with pytest.raises(
+ ValueError, match="does not match the set of dimensions of this"
+ ):
+ ds.to_dataframe(dim_order=invalid_order)
+
# check pathological cases
df = pd.DataFrame([1])
actual = Dataset.from_dataframe(df)
@@ -5896,8 +5923,6 @@ def test_rolling_keep_attrs():
def test_rolling_properties(ds):
# catching invalid args
- with pytest.raises(ValueError, match="exactly one dim/window should"):
- ds.rolling(time=7, x=2)
with pytest.raises(ValueError, match="window must be > 0"):
ds.rolling(time=-2)
with pytest.raises(ValueError, match="min_periods must be greater than zero"):
@@ -6022,6 +6047,66 @@ def test_rolling_reduce(ds, center, min_periods, window, name):
assert src_var.dims == actual[key].dims
+@pytest.mark.parametrize("ds", (2,), indirect=True)
+@pytest.mark.parametrize("center", (True, False))
+@pytest.mark.parametrize("min_periods", (None, 1))
+@pytest.mark.parametrize("name", ("sum", "max"))
+@pytest.mark.parametrize("dask", (True, False))
+def test_ndrolling_reduce(ds, center, min_periods, name, dask):
+ if dask and has_dask:
+ ds = ds.chunk({"x": 4})
+
+ rolling_obj = ds.rolling(time=4, x=3, center=center, min_periods=min_periods)
+
+ actual = getattr(rolling_obj, name)()
+ expected = getattr(
+ getattr(
+ ds.rolling(time=4, center=center, min_periods=min_periods), name
+ )().rolling(x=3, center=center, min_periods=min_periods),
+ name,
+ )()
+ assert_allclose(actual, expected)
+ assert actual.dims == expected.dims
+
+ # Do it in the opposite order
+ expected = getattr(
+ getattr(
+ ds.rolling(x=3, center=center, min_periods=min_periods), name
+ )().rolling(time=4, center=center, min_periods=min_periods),
+ name,
+ )()
+
+ assert_allclose(actual, expected)
+ assert actual.dims == expected.dims
+
+
+@pytest.mark.parametrize("center", (True, False, (True, False)))
+@pytest.mark.parametrize("fill_value", (np.nan, 0.0))
+@pytest.mark.parametrize("dask", (True, False))
+def test_ndrolling_construct(center, fill_value, dask):
+ da = DataArray(
+ np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float),
+ dims=["x", "y", "z"],
+ coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)},
+ )
+ ds = xr.Dataset({"da": da})
+ if dask and has_dask:
+ ds = ds.chunk({"x": 4})
+
+ actual = ds.rolling(x=3, z=2, center=center).construct(
+ x="x1", z="z1", fill_value=fill_value
+ )
+ if not isinstance(center, tuple):
+ center = (center, center)
+ expected = (
+ ds.rolling(x=3, center=center[0])
+ .construct(x="x1", fill_value=fill_value)
+ .rolling(z=2, center=center[1])
+ .construct(z="z1", fill_value=fill_value)
+ )
+ assert_allclose(actual, expected)
+
+
def test_raise_no_warning_for_nan_in_binary_ops():
with pytest.warns(None) as record:
Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0
diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py
index feedcd27164..7d54aac36f8 100644
--- a/xarray/tests/test_duck_array_ops.py
+++ b/xarray/tests/test_duck_array_ops.py
@@ -33,6 +33,7 @@
arm_xfail,
assert_array_equal,
has_dask,
+ has_scipy,
raises_regex,
requires_cftime,
requires_dask,
@@ -332,6 +333,40 @@ def test_cftime_datetime_mean():
assert_equal(result, expected)
+@requires_cftime
+def test_cftime_datetime_mean_long_time_period():
+ import cftime
+
+ times = np.array(
+ [
+ [
+ cftime.DatetimeNoLeap(400, 12, 31, 0, 0, 0, 0),
+ cftime.DatetimeNoLeap(520, 12, 31, 0, 0, 0, 0),
+ ],
+ [
+ cftime.DatetimeNoLeap(520, 12, 31, 0, 0, 0, 0),
+ cftime.DatetimeNoLeap(640, 12, 31, 0, 0, 0, 0),
+ ],
+ [
+ cftime.DatetimeNoLeap(640, 12, 31, 0, 0, 0, 0),
+ cftime.DatetimeNoLeap(760, 12, 31, 0, 0, 0, 0),
+ ],
+ ]
+ )
+
+ da = DataArray(times, dims=["time", "d2"])
+ result = da.mean("d2")
+ expected = DataArray(
+ [
+ cftime.DatetimeNoLeap(460, 12, 31, 0, 0, 0, 0),
+ cftime.DatetimeNoLeap(580, 12, 31, 0, 0, 0, 0),
+ cftime.DatetimeNoLeap(700, 12, 31, 0, 0, 0, 0),
+ ],
+ dims=["time"],
+ )
+ assert_equal(result, expected)
+
+
@requires_cftime
@requires_dask
def test_cftime_datetime_mean_dask_error():
@@ -767,8 +802,8 @@ def test_timedelta_to_numeric(td):
@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("skipna", [True, False])
def test_least_squares(use_dask, skipna):
- if use_dask and not has_dask:
- pytest.skip("requires dask")
+ if use_dask and (not has_dask or not has_scipy):
+ pytest.skip("requires dask and scipy")
lhs = np.array([[1, 2], [1, 2], [3, 2]])
rhs = DataArray(np.array([3, 5, 7]), dims=("y",))
diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py
index 82de8080c80..1cc91266421 100644
--- a/xarray/tests/test_formatting.py
+++ b/xarray/tests/test_formatting.py
@@ -7,6 +7,7 @@
import xarray as xr
from xarray.core import formatting
+from xarray.core.npcompat import IS_NEP18_ACTIVE
from . import raises_regex
@@ -391,6 +392,44 @@ def test_array_repr(self):
assert actual == expected
+@pytest.mark.skipif(not IS_NEP18_ACTIVE, reason="requires __array_function__")
+def test_inline_variable_array_repr_custom_repr():
+ class CustomArray:
+ def __init__(self, value, attr):
+ self.value = value
+ self.attr = attr
+
+ def _repr_inline_(self, width):
+ formatted = f"({self.attr}) {self.value}"
+ if len(formatted) > width:
+ formatted = f"({self.attr}) ..."
+
+ return formatted
+
+ def __array_function__(self, *args, **kwargs):
+ return NotImplemented
+
+ @property
+ def shape(self):
+ return self.value.shape
+
+ @property
+ def dtype(self):
+ return self.value.dtype
+
+ @property
+ def ndim(self):
+ return self.value.ndim
+
+ value = CustomArray(np.array([20, 40]), "m")
+ variable = xr.Variable("x", value)
+
+ max_width = 10
+ actual = formatting.inline_variable_array_repr(variable, max_width=10)
+
+ assert actual == value._repr_inline_(max_width)
+
+
def test_set_numpy_options():
original_options = np.get_printoptions()
with formatting.set_numpy_options(threshold=10):
diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py
index 7a0dda216e2..17e418c3731 100644
--- a/xarray/tests/test_interp.py
+++ b/xarray/tests/test_interp.py
@@ -1,9 +1,18 @@
+from itertools import combinations, permutations
+
import numpy as np
import pandas as pd
import pytest
import xarray as xr
-from xarray.tests import assert_allclose, assert_equal, requires_cftime, requires_scipy
+from xarray.tests import (
+ assert_allclose,
+ assert_equal,
+ assert_identical,
+ requires_cftime,
+ requires_dask,
+ requires_scipy,
+)
from ..coding.cftimeindex import _parse_array_of_cftime_strings
from . import has_dask, has_scipy
@@ -63,12 +72,6 @@ def test_interpolate_1d(method, dim, case):
da = get_example_data(case)
xdest = np.linspace(0.0, 0.9, 80)
-
- if dim == "y" and case == 1:
- with pytest.raises(NotImplementedError):
- actual = da.interp(method=method, **{dim: xdest})
- pytest.skip("interpolation along chunked dimension is " "not yet supported")
-
actual = da.interp(method=method, **{dim: xdest})
# scipy interpolation for the reference
@@ -376,8 +379,6 @@ def test_errors(use_dask):
# invalid method
with pytest.raises(ValueError):
da.interp(x=[2, 0], method="boo")
- with pytest.raises(ValueError):
- da.interp(x=[2, 0], y=2, method="cubic")
with pytest.raises(ValueError):
da.interp(y=[2, 0], method="boo")
@@ -717,3 +718,120 @@ def test_decompose(method):
actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y"))
expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y"))
assert_allclose(actual, expected)
+
+
+@requires_scipy
+@requires_dask
+@pytest.mark.parametrize(
+ "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]
+)
+@pytest.mark.parametrize("chunked", [True, False])
+@pytest.mark.parametrize(
+ "data_ndim,interp_ndim,nscalar",
+ [
+ (data_ndim, interp_ndim, nscalar)
+ for data_ndim in range(1, 4)
+ for interp_ndim in range(1, data_ndim + 1)
+ for nscalar in range(0, interp_ndim + 1)
+ ],
+)
+def test_interpolate_chunk_1d(method, data_ndim, interp_ndim, nscalar, chunked):
+ """Interpolate nd array with multiple independant indexers
+
+ It should do a series of 1d interpolation
+ """
+
+ # 3d non chunked data
+ x = np.linspace(0, 1, 5)
+ y = np.linspace(2, 4, 7)
+ z = np.linspace(-0.5, 0.5, 11)
+ da = xr.DataArray(
+ data=np.sin(x[:, np.newaxis, np.newaxis])
+ * np.cos(y[:, np.newaxis])
+ * np.exp(z),
+ coords=[("x", x), ("y", y), ("z", z)],
+ )
+ kwargs = {"fill_value": "extrapolate"}
+
+ # choose the data dimensions
+ for data_dims in permutations(da.dims, data_ndim):
+
+ # select only data_ndim dim
+ da = da.isel( # take the middle line
+ {dim: len(da.coords[dim]) // 2 for dim in da.dims if dim not in data_dims}
+ )
+
+ # chunk data
+ da = da.chunk(chunks={dim: i + 1 for i, dim in enumerate(da.dims)})
+
+ # choose the interpolation dimensions
+ for interp_dims in permutations(da.dims, interp_ndim):
+ # choose the scalar interpolation dimensions
+ for scalar_dims in combinations(interp_dims, nscalar):
+ dest = {}
+ for dim in interp_dims:
+ if dim in scalar_dims:
+ # take the middle point
+ dest[dim] = 0.5 * (da.coords[dim][0] + da.coords[dim][-1])
+ else:
+ # pick some points, including outside the domain
+ before = 2 * da.coords[dim][0] - da.coords[dim][1]
+ after = 2 * da.coords[dim][-1] - da.coords[dim][-2]
+
+ dest[dim] = np.linspace(before, after, len(da.coords[dim]) * 13)
+ if chunked:
+ dest[dim] = xr.DataArray(data=dest[dim], dims=[dim])
+ dest[dim] = dest[dim].chunk(2)
+ actual = da.interp(method=method, **dest, kwargs=kwargs)
+ expected = da.compute().interp(method=method, **dest, kwargs=kwargs)
+
+ assert_identical(actual, expected)
+
+ # all the combinations are usualy not necessary
+ break
+ break
+ break
+
+
+@requires_scipy
+@requires_dask
+@pytest.mark.parametrize("method", ["linear", "nearest"])
+def test_interpolate_chunk_advanced(method):
+ """Interpolate nd array with an nd indexer sharing coordinates."""
+ # Create original array
+ x = np.linspace(-1, 1, 5)
+ y = np.linspace(-1, 1, 7)
+ z = np.linspace(-1, 1, 11)
+ t = np.linspace(0, 1, 13)
+ q = np.linspace(0, 1, 17)
+ da = xr.DataArray(
+ data=np.sin(x[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis])
+ * np.cos(y[:, np.newaxis, np.newaxis, np.newaxis])
+ * np.exp(z[:, np.newaxis, np.newaxis])
+ * t[:, np.newaxis]
+ + q,
+ dims=("x", "y", "z", "t", "q"),
+ coords={"x": x, "y": y, "z": z, "t": t, "q": q, "label": "dummy_attr"},
+ )
+
+ # Create indexer into `da` with shared coordinate ("full-twist" Möbius strip)
+ theta = np.linspace(0, 2 * np.pi, 5)
+ w = np.linspace(-0.25, 0.25, 7)
+ r = xr.DataArray(
+ data=1 + w[:, np.newaxis] * np.cos(theta), coords=[("w", w), ("theta", theta)],
+ )
+
+ x = r * np.cos(theta)
+ y = r * np.sin(theta)
+ z = xr.DataArray(
+ data=w[:, np.newaxis] * np.sin(theta), coords=[("w", w), ("theta", theta)],
+ )
+
+ kwargs = {"fill_value": None}
+ expected = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method)
+
+ da = da.chunk(2)
+ x = x.chunk(1)
+ z = z.chunk(3)
+ actual = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method)
+ assert_identical(actual, expected)
diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py
index 1002a9dd9e3..ccb825dc7e9 100644
--- a/xarray/tests/test_nputils.py
+++ b/xarray/tests/test_nputils.py
@@ -1,4 +1,5 @@
import numpy as np
+import pytest
from numpy.testing import assert_array_equal
from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous, rolling_window
@@ -47,3 +48,19 @@ def test_rolling():
actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0)
expected = np.stack([expected, expected * 1.1], axis=0)
assert_array_equal(actual, expected)
+
+
+@pytest.mark.parametrize("center", [[True, True], [False, False]])
+@pytest.mark.parametrize("axis", [(0, 1), (1, 2), (2, 0)])
+def test_nd_rolling(center, axis):
+ x = np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float)
+ window = [3, 3]
+ actual = rolling_window(
+ x, axis=axis, window=window, center=center, fill_value=np.nan
+ )
+ expected = x
+ for ax, win, cent in zip(axis, window, center):
+ expected = rolling_window(
+ expected, axis=ax, window=win, center=cent, fill_value=np.nan
+ )
+ assert_array_equal(actual, expected)
diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py
index 788c26f3b39..5a32e454222 100644
--- a/xarray/tests/test_plot.py
+++ b/xarray/tests/test_plot.py
@@ -2251,6 +2251,7 @@ def test_datetime_line_plot(self):
self.darray.plot.line()
+@pytest.mark.filterwarnings("ignore:setting an array element with a sequence")
@requires_nc_time_axis
@requires_cftime
class TestCFDatetimePlot(PlotTestCase):
diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py
index 39ad250246b..adc29a3cc92 100644
--- a/xarray/tests/test_testing.py
+++ b/xarray/tests/test_testing.py
@@ -37,7 +37,7 @@ def test_allclose_regression():
"obj1,obj2",
(
pytest.param(
- xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable",
+ xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable"
),
pytest.param(
xr.DataArray([1e-17, 2], dims="x"),