From 4550a01c9dca27dd043d734bab1a78ef972be68b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 1 Dec 2023 19:52:10 +0100 Subject: [PATCH] Add expand_dims (#8407) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> --- xarray/core/variable.py | 2 +- xarray/namedarray/_array_api.py | 52 +++++++++++++++++++++++++++++++++ xarray/namedarray/_typing.py | 14 +++++++++ xarray/namedarray/core.py | 5 ++-- xarray/namedarray/utils.py | 15 +--------- xarray/tests/test_namedarray.py | 10 +++++-- 6 files changed, 78 insertions(+), 20 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d9102dc9e0a..3add7a1441e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2596,7 +2596,7 @@ def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: """ Use sparse-array as backend. """ - from xarray.namedarray.utils import _default as _default_named + from xarray.namedarray._typing import _default as _default_named if sparse_format is _default: sparse_format = _default_named diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index e205c4d4efe..b5c320e0b96 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -7,7 +7,11 @@ import numpy as np from xarray.namedarray._typing import ( + Default, _arrayapi, + _Axis, + _default, + _Dim, _DType, _ScalarType, _ShapeType, @@ -144,3 +148,51 @@ def real( xp = _get_data_namespace(x) out = x._new(data=xp.real(x._data)) return out + + +# %% Manipulation functions +def expand_dims( + x: NamedArray[Any, _DType], + /, + *, + dim: _Dim | Default = _default, + axis: _Axis = 0, +) -> NamedArray[Any, _DType]: + """ + Expands the shape of an array by inserting a new dimension of size one at the + position specified by dims. + + Parameters + ---------- + x : + Array to expand. + dim : + Dimension name. New dimension will be stored in the axis position. + axis : + (Not recommended) Axis position (zero-based). Default is 0. + + Returns + ------- + out : + An expanded output array having the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> expand_dims(x) + + Array([[[1., 2.], + [3., 4.]]], dtype=float64) + >>> expand_dims(x, dim="z") + + Array([[[1., 2.], + [3., 4.]]], dtype=float64) + """ + xp = _get_data_namespace(x) + dims = x.dims + if dim is _default: + dim = f"dim_{len(dims)}" + d = list(dims) + d.insert(axis, dim) + out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) + return out diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 0b972e19539..670a2076eb1 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -1,10 +1,12 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import Enum from types import ModuleType from typing import ( Any, Callable, + Final, Protocol, SupportsIndex, TypeVar, @@ -15,6 +17,14 @@ import numpy as np + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token + # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) @@ -49,6 +59,10 @@ def dtype(self) -> _DType_co: _ShapeType = TypeVar("_ShapeType", bound=Any) _ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) +_Axis = int +_Axes = tuple[_Axis, ...] +_AxisLike = Union[_Axis, _Axes] + _Chunks = tuple[_Shape, ...] _Dim = Hashable diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 002afe96358..b9ad27b6679 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -25,6 +25,7 @@ _arrayapi, _arrayfunction_or_api, _chunkedarray, + _default, _dtype, _DType_co, _ScalarType_co, @@ -33,13 +34,14 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array +from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray from xarray.core.types import Dims from xarray.namedarray._typing import ( + Default, _AttrsLike, _Chunks, _Dim, @@ -52,7 +54,6 @@ _ShapeType, duckarray, ) - from xarray.namedarray.utils import Default try: from dask.typing import ( diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 03eb0134231..4bd20931189 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -2,12 +2,7 @@ import sys from collections.abc import Hashable -from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - Final, -) +from typing import TYPE_CHECKING, Any import numpy as np @@ -31,14 +26,6 @@ DaskCollection: Any = NDArray # type: ignore -# Singleton type, as per https://github.com/python/typing/pull/240 -class Default(Enum): - token: Final = 0 - - -_default = Default.token - - def module_available(module: str) -> bool: """Checks whether a module is installed without importing it. diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index deeb5ce753a..c75b01e9e50 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -10,9 +10,13 @@ import pytest from xarray.core.indexing import ExplicitlyIndexed -from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co +from xarray.namedarray._typing import ( + _arrayfunction_or_api, + _default, + _DType_co, + _ShapeType_co, +) from xarray.namedarray.core import NamedArray, from_array -from xarray.namedarray.utils import _default if TYPE_CHECKING: from types import ModuleType @@ -20,13 +24,13 @@ from numpy.typing import ArrayLike, DTypeLike, NDArray from xarray.namedarray._typing import ( + Default, _AttrsLike, _DimsLike, _DType, _Shape, duckarray, ) - from xarray.namedarray.utils import Default class CustomArrayBase(Generic[_ShapeType_co, _DType_co]):