Skip to content

Commit

Permalink
Annotations for .data_vars() and .coords() (#3207)
Browse files Browse the repository at this point in the history
* Annotations for .data_vars() and .coords()

* Finish annotations for coordinates.py
  • Loading branch information
crusaderky authored Aug 12, 2019
1 parent fc44bae commit 14f1a97
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 58 deletions.
132 changes: 82 additions & 50 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import collections.abc
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Hashable, Mapping, Iterator, Union, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
Hashable,
Mapping,
Iterator,
Union,
Set,
Tuple,
Sequence,
cast,
)

import pandas as pd

from . import formatting, indexing
from .indexes import Indexes
from .merge import (
expand_and_merge_variables,
merge_coords,
Expand All @@ -23,49 +34,58 @@
_THIS_ARRAY = ReprObject("<this-array>")


class AbstractCoordinates(collections.abc.Mapping):
def __getitem__(self, key):
raise NotImplementedError
class AbstractCoordinates(Mapping[Hashable, "DataArray"]):
_data = None # type: Union["DataArray", "Dataset"]

def __setitem__(self, key, value):
def __getitem__(self, key: Hashable) -> "DataArray":
raise NotImplementedError()

def __setitem__(self, key: Hashable, value: Any) -> None:
self.update({key: value})

@property
def indexes(self):
def _names(self) -> Set[Hashable]:
raise NotImplementedError()

@property
def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]:
raise NotImplementedError()

@property
def indexes(self) -> Indexes:
return self._data.indexes

@property
def variables(self):
raise NotImplementedError
raise NotImplementedError()

def _update_coords(self, coords):
raise NotImplementedError
raise NotImplementedError()

def __iter__(self):
def __iter__(self) -> Iterator["Hashable"]:
# needs to be in the same order as the dataset variables
for k in self.variables:
if k in self._names:
yield k

def __len__(self):
def __len__(self) -> int:
return len(self._names)

def __contains__(self, key):
def __contains__(self, key: Hashable) -> bool:
return key in self._names

def __repr__(self):
def __repr__(self) -> str:
return formatting.coords_repr(self)

@property
def dims(self):
return self._data.dims
def to_dataset(self) -> "Dataset":
raise NotImplementedError()

def to_index(self, ordered_dims=None):
def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
"""Convert all index coordinates into a :py:class:`pandas.Index`.
Parameters
----------
ordered_dims : sequence, optional
ordered_dims : sequence of hashable, optional
Possibly reordered version of this object's dimensions indicating
the order in which dimensions should appear on the result.
Expand All @@ -77,7 +97,7 @@ def to_index(self, ordered_dims=None):
than more dimension.
"""
if ordered_dims is None:
ordered_dims = self.dims
ordered_dims = list(self.dims)
elif set(ordered_dims) != set(self.dims):
raise ValueError(
"ordered_dims must match dims, but does not: "
Expand All @@ -94,7 +114,7 @@ def to_index(self, ordered_dims=None):
names = list(ordered_dims)
return pd.MultiIndex.from_product(indexes, names=names)

def update(self, other):
def update(self, other: Mapping[Hashable, Any]) -> None:
other_vars = getattr(other, "variables", other)
coords = merge_coords(
[self.variables, other_vars], priority_arg=1, indexes=self.indexes
Expand Down Expand Up @@ -127,7 +147,7 @@ def _merge_inplace(self, other):
yield
self._update_coords(variables)

def merge(self, other):
def merge(self, other: "AbstractCoordinates") -> "Dataset":
"""Merge two sets of coordinates to create a new Dataset
The method implements the logic used for joining coordinates in the
Expand Down Expand Up @@ -167,32 +187,38 @@ class DatasetCoordinates(AbstractCoordinates):
objects.
"""

def __init__(self, dataset):
_data = None # type: Dataset

def __init__(self, dataset: "Dataset"):
self._data = dataset

@property
def _names(self):
def _names(self) -> Set[Hashable]:
return self._data._coord_names

@property
def variables(self):
def dims(self) -> Mapping[Hashable, int]:
return self._data.dims

@property
def variables(self) -> Mapping[Hashable, Variable]:
return Frozen(
OrderedDict(
(k, v) for k, v in self._data.variables.items() if k in self._names
)
)

def __getitem__(self, key):
def __getitem__(self, key: Hashable) -> "DataArray":
if key in self._data.data_vars:
raise KeyError(key)
return self._data[key]
return cast("DataArray", self._data[key])

def to_dataset(self):
def to_dataset(self) -> "Dataset":
"""Convert these coordinates into a new Dataset
"""
return self._data._copy_listed(self._names)

def _update_coords(self, coords):
def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
from .dataset import calculate_dimensions

variables = self._data._variables.copy()
Expand All @@ -210,7 +236,7 @@ def _update_coords(self, coords):
self._data._dims = dims
self._data._indexes = None

def __delitem__(self, key):
def __delitem__(self, key: Hashable) -> None:
if key in self:
del self._data[key]
else:
Expand All @@ -232,17 +258,23 @@ class DataArrayCoordinates(AbstractCoordinates):
dimensions and the values given by corresponding DataArray objects.
"""

def __init__(self, dataarray):
_data = None # type: DataArray

def __init__(self, dataarray: "DataArray"):
self._data = dataarray

@property
def _names(self):
def dims(self) -> Tuple[Hashable, ...]:
return self._data.dims

@property
def _names(self) -> Set[Hashable]:
return set(self._data._coords)

def __getitem__(self, key):
def __getitem__(self, key: Hashable) -> "DataArray":
return self._data._getitem_coord(key)

def _update_coords(self, coords):
def _update_coords(self, coords) -> None:
from .dataset import calculate_dimensions

coords_plus_data = coords.copy()
Expand All @@ -259,19 +291,15 @@ def _update_coords(self, coords):
def variables(self):
return Frozen(self._data._coords)

def _to_dataset(self, shallow_copy=True):
def to_dataset(self) -> "Dataset":
from .dataset import Dataset

coords = OrderedDict(
(k, v.copy(deep=False) if shallow_copy else v)
for k, v in self._data._coords.items()
(k, v.copy(deep=False)) for k, v in self._data._coords.items()
)
return Dataset._from_vars_and_coord_names(coords, set(coords))

def to_dataset(self):
return self._to_dataset()

def __delitem__(self, key):
def __delitem__(self, key: Hashable) -> None:
del self._data._coords[key]

def _ipython_key_completions_(self):
Expand Down Expand Up @@ -300,9 +328,10 @@ def __len__(self) -> int:
return len(self._data._level_coords)


def assert_coordinate_consistent(obj, coords):
""" Maeke sure the dimension coordinate of obj is
consistent with coords.
def assert_coordinate_consistent(
obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable]
) -> None:
"""Make sure the dimension coordinate of obj is consistent with coords.
obj: DataArray or Dataset
coords: Dict-like of variables
Expand All @@ -320,17 +349,20 @@ def assert_coordinate_consistent(obj, coords):


def remap_label_indexers(
obj, indexers=None, method=None, tolerance=None, **indexers_kwargs
):
"""
Remap **indexers from obj.coords.
If indexer is an instance of DataArray and it has coordinate, then this
coordinate will be attached to pos_indexers.
obj: Union["DataArray", "Dataset"],
indexers: Mapping[Hashable, Any] = None,
method: str = None,
tolerance=None,
**indexers_kwargs: Any
) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing
"""Remap indexers from obj.coords.
If indexer is an instance of DataArray and it has coordinate, then this coordinate
will be attached to pos_indexers.
Returns
-------
pos_indexers: Same type of indexers.
np.ndarray or Variable or DataArra
np.ndarray or Variable or DataArray
new_indexes: mapping of new dimensional-coordinate.
"""
from .dataarray import DataArray
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __setitem__(self, key, value) -> None:
labels = indexing.expanded_indexer(key, self.data_array.ndim)
key = dict(zip(self.data_array.dims, labels))

pos_indexers, _ = remap_label_indexers(self.data_array, **key)
pos_indexers, _ = remap_label_indexers(self.data_array, key)
self.data_array[pos_indexers] = value


Expand Down
13 changes: 6 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def as_dataset(obj: Any) -> "Dataset":
return obj


class DataVariables(Mapping[Hashable, "Union[DataArray, Dataset]"]):
class DataVariables(Mapping[Hashable, "DataArray"]):
def __init__(self, dataset: "Dataset"):
self._dataset = dataset

Expand All @@ -349,14 +349,13 @@ def __iter__(self) -> Iterator[Hashable]:
def __len__(self) -> int:
return len(self._dataset._variables) - len(self._dataset._coord_names)

def __contains__(self, key) -> bool:
def __contains__(self, key: Hashable) -> bool:
return key in self._dataset._variables and key not in self._dataset._coord_names

def __getitem__(self, key) -> "Union[DataArray, Dataset]":
def __getitem__(self, key: Hashable) -> "DataArray":
if key not in self._dataset._coord_names:
return self._dataset[key]
else:
raise KeyError(key)
return cast("DataArray", self._dataset[key])
raise KeyError(key)

def __repr__(self) -> str:
return formatting.data_vars_repr(self)
Expand Down Expand Up @@ -1317,7 +1316,7 @@ def identical(self, other: "Dataset") -> bool:
return False

@property
def indexes(self) -> "Mapping[Any, pd.Index]":
def indexes(self) -> Indexes:
"""Mapping of pandas.Index objects used for label based indexing
"""
if self._indexes is None:
Expand Down

0 comments on commit 14f1a97

Please sign in to comment.