Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Array API changes #7067

Merged
merged 8 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def maybe_encode_bools(var):
):
dims, data, attrs, encoding = _var_as_tuple(var)
attrs["dtype"] = "bool"
data = data.astype(dtype="i1", copy=True)
data = duck_array_ops.astype(data, dtype="i1", copy=True)
var = Variable(dims, data, attrs, encoding)
return var

Expand Down
62 changes: 47 additions & 15 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@
from numpy import zeros_like # noqa
from numpy import around, broadcast_to # noqa
from numpy import concatenate as _concatenate
from numpy import einsum, gradient, isclose, isin, isnan, isnat # noqa
from numpy import stack as _stack
from numpy import take, tensordot, transpose, unravel_index # noqa
from numpy import where as _where
from numpy import ( # noqa
einsum,
gradient,
isclose,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
isin,
isnat,
take,
tensordot,
transpose,
unravel_index,
)
from numpy.lib.stride_tricks import sliding_window_view # noqa

from . import dask_array_ops, dtypes, nputils
Expand All @@ -36,6 +43,13 @@
dask_array = None # type: ignore


def get_array_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
else:
return np


def _dask_or_eager_func(
name,
eager_module=np,
Expand Down Expand Up @@ -108,7 +122,8 @@ def isnull(data):
return isnat(data)
elif issubclass(scalar_type, np.inexact):
# float types use NaN for null
return isnan(data)
xp = get_array_namespace(data)
return xp.isnan(data)
elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
# these types cannot represent missing values
return zeros_like(data, dtype=bool)
Expand Down Expand Up @@ -164,28 +179,31 @@ def cumulative_trapezoid(y, x, axis):


def astype(data, dtype, **kwargs):
if hasattr(data, "__array_namespace__"):
xp = get_array_namespace(data)
return xp.astype(data, dtype, **kwargs)
return data.astype(dtype, **kwargs)


def asarray(data, xp=np):
return data if is_duck_array(data) else xp.asarray(data)


def as_shared_dtype(scalars_or_arrays):
def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""

if any(isinstance(x, cupy_array_type) for x in scalars_or_arrays):
import cupy as cp

arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
else:
arrays = [asarray(x) for x in scalars_or_arrays]
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
# Note that result_type() safely gets the dtype from dask arrays without
# evaluating them.
out_type = dtypes.result_type(*arrays)
return [x.astype(out_type, copy=False) for x in arrays]
return [astype(x, out_type, copy=False) for x in arrays]


def lazy_array_equiv(arr1, arr2):
Expand Down Expand Up @@ -259,9 +277,20 @@ def count(data, axis=None):
return np.sum(np.logical_not(isnull(data)), axis=axis)


def sum_where(data, axis=None, dtype=None, where=None):
xp = get_array_namespace(data)
if where is not None:
a = where_method(xp.zeros_like(data), where, data)
else:
a = data
result = xp.sum(a, axis=axis, dtype=dtype)
return result


def where(condition, x, y):
"""Three argument where() with better dtype promotion rules."""
return _where(condition, *as_shared_dtype([x, y]))
xp = get_array_namespace(condition)
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))


def where_method(data, cond, other=dtypes.NA):
Expand All @@ -284,7 +313,13 @@ def concatenate(arrays, axis=0):

def stack(arrays, axis=0):
"""stack() with better dtype promotion rules."""
return _stack(as_shared_dtype(arrays), axis=axis)
xp = get_array_namespace(arrays[0])
return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)


def reshape(array, shape):
xp = get_array_namespace(array)
return xp.reshape(array, shape)


@contextlib.contextmanager
Expand Down Expand Up @@ -323,11 +358,8 @@ def f(values, axis=None, skipna=None, **kwargs):
if name in ["sum", "prod"]:
kwargs.pop("min_count", None)

if hasattr(values, "__array_namespace__"):
xp = values.__array_namespace__()
func = getattr(xp, name)
else:
func = getattr(np, name)
xp = get_array_namespace(values)
func = getattr(xp, name)

try:
with warnings.catch_warnings():
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from . import dtypes, nputils, utils
from .duck_array_ops import count, fillna, isnull, where, where_method
from .duck_array_ops import count, fillna, isnull, sum_where, where, where_method


def _maybe_null_out(result, axis, mask, min_count=1):
Expand Down Expand Up @@ -84,7 +84,7 @@ def nanargmax(a, axis=None):

def nansum(a, axis=None, dtype=None, out=None, min_count=None):
mask = isnull(a)
result = np.nansum(a, axis=axis, dtype=dtype)
result = sum_where(a, axis=axis, dtype=dtype, where=mask)
if min_count is not None:
return _maybe_null_out(result, axis, mask, min_count)
else:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,7 +1637,7 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable):
reordered = self.transpose(*dim_order)

new_shape = reordered.shape[: len(other_dims)] + (-1,)
new_data = reordered.data.reshape(new_shape)
new_data = duck_array_ops.reshape(reordered.data, new_shape)
new_dims = reordered.dims[: len(other_dims)] + (new_dim,)

return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True)
Expand Down
46 changes: 44 additions & 2 deletions xarray/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@

@pytest.fixture
def arrays() -> tuple[xr.DataArray, xr.DataArray]:
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
np_arr = xr.DataArray(
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]),
dims=("x", "y"),
coords={"x": [10, 20]},
)
xp_arr = xr.DataArray(
xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]),
dims=("x", "y"),
coords={"x": [10, 20]},
)
assert isinstance(xp_arr.data, Array)
return np_arr, xp_arr

Expand All @@ -32,13 +40,30 @@ def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:


def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
np_arr, xp_arr = arrays
expected = np_arr.sum()
actual = xp_arr.sum()
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_aggregation_skipna(arrays) -> None:
np_arr, xp_arr = arrays
expected = np_arr.sum(skipna=False)
actual = xp_arr.sum(skipna=False)
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_astype(arrays) -> None:
np_arr, xp_arr = arrays
expected = np_arr.astype(np.int64)
actual = xp_arr.astype(np.int64)
assert actual.dtype == np.int64
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
np_arr, xp_arr = arrays
expected = np_arr[:, 0]
Expand All @@ -59,3 +84,20 @@ def test_reorganizing_operation(arrays: tuple[xr.DataArray, xr.DataArray]) -> No
actual = xp_arr.transpose()
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_stack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
np_arr, xp_arr = arrays
expected = np_arr.stack(z=("x", "y"))
actual = xp_arr.stack(z=("x", "y"))
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_where() -> None:
np_arr = xr.DataArray(np.array([1, 0]), dims="x")
xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")
expected = xr.where(np_arr, 1, 0)
actual = xr.where(xp_arr, 1, 0)
assert isinstance(actual.data, Array)
assert_equal(actual, expected)