Skip to content

Commit

Permalink
Clean up Dims type annotation (#8606)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jan 16, 2024
1 parent 53fdfca commit 1580c2c
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 39 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ jobs:
python xarray/util/print_versions.py
- name: Install mypy
run: |
python -m pip install "mypy<1.8" --force-reinstall
python -m pip install "mypy<1.9" --force-reinstall
- name: Run mypy
run: |
Expand Down Expand Up @@ -174,7 +174,7 @@ jobs:
python xarray/util/print_versions.py
- name: Install mypy
run: |
python -m pip install "mypy<1.8" --force-reinstall
python -m pip install "mypy<1.9" --force-reinstall
- name: Run mypy
run: |
Expand Down
12 changes: 5 additions & 7 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from xarray.core.parallelcompat import get_chunked_array_type
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar
from xarray.core.utils import is_dict_like, is_scalar, parse_dims
from xarray.core.variable import Variable
from xarray.util.deprecation_helpers import deprecate_dims

Expand Down Expand Up @@ -1875,16 +1875,14 @@ def dot(
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dim is ...:
dim = all_dims
elif isinstance(dim, str):
dim = (dim,)
elif dim is None:
# find dimensions that occur more than one times
if dim is None:
# find dimensions that occur more than once
dim_counts: Counter = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dim = tuple(d for d, c in dim_counts.items() if c > 1)
else:
dim = parse_dims(dim, all_dims=tuple(all_dims))

dot_dims: set[Hashable] = set(dim)

Expand Down
7 changes: 4 additions & 3 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
import sys
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -182,8 +182,9 @@ def copy(
DsCompatible = Union["Dataset", "DaCompatible"]
GroupByCompatible = Union["Dataset", "DataArray"]

Dims = Union[str, Iterable[Hashable], "ellipsis", None]
OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None]
# Don't change to Hashable | Collection[Hashable]
# Read: https://github.com/pydata/xarray/issues/6142
Dims = Union[str, Collection[Hashable], "ellipsis", None]

# FYI in some cases we don't allow `None`, which this doesn't take account of.
T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]
Expand Down
26 changes: 11 additions & 15 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
Mapping,
MutableMapping,
MutableSet,
Sequence,
ValuesView,
)
from enum import Enum
Expand All @@ -76,7 +75,7 @@
import pandas as pd

if TYPE_CHECKING:
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray
from xarray.core.types import Dims, ErrorOptionsWithWarn, T_DuckArray

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -983,12 +982,9 @@ def drop_missing_dims(
)


T_None = TypeVar("T_None", None, "ellipsis")


@overload
def parse_dims(
dim: str | Iterable[Hashable] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand All @@ -999,12 +995,12 @@ def parse_dims(

@overload
def parse_dims(
dim: str | Iterable[Hashable] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
replace_none: Literal[False],
) -> tuple[Hashable, ...] | T_None:
) -> tuple[Hashable, ...] | None | ellipsis:
...


Expand Down Expand Up @@ -1051,7 +1047,7 @@ def parse_dims(

@overload
def parse_ordered_dims(
dim: str | Sequence[Hashable | ellipsis] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand All @@ -1062,17 +1058,17 @@ def parse_ordered_dims(

@overload
def parse_ordered_dims(
dim: str | Sequence[Hashable | ellipsis] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
replace_none: Literal[False],
) -> tuple[Hashable, ...] | T_None:
) -> tuple[Hashable, ...] | None | ellipsis:
...


def parse_ordered_dims(
dim: OrderedDims,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand Down Expand Up @@ -1126,9 +1122,9 @@ def parse_ordered_dims(
)


def _check_dims(dim: set[Hashable | ellipsis], all_dims: set[Hashable]) -> None:
wrong_dims = dim - all_dims
if wrong_dims and wrong_dims != {...}:
def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None:
wrong_dims = (dim - all_dims) - {...}
if wrong_dims:
wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims)
raise ValueError(
f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}"
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ def test_interpolate_chunk_1d(
if chunked:
dest[dim] = xr.DataArray(data=dest[dim], dims=[dim])
dest[dim] = dest[dim].chunk(2)
actual = da.interp(method=method, **dest, kwargs=kwargs) # type: ignore
expected = da.compute().interp(method=method, **dest, kwargs=kwargs) # type: ignore
actual = da.interp(method=method, **dest, kwargs=kwargs)
expected = da.compute().interp(method=method, **dest, kwargs=kwargs)

assert_identical(actual, expected)

Expand Down
18 changes: 8 additions & 10 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable, Sequence
from collections.abc import Hashable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -257,17 +257,18 @@ def test_infix_dims_errors(supplied, all_):
pytest.param("a", ("a",), id="str"),
pytest.param(["a", "b"], ("a", "b"), id="list_of_str"),
pytest.param(["a", 1], ("a", 1), id="list_mixed"),
pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"),
pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"),
pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"),
pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"),
pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"),
pytest.param((), (), id="empty_tuple"),
pytest.param(set(), (), id="empty_collection"),
pytest.param(None, None, id="None"),
pytest.param(..., ..., id="ellipsis"),
],
)
def test_parse_dims(
dim: str | Iterable[Hashable] | None,
expected: tuple[Hashable, ...],
) -> None:
def test_parse_dims(dim, expected):
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
actual = utils.parse_dims(dim, all_dims, replace_none=False)
assert actual == expected
Expand Down Expand Up @@ -297,7 +298,7 @@ def test_parse_dims_replace_none(dim: None | ellipsis) -> None:
pytest.param(["x", 2], id="list_missing_all"),
],
)
def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
def test_parse_dims_raises(dim):
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
with pytest.raises(ValueError, match="'x'"):
utils.parse_dims(dim, all_dims, check_exists=True)
Expand All @@ -313,10 +314,7 @@ def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
pytest.param(["a", ..., "b"], ("a", "c", "b"), id="list_with_middle_ellipsis"),
],
)
def test_parse_ordered_dims(
dim: str | Sequence[Hashable | ellipsis],
expected: tuple[Hashable, ...],
) -> None:
def test_parse_ordered_dims(dim, expected):
all_dims = ("a", "b", "c")
actual = utils.parse_ordered_dims(dim, all_dims)
assert actual == expected
Expand Down

0 comments on commit 1580c2c

Please sign in to comment.