From 68d0e0d25b498da14414af6e56e33f55ae4674f1 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 19 Aug 2020 22:32:36 +0200 Subject: [PATCH] annotate concat (#4346) * annotate concat * fix annotation: _from_temp_dataset * whats new entry * Update xarray/core/concat.py * revert faulty change --- doc/whats-new.rst | 2 + xarray/core/concat.py | 105 +++++++++++++++++++++++++++++---------- xarray/core/dataarray.py | 2 +- 3 files changed, 81 insertions(+), 28 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2652af5d9fd..91e6c8292f2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -85,6 +85,8 @@ Internal Changes By `Guido Imperiale `_ - Only load resource files when running inside a Jupyter Notebook (:issue:`4294`) By `Guido Imperiale `_ +- Enable type checking for :py:func:`concat` (:issue:`4238`) + By `Mathias Hauser `_. .. _whats-new.0.16.0: diff --git a/xarray/core/concat.py b/xarray/core/concat.py index fa3fac92277..b238bec40ba 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,3 +1,16 @@ +from typing import ( + TYPE_CHECKING, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + Union, + overload, +) + import pandas as pd from . import dtypes, utils @@ -7,6 +20,40 @@ from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + + +@overload +def concat( + objs: Iterable["Dataset"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]] = "all", + coords: Union[str, List[str]] = "different", + compat: str = "equals", + positions: Optional[Iterable[int]] = None, + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "Dataset": + ... + + +@overload +def concat( + objs: Iterable["DataArray"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]] = "all", + coords: Union[str, List[str]] = "different", + compat: str = "equals", + positions: Optional[Iterable[int]] = None, + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "DataArray": + ... + def concat( objs, @@ -285,13 +332,15 @@ def process_subset_opt(opt, subset): # determine dimensional coordinate names and a dict mapping name to DataArray -def _parse_datasets(datasets): +def _parse_datasets( + datasets: Iterable["Dataset"], +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, int], Set[Hashable], Set[Hashable]]: - dims = set() - all_coord_names = set() - data_vars = set() # list of data_vars - dim_coords = {} # maps dim name to variable - dims_sizes = {} # shared dimension sizes to expand variables + dims: Set[Hashable] = set() + all_coord_names: Set[Hashable] = set() + data_vars: Set[Hashable] = set() # list of data_vars + dim_coords: Dict[Hashable, Variable] = {} # maps dim name to variable + dims_sizes: Dict[Hashable, int] = {} # shared dimension sizes to expand variables for ds in datasets: dims_sizes.update(ds.dims) @@ -307,16 +356,16 @@ def _parse_datasets(datasets): def _dataset_concat( - datasets, - dim, - data_vars, - coords, - compat, - positions, - fill_value=dtypes.NA, - join="outer", - combine_attrs="override", -): + datasets: List["Dataset"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]], + coords: Union[str, List[str]], + compat: str, + positions: Optional[Iterable[int]], + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "Dataset": """ Concatenate a sequence of datasets along a new or existing dimension """ @@ -356,7 +405,9 @@ def _dataset_concat( result_vars = {} if variables_to_merge: - to_merge = {var: [] for var in variables_to_merge} + to_merge: Dict[Hashable, List[Variable]] = { + var: [] for var in variables_to_merge + } for ds in datasets: for var in variables_to_merge: @@ -427,16 +478,16 @@ def ensure_common_dims(vars): def _dataarray_concat( - arrays, - dim, - data_vars, - coords, - compat, - positions, - fill_value=dtypes.NA, - join="outer", - combine_attrs="override", -): + arrays: Iterable["DataArray"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]], + coords: Union[str, List[str]], + compat: str, + positions: Optional[Iterable[int]], + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "DataArray": arrays = list(arrays) if data_vars != "all": diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 426329e6a6e..ea0465e2f2d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -422,7 +422,7 @@ def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) def _from_temp_dataset( - self, dataset: Dataset, name: Hashable = _default + self, dataset: Dataset, name: Union[Hashable, None, Default] = _default ) -> "DataArray": variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables