Skip to content

Commit

Permalink
Add expand_dims (#8407)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 1, 2023
1 parent c93b31a commit 4550a01
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 20 deletions.
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,7 +2596,7 @@ def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable:
"""
Use sparse-array as backend.
"""
from xarray.namedarray.utils import _default as _default_named
from xarray.namedarray._typing import _default as _default_named

if sparse_format is _default:
sparse_format = _default_named
Expand Down
52 changes: 52 additions & 0 deletions xarray/namedarray/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import numpy as np

from xarray.namedarray._typing import (
Default,
_arrayapi,
_Axis,
_default,
_Dim,
_DType,
_ScalarType,
_ShapeType,
Expand Down Expand Up @@ -144,3 +148,51 @@ def real(
xp = _get_data_namespace(x)
out = x._new(data=xp.real(x._data))
return out


# %% Manipulation functions
def expand_dims(
x: NamedArray[Any, _DType],
/,
*,
dim: _Dim | Default = _default,
axis: _Axis = 0,
) -> NamedArray[Any, _DType]:
"""
Expands the shape of an array by inserting a new dimension of size one at the
position specified by dims.
Parameters
----------
x :
Array to expand.
dim :
Dimension name. New dimension will be stored in the axis position.
axis :
(Not recommended) Axis position (zero-based). Default is 0.
Returns
-------
out :
An expanded output array having the same data type as x.
Examples
--------
>>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]]))
>>> expand_dims(x)
<xarray.NamedArray (dim_2: 1, x: 2, y: 2)>
Array([[[1., 2.],
[3., 4.]]], dtype=float64)
>>> expand_dims(x, dim="z")
<xarray.NamedArray (z: 1, x: 2, y: 2)>
Array([[[1., 2.],
[3., 4.]]], dtype=float64)
"""
xp = _get_data_namespace(x)
dims = x.dims
if dim is _default:
dim = f"dim_{len(dims)}"
d = list(dims)
d.insert(axis, dim)
out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
return out
14 changes: 14 additions & 0 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable, Mapping, Sequence
from enum import Enum
from types import ModuleType
from typing import (
Any,
Callable,
Final,
Protocol,
SupportsIndex,
TypeVar,
Expand All @@ -15,6 +17,14 @@

import numpy as np


# Singleton type, as per https://github.com/python/typing/pull/240
class Default(Enum):
token: Final = 0


_default = Default.token

# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
Expand Down Expand Up @@ -49,6 +59,10 @@ def dtype(self) -> _DType_co:
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True)

_Axis = int
_Axes = tuple[_Axis, ...]
_AxisLike = Union[_Axis, _Axes]

_Chunks = tuple[_Shape, ...]

_Dim = Hashable
Expand Down
5 changes: 3 additions & 2 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_arrayapi,
_arrayfunction_or_api,
_chunkedarray,
_default,
_dtype,
_DType_co,
_ScalarType_co,
Expand All @@ -33,13 +34,14 @@
_SupportsImag,
_SupportsReal,
)
from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array
from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array

if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray

from xarray.core.types import Dims
from xarray.namedarray._typing import (
Default,
_AttrsLike,
_Chunks,
_Dim,
Expand All @@ -52,7 +54,6 @@
_ShapeType,
duckarray,
)
from xarray.namedarray.utils import Default

try:
from dask.typing import (
Expand Down
15 changes: 1 addition & 14 deletions xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@

import sys
from collections.abc import Hashable
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Final,
)
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -31,14 +26,6 @@
DaskCollection: Any = NDArray # type: ignore


# Singleton type, as per https://github.com/python/typing/pull/240
class Default(Enum):
token: Final = 0


_default = Default.token


def module_available(module: str) -> bool:
"""Checks whether a module is installed without importing it.
Expand Down
10 changes: 7 additions & 3 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@
import pytest

from xarray.core.indexing import ExplicitlyIndexed
from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co
from xarray.namedarray._typing import (
_arrayfunction_or_api,
_default,
_DType_co,
_ShapeType_co,
)
from xarray.namedarray.core import NamedArray, from_array
from xarray.namedarray.utils import _default

if TYPE_CHECKING:
from types import ModuleType

from numpy.typing import ArrayLike, DTypeLike, NDArray

from xarray.namedarray._typing import (
Default,
_AttrsLike,
_DimsLike,
_DType,
_Shape,
duckarray,
)
from xarray.namedarray.utils import Default


class CustomArrayBase(Generic[_ShapeType_co, _DType_co]):
Expand Down

0 comments on commit 4550a01

Please sign in to comment.