diff --git a/xarray/conventions.py b/xarray/conventions.py index 8bd316d199f..695bed3b365 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -141,7 +141,7 @@ def maybe_encode_bools(var): ): dims, data, attrs, encoding = _var_as_tuple(var) attrs["dtype"] = "bool" - data = data.astype(dtype="i1", copy=True) + data = duck_array_ops.astype(data, dtype="i1", copy=True) var = Variable(dims, data, attrs, encoding) return var diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 8c8f2443967..8c92bc4ee6e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,10 +18,17 @@ from numpy import zeros_like # noqa from numpy import around, broadcast_to # noqa from numpy import concatenate as _concatenate -from numpy import einsum, gradient, isclose, isin, isnan, isnat # noqa -from numpy import stack as _stack -from numpy import take, tensordot, transpose, unravel_index # noqa -from numpy import where as _where +from numpy import ( # noqa + einsum, + gradient, + isclose, + isin, + isnat, + take, + tensordot, + transpose, + unravel_index, +) from numpy.lib.stride_tricks import sliding_window_view # noqa from . import dask_array_ops, dtypes, nputils @@ -36,6 +43,13 @@ dask_array = None # type: ignore +def get_array_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + else: + return np + + def _dask_or_eager_func( name, eager_module=np, @@ -108,7 +122,8 @@ def isnull(data): return isnat(data) elif issubclass(scalar_type, np.inexact): # float types use NaN for null - return isnan(data) + xp = get_array_namespace(data) + return xp.isnan(data) elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)): # these types cannot represent missing values return zeros_like(data, dtype=bool) @@ -164,6 +179,9 @@ def cumulative_trapezoid(y, x, axis): def astype(data, dtype, **kwargs): + if hasattr(data, "__array_namespace__"): + xp = get_array_namespace(data) + return xp.astype(data, dtype, **kwargs) return data.astype(dtype, **kwargs) @@ -171,7 +189,7 @@ def asarray(data, xp=np): return data if is_duck_array(data) else xp.asarray(data) -def as_shared_dtype(scalars_or_arrays): +def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" if any(isinstance(x, cupy_array_type) for x in scalars_or_arrays): @@ -179,13 +197,13 @@ def as_shared_dtype(scalars_or_arrays): arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] else: - arrays = [asarray(x) for x in scalars_or_arrays] + arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars # get handled properly. # Note that result_type() safely gets the dtype from dask arrays without # evaluating them. out_type = dtypes.result_type(*arrays) - return [x.astype(out_type, copy=False) for x in arrays] + return [astype(x, out_type, copy=False) for x in arrays] def lazy_array_equiv(arr1, arr2): @@ -259,9 +277,20 @@ def count(data, axis=None): return np.sum(np.logical_not(isnull(data)), axis=axis) +def sum_where(data, axis=None, dtype=None, where=None): + xp = get_array_namespace(data) + if where is not None: + a = where_method(xp.zeros_like(data), where, data) + else: + a = data + result = xp.sum(a, axis=axis, dtype=dtype) + return result + + def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" - return _where(condition, *as_shared_dtype([x, y])) + xp = get_array_namespace(condition) + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): @@ -284,7 +313,13 @@ def concatenate(arrays, axis=0): def stack(arrays, axis=0): """stack() with better dtype promotion rules.""" - return _stack(as_shared_dtype(arrays), axis=axis) + xp = get_array_namespace(arrays[0]) + return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis) + + +def reshape(array, shape): + xp = get_array_namespace(array) + return xp.reshape(array, shape) @contextlib.contextmanager @@ -323,11 +358,8 @@ def f(values, axis=None, skipna=None, **kwargs): if name in ["sum", "prod"]: kwargs.pop("min_count", None) - if hasattr(values, "__array_namespace__"): - xp = values.__array_namespace__() - func = getattr(xp, name) - else: - func = getattr(np, name) + xp = get_array_namespace(values) + func = getattr(xp, name) try: with warnings.catch_warnings(): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 4c71658b577..651fd9aca17 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -5,7 +5,7 @@ import numpy as np from . import dtypes, nputils, utils -from .duck_array_ops import count, fillna, isnull, where, where_method +from .duck_array_ops import count, fillna, isnull, sum_where, where, where_method def _maybe_null_out(result, axis, mask, min_count=1): @@ -84,7 +84,7 @@ def nanargmax(a, axis=None): def nansum(a, axis=None, dtype=None, out=None, min_count=None): mask = isnull(a) - result = np.nansum(a, axis=axis, dtype=dtype) + result = sum_where(a, axis=axis, dtype=dtype, where=mask) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 796c178f2a0..c70cd45b502 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1637,7 +1637,7 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable): reordered = self.transpose(*dim_order) new_shape = reordered.shape[: len(other_dims)] + (-1,) - new_data = reordered.data.reshape(new_shape) + new_data = duck_array_ops.reshape(reordered.data, new_shape) new_dims = reordered.dims[: len(other_dims)] + (new_dim,) return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index a15492028dd..7940c979249 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -17,8 +17,16 @@ @pytest.fixture def arrays() -> tuple[xr.DataArray, xr.DataArray]: - np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) - xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) + np_arr = xr.DataArray( + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]), + dims=("x", "y"), + coords={"x": [10, 20]}, + ) + xp_arr = xr.DataArray( + xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]), + dims=("x", "y"), + coords={"x": [10, 20]}, + ) assert isinstance(xp_arr.data, Array) return np_arr, xp_arr @@ -32,6 +40,14 @@ def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.sum() + actual = xp_arr.sum() + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_aggregation_skipna(arrays) -> None: np_arr, xp_arr = arrays expected = np_arr.sum(skipna=False) actual = xp_arr.sum(skipna=False) @@ -39,6 +55,15 @@ def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) +def test_astype(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr.astype(np.int64) + actual = xp_arr.astype(np.int64) + assert actual.dtype == np.int64 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = np_arr[:, 0] @@ -59,3 +84,20 @@ def test_reorganizing_operation(arrays: tuple[xr.DataArray, xr.DataArray]) -> No actual = xp_arr.transpose() assert isinstance(actual.data, Array) assert_equal(actual, expected) + + +def test_stack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.stack(z=("x", "y")) + actual = xp_arr.stack(z=("x", "y")) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_where() -> None: + np_arr = xr.DataArray(np.array([1, 0]), dims="x") + xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x") + expected = xr.where(np_arr, 1, 0) + actual = xr.where(xp_arr, 1, 0) + assert isinstance(actual.data, Array) + assert_equal(actual, expected)