Skip to content

Commit

Permalink
annotate concat (#4346)
Browse files Browse the repository at this point in the history
* annotate concat

* fix annotation: _from_temp_dataset

* whats new entry

* Update xarray/core/concat.py

* revert faulty change
  • Loading branch information
mathause authored Aug 19, 2020
1 parent d9ebcaf commit 68d0e0d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 28 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ Internal Changes
By `Guido Imperiale <https://github.com/crusaderky>`_
- Only load resource files when running inside a Jupyter Notebook
(:issue:`4294`) By `Guido Imperiale <https://github.com/crusaderky>`_
- Enable type checking for :py:func:`concat` (:issue:`4238`)
By `Mathias Hauser <https://github.com/mathause>`_.


.. _whats-new.0.16.0:
Expand Down
105 changes: 78 additions & 27 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
from typing import (
TYPE_CHECKING,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
overload,
)

import pandas as pd

from . import dtypes, utils
Expand All @@ -7,6 +20,40 @@
from .variable import IndexVariable, Variable, as_variable
from .variable import concat as concat_vars

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset


@overload
def concat(
objs: Iterable["Dataset"],
dim: Union[str, "DataArray", pd.Index],
data_vars: Union[str, List[str]] = "all",
coords: Union[str, List[str]] = "different",
compat: str = "equals",
positions: Optional[Iterable[int]] = None,
fill_value: object = dtypes.NA,
join: str = "outer",
combine_attrs: str = "override",
) -> "Dataset":
...


@overload
def concat(
objs: Iterable["DataArray"],
dim: Union[str, "DataArray", pd.Index],
data_vars: Union[str, List[str]] = "all",
coords: Union[str, List[str]] = "different",
compat: str = "equals",
positions: Optional[Iterable[int]] = None,
fill_value: object = dtypes.NA,
join: str = "outer",
combine_attrs: str = "override",
) -> "DataArray":
...


def concat(
objs,
Expand Down Expand Up @@ -285,13 +332,15 @@ def process_subset_opt(opt, subset):


# determine dimensional coordinate names and a dict mapping name to DataArray
def _parse_datasets(datasets):
def _parse_datasets(
datasets: Iterable["Dataset"],
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, int], Set[Hashable], Set[Hashable]]:

dims = set()
all_coord_names = set()
data_vars = set() # list of data_vars
dim_coords = {} # maps dim name to variable
dims_sizes = {} # shared dimension sizes to expand variables
dims: Set[Hashable] = set()
all_coord_names: Set[Hashable] = set()
data_vars: Set[Hashable] = set() # list of data_vars
dim_coords: Dict[Hashable, Variable] = {} # maps dim name to variable
dims_sizes: Dict[Hashable, int] = {} # shared dimension sizes to expand variables

for ds in datasets:
dims_sizes.update(ds.dims)
Expand All @@ -307,16 +356,16 @@ def _parse_datasets(datasets):


def _dataset_concat(
datasets,
dim,
data_vars,
coords,
compat,
positions,
fill_value=dtypes.NA,
join="outer",
combine_attrs="override",
):
datasets: List["Dataset"],
dim: Union[str, "DataArray", pd.Index],
data_vars: Union[str, List[str]],
coords: Union[str, List[str]],
compat: str,
positions: Optional[Iterable[int]],
fill_value: object = dtypes.NA,
join: str = "outer",
combine_attrs: str = "override",
) -> "Dataset":
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
Expand Down Expand Up @@ -356,7 +405,9 @@ def _dataset_concat(

result_vars = {}
if variables_to_merge:
to_merge = {var: [] for var in variables_to_merge}
to_merge: Dict[Hashable, List[Variable]] = {
var: [] for var in variables_to_merge
}

for ds in datasets:
for var in variables_to_merge:
Expand Down Expand Up @@ -427,16 +478,16 @@ def ensure_common_dims(vars):


def _dataarray_concat(
arrays,
dim,
data_vars,
coords,
compat,
positions,
fill_value=dtypes.NA,
join="outer",
combine_attrs="override",
):
arrays: Iterable["DataArray"],
dim: Union[str, "DataArray", pd.Index],
data_vars: Union[str, List[str]],
coords: Union[str, List[str]],
compat: str,
positions: Optional[Iterable[int]],
fill_value: object = dtypes.NA,
join: str = "outer",
combine_attrs: str = "override",
) -> "DataArray":
arrays = list(arrays)

if data_vars != "all":
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def _to_temp_dataset(self) -> Dataset:
return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False)

def _from_temp_dataset(
self, dataset: Dataset, name: Hashable = _default
self, dataset: Dataset, name: Union[Hashable, None, Default] = _default
) -> "DataArray":
variable = dataset._variables.pop(_THIS_ARRAY)
coords = dataset._variables
Expand Down

0 comments on commit 68d0e0d

Please sign in to comment.