Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Dims type annotation #8606

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ jobs:
python xarray/util/print_versions.py
- name: Install mypy
run: |
python -m pip install "mypy<1.8" --force-reinstall
python -m pip install "mypy<1.9" --force-reinstall

- name: Run mypy
run: |
Expand Down Expand Up @@ -174,7 +174,7 @@ jobs:
python xarray/util/print_versions.py
- name: Install mypy
run: |
python -m pip install "mypy<1.8" --force-reinstall
python -m pip install "mypy<1.9" --force-reinstall

- name: Run mypy
run: |
Expand Down
12 changes: 5 additions & 7 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from xarray.core.parallelcompat import get_chunked_array_type
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar
from xarray.core.utils import is_dict_like, is_scalar, parse_dims
from xarray.core.variable import Variable
from xarray.util.deprecation_helpers import deprecate_dims

Expand Down Expand Up @@ -1875,16 +1875,14 @@ def dot(
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dim is ...:
dim = all_dims
elif isinstance(dim, str):
dim = (dim,)
elif dim is None:
# find dimensions that occur more than one times
if dim is None:
# find dimensions that occur more than once
dim_counts: Counter = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dim = tuple(d for d, c in dim_counts.items() if c > 1)
else:
dim = parse_dims(dim, all_dims=tuple(all_dims))

dot_dims: set[Hashable] = set(dim)

Expand Down
7 changes: 4 additions & 3 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
import sys
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -182,8 +182,9 @@ def copy(
DsCompatible = Union["Dataset", "DaCompatible"]
GroupByCompatible = Union["Dataset", "DataArray"]

Dims = Union[str, Iterable[Hashable], "ellipsis", None]
OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None]
# Don't change to Hashable | Collection[Hashable]
# Read: https://github.com/pydata/xarray/issues/6142
Dims = Union[str, Collection[Hashable], "ellipsis", None]

# FYI in some cases we don't allow `None`, which this doesn't take account of.
T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]
Expand Down
26 changes: 11 additions & 15 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
Mapping,
MutableMapping,
MutableSet,
Sequence,
ValuesView,
)
from enum import Enum
Expand All @@ -76,7 +75,7 @@
import pandas as pd

if TYPE_CHECKING:
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray
from xarray.core.types import Dims, ErrorOptionsWithWarn, T_DuckArray

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -983,12 +982,9 @@ def drop_missing_dims(
)


T_None = TypeVar("T_None", None, "ellipsis")


@overload
def parse_dims(
dim: str | Iterable[Hashable] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand All @@ -999,12 +995,12 @@ def parse_dims(

@overload
def parse_dims(
dim: str | Iterable[Hashable] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
replace_none: Literal[False],
) -> tuple[Hashable, ...] | T_None:
) -> tuple[Hashable, ...] | None | ellipsis:
...


Expand Down Expand Up @@ -1051,7 +1047,7 @@ def parse_dims(

@overload
def parse_ordered_dims(
dim: str | Sequence[Hashable | ellipsis] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand All @@ -1062,17 +1058,17 @@ def parse_ordered_dims(

@overload
def parse_ordered_dims(
dim: str | Sequence[Hashable | ellipsis] | T_None,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
replace_none: Literal[False],
) -> tuple[Hashable, ...] | T_None:
) -> tuple[Hashable, ...] | None | ellipsis:
...


def parse_ordered_dims(
dim: OrderedDims,
dim: Dims,
all_dims: tuple[Hashable, ...],
*,
check_exists: bool = True,
Expand Down Expand Up @@ -1126,9 +1122,9 @@ def parse_ordered_dims(
)


def _check_dims(dim: set[Hashable | ellipsis], all_dims: set[Hashable]) -> None:
wrong_dims = dim - all_dims
if wrong_dims and wrong_dims != {...}:
def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None:
wrong_dims = (dim - all_dims) - {...}
if wrong_dims:
wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims)
raise ValueError(
f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}"
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ def test_interpolate_chunk_1d(
if chunked:
dest[dim] = xr.DataArray(data=dest[dim], dims=[dim])
dest[dim] = dest[dim].chunk(2)
actual = da.interp(method=method, **dest, kwargs=kwargs) # type: ignore
expected = da.compute().interp(method=method, **dest, kwargs=kwargs) # type: ignore
actual = da.interp(method=method, **dest, kwargs=kwargs)
expected = da.compute().interp(method=method, **dest, kwargs=kwargs)

assert_identical(actual, expected)

Expand Down
18 changes: 8 additions & 10 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable, Sequence
from collections.abc import Hashable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -257,17 +257,18 @@ def test_infix_dims_errors(supplied, all_):
pytest.param("a", ("a",), id="str"),
pytest.param(["a", "b"], ("a", "b"), id="list_of_str"),
pytest.param(["a", 1], ("a", 1), id="list_mixed"),
pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"),
pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"),
pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"),
pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"),
pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"),
pytest.param((), (), id="empty_tuple"),
pytest.param(set(), (), id="empty_collection"),
pytest.param(None, None, id="None"),
pytest.param(..., ..., id="ellipsis"),
],
)
def test_parse_dims(
dim: str | Iterable[Hashable] | None,
expected: tuple[Hashable, ...],
) -> None:
def test_parse_dims(dim, expected):
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
actual = utils.parse_dims(dim, all_dims, replace_none=False)
assert actual == expected
Expand Down Expand Up @@ -297,7 +298,7 @@ def test_parse_dims_replace_none(dim: None | ellipsis) -> None:
pytest.param(["x", 2], id="list_missing_all"),
],
)
def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
def test_parse_dims_raises(dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove these annotations? IIUC these get mypy to use the tests to test our annotations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing that they validate is that the signature on the test matches the signature declared for the function being tested - quite pointless IMHO. Notably, they don't validate the parameters being passed from the lines above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the signature on the test matches the signature declared for the function being tested - quite pointless IMHO

I'm not completely sure what you mean by this. But without check_untyped_defs, having the -> None is the only way we test whether our annotations are correct (or am I wrong there?)

I think you should revert removing the annotations, at least the -> None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was trying to say is that the signature in the unit test causes mypy to verify that the signature declared in the unit test, dim: str | Iterable[Hashable] is indeed compatible with the signature declared in the parse_dims function, dim: Dims. Which IMHO is pointless.

It would have been useful if it verified that the values with which dim is actually populated in the test are legal for the Dims type, but it does not do that.

For example, in the future someone may short-sightedly change the definition of Dims from str | Collection[Hashable] to str | Sequence[Hashable]. One of the parameters in the test is

        pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"),

Nothing will trip, with or without annotations in the test signature, unless someone changes the actual implementation of parse_dims to crash if you pass a set to it.

Annotating the test signature would become useful if we rewrote it without parametrization:

def test_parse_dims() -> None:
    all_dims = ("a", "b", 1, ("b", "c"))  # selection of different Hashables

    # non-sequence collection
    actual = utils.parse_dims({"a", 1}, all_dims, replace_none=False)
    assert actual == tuple({"a", 1})

    ... # repeat for all other use cases

which would be a valid choice, but it would fail on the first bad use case instead of moving on and it would be less immediately obvious what broke. You win some, you lose some.

Would you like me to open a new PR where I rewrite the unit test without parametrization and with annotations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was trying to say is that the signature in the unit test causes mypy to verify that the signature declared in the unit test, dim: str | Iterable[Hashable] is indeed compatible with the signature declared in the parse_dims function, dim: Dims. Which IMHO is pointless.

Totally agree!

Annotating the test signature would become useful if we rewrote it without parametrization:

Yes. In this specific case it's still slightly useful to have annotations on this function.

  • The dim typing check isn't that useful, because we supply it atm
  • all_dims will be checked

Would you like me to open a new PR where I rewrite the unit test without parametrization and with annotations?

Sorry, no. (and to the extent you're interpreting my comment as arguably bad suggestions, I apologize, I wasn't suggesting doing this)

The thing that I do think we should do is get to a point where tests checking as many annotations as possible. There are two ways to do this:

  • check_untyped_defs=True
  • -> None on test functions

So even though in this case it's only slightly useful:

  • There's no downside
  • The principle of having -> None is helpful, and gets us closer to having a blanket check_typed_defs

...so I think we should restore -> None (and nothing else)


Just to put this in perspective, I'm not trying to be difficult / antagonistic. We previously had lots of incorrect annotations! And so I have done a decent amount of work moving xarray on this — you can see the progress we've made at converting the library to test annotations here, and notably this file is currently excluded. I started by adding -> None to lots of functions, so hence my protest that we're now undoing them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
with pytest.raises(ValueError, match="'x'"):
utils.parse_dims(dim, all_dims, check_exists=True)
Expand All @@ -313,10 +314,7 @@ def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
pytest.param(["a", ..., "b"], ("a", "c", "b"), id="list_with_middle_ellipsis"),
],
)
def test_parse_ordered_dims(
dim: str | Sequence[Hashable | ellipsis],
expected: tuple[Hashable, ...],
) -> None:
def test_parse_ordered_dims(dim, expected):
all_dims = ("a", "b", "c")
actual = utils.parse_ordered_dims(dim, all_dims)
assert actual == expected
Expand Down
Loading