Skip to content
forked from pydata/xarray

Commit

Permalink
GroupBy(multiple groupers)
Browse files Browse the repository at this point in the history
Closes pydata#924
Closes pydata#1056
Closes pydata#9332
xref pydata#324
  • Loading branch information
dcherian committed Aug 16, 2024
1 parent 3c19231 commit ee18848
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 67 deletions.
18 changes: 6 additions & 12 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6801,27 +6801,21 @@ def groupby(
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

grouper: Grouper
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
grouper = UniqueGrouper()
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if len(groupers) > 1:
raise ValueError("grouping by multiple variables is not supported yet.")
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
group, grouper = next(iter(groupers.items()))

rgrouper = ResolvedGrouper(grouper, group, self)
rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
)

return DataArrayGroupBy(
self,
(rgrouper,),
restore_coord_dims=restore_coord_dims,
)
return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

@_deprecate_positional_args("v2024.07.0")
def groupby_bins(
Expand Down
18 changes: 7 additions & 11 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10388,20 +10388,16 @@ def groupby(
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if len(groupers) > 1:
raise ValueError("Grouping by multiple variables is not supported yet.")
elif not groupers:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
for group, grouper in groupers.items():
rgrouper = ResolvedGrouper(grouper, group, self)
rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
)

return DatasetGroupBy(
self,
(rgrouper,),
restore_coord_dims=restore_coord_dims,
)
return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

@_deprecate_positional_args("v2024.07.0")
def groupby_bins(
Expand Down
93 changes: 60 additions & 33 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import copy
import functools
import math
import warnings
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
Expand Down Expand Up @@ -68,10 +70,11 @@ def check_reduce_dims(reduce_dims, dimensions):
)


def _codes_to_group_indices(inverse: np.ndarray, N: int) -> GroupIndices:
assert inverse.ndim == 1
def _codes_to_group_indices(codes: np.ndarray, N: int) -> GroupIndices:
"""Converts integer codes for groups to group indices."""
assert codes.ndim == 1
groups: GroupIndices = tuple([] for _ in range(N))
for n, g in enumerate(inverse):
for n, g in enumerate(codes):
if g >= 0:
groups[g].append(n)
return groups
Expand Down Expand Up @@ -448,7 +451,7 @@ class GroupBy(Generic[T_Xarray]):
"_codes",
)
_obj: T_Xarray
groupers: tuple[ResolvedGrouper]
groupers: tuple[ResolvedGrouper, ...]
_restore_coord_dims: bool

_original_obj: T_Xarray
Expand All @@ -464,7 +467,7 @@ class GroupBy(Generic[T_Xarray]):
def __init__(
self,
obj: T_Xarray,
groupers: tuple[ResolvedGrouper],
groupers: tuple[ResolvedGrouper, ...],
restore_coord_dims: bool = True,
) -> None:
"""Create a GroupBy object
Expand All @@ -483,16 +486,35 @@ def __init__(

self._original_obj = obj

(grouper,) = self.groupers
self._original_group = grouper.group
if len(groupers) > 1:
for grouper in groupers:
if grouper.group.ndim > 1:
raise NotImplementedError(
"Only grouping by multiple 1D variables is supported at the moment."
)
(grouper, *_) = self.groupers # FIXME
self._original_group = grouper.group # FIXME

# specification for the groupby operation
self._obj = grouper.stacked_obj
self._obj = grouper.stacked_obj # FIXME
self._restore_coord_dims = restore_coord_dims

# These should generalize to multiple groupers
self._group_indices = grouper.group_indices
self._codes = self._maybe_unstack(grouper.codes)
self._shape = tuple(grouper.size for grouper in groupers)
self._len = math.prod(self._shape)

self._codes = tuple(self._maybe_unstack(grouper.codes) for grouper in groupers)
self._flatcodes = np.ravel_multi_index(self._codes, self._shape, mode="wrap")
# NaNs; as well as values outside the bins are coded by -1
# Restore these after the raveling
mask = functools.reduce(np.logical_or, [(code == -1) for code in self._codes])
self._flatcodes[mask] = -1

if len(groupers) == 1:
# For ordered `group` we index into the array using slices.
# Preserve this optimization when grouping by a single variable
self._group_indices = self.groupers[0].group_indices
else:
self._group_indices = _codes_to_group_indices(self._flatcodes, self._len)

(self._group_dim,) = grouper.group1d.dims
# cached attributes
Expand Down Expand Up @@ -566,13 +588,16 @@ def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]:
return zip(grouper.unique_coord.data, self._iter_grouped())

def __repr__(self) -> str:
(grouper,) = self.groupers
return "{}, grouped over {!r}\n{!r} groups with labels {}.".format(
self.__class__.__name__,
grouper.name,
grouper.full_index.size,
", ".join(format_array_flat(grouper.full_index, 30).split()),
text = (
f"<{self.__class__.__name__}, "
f"grouped over {len(self.groupers)} grouper(s),"
f" {self._len} groups in total:"
)
for grouper in self.groupers:
coord = grouper.unique_coord
labels = ", ".join(format_array_flat(coord, 30).split())
text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}"
return text + ">"

def _iter_grouped(self) -> Iterator[T_Xarray]:
"""Iterate over each element in this group"""
Expand Down Expand Up @@ -609,7 +634,7 @@ def _binary_op(self, other, f, reflexive=False):
obj = self._original_obj
name = grouper.name
group = grouper.group
codes = self._codes
(codes,) = self._codes
dims = group.dims

if isinstance(group, _DummyGroup):
Expand Down Expand Up @@ -709,15 +734,16 @@ def _maybe_restore_empty_groups(self, combined):
def _maybe_unstack(self, obj):
"""This gets called if we are applying on an array with a
multidimensional group."""
(grouper,) = self.groupers
stacked_dim = grouper.stacked_dim
inserted_dims = grouper.inserted_dims
if stacked_dim is not None and stacked_dim in obj.dims:
obj = obj.unstack(stacked_dim)
for dim in inserted_dims:
if dim in obj.coords:
del obj.coords[dim]
obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords))
# TODO: Is this really right?
for grouper in self.groupers:
stacked_dim = grouper.stacked_dim
if stacked_dim is not None and stacked_dim in obj.dims:
inserted_dims = grouper.inserted_dims
obj = obj.unstack(stacked_dim)
for dim in inserted_dims:
if dim in obj.coords:
del obj.coords[dim]
obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords))
return obj

def _flox_reduce(
Expand Down Expand Up @@ -1115,20 +1141,21 @@ def _concat_shortcut(self, applied, dim, positions=None):
return self._obj._replace_maybe_drop_dims(reordered)

def _restore_dim_order(self, stacked: DataArray) -> DataArray:
(grouper,) = self.groupers
group = grouper.group1d

def lookup_order(dimension):
if dimension == grouper.name:
(dimension,) = group.dims
for grouper in self.groupers:
if dimension == grouper.name and grouper.group.ndim == 1:
(dimension,) = grouper.group.dims
if dimension in self._obj.dims:
axis = self._obj.get_axis_num(dimension)
else:
axis = 1e6 # some arbitrarily high value
return axis

new_order = sorted(stacked.dims, key=lookup_order)
return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims)
stacked = stacked.transpose(
*new_order, transpose_coords=self._restore_coord_dims
)
return stacked

def map(
self,
Expand Down
38 changes: 27 additions & 11 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,27 +556,28 @@ def test_da_groupby_assign_coords() -> None:
@pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")])
def test_groupby_repr(obj, dim) -> None:
actual = repr(obj.groupby(dim))
expected = f"{obj.__class__.__name__}GroupBy"
expected += f", grouped over {dim!r}"
expected += f"\n{len(np.unique(obj[dim]))!r} groups with labels "
N = len(np.unique(obj[dim]))
expected = f"<{obj.__class__.__name__}GroupBy"
expected += f", grouped over 1 grouper(s), {N} groups in total:"
expected += f"\n\t{dim!r}: {N} groups with labels "
if dim == "x":
expected += "1, 2, 3, 4, 5."
expected += "1, 2, 3, 4, 5>"
elif dim == "y":
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19."
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19>"
elif dim == "z":
expected += "'a', 'b', 'c'."
expected += "'a', 'b', 'c'>"
elif dim == "month":
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>"
assert actual == expected


@pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")])
def test_groupby_repr_datetime(obj) -> None:
actual = repr(obj.groupby("t.month"))
expected = f"{obj.__class__.__name__}GroupBy"
expected += ", grouped over 'month'"
expected += f"\n{len(np.unique(obj.t.dt.month))!r} groups with labels "
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
expected = f"<{obj.__class__.__name__}GroupBy"
expected += ", grouped over 1 grouper(s), 12 groups in total:\n"
expected += "\t'month': 12 groups with labels "
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>"
assert actual == expected


Expand Down Expand Up @@ -2561,3 +2562,18 @@ def factorize(self, group) -> EncodedGroups:
obj.groupby("time.year", time=YearGrouper())
with pytest.raises(ValueError):
obj.groupby()


def test_multiple_groupers() -> None:
da = xr.DataArray(
np.array([1, 2, 3, 0, 2, np.nan]),
dims="d",
coords=dict(
labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
),
)

gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper())
repr(gb)
gb.mean()

0 comments on commit ee18848

Please sign in to comment.