Skip to content

Commit

Permalink
dataset fix and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
0x0L authored and 0x0L committed Dec 10, 2017
1 parent b990042 commit cd152a4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
10 changes: 5 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3278,11 +3278,11 @@ def rank(self, dim, pct=False, keep_attrs=False):

variables = OrderedDict()
for name, var in iteritems(self.variables):
if name in self.data_vars and dim in var.dims:
variables[name] = var.rank(dim, pct=pct)
variables.update({
k: self.variables[k] for k in var.dims
if k not in variables and k in self.variables})
if name in self.data_vars:
if dim in var.dims:
variables[name] = var.rank(dim, pct=pct)
else:
variables[name] = var

coord_names = set(k for k in self.coords if k in variables)
attrs = self.attrs if keep_attrs else None
Expand Down
12 changes: 11 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3417,9 +3417,19 @@ def test_quantile(self):
@requires_bottleneck
def test_rank(self):
ds = create_test_data(seed=1234)
x = ds.rank('dim3').var3
# only ds.var3 depends on dim3
z = ds.rank('dim3')
self.assertItemsEqual(['var3'], list(z.data_vars))
# same as dataarray version
x = z.var3
y = ds.var3.rank('dim3')
self.assertDataArrayEqual(x, y)
# coordinates stick
self.assertItemsEqual(list(z.coords), list(ds.coords))
self.assertItemsEqual(list(x.coords), list(y.coords))
# invalid dim
with raises_regex(ValueError, 'does not contain'):
x.rank('invalid_dim')

def test_count(self):
ds = Dataset({'x': ('a', [np.nan, 1]), 'y': 0, 'z': np.nan})
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,13 @@ def test_quantile_dask_raises(self):
with raises_regex(TypeError, 'arrays stored as dask'):
v.quantile(0.5, dim='x')

@requires_dask
@requires_bottleneck
def test_rank_dask_raises(self):
v = Variable(['x'], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2)
with raises_regex(TypeError, 'arrays stored as dask'):
v.rank('x')

@requires_bottleneck
def test_rank(self):
import bottleneck as bn
Expand All @@ -1376,6 +1383,9 @@ def test_rank(self):
v = Variable(['x'], [3.0, 1.0, np.nan, 2.0, 4.0])
v_expect = Variable(['x'], [0.75, 0.25, np.nan, 0.5, 1.0])
self.assertVariableEqual(v.rank('x', pct=True), v_expect)
# invalid dim
with raises_regex(ValueError, 'not found'):
v.rank('y')

def test_big_endian_reduce(self):
# regression test for GH489
Expand Down

0 comments on commit cd152a4

Please sign in to comment.