diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index a4f8db2786b..772d888306c 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -2,8 +2,49 @@ import xarray as xr +from . import requires_dask -class Combine: + +class Combine1d: + """Benchmark concatenating and merging large datasets""" + + def setup(self) -> None: + """Create 2 datasets with two different variables""" + + t_size = 8000 + t = np.arange(t_size) + data = np.random.randn(t_size) + + self.dsA0 = xr.Dataset({"A": xr.DataArray(data, coords={"T": t}, dims=("T"))}) + self.dsA1 = xr.Dataset( + {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))} + ) + + def time_combine_by_coords(self) -> None: + """Also has to load and arrange t coordinate""" + datasets = [self.dsA0, self.dsA1] + + xr.combine_by_coords(datasets) + + +class Combine1dDask(Combine1d): + """Benchmark concatenating and merging large datasets""" + + def setup(self) -> None: + """Create 2 datasets with two different variables""" + requires_dask() + + t_size = 8000 + t = np.arange(t_size) + var = xr.Variable(dims=("T",), data=np.random.randn(t_size)).chunk() + + data_vars = {f"long_name_{v}": ("T", var) for v in range(500)} + + self.dsA0 = xr.Dataset(data_vars, coords={"T": t}) + self.dsA1 = xr.Dataset(data_vars, coords={"T": t + t_size}) + + +class Combine3d: """Benchmark concatenating and merging large datasets""" def setup(self): diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 55fbda9b096..b03388cb551 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,7 +37,8 @@ Deprecations Performance ~~~~~~~~~~~ - +- Improve concatenation performance (:issue:`7833`, :pull:`7824`). + By `Jimmy Westling `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 8106c295f5a..cee27300beb 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -970,10 +970,9 @@ def combine_by_coords( # Perform the multidimensional combine on each group of data variables # before merging back together - concatenated_grouped_by_data_vars = [] - for vars, datasets_with_same_vars in grouped_by_vars: - concatenated = _combine_single_variable_hypercube( - list(datasets_with_same_vars), + concatenated_grouped_by_data_vars = tuple( + _combine_single_variable_hypercube( + tuple(datasets_with_same_vars), fill_value=fill_value, data_vars=data_vars, coords=coords, @@ -981,7 +980,8 @@ def combine_by_coords( join=join, combine_attrs=combine_attrs, ) - concatenated_grouped_by_data_vars.append(concatenated) + for vars, datasets_with_same_vars in grouped_by_vars + ) return merge( concatenated_grouped_by_data_vars, diff --git a/xarray/core/common.py b/xarray/core/common.py index 397d6de226a..5dd4c4dbd96 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -211,7 +211,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, . int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if isinstance(dim, Iterable) and not isinstance(dim, str): + if not isinstance(dim, str) and isinstance(dim, Iterable): return tuple(self._get_axis_num(d) for d in dim) else: return self._get_axis_num(dim) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index dcf2a23d311..d7aad8c7188 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -3,6 +3,7 @@ from collections.abc import Hashable, Iterable from typing import TYPE_CHECKING, Any, Union, cast, overload +import numpy as np import pandas as pd from xarray.core import dtypes, utils @@ -517,7 +518,7 @@ def _dataset_concat( if variables_to_merge: grouped = { k: v - for k, v in collect_variables_and_indexes(list(datasets)).items() + for k, v in collect_variables_and_indexes(datasets).items() if k in variables_to_merge } merged_vars, merged_indexes = merge_collected( @@ -543,7 +544,7 @@ def ensure_common_dims(vars, concat_dim_lengths): # ensure each variable with the given name shares the same # dimensions and the same shape for all of them except along the # concat dimension - common_dims = tuple(pd.unique([d for v in vars for d in v.dims])) + common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims)) if dim not in common_dims: common_dims = (dim,) + common_dims for var, dim_len in zip(vars, concat_dim_lengths): @@ -568,38 +569,45 @@ def get_indexes(name): yield PandasIndex(data, dim, coord_dtype=var.dtype) # create concatenation index, needed for later reindexing - concat_index = list(range(sum(concat_dim_lengths))) + file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths)) + concat_index = np.arange(file_start_indexes[-1]) + concat_index_size = concat_index.size + variable_index_mask = np.ones(concat_index_size, dtype=bool) # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. + ndatasets = len(datasets) for name in vars_order: if name in concat_over and name not in result_indexes: variables = [] - variable_index = [] + # Initialize the mask to all True then set False if any name is missing in + # the datasets: + variable_index_mask.fill(True) var_concat_dim_length = [] for i, ds in enumerate(datasets): if name in ds.variables: variables.append(ds[name].variable) - # add to variable index, needed for reindexing - var_idx = [ - sum(concat_dim_lengths[:i]) + k - for k in range(concat_dim_lengths[i]) - ] - variable_index.extend(var_idx) - var_concat_dim_length.append(len(var_idx)) + var_concat_dim_length.append(concat_dim_lengths[i]) else: # raise if coordinate not in all datasets if name in coord_names: raise ValueError( f"coordinate {name!r} not present in all datasets." ) + + # Mask out the indexes without the name: + start = file_start_indexes[i] + end = file_start_indexes[i + 1] + variable_index_mask[slice(start, end)] = False + + variable_index = concat_index[variable_index_mask] vars = ensure_common_dims(variables, var_concat_dim_length) # Try to concatenate the indexes, concatenate the variables when no index # is found on all datasets. indexes: list[Index] = list(get_indexes(name)) if indexes: - if len(indexes) < len(datasets): + if len(indexes) < ndatasets: raise ValueError( f"{name!r} must have either an index or no index in all datasets, " f"found {len(indexes)}/{len(datasets)} datasets with an index." @@ -623,7 +631,7 @@ def get_indexes(name): vars, dim, positions, combine_attrs=combine_attrs ) # reindex if variable is not present in all datasets - if len(variable_index) < len(concat_index): + if len(variable_index) < concat_index_size: combined_var = reindex_variables( variables={name: combined_var}, dim_pos_indexers={ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 04ad1118124..81860bede95 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -647,7 +647,7 @@ def __init__( ) if isinstance(coords, Dataset): - coords = coords.variables + coords = coords._variables variables, coord_names, dims, indexes, _ = merge_data_and_coords( data_vars, coords, compat="broadcast_equals" @@ -1399,8 +1399,8 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: - coords[k] = self.variables[k] + if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: + coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 4d8583cfe65..f5a184289c9 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -37,11 +37,11 @@ def __eq__(self, other): # instead of following NumPy's own type-promotion rules. These type promotion # rules match pandas instead. For reference, see the NumPy type hierarchy: # https://numpy.org/doc/stable/reference/arrays.scalars.html -PROMOTE_TO_OBJECT = [ - {np.number, np.character}, # numpy promotes to character - {np.bool_, np.character}, # numpy promotes to character - {np.bytes_, np.unicode_}, # numpy promotes to unicode -] +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.unicode_), # numpy promotes to unicode +) def maybe_promote(dtype): @@ -156,7 +156,9 @@ def is_datetime_like(dtype): return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def result_type(*arrays_and_dtypes): +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype: """Like np.result_type, but with type promotion rules matching pandas. Examples of changed behavior: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 4d7998e1475..4f245e59f73 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -194,7 +194,10 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays): + array_type_cupy = array_type("cupy") + if array_type_cupy and any( + isinstance(x, array_type_cupy) for x in scalars_or_arrays + ): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 93e9e535fe3..9ee9bc374d4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1495,7 +1495,7 @@ def filter_indexes_from_coords( of coordinate names. """ - filtered_indexes: dict[Any, Index] = dict(**indexes) + filtered_indexes: dict[Any, Index] = dict(indexes) index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) for name, idx in indexes.items(): diff --git a/xarray/core/merge.py b/xarray/core/merge.py index bf7288ad7ed..56e51256ba1 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -195,11 +195,11 @@ def _assert_prioritized_valid( def merge_collected( - grouped: dict[Hashable, list[MergeElement]], + grouped: dict[Any, list[MergeElement]], prioritized: Mapping[Any, MergeElement] | None = None, compat: CompatOptions = "minimal", combine_attrs: CombineAttrsOptions = "override", - equals: dict[Hashable, bool] | None = None, + equals: dict[Any, bool] | None = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -306,7 +306,7 @@ def merge_collected( def collect_variables_and_indexes( - list_of_mappings: list[DatasetLike], + list_of_mappings: Iterable[DatasetLike], indexes: Mapping[Any, Any] | None = None, ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. @@ -556,7 +556,12 @@ def merge_coords( return variables, out_indexes -def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"): +def merge_data_and_coords( + data_vars: Mapping[Any, Any], + coords: Mapping[Any, Any], + compat: CompatOptions = "broadcast_equals", + join: JoinOptions = "outer", +) -> _MergeResult: """Used in Dataset.__init__.""" indexes, coords = _create_indexes_from_coords(coords, data_vars) objects = [data_vars, coords] @@ -570,7 +575,9 @@ def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="ou ) -def _create_indexes_from_coords(coords, data_vars=None): +def _create_indexes_from_coords( + coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None +) -> tuple[dict, dict]: """Maybe create default indexes from a mapping of coordinates. Return those indexes and updated coordinates. @@ -605,7 +612,11 @@ def _create_indexes_from_coords(coords, data_vars=None): return indexes, updated_coords -def assert_valid_explicit_coords(variables, dims, explicit_coords): +def assert_valid_explicit_coords( + variables: Mapping[Any, Any], + dims: Mapping[Any, int], + explicit_coords: Iterable[Hashable], +) -> None: """Validate explicit coordinate names/dims. Raise a MergeError if an explicit coord shares a name with a dimension @@ -688,7 +699,7 @@ def merge_core( join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", priority_arg: int | None = None, - explicit_coords: Sequence | None = None, + explicit_coords: Iterable[Hashable] | None = None, indexes: Mapping[Any, Any] | None = None, fill_value: object = dtypes.NA, ) -> _MergeResult: @@ -1035,7 +1046,7 @@ def dataset_merge_method( # method due for backwards compatibility # TODO: consider deprecating it? - if isinstance(overwrite_vars, Iterable) and not isinstance(overwrite_vars, str): + if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable): overwrite_vars = set(overwrite_vars) else: overwrite_vars = {overwrite_vars} diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 84a1ec70c53..9af5d693170 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -61,14 +61,26 @@ def __init__(self, mod: ModType) -> None: self.available = duck_array_module is not None +_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {} + + +def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule: + if mod not in _cached_duck_array_modules: + duckmod = DuckArrayModule(mod) + _cached_duck_array_modules[mod] = duckmod + return duckmod + else: + return _cached_duck_array_modules[mod] + + def array_type(mod: ModType) -> DuckArrayTypes: """Quick wrapper to get the array class of the module.""" - return DuckArrayModule(mod).type + return _get_cached_duck_array_module(mod).type def mod_version(mod: ModType) -> Version: """Quick wrapper to get the version of the module.""" - return DuckArrayModule(mod).version + return _get_cached_duck_array_module(mod).version def is_dask_collection(x): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f4813af9782..83ccbc9a1cf 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -238,7 +238,7 @@ def _possibly_convert_datetime_or_timedelta_index(data): return data -def as_compatible_data(data, fastpath=False): +def as_compatible_data(data, fastpath: bool = False): """Prepare and wrap data to put in a Variable. - If data does not have the necessary attributes, convert it to ndarray. @@ -677,7 +677,8 @@ def dims(self, value: str | Iterable[Hashable]) -> None: def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]: if isinstance(dims, str): dims = (dims,) - dims = tuple(dims) + else: + dims = tuple(dims) if len(dims) != self.ndim: raise ValueError( f"dimensions {dims} must have the same length as the " @@ -2102,12 +2103,13 @@ def concat( # twice variables = list(variables) first_var = variables[0] + first_var_dims = first_var.dims - arrays = [v.data for v in variables] + arrays = [v._data for v in variables] - if dim in first_var.dims: + if dim in first_var_dims: axis = first_var.get_axis_num(dim) - dims = first_var.dims + dims = first_var_dims data = duck_array_ops.concatenate(arrays, axis=axis) if positions is not None: # TODO: deprecate this option -- we don't need it for groupby @@ -2116,7 +2118,7 @@ def concat( data = duck_array_ops.take(data, indices, axis=axis) else: axis = 0 - dims = (dim,) + first_var.dims + dims = (dim,) + first_var_dims data = duck_array_ops.stack(arrays, axis=axis) attrs = merge_attrs( @@ -2125,12 +2127,12 @@ def concat( encoding = dict(first_var.encoding) if not shortcut: for var in variables: - if var.dims != first_var.dims: + if var.dims != first_var_dims: raise ValueError( - f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var.dims)}" + f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var_dims)}" ) - return cls(dims, data, attrs, encoding) + return cls(dims, data, attrs, encoding, fastpath=True) def equals(self, other, equiv=duck_array_ops.array_equiv): """True if two Variables have the same dimensions and values;