From d11016b1f74ebe8b1707a7b5959ff6521de92129 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Tue, 24 Oct 2023 02:42:28 -0700 Subject: [PATCH] add array-api-package --- xarray/namedarray/array_api/__init__.py | 0 xarray/namedarray/array_api/_constants.py | 9 ++ .../array_api/_data_type_functions.py | 114 ++++++++++++++++++ xarray/namedarray/array_api/_dtypes.py | 23 ++++ .../array_api/_searching_functions.py | 7 ++ .../array_api/_statistical_functions.py | 0 6 files changed, 153 insertions(+) create mode 100644 xarray/namedarray/array_api/__init__.py create mode 100644 xarray/namedarray/array_api/_constants.py create mode 100644 xarray/namedarray/array_api/_data_type_functions.py create mode 100644 xarray/namedarray/array_api/_dtypes.py create mode 100644 xarray/namedarray/array_api/_searching_functions.py create mode 100644 xarray/namedarray/array_api/_statistical_functions.py diff --git a/xarray/namedarray/array_api/__init__.py b/xarray/namedarray/array_api/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/xarray/namedarray/array_api/_constants.py b/xarray/namedarray/array_api/_constants.py new file mode 100644 index 00000000000..8e3ba908436 --- /dev/null +++ b/xarray/namedarray/array_api/_constants.py @@ -0,0 +1,9 @@ +# Constants +# https://data-apis.org/array-api/latest/API_specification/constants.html + +import numpy as np + +e = np.e +inf = np.inf +nan = np.nan +pi = np.pi diff --git a/xarray/namedarray/array_api/_data_type_functions.py b/xarray/namedarray/array_api/_data_type_functions.py new file mode 100644 index 00000000000..60c021145fb --- /dev/null +++ b/xarray/namedarray/array_api/_data_type_functions.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import warnings +from typing import Any + +import numpy as np + +from xarray.namedarray._typing import _arrayapi, _DType, _ShapeType +from xarray.namedarray.core import NamedArray + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def astype( + x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True +) -> NamedArray[_ShapeType, _DType]: + """ + Copies an array to a specified data type irrespective of Type Promotion Rules rules. + + Parameters + ---------- + x : NamedArray + Array to cast. + dtype : _DType + Desired data type. + copy : bool, optional + Specifies whether to copy an array when the specified dtype matches the data + type of the input array x. + If True, a newly allocated array must always be returned. + If False and the specified dtype matches the data type of the input array, + the input array must be returned; otherwise, a newly allocated array must be + returned. Default: True. + + Returns + ------- + out : NamedArray + An array having the specified data type. The returned array must have the + same shape as x. + + Examples + -------- + >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5])) + >>> narr + + Array([1.5, 2.5], dtype=float64) + >>> astype(narr, np.dtype(np.int32)) + + Array([1, 2], dtype=int32) + """ + if isinstance(x._data, _arrayapi): + xp = x._data.__array_namespace__() + return x._new(data=xp.astype(x._data, dtype, copy=copy)) + + # np.astype doesn't exist yet: + return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] + + +def can_cast(from_, to, /): + if isinstance(from_, NamedArray): + from_ = from_.dtype + return nxp.can_cast(from_, to) + + +def finfo(type, /): + return nxp.finfo(type) + + +def iinfo(type, /): + return nxp.iinfo(type) + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype[np.generic]: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/xarray/namedarray/array_api/_dtypes.py b/xarray/namedarray/array_api/_dtypes.py new file mode 100644 index 00000000000..e5d830f9987 --- /dev/null +++ b/xarray/namedarray/array_api/_dtypes.py @@ -0,0 +1,23 @@ +# Use type code from numpy.array_api +from numpy.array_api._dtypes import ( # noqa: F401 + _all_dtypes, + _boolean_dtypes, + _dtype_categories, + _floating_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _numeric_dtypes, + _promotion_table, + _result_type, + bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) diff --git a/xarray/namedarray/array_api/_searching_functions.py b/xarray/namedarray/array_api/_searching_functions.py new file mode 100644 index 00000000000..5beb2e65780 --- /dev/null +++ b/xarray/namedarray/array_api/_searching_functions.py @@ -0,0 +1,7 @@ +# Searching Functions +# https://data-apis.org/array-api/latest/API_specification/searching_functions.html +import xarray.namedarray.core as xrna + + +def nonzero(x: xrna.NameArray, /): + return x._nonzero() diff --git a/xarray/namedarray/array_api/_statistical_functions.py b/xarray/namedarray/array_api/_statistical_functions.py new file mode 100644 index 00000000000..e69de29bb2d