From b0db8a8db4dd28e9c5415ca39c4d22a8ea166b5b Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Thu, 7 Apr 2016 12:57:11 -0400 Subject: [PATCH] added unstack in apply fix dataset test bug fixed apply --- xarray/core/groupby.py | 45 +++++++++++++++++++---------------- xarray/test/test_dataarray.py | 11 +++++++++ 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index f47de953abc..9b94ab3eed7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -101,31 +101,22 @@ def __init__(self, obj, group, squeeze=False, grouper=None): """ from .dataset import as_dataset + if getattr(group, 'name', None) is None: + raise ValueError('`group` must have a name') + self._stacked_dim = None if group.ndim != 1: # try to stack the dims of the group into a single dim # TODO: figure out how to exclude dimensions from the stacking # (e.g. group over space dims but leave time dim intact) orig_dims = group.dims stacked_dim_name = 'stacked_' + '_'.join(orig_dims) - # the copy is necessary here + # the copy is necessary here, otherwise read only array raises error + # in pandas: https://github.com/pydata/pandas/issues/12813 + # Is there a performance penalty for calling copy? group = group.stack(**{stacked_dim_name: orig_dims}).copy() - # without it, an error is raised deep in pandas - ######################## - # xarray/core/groupby.py - # ---> 31 inverse, values = pd.factorize(ar, sort=True) - # pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint) - # --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True) - # pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)() - # pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)() - # pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)() - # ValueError: buffer source array is read-only - ####################### - # seems related to - # https://github.com/pydata/pandas/issues/10043 - # https://github.com/pydata/pandas/pull/10070 obj = obj.stack(**{stacked_dim_name: orig_dims}) - if getattr(group, 'name', None) is None: - raise ValueError('`group` must have a name') + self._stacked_dim = stacked_dim_name + self._unstacked_dims = orig_dims if not hasattr(group, 'dims'): raise ValueError("`group` must have a 'dims' attribute") group_dim, = group.dims @@ -249,6 +240,13 @@ def _maybe_restore_empty_groups(self, combined): combined = combined.reindex(**indexers) return combined + def _maybe_unstack_array(self, arr): + """This gets called if we are applying on an array with a + multidimensional group.""" + if self._stacked_dim is not None and self._stacked_dim in arr.dims: + arr = arr.unstack(self._stacked_dim) + return arr + def fillna(self, value): """Fill missing values in this object by group. @@ -358,6 +356,11 @@ def lookup_order(dimension): new_order = sorted(stacked.dims, key=lookup_order) return stacked.transpose(*new_order) + def _restore_multiindex(self, combined): + if self._stacked_dim is not None and self._stacked_dim in combined.dims: + combined[self._stacked_dim] = self.group[self._stacked_dim] + return combined + def apply(self, func, shortcut=False, **kwargs): """Apply a function over each array in the group and concatenate them together into a new array. @@ -399,23 +402,23 @@ def apply(self, func, shortcut=False, **kwargs): grouped = self._iter_grouped_shortcut() else: grouped = self._iter_grouped() - applied = (maybe_wrap_array(arr, func(arr, **kwargs)) for arr in grouped) + applied = (maybe_wrap_array(arr,func(arr, **kwargs)) for arr in grouped) combined = self._concat(applied, shortcut=shortcut) - result = self._maybe_restore_empty_groups(combined) + result = self._maybe_restore_empty_groups( + self._maybe_unstack_array(combined)) return result def _concat(self, applied, shortcut=False): # peek at applied to determine which coordinate to stack over applied_example, applied = peek_at(applied) concat_dim, positions = self._infer_concat_args(applied_example) - if shortcut: combined = self._concat_shortcut(applied, concat_dim, positions) else: combined = concat(applied, concat_dim, positions=positions) - if isinstance(combined, type(self.obj)): combined = self._restore_dim_order(combined) + combined = self._restore_multiindex(combined) return combined def reduce(self, func, dim=None, axis=None, keep_attrs=False, diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index d0f83457e68..853b051e425 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -1255,6 +1255,17 @@ def test_groupby_multidim(self): actual_sum = array.groupby(dim).sum() self.assertDataArrayIdentical(expected_sum, actual_sum) + def test_groupby_multidim_apply(self): + array = DataArray([[0,1],[2,3]], + coords={'lon': (['ny','nx'], [[30,40],[40,50]] ), + 'lat': (['ny','nx'], [[10,10],[20,20]] ),}, + dims=['ny','nx']) + actual = array.groupby('lon').apply( + lambda x : x - x.mean(), shortcut=False) + expected = DataArray([[0.,-0.5],[0.5,0.]], + coords=array.coords, dims=array.dims) + self.assertDataArrayIdentical(expected, actual) + def make_rolling_example_array(self): times = pd.date_range('2000-01-01', freq='1D', periods=21) values = np.random.random((21, 4))