Skip to content

Commit

Permalink
added unstack in apply
Browse files Browse the repository at this point in the history
  • Loading branch information
rabernat committed Apr 7, 2016
1 parent 807001e commit 562ba92
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
40 changes: 22 additions & 18 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,29 +101,24 @@ def __init__(self, obj, group, squeeze=False, grouper=None):
"""
from .dataset import as_dataset

self._stacked_dim = None
group_name = group.name
if group.ndim != 1:
# try to stack the dims of the group into a single dim
# TODO: figure out how to exclude dimensions from the stacking
# (e.g. group over space dims but leave time dim intact)
orig_dims = group.dims
stacked_dim_name = 'stacked_' + '_'.join(orig_dims)
# the copy is necessary here
# the copy is necessary here, otherwise read only array raises error
# in pandas: https://github.com/pydata/pandas/issues/12813
# Is there a performance penalty for calling copy?
group = group.stack(**{stacked_dim_name: orig_dims}).copy()
# without it, an error is raised deep in pandas
########################
# xarray/core/groupby.py
# ---> 31 inverse, values = pd.factorize(ar, sort=True)
# pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint)
# --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True)
# pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)()
# pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)()
# pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)()
# ValueError: buffer source array is read-only
#######################
# seems related to
# https://github.com/pydata/pandas/issues/10043
# https://github.com/pydata/pandas/pull/10070
obj = obj.stack(**{stacked_dim_name: orig_dims})
self._stacked_dim = stacked_dim_name
self._unstacked_dims = orig_dims
# we also need to rename the group name to avoid a conflict when
# concatenating
group_name += '_groups'
if getattr(group, 'name', None) is None:
raise ValueError('`group` must have a name')
if not hasattr(group, 'dims'):
Expand Down Expand Up @@ -157,7 +152,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None):
unique_coord = Coordinate(group.name, first_items.index)
elif group.name in obj.dims:
# assume that group already has sorted, unique values
if group.dims != (group.name,):
if group.dims != (group_name,):
raise ValueError('`group` is required to be a coordinate if '
'`group.name` is a dimension in `obj`')
group_indices = np.arange(group.size)
Expand All @@ -169,7 +164,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None):
else:
# look through group to find the unique values
unique_values, group_indices = unique_value_groups(group)
unique_coord = Coordinate(group.name, unique_values)
unique_coord = Coordinate(group_name, unique_values)

self.obj = obj
self.group = group
Expand Down Expand Up @@ -249,6 +244,13 @@ def _maybe_restore_empty_groups(self, combined):
combined = combined.reindex(**indexers)
return combined

def _maybe_unstack_array(self, arr):
"""This gets called if we are applying on an array with a
multidimensional group."""
if self._stacked_dim is not None and self._stacked_dim in arr.dims:
arr = arr.unstack(self._stacked_dim)
return arr

def fillna(self, value):
"""Fill missing values in this object by group.
Expand Down Expand Up @@ -399,7 +401,9 @@ def apply(self, func, shortcut=False, **kwargs):
grouped = self._iter_grouped_shortcut()
else:
grouped = self._iter_grouped()
applied = (maybe_wrap_array(arr, func(arr, **kwargs)) for arr in grouped)
applied = (self._maybe_unstack_array(
maybe_wrap_array(arr,func(arr, **kwargs)))
for arr in grouped)
combined = self._concat(applied, shortcut=shortcut)
result = self._maybe_restore_empty_groups(combined)
return result
Expand Down
6 changes: 4 additions & 2 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,8 +1250,10 @@ def test_groupby_multidim(self):
'lat': (['ny','nx'], [[10,10],[20,20]] ),},
dims=['ny','nx'])
for dim, expected_sum in [
('lon', DataArray([0, 3, 3], coords={'lon': [30,40,50]})),
('lat', DataArray([1,5], coords={'lat': [10,20]}))]:
('lon', DataArray([0, 3, 3],
coords={'lon_groups': [30,40,50]})),
('lat', DataArray([1,5],
coords={'lat_groups': [10,20]}))]:
actual_sum = array.groupby(dim).sum()
self.assertDataArrayIdentical(expected_sum, actual_sum)

Expand Down

0 comments on commit 562ba92

Please sign in to comment.