Skip to content

Commit

Permalink
CLN: Implement multiindex handling for get_op_result_name (pandas-dev…
Browse files Browse the repository at this point in the history
…#38323)

* CLN: Implement multiindex handling for get_op_result_name

* Change import order

* Override method

* Move import

* Remove import

* Fix merge issue

* Move methods
  • Loading branch information
phofl authored and luckyvs1 committed Jan 20, 2021
1 parent 2caad13 commit cede451
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
1 change: 0 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2580,7 +2580,6 @@ def __nonzero__(self):
# --------------------------------------------------------------------
# Set Operation Methods

@final
def _get_reconciled_name_object(self, other):
"""
If the result of a set operation will be self,
Expand Down
34 changes: 31 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3593,6 +3593,34 @@ def _union(self, other, sort):
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
return is_object_dtype(dtype)

def _get_reconciled_name_object(self, other):
"""
If the result of a set operation will be self,
return self, unless the names change, in which
case make a shallow copy of self.
"""
names = self._maybe_match_names(other)
if self.names != names:
return self.rename(names)
return self

def _maybe_match_names(self, other):
"""
Try to find common names to attach to the result of an operation between
a and b. Return a consensus list of names if they match at least partly
or None if they have completely different names.
"""
if len(self.names) != len(other.names):
return None
names = []
for a_name, b_name in zip(self.names, other.names):
if a_name == b_name:
names.append(a_name)
else:
# TODO: what if they both have np.nan for their names?
names.append(None)
return names

def intersection(self, other, sort=False):
"""
Form the intersection of two MultiIndex objects.
Expand All @@ -3616,12 +3644,12 @@ def intersection(self, other, sort=False):
"""
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other, result_names = self._convert_can_do_setop(other)
other, _ = self._convert_can_do_setop(other)

if self.equals(other):
if self.has_duplicates:
return self.unique().rename(result_names)
return self.rename(result_names)
return self.unique()._get_reconciled_name_object(other)
return self._get_reconciled_name_object(other)

return self._intersection(other, sort=sort)

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _maybe_match_name(a, b):
"""
Try to find a name to attach to the result of an operation between
a and b. If only one of these has a `name` attribute, return that
name. Otherwise return a consensus name if they match of None if
name. Otherwise return a consensus name if they match or None if
they have different names.
Parameters
Expand Down
31 changes: 31 additions & 0 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,29 @@ def test_intersect_with_duplicates(tuples, exp_tuples):
tm.assert_index_equal(result, expected)


@pytest.mark.parametrize(
"data, names, expected",
[
((1,), None, None),
((1,), ["a"], None),
((1,), ["b"], None),
((1, 2), ["c", "d"], [None, None]),
((1, 2), ["b", "a"], [None, None]),
((1, 2, 3), ["a", "b", "c"], None),
((1, 2), ["a", "c"], ["a", None]),
((1, 2), ["c", "b"], [None, "b"]),
((1, 2), ["a", "b"], ["a", "b"]),
((1, 2), [None, "b"], [None, "b"]),
],
)
def test_maybe_match_names(data, names, expected):
# GH#38323
mi = pd.MultiIndex.from_tuples([], names=["a", "b"])
mi2 = pd.MultiIndex.from_tuples([data], names=names)
result = mi._maybe_match_names(mi2)
assert result == expected


def test_intersection_equal_different_names():
# GH#30302
mi1 = MultiIndex.from_arrays([[1, 2], [3, 4]], names=["c", "b"])
Expand All @@ -429,3 +452,11 @@ def test_intersection_equal_different_names():
result = mi1.intersection(mi2)
expected = MultiIndex.from_arrays([[1, 2], [3, 4]], names=[None, "b"])
tm.assert_index_equal(result, expected)


def test_intersection_different_names():
# GH#38323
mi = MultiIndex.from_arrays([[1], [3]], names=["c", "b"])
mi2 = MultiIndex.from_arrays([[1], [3]])
result = mi.intersection(mi2)
tm.assert_index_equal(result, mi2)

0 comments on commit cede451

Please sign in to comment.