Skip to content

Commit

Permalink
fix multi-index selection regression
Browse files Browse the repository at this point in the history
See #5691
  • Loading branch information
benbovy committed Aug 12, 2021
1 parent 1bb61d9 commit a551c7f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
22 changes: 15 additions & 7 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def _is_nested_tuple(possible_tuple):
)


def normalize_label(value, extract_scalar=False):
if getattr(value, "ndim", 1) <= 1:
value = _asarray_tuplesafe(value)
if extract_scalar:
# see https://github.com/pydata/xarray/pull/4292 for details
value = value[()] if value.dtype.kind in "mM" else value.item()
return value


def get_indexer_nd(index, labels, method=None, tolerance=None):
"""Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional
labels
Expand Down Expand Up @@ -207,14 +216,9 @@ def query(self, labels, method=None, tolerance=None):
"a dimension that does not have a MultiIndex"
)
else:
label = (
label
if getattr(label, "ndim", 1) > 1 # vectorized-indexing
else _asarray_tuplesafe(label)
)
label = normalize_label(label)
if label.ndim == 0:
# see https://github.com/pydata/xarray/pull/4292 for details
label_value = label[()] if label.dtype.kind in "mM" else label.item()
label_value = normalize_label(label, extract_scalar=True)
if isinstance(self.index, pd.CategoricalIndex):
if method is not None:
raise ValueError(
Expand Down Expand Up @@ -336,6 +340,10 @@ def query(self, labels, method=None, tolerance=None):
# label(s) given for multi-index level(s)
if all([lbl in self.index.names for lbl in labels]):
is_nested_vals = _is_nested_tuple(tuple(labels.values()))
labels = {
k: normalize_label(v, extract_scalar=True) for k, v in labels.items()
}

if len(labels) == self.index.nlevels and not is_nested_vals:
indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names))
else:
Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,20 @@ def test_sel_float(self):
assert_equal(expected_scalar, actual_scalar)
assert_equal(expected_16, actual_16)

def test_sel_float_multiindex(self):
# regression test https://github.com/pydata/xarray/issues/5691
midx = pd.MultiIndex.from_arrays(
[["a", "a", "b", "b"], [0.1, 0.2, 0.3, 0.4]], names=["lvl1", "lvl2"]
)
da = xr.DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x")

actual = da.sel(lvl1="a", lvl2=0.1)
expected = da.isel(x=0)

assert_equal(actual, expected)

# TODO: test multi-index created from coordinates, one with dtype=float32

def test_sel_no_index(self):
array = DataArray(np.arange(10), dims="x")
assert_identical(array[0], array.sel(x=0))
Expand Down

0 comments on commit a551c7f

Please sign in to comment.