Skip to content

Commit

Permalink
added unstack in apply
Browse files Browse the repository at this point in the history
fix dataset test bug

fixed apply
  • Loading branch information
rabernat committed Apr 7, 2016
1 parent 807001e commit b0db8a8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
45 changes: 24 additions & 21 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,22 @@ def __init__(self, obj, group, squeeze=False, grouper=None):
"""
from .dataset import as_dataset

if getattr(group, 'name', None) is None:
raise ValueError('`group` must have a name')
self._stacked_dim = None
if group.ndim != 1:
# try to stack the dims of the group into a single dim
# TODO: figure out how to exclude dimensions from the stacking
# (e.g. group over space dims but leave time dim intact)
orig_dims = group.dims
stacked_dim_name = 'stacked_' + '_'.join(orig_dims)
# the copy is necessary here
# the copy is necessary here, otherwise read only array raises error
# in pandas: https://github.com/pydata/pandas/issues/12813
# Is there a performance penalty for calling copy?
group = group.stack(**{stacked_dim_name: orig_dims}).copy()
# without it, an error is raised deep in pandas
########################
# xarray/core/groupby.py
# ---> 31 inverse, values = pd.factorize(ar, sort=True)
# pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint)
# --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True)
# pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)()
# pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)()
# pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)()
# ValueError: buffer source array is read-only
#######################
# seems related to
# https://github.com/pydata/pandas/issues/10043
# https://github.com/pydata/pandas/pull/10070
obj = obj.stack(**{stacked_dim_name: orig_dims})
if getattr(group, 'name', None) is None:
raise ValueError('`group` must have a name')
self._stacked_dim = stacked_dim_name
self._unstacked_dims = orig_dims
if not hasattr(group, 'dims'):
raise ValueError("`group` must have a 'dims' attribute")
group_dim, = group.dims
Expand Down Expand Up @@ -249,6 +240,13 @@ def _maybe_restore_empty_groups(self, combined):
combined = combined.reindex(**indexers)
return combined

def _maybe_unstack_array(self, arr):
"""This gets called if we are applying on an array with a
multidimensional group."""
if self._stacked_dim is not None and self._stacked_dim in arr.dims:
arr = arr.unstack(self._stacked_dim)
return arr

def fillna(self, value):
"""Fill missing values in this object by group.
Expand Down Expand Up @@ -358,6 +356,11 @@ def lookup_order(dimension):
new_order = sorted(stacked.dims, key=lookup_order)
return stacked.transpose(*new_order)

def _restore_multiindex(self, combined):
if self._stacked_dim is not None and self._stacked_dim in combined.dims:
combined[self._stacked_dim] = self.group[self._stacked_dim]
return combined

def apply(self, func, shortcut=False, **kwargs):
"""Apply a function over each array in the group and concatenate them
together into a new array.
Expand Down Expand Up @@ -399,23 +402,23 @@ def apply(self, func, shortcut=False, **kwargs):
grouped = self._iter_grouped_shortcut()
else:
grouped = self._iter_grouped()
applied = (maybe_wrap_array(arr, func(arr, **kwargs)) for arr in grouped)
applied = (maybe_wrap_array(arr,func(arr, **kwargs)) for arr in grouped)
combined = self._concat(applied, shortcut=shortcut)
result = self._maybe_restore_empty_groups(combined)
result = self._maybe_restore_empty_groups(
self._maybe_unstack_array(combined))
return result

def _concat(self, applied, shortcut=False):
# peek at applied to determine which coordinate to stack over
applied_example, applied = peek_at(applied)
concat_dim, positions = self._infer_concat_args(applied_example)

if shortcut:
combined = self._concat_shortcut(applied, concat_dim, positions)
else:
combined = concat(applied, concat_dim, positions=positions)

if isinstance(combined, type(self.obj)):
combined = self._restore_dim_order(combined)
combined = self._restore_multiindex(combined)
return combined

def reduce(self, func, dim=None, axis=None, keep_attrs=False,
Expand Down
11 changes: 11 additions & 0 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,17 @@ def test_groupby_multidim(self):
actual_sum = array.groupby(dim).sum()
self.assertDataArrayIdentical(expected_sum, actual_sum)

def test_groupby_multidim_apply(self):
array = DataArray([[0,1],[2,3]],
coords={'lon': (['ny','nx'], [[30,40],[40,50]] ),
'lat': (['ny','nx'], [[10,10],[20,20]] ),},
dims=['ny','nx'])
actual = array.groupby('lon').apply(
lambda x : x - x.mean(), shortcut=False)
expected = DataArray([[0.,-0.5],[0.5,0.]],
coords=array.coords, dims=array.dims)
self.assertDataArrayIdentical(expected, actual)

def make_rolling_example_array(self):
times = pd.date_range('2000-01-01', freq='1D', periods=21)
values = np.random.random((21, 4))
Expand Down

0 comments on commit b0db8a8

Please sign in to comment.