diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index f47de953abc..83fa8b5c138 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -101,29 +101,24 @@ def __init__(self, obj, group, squeeze=False, grouper=None): """ from .dataset import as_dataset + self._stacked_dim = None + group_name = group.name if group.ndim != 1: # try to stack the dims of the group into a single dim # TODO: figure out how to exclude dimensions from the stacking # (e.g. group over space dims but leave time dim intact) orig_dims = group.dims stacked_dim_name = 'stacked_' + '_'.join(orig_dims) - # the copy is necessary here + # the copy is necessary here, otherwise read only array raises error + # in pandas: https://github.com/pydata/pandas/issues/12813 + # Is there a performance penalty for calling copy? group = group.stack(**{stacked_dim_name: orig_dims}).copy() - # without it, an error is raised deep in pandas - ######################## - # xarray/core/groupby.py - # ---> 31 inverse, values = pd.factorize(ar, sort=True) - # pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint) - # --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True) - # pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)() - # pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)() - # pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)() - # ValueError: buffer source array is read-only - ####################### - # seems related to - # https://github.com/pydata/pandas/issues/10043 - # https://github.com/pydata/pandas/pull/10070 obj = obj.stack(**{stacked_dim_name: orig_dims}) + self._stacked_dim = stacked_dim_name + self._unstacked_dims = orig_dims + # we also need to rename the group name to avoid a conflict when + # concatenating + group_name += '_groups' if getattr(group, 'name', None) is None: raise ValueError('`group` must have a name') if not hasattr(group, 'dims'): @@ -157,7 +152,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None): unique_coord = Coordinate(group.name, first_items.index) elif group.name in obj.dims: # assume that group already has sorted, unique values - if group.dims != (group.name,): + if group.dims != (group_name,): raise ValueError('`group` is required to be a coordinate if ' '`group.name` is a dimension in `obj`') group_indices = np.arange(group.size) @@ -169,7 +164,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None): else: # look through group to find the unique values unique_values, group_indices = unique_value_groups(group) - unique_coord = Coordinate(group.name, unique_values) + unique_coord = Coordinate(group_name, unique_values) self.obj = obj self.group = group @@ -249,6 +244,13 @@ def _maybe_restore_empty_groups(self, combined): combined = combined.reindex(**indexers) return combined + def _maybe_unstack_array(self, arr): + """This gets called if we are applying on an array with a + multidimensional group.""" + if self._stacked_dim is not None and self._stacked_dim in arr.dims: + arr = arr.unstack(self._stacked_dim) + return arr + def fillna(self, value): """Fill missing values in this object by group. @@ -399,7 +401,9 @@ def apply(self, func, shortcut=False, **kwargs): grouped = self._iter_grouped_shortcut() else: grouped = self._iter_grouped() - applied = (maybe_wrap_array(arr, func(arr, **kwargs)) for arr in grouped) + applied = (self._maybe_unstack_array( + maybe_wrap_array(arr,func(arr, **kwargs))) + for arr in grouped) combined = self._concat(applied, shortcut=shortcut) result = self._maybe_restore_empty_groups(combined) return result diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index d0f83457e68..aa118dec663 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -1250,8 +1250,10 @@ def test_groupby_multidim(self): 'lat': (['ny','nx'], [[10,10],[20,20]] ),}, dims=['ny','nx']) for dim, expected_sum in [ - ('lon', DataArray([0, 3, 3], coords={'lon': [30,40,50]})), - ('lat', DataArray([1,5], coords={'lat': [10,20]}))]: + ('lon', DataArray([0, 3, 3], + coords={'lon_groups': [30,40,50]})), + ('lat', DataArray([1,5], + coords={'lat_groups': [10,20]}))]: actual_sum = array.groupby(dim).sum() self.assertDataArrayIdentical(expected_sum, actual_sum)