Skip to content

Commit

Permalink
add array-api-package
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Oct 24, 2023
1 parent eb74944 commit d11016b
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 0 deletions.
Empty file.
9 changes: 9 additions & 0 deletions xarray/namedarray/array_api/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Constants
# https://data-apis.org/array-api/latest/API_specification/constants.html

import numpy as np

e = np.e
inf = np.inf
nan = np.nan
pi = np.pi
114 changes: 114 additions & 0 deletions xarray/namedarray/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import warnings
from typing import Any

import numpy as np

from xarray.namedarray._typing import _arrayapi, _DType, _ShapeType
from xarray.namedarray.core import NamedArray

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
r"The numpy.array_api submodule is still experimental",
category=UserWarning,
)
import numpy.array_api as nxp

# Pairs of types that, if both found, should be promoted to object dtype
# instead of following NumPy's own type-promotion rules. These type promotion
# rules match pandas instead. For reference, see the NumPy type hierarchy:
# https://numpy.org/doc/stable/reference/arrays.scalars.html
PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
(np.number, np.character), # numpy promotes to character
(np.bool_, np.character), # numpy promotes to character
(np.bytes_, np.str_), # numpy promotes to unicode
)


def astype(
x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True
) -> NamedArray[_ShapeType, _DType]:
"""
Copies an array to a specified data type irrespective of Type Promotion Rules rules.
Parameters
----------
x : NamedArray
Array to cast.
dtype : _DType
Desired data type.
copy : bool, optional
Specifies whether to copy an array when the specified dtype matches the data
type of the input array x.
If True, a newly allocated array must always be returned.
If False and the specified dtype matches the data type of the input array,
the input array must be returned; otherwise, a newly allocated array must be
returned. Default: True.
Returns
-------
out : NamedArray
An array having the specified data type. The returned array must have the
same shape as x.
Examples
--------
>>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5]))
>>> narr
<xarray.NamedArray (x: 2)>
Array([1.5, 2.5], dtype=float64)
>>> astype(narr, np.dtype(np.int32))
<xarray.NamedArray (x: 2)>
Array([1, 2], dtype=int32)
"""
if isinstance(x._data, _arrayapi):
xp = x._data.__array_namespace__()
return x._new(data=xp.astype(x._data, dtype, copy=copy))

# np.astype doesn't exist yet:
return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined]


def can_cast(from_, to, /):
if isinstance(from_, NamedArray):
from_ = from_.dtype
return nxp.can_cast(from_, to)


def finfo(type, /):
return nxp.finfo(type)


def iinfo(type, /):
return nxp.iinfo(type)


def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
) -> np.dtype[np.generic]:
"""Like np.result_type, but with type promotion rules matching pandas.
Examples of changed behavior:
number + string -> object (not string)
bytes + unicode -> object (not unicode)
Parameters
----------
*arrays_and_dtypes : list of arrays and dtypes
The dtype is extracted from both numpy and dask arrays.
Returns
-------
numpy.dtype for the result.
"""
types = {np.result_type(t).type for t in arrays_and_dtypes}

for left, right in PROMOTE_TO_OBJECT:
if any(issubclass(t, left) for t in types) and any(
issubclass(t, right) for t in types
):
return np.dtype(object)

return np.result_type(*arrays_and_dtypes)
23 changes: 23 additions & 0 deletions xarray/namedarray/array_api/_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Use type code from numpy.array_api
from numpy.array_api._dtypes import ( # noqa: F401
_all_dtypes,
_boolean_dtypes,
_dtype_categories,
_floating_dtypes,
_integer_dtypes,
_integer_or_boolean_dtypes,
_numeric_dtypes,
_promotion_table,
_result_type,
bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
)
7 changes: 7 additions & 0 deletions xarray/namedarray/array_api/_searching_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Searching Functions
# https://data-apis.org/array-api/latest/API_specification/searching_functions.html
import xarray.namedarray.core as xrna


def nonzero(x: xrna.NameArray, /):
return x._nonzero()
Empty file.

0 comments on commit d11016b

Please sign in to comment.