Skip to content

Commit

Permalink
multidimensional groupby
Browse files Browse the repository at this point in the history
  • Loading branch information
rabernat committed Apr 6, 2016
1 parent 4fdf6d4 commit 807001e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
24 changes: 22 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,28 @@ def __init__(self, obj, group, squeeze=False, grouper=None):
from .dataset import as_dataset

if group.ndim != 1:
# TODO: remove this limitation?
raise ValueError('`group` must be 1 dimensional')
# try to stack the dims of the group into a single dim
# TODO: figure out how to exclude dimensions from the stacking
# (e.g. group over space dims but leave time dim intact)
orig_dims = group.dims
stacked_dim_name = 'stacked_' + '_'.join(orig_dims)
# the copy is necessary here
group = group.stack(**{stacked_dim_name: orig_dims}).copy()
# without it, an error is raised deep in pandas
########################
# xarray/core/groupby.py
# ---> 31 inverse, values = pd.factorize(ar, sort=True)
# pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint)
# --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True)
# pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)()
# pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)()
# pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)()
# ValueError: buffer source array is read-only
#######################
# seems related to
# https://github.com/pydata/pandas/issues/10043
# https://github.com/pydata/pandas/pull/10070
obj = obj.stack(**{stacked_dim_name: orig_dims})
if getattr(group, 'name', None) is None:
raise ValueError('`group` must have a name')
if not hasattr(group, 'dims'):
Expand Down
25 changes: 18 additions & 7 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,17 @@ def test_groupby_first_and_last(self):
expected = array # should be a no-op
self.assertDataArrayIdentical(expected, actual)

def test_groupby_multidim(self):
array = DataArray([[0,1],[2,3]],
coords={'lon': (['ny','nx'], [[30,40],[40,50]] ),
'lat': (['ny','nx'], [[10,10],[20,20]] ),},
dims=['ny','nx'])
for dim, expected_sum in [
('lon', DataArray([0, 3, 3], coords={'lon': [30,40,50]})),
('lat', DataArray([1,5], coords={'lat': [10,20]}))]:
actual_sum = array.groupby(dim).sum()
self.assertDataArrayIdentical(expected_sum, actual_sum)

def make_rolling_example_array(self):
times = pd.date_range('2000-01-01', freq='1D', periods=21)
values = np.random.random((21, 4))
Expand Down Expand Up @@ -1792,29 +1803,29 @@ def test_full_like(self):
actual = _full_like(DataArray([1, 2, 3]), fill_value=np.nan)
self.assertEqual(actual.dtype, np.float)
np.testing.assert_equal(actual.values, np.nan)

def test_dot(self):
x = np.linspace(-3, 3, 6)
y = np.linspace(-3, 3, 5)
z = range(4)
z = range(4)
da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4))
da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z'])

dm_vals = range(4)
dm = DataArray(dm_vals, coords=[z], dims=['z'])

# nd dot 1d
actual = da.dot(dm)
expected_vals = np.tensordot(da_vals, dm_vals, [2, 0])
expected = DataArray(expected_vals, coords=[x, y], dims=['x', 'y'])
self.assertDataArrayEqual(expected, actual)

# all shared dims
actual = da.dot(da)
expected_vals = np.tensordot(da_vals, da_vals, axes=([0, 1, 2], [0, 1, 2]))
expected = DataArray(expected_vals)
self.assertDataArrayEqual(expected, actual)

# multiple shared dims
dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4))
j = np.linspace(-3, 3, 20)
Expand All @@ -1823,7 +1834,7 @@ def test_dot(self):
expected_vals = np.tensordot(da_vals, dm_vals, axes=([1, 2], [1, 2]))
expected = DataArray(expected_vals, coords=[x, j], dims=['x', 'j'])
self.assertDataArrayEqual(expected, actual)

with self.assertRaises(NotImplementedError):
da.dot(dm.to_dataset(name='dm'))
with self.assertRaises(TypeError):
Expand Down
2 changes: 0 additions & 2 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,8 +1545,6 @@ def test_groupby_iter(self):

def test_groupby_errors(self):
data = create_test_data()
with self.assertRaisesRegexp(ValueError, 'must be 1 dimensional'):
data.groupby('var1')
with self.assertRaisesRegexp(ValueError, 'must have a name'):
data.groupby(np.arange(10))
with self.assertRaisesRegexp(ValueError, 'length does not match'):
Expand Down

0 comments on commit 807001e

Please sign in to comment.