Skip to content

Commit

Permalink
Add optional sort parameter to difference method in subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
reidy-p committed Oct 11, 2018
1 parent 39715d6 commit 8f881c8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ Other Enhancements
- :meth:`Index.to_frame` now supports overriding column name(s) (:issue:`22580`).
- New attribute :attr:`__git_version__` will return git commit sha of current build (:issue:`21295`).
- Compatibility with Matplotlib 3.0 (:issue:`22790`).
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)

.. _whatsnew_0240.api_breaking:

Expand Down
20 changes: 15 additions & 5 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,10 +2764,18 @@ def intersection(self, other):
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)

def difference(self, other):
def difference(self, other, sort=True):
"""
Compute sorted set difference of two MultiIndex objects
Parameters
----------
other : MultiIndex
sort : bool, default True
Sort the resulting MultiIndex if possible
.. versionadded:: 0.24.0
Returns
-------
diff : MultiIndex
Expand All @@ -2780,11 +2788,13 @@ def difference(self, other):

if self.equals(other):
return MultiIndex(levels=self.levels,
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)

difference = set(self._ndarray_values) - set(other._ndarray_values)

difference = sorted(set(self._ndarray_values) -
set(other._ndarray_values))
if sort:
difference = sorted(difference)

if len(difference) == 0:
return MultiIndex(levels=[[]] * self.nlevels,
Expand Down
30 changes: 20 additions & 10 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,15 +1047,17 @@ def test_iadd_string(self):

@pytest.mark.parametrize("second_name,expected", [
(None, None), ('name', 'name')])
def test_difference_name_preservation(self, second_name, expected):
@pytest.mark.parametrize("sort", [
(True, False)])
def test_difference_name_preservation(self, second_name, expected, sort):
# TODO: replace with fixturesult
first = self.strIndex[5:20]
second = self.strIndex[:10]
answer = self.strIndex[10:20]

first.name = 'name'
second.name = second_name
result = first.difference(second)
result = first.difference(second, sort)

assert tm.equalContents(result, answer)

Expand All @@ -1064,18 +1066,22 @@ def test_difference_name_preservation(self, second_name, expected):
else:
assert result.name == expected

def test_difference_empty_arg(self):
@pytest.mark.parametrize("sort", [
(True, False)])
def test_difference_empty_arg(self, sort):
first = self.strIndex[5:20]
first.name == 'name'
result = first.difference([])
result = first.difference([], sort=sort)

assert tm.equalContents(result, first)
assert result.name == first.name

def test_difference_identity(self):
@pytest.mark.parametrize("sort", [
(True, False)])
def test_difference_identity(self, sort):
first = self.strIndex[5:20]
first.name == 'name'
result = first.difference(first)
result = first.difference(first, sort)

assert len(result) == 0
assert result.name == first.name
Expand Down Expand Up @@ -1124,13 +1130,15 @@ def test_symmetric_difference_non_index(self):
assert tm.equalContents(result, expected)
assert result.name == 'new_name'

def test_difference_type(self):
@pytest.mark.parametrize("sort", [
(True, False)])
def test_difference_type(self, sort):
# GH 20040
# If taking difference of a set and itself, it
# needs to preserve the type of the index
skip_index_keys = ['repeats']
for key, index in self.generate_index_types(skip_index_keys):
result = index.difference(index)
result = index.difference(index, sort)
expected = index.drop(index)
tm.assert_index_equal(result, expected)

Expand Down Expand Up @@ -2344,13 +2352,15 @@ def test_intersection_different_type_base(self, klass):
result = first.intersection(klass(second.values))
assert tm.equalContents(result, second)

def test_difference_base(self):
@pytest.mark.parametrize("sort", [
(True, False)])
def test_difference_base(self, sort):
# (same results for py2 and py3 but sortedness not tested elsewhere)
index = self.create_index()
first = index[:4]
second = index[3:]

result = first.difference(second)
result = first.difference(second, sort)
expected = Index([0, 1, 'a'])
tm.assert_index_equal(result, expected)

Expand Down

0 comments on commit 8f881c8

Please sign in to comment.