Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve concat performance #7824

Merged
merged 61 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
670d970
1. var_idx very slow
Illviljan May 7, 2023
1370a0e
2. slow any
Illviljan May 7, 2023
8ab83fa
Add test
Illviljan May 7, 2023
b6e1881
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2023
487e958
3. Slow array_type called multiple times
Illviljan May 7, 2023
fbb5430
4. Can use fastpath for variable.concat?
Illviljan May 7, 2023
e6fe2c4
Merge branch 'improve_concat' of https://github.com/Illviljan/xarray …
Illviljan May 7, 2023
dc5f0e6
5. slow init of pd.unique
Illviljan May 7, 2023
38aa169
typos
Illviljan May 7, 2023
5fd6bcb
Update concat.py
Illviljan May 7, 2023
43dcff2
Update merge.py
Illviljan May 7, 2023
7ef0e5d
6. Avoid recalculating in loops
Illviljan May 8, 2023
b2c067d
7. No need to transpose 1d arrays.
Illviljan May 8, 2023
914481b
Merge branch 'main' into improve_concat
Illviljan May 8, 2023
ad048b6
8. speed up dask_dataframe
Illviljan May 10, 2023
d609883
Update dataset.py
Illviljan May 10, 2023
4005f6f
Update dataset.py
Illviljan May 10, 2023
d23c833
Update dataset.py
Illviljan May 11, 2023
6c6b5c7
Add dask combine test with many variables
Illviljan May 11, 2023
068ba55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2023
1bb9dd5
Merge branch 'main' into improve_concat
Illviljan May 11, 2023
6828fb6
Merge branch 'improve_concat' of https://github.com/Illviljan/xarray …
Illviljan May 11, 2023
5670331
Update combine.py
Illviljan May 11, 2023
b11afe8
Update combine.py
Illviljan May 11, 2023
e1938a8
Update combine.py
Illviljan May 11, 2023
43fd7a2
list not needed
Illviljan May 14, 2023
a59635b
dim is usually string, might be faster to check for that
Illviljan May 14, 2023
70be8c9
first_var.dims doesn't change and can be defined 1 time
Illviljan May 14, 2023
68a0c08
mask bad points rather than append good points
Illviljan May 14, 2023
2ebf78a
reduce duplicated code
Illviljan May 14, 2023
dd81325
don't think id() is required here.
Illviljan May 14, 2023
8805ab1
Merge branch 'main' into improve_concat
Illviljan May 14, 2023
8d3d152
get dtype directly instead of through result_dtype
Illviljan May 14, 2023
476b1e0
seems better to delete rather than append,
Illviljan May 14, 2023
674638a
use internal fastpath if it's a dataset, values should be fine then
Illviljan May 14, 2023
05206f8
Change isinstance order.
Illviljan May 14, 2023
e0dae6d
use fastpath if already xarray objtect
Illviljan May 14, 2023
9cc6c2d
Update variable.py
Illviljan May 14, 2023
f03154c
Update dtypes.py
Illviljan May 14, 2023
529e386
typing fixes
Illviljan May 14, 2023
b7492ca
more typing fixes
Illviljan May 14, 2023
cf51f16
test undoing as_compatible_data
Illviljan May 14, 2023
4b6a4c6
undo concat_dim_length deletion
Illviljan May 14, 2023
2cd984d
Update xarray/core/concat.py
Illviljan May 14, 2023
86eb72a
Remove .copy and sum
Illviljan May 16, 2023
b2a498a
Merge branch 'main' into improve_concat
Illviljan May 16, 2023
ee2c3f6
Update concat.py
Illviljan May 18, 2023
4092d1b
Merge branch 'main' into improve_concat
Illviljan May 18, 2023
1e514f6
Use OrderedSet
Illviljan May 21, 2023
59bf15b
Merge branch 'main' into improve_concat
Illviljan May 21, 2023
b4198a3
Merge branch 'main' into improve_concat
Illviljan May 25, 2023
0d0b76e
Apply suggestions from code review
Illviljan May 25, 2023
51768bd
Merge branch 'main' into improve_concat
Illviljan May 30, 2023
15e2783
Update whats-new.rst
Illviljan May 30, 2023
6f86e98
Update xarray/core/concat.py
Illviljan May 31, 2023
db25a0b
no need to check arrays if cupy isnt even installed
Illviljan May 31, 2023
da959a8
Merge branch 'main' into improve_concat
Illviljan May 31, 2023
629cdea
Update whats-new.rst
Illviljan Jun 1, 2023
cfaa866
Add concat comment
Illviljan Jun 1, 2023
fe62336
minimize diff
Illviljan Jun 1, 2023
4e90504
revert sketchy
Illviljan Jun 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion asv_bench/benchmarks/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,49 @@

import xarray as xr

from . import requires_dask

class Combine:

class Combine1d:
"""Benchmark concatenating and merging large datasets"""

def setup(self):
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
"""Create 2 datasets with two different variables"""

t_size = 8000
t = np.arange(t_size)
data = np.random.randn(t_size)

self.dsA0 = xr.Dataset({"A": xr.DataArray(data, coords={"T": t}, dims=("T"))})
self.dsA1 = xr.Dataset(
{"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))}
)

def time_combine_by_coords(self):
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
"""Also has to load and arrange t coordinate"""
datasets = [self.dsA0, self.dsA1]

xr.combine_by_coords(datasets)


class Combine1dDask(Combine1d):
"""Benchmark concatenating and merging large datasets"""

def setup(self):
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
"""Create 2 datasets with two different variables"""
requires_dask()

t_size = 8000
t = np.arange(t_size)
var = xr.Variable(dims=("T",), data=np.random.randn(t_size)).chunk()

data_vars = {f"long_name_{v}": ("T", var) for v in range(500)}

self.dsA0 = xr.Dataset(data_vars, coords={"T": t})
self.dsA1 = xr.Dataset(data_vars, coords={"T": t + t_size})


class Combine3d:
"""Benchmark concatenating and merging large datasets"""

def setup(self):
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,18 +970,18 @@ def combine_by_coords(

# Perform the multidimensional combine on each group of data variables
# before merging back together
concatenated_grouped_by_data_vars = []
for vars, datasets_with_same_vars in grouped_by_vars:
concatenated = _combine_single_variable_hypercube(
list(datasets_with_same_vars),
concatenated_grouped_by_data_vars = tuple(
_combine_single_variable_hypercube(
tuple(datasets_with_same_vars),
fill_value=fill_value,
data_vars=data_vars,
coords=coords,
compat=compat,
join=join,
combine_attrs=combine_attrs,
)
concatenated_grouped_by_data_vars.append(concatenated)
for vars, datasets_with_same_vars in grouped_by_vars
)

return merge(
concatenated_grouped_by_data_vars,
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, .
int or tuple of int
Axis number or numbers corresponding to the given dimensions.
"""
if isinstance(dim, Iterable) and not isinstance(dim, str):
if not isinstance(dim, str) and isinstance(dim, Iterable):
return tuple(self._get_axis_num(d) for d in dim)
else:
return self._get_axis_num(dim)
Expand Down
32 changes: 19 additions & 13 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, Union, cast, overload

import numpy as np
import pandas as pd

from xarray.core import dtypes, utils
Expand Down Expand Up @@ -517,7 +518,7 @@ def _dataset_concat(
if variables_to_merge:
grouped = {
k: v
for k, v in collect_variables_and_indexes(list(datasets)).items()
for k, v in collect_variables_and_indexes(datasets).items()
if k in variables_to_merge
}
merged_vars, merged_indexes = merge_collected(
Expand All @@ -543,7 +544,7 @@ def ensure_common_dims(vars, concat_dim_lengths):
# ensure each variable with the given name shares the same
# dimensions and the same shape for all of them except along the
# concat dimension
common_dims = tuple(pd.unique([d for v in vars for d in v.dims]))
common_dims = tuple(dict.fromkeys(d for v in vars for d in v.dims))
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
if dim not in common_dims:
common_dims = (dim,) + common_dims
for var, dim_len in zip(vars, concat_dim_lengths):
Expand All @@ -568,38 +569,43 @@ def get_indexes(name):
yield PandasIndex(data, dim, coord_dtype=var.dtype)

# create concatenation index, needed for later reindexing
concat_index = list(range(sum(concat_dim_lengths)))
file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths))
concat_index = np.arange(file_start_indexes[-1])
concat_index_size = concat_index.size
variable_index_mask = np.ones(concat_index_size, dtype=bool)

# stack up each variable and/or index to fill-out the dataset (in order)
# n.b. this loop preserves variable order, needed for groupby.
ndatasets = len(datasets)
for name in vars_order:
if name in concat_over and name not in result_indexes:
variables = []
variable_index = []
variable_index_mask.fill(True)
dcherian marked this conversation as resolved.
Show resolved Hide resolved
var_concat_dim_length = []
for i, ds in enumerate(datasets):
if name in ds.variables:
variables.append(ds[name].variable)
# add to variable index, needed for reindexing
var_idx = [
sum(concat_dim_lengths[:i]) + k
for k in range(concat_dim_lengths[i])
]
variable_index.extend(var_idx)
var_concat_dim_length.append(len(var_idx))
dcherian marked this conversation as resolved.
Show resolved Hide resolved
var_concat_dim_length.append(concat_dim_lengths[i])
else:
# raise if coordinate not in all datasets
if name in coord_names:
raise ValueError(
f"coordinate {name!r} not present in all datasets."
)

# Mask out the indexes without the name::
start = file_start_indexes[i]
end = file_start_indexes[i + 1]
variable_index_mask[slice(start, end)] = False

variable_index = concat_index[variable_index_mask]
vars = ensure_common_dims(variables, var_concat_dim_length)

# Try to concatenate the indexes, concatenate the variables when no index
# is found on all datasets.
indexes: list[Index] = list(get_indexes(name))
if indexes:
if len(indexes) < len(datasets):
if len(indexes) < ndatasets:
raise ValueError(
f"{name!r} must have either an index or no index in all datasets, "
f"found {len(indexes)}/{len(datasets)} datasets with an index."
Expand All @@ -623,7 +629,7 @@ def get_indexes(name):
vars, dim, positions, combine_attrs=combine_attrs
)
# reindex if variable is not present in all datasets
if len(variable_index) < len(concat_index):
if len(variable_index) < concat_index_size:
combined_var = reindex_variables(
variables={name: combined_var},
dim_pos_indexers={
Expand Down
25 changes: 20 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def __init__(
)

if isinstance(coords, Dataset):
coords = coords.variables
coords = coords._variables

variables, coord_names, dims, indexes, _ = merge_data_and_coords(
data_vars, coords, compat="broadcast_equals"
Expand Down Expand Up @@ -1356,11 +1356,14 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:

needed_dims = set(variable.dims)

coord_name = self._coord_names
coords: dict[Hashable, Variable] = {}
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self.variables[k].dims) <= needed_dims:
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
coords[k] = self.variables[k]
if k in coord_name:
var = self._variables[k]
if set(var.dims) <= needed_dims:
coords[k] = var

indexes = filter_indexes_from_coords(self._indexes, set(coords))

Expand Down Expand Up @@ -6403,7 +6406,14 @@ def to_dask_dataframe(
columns.extend(k for k in self.coords if k not in self.dims)
columns.extend(self.data_vars)

has_many_dims = len(ordered_dims) > 1
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
if has_many_dims:
ds_chunks = self.chunks
else:
ds_chunks = {}

series_list = []
df_meta = pd.DataFrame()
for name in columns:
try:
var = self.variables[name]
Expand All @@ -6422,8 +6432,13 @@ def to_dask_dataframe(
if not is_duck_dask_array(var._data):
var = var.chunk()

dask_array = var.set_dims(ordered_dims).chunk(self.chunks).data
series = dd.from_array(dask_array.reshape(-1), columns=[name])
if has_many_dims:
# Broadcast then flatten the array:
var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks)
dask_array = var_new_dims._data.reshape(-1)
else:
dask_array = var._data
series = dd.from_dask_array(dask_array, columns=name, meta=df_meta)
series_list.append(series)

df = dd.concat(series_list, axis=1)
Expand Down
14 changes: 8 additions & 6 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __eq__(self, other):
# instead of following NumPy's own type-promotion rules. These type promotion
# rules match pandas instead. For reference, see the NumPy type hierarchy:
# https://numpy.org/doc/stable/reference/arrays.scalars.html
PROMOTE_TO_OBJECT = [
{np.number, np.character}, # numpy promotes to character
{np.bool_, np.character}, # numpy promotes to character
{np.bytes_, np.unicode_}, # numpy promotes to unicode
]
PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
(np.number, np.character), # numpy promotes to character
(np.bool_, np.character), # numpy promotes to character
(np.bytes_, np.unicode_), # numpy promotes to unicode
)


def maybe_promote(dtype):
Expand Down Expand Up @@ -156,7 +156,9 @@ def is_datetime_like(dtype):
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def result_type(*arrays_and_dtypes):
def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
) -> np.dtype:
"""Like np.result_type, but with type promotion rules matching pandas.

Examples of changed behavior:
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def asarray(data, xp=np):

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays):
array_type_cupy = array_type("cupy")
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
import cupy as cp

arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,11 +1495,11 @@ def filter_indexes_from_coords(
of coordinate names.

"""
filtered_indexes: dict[Any, Index] = dict(**indexes)
filtered_indexes: dict[Any, Index] = dict(indexes)

index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set)
for name, idx in indexes.items():
index_coord_names[id(idx)].add(name)
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
index_coord_names[idx].add(name)

for idx_coord_names in index_coord_names.values():
if not idx_coord_names <= filtered_coord_names:
Expand Down
29 changes: 20 additions & 9 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ def _assert_prioritized_valid(


def merge_collected(
grouped: dict[Hashable, list[MergeElement]],
grouped: dict[Any, list[MergeElement]],
prioritized: Mapping[Any, MergeElement] | None = None,
compat: CompatOptions = "minimal",
combine_attrs: CombineAttrsOptions = "override",
equals: dict[Hashable, bool] | None = None,
) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
equals: dict[Any, bool] | None = None,
) -> tuple[dict[Any, Variable], dict[Any, Index]]:
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
"""Merge dicts of variables, while resolving conflicts appropriately.

Parameters
Expand Down Expand Up @@ -306,7 +306,7 @@ def merge_collected(


def collect_variables_and_indexes(
list_of_mappings: list[DatasetLike],
list_of_mappings: Iterable[DatasetLike],
indexes: Mapping[Any, Any] | None = None,
) -> dict[Hashable, list[MergeElement]]:
"""Collect variables and indexes from list of mappings of xarray objects.
Expand Down Expand Up @@ -556,7 +556,12 @@ def merge_coords(
return variables, out_indexes


def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"):
def merge_data_and_coords(
data_vars: Mapping[Any, Any],
coords: Mapping[Any, Any],
compat: CompatOptions = "broadcast_equals",
join: JoinOptions = "outer",
) -> _MergeResult:
"""Used in Dataset.__init__."""
indexes, coords = _create_indexes_from_coords(coords, data_vars)
objects = [data_vars, coords]
Expand All @@ -570,7 +575,9 @@ def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="ou
)


def _create_indexes_from_coords(coords, data_vars=None):
def _create_indexes_from_coords(
coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None
) -> tuple[dict, dict]:
"""Maybe create default indexes from a mapping of coordinates.

Return those indexes and updated coordinates.
Expand Down Expand Up @@ -605,7 +612,11 @@ def _create_indexes_from_coords(coords, data_vars=None):
return indexes, updated_coords


def assert_valid_explicit_coords(variables, dims, explicit_coords):
def assert_valid_explicit_coords(
variables: Mapping[Any, Any],
dims: Mapping[Any, int],
explicit_coords: Iterable[Hashable],
) -> None:
"""Validate explicit coordinate names/dims.

Raise a MergeError if an explicit coord shares a name with a dimension
Expand Down Expand Up @@ -688,7 +699,7 @@ def merge_core(
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
priority_arg: int | None = None,
explicit_coords: Sequence | None = None,
explicit_coords: Iterable[Hashable] | None = None,
indexes: Mapping[Any, Any] | None = None,
fill_value: object = dtypes.NA,
) -> _MergeResult:
Expand Down Expand Up @@ -1035,7 +1046,7 @@ def dataset_merge_method(
# method due for backwards compatibility
# TODO: consider deprecating it?

if isinstance(overwrite_vars, Iterable) and not isinstance(overwrite_vars, str):
if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable):
overwrite_vars = set(overwrite_vars)
else:
overwrite_vars = {overwrite_vars}
Expand Down
Loading