Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reidy-p committed Oct 20, 2018
1 parent b13db31 commit f7446b5
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 75 deletions.
8 changes: 6 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def equals(self, other):
self.closed == other.closed)

def _setop(op_name):
def func(self, other):
def func(self, other, sort=True):
other = self._as_like_interval_index(other)

# GH 19016: ensure set op will not return a prohibited dtype
Expand All @@ -1040,7 +1040,11 @@ def func(self, other):
'objects that have compatible dtypes')
raise TypeError(msg.format(op=op_name))

result = getattr(self._multiindex, op_name)(other._multiindex)
if op_name == 'difference':
result = getattr(self._multiindex, op_name)(other._multiindex,
sort)
else:
result = getattr(self._multiindex, op_name)(other._multiindex)
result_name = self.name if self.name == other.name else None

# GH 19101: ensure empty results have correct dtype
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2791,8 +2791,14 @@ def difference(self, other, sort=True):
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)

difference = set(self._ndarray_values) - set(other._ndarray_values)
this = self._get_unique_index()

indexer = this.get_indexer(other)
indexer = indexer.take((indexer != -1).nonzero()[0])

label_diff = np.setdiff1d(np.arange(this.size), indexer,
assume_unique=True)
difference = this.values.take(label_diff)
if sort:
difference = sorted(difference)

Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,13 @@ def test_union_base(self):
with tm.assert_raises_regex(TypeError, msg):
result = first.union([1, 2, 3])

def test_difference_base(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_base(self, sort):
for name, idx in compat.iteritems(self.indices):
first = idx[2:]
second = idx[:4]
answer = idx[4:]
result = first.difference(second)
result = first.difference(second, sort)

if isinstance(idx, CategoricalIndex):
pass
Expand All @@ -687,21 +688,21 @@ def test_difference_base(self):
if isinstance(idx, PeriodIndex):
msg = "can only call with other PeriodIndex-ed objects"
with tm.assert_raises_regex(ValueError, msg):
result = first.difference(case)
result = first.difference(case, sort)
elif isinstance(idx, CategoricalIndex):
pass
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
assert result.__class__ == answer.__class__
tm.assert_numpy_array_equal(result.sort_values().asi8,
answer.sort_values().asi8)
else:
result = first.difference(case)
result = first.difference(case, sort)
assert tm.equalContents(result, answer)

if isinstance(idx, MultiIndex):
msg = "other must be a MultiIndex or a list of tuples"
with tm.assert_raises_regex(TypeError, msg):
result = first.difference([1, 2, 3])
result = first.difference([1, 2, 3], sort)

def test_symmetric_difference(self):
for name, idx in compat.iteritems(self.indices):
Expand Down
34 changes: 21 additions & 13 deletions pandas/tests/indexes/datetimes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,47 +206,55 @@ def test_intersection_bug_1708(self):
assert len(result) == 0

@pytest.mark.parametrize("tz", tz)
def test_difference(self, tz):
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
@pytest.mark.parametrize("sort", [True, False])
def test_difference(self, tz, sort):
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
'1/5/2000']

rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)

rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)

rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
other3 = pd.DatetimeIndex([], tz=tz)
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)

for rng, other, expected in [(rng1, other1, expected1),
(rng2, other2, expected2),
(rng3, other3, expected3)]:
result_diff = rng.difference(other)
result_diff = rng.difference(other, sort)
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result_diff, expected)

def test_difference_freq(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_freq(self, sort):
# GH14323: difference of DatetimeIndex should not preserve frequency

index = date_range("20160920", "20160925", freq="D")
other = date_range("20160921", "20160924", freq="D")
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)

other = date_range("20160922", "20160925", freq="D")
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)

def test_datetimeindex_diff(self):
@pytest.mark.parametrize("sort", [True, False])
def test_datetimeindex_diff(self, sort):
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
periods=100)
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
periods=98)
assert len(dti1.difference(dti2)) == 2
assert len(dti1.difference(dti2, sort)) == 2

def test_datetimeindex_union_join_empty(self):
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')
Expand Down
17 changes: 12 additions & 5 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,19 +798,26 @@ def test_intersection(self, closed):
result = index.intersection(other)
tm.assert_index_equal(result, expected)

def test_difference(self, closed):
index = self.create_index(closed=closed)
tm.assert_index_equal(index.difference(index[:1]), index[1:])
@pytest.mark.parametrize("sort", [True, False])
def test_difference(self, closed, sort):
index = IntervalIndex.from_arrays([1, 0, 3, 2],
[1, 2, 3, 4],
closed=closed)
result = index.difference(index[:1], sort)
expected = index[1:]
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result, expected)

# GH 19101: empty result, same dtype
result = index.difference(index)
result = index.difference(index, sort)
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different dtypes
other = IntervalIndex.from_arrays(index.left.astype('float64'),
index.right, closed=closed)
result = index.difference(other)
result = index.difference(other, sort)
tm.assert_index_equal(result, expected)

def test_symmetric_difference(self, closed):
Expand Down
38 changes: 23 additions & 15 deletions pandas/tests/indexes/multi/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pandas.util.testing as tm
from pandas import MultiIndex, Series
import pytest


def test_setops_errorcases(idx):
Expand Down Expand Up @@ -58,24 +59,25 @@ def test_union_base(idx):
result = first.union([1, 2, 3])


def test_difference_base(idx):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_base(idx, sort):
first = idx[2:]
second = idx[:4]
answer = idx[4:]
result = first.difference(second)
result = first.difference(second, sort)

assert tm.equalContents(result, answer)

# GH 10149
cases = [klass(second.values)
for klass in [np.array, Series, list]]
for case in cases:
result = first.difference(case)
result = first.difference(case, sort)
assert tm.equalContents(result, answer)

msg = "other must be a MultiIndex or a list of tuples"
with tm.assert_raises_regex(TypeError, msg):
result = first.difference([1, 2, 3])
result = first.difference([1, 2, 3], sort)


def test_symmetric_difference(idx):
Expand Down Expand Up @@ -103,11 +105,17 @@ def test_empty(idx):
assert idx[:0].empty


def test_difference(idx):
@pytest.mark.parametrize("sort", [True, False])
def test_difference(idx, sort):

first = idx
result = first.difference(idx[-3:])
expected = MultiIndex.from_tuples(sorted(idx[:-3].values),
result = first.difference(idx[-3:], sort)
vals = idx[:-3].values

if sort:
vals = sorted(vals)

expected = MultiIndex.from_tuples(vals,
sortorder=0,
names=idx.names)

Expand All @@ -116,44 +124,44 @@ def test_difference(idx):
assert result.names == idx.names

# empty difference: reflexive
result = idx.difference(idx)
result = idx.difference(idx, sort)
expected = idx[:0]
assert result.equals(expected)
assert result.names == idx.names

# empty difference: superset
result = idx[-3:].difference(idx)
result = idx[-3:].difference(idx, sort)
expected = idx[:0]
assert result.equals(expected)
assert result.names == idx.names

# empty difference: degenerate
result = idx[:0].difference(idx)
result = idx[:0].difference(idx, sort)
expected = idx[:0]
assert result.equals(expected)
assert result.names == idx.names

# names not the same
chunklet = idx[-3:]
chunklet.names = ['foo', 'baz']
result = first.difference(chunklet)
result = first.difference(chunklet, sort)
assert result.names == (None, None)

# empty, but non-equal
result = idx.difference(idx.sortlevel(1)[0])
result = idx.difference(idx.sortlevel(1)[0], sort)
assert len(result) == 0

# raise Exception called with non-MultiIndex
result = first.difference(first.values)
result = first.difference(first.values, sort)
assert result.equals(first[:0])

# name from empty array
result = first.difference([])
result = first.difference([], sort)
assert first.equals(result)
assert first.names == result.names

# name from non-empty array
result = first.difference([('foo', 'one')])
result = first.difference([('foo', 'one')], sort)
expected = pd.MultiIndex.from_tuples([('bar', 'one'), ('baz', 'two'), (
'foo', 'two'), ('qux', 'one'), ('qux', 'two')])
expected.names = first.names
Expand Down
7 changes: 4 additions & 3 deletions pandas/tests/indexes/period/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,21 @@ def test_no_millisecond_field(self):
with pytest.raises(AttributeError):
DatetimeIndex([]).millisecond

def test_difference_freq(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_freq(self, sort):
# GH14323: difference of Period MUST preserve frequency
# but the ability to union results must be preserved

index = period_range("20160920", "20160925", freq="D")

other = period_range("20160921", "20160924", freq="D")
expected = PeriodIndex(["20160920", "20160925"], freq='D')
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)

other = period_range("20160922", "20160925", freq="D")
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
expected = PeriodIndex(["20160920", "20160921"], freq='D')
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)
Expand Down
42 changes: 28 additions & 14 deletions pandas/tests/indexes/period/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,37 +204,49 @@ def test_intersection_cases(self):
result = rng.intersection(rng[0:0])
assert len(result) == 0

def test_difference(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference(self, sort):
# diff
rng1 = pd.period_range('1/1/2000', freq='D', periods=5)
period_rng = ['1/3/2000', '1/2/2000', '1/1/2000', '1/5/2000',
'1/4/2000']
rng1 = pd.PeriodIndex(period_rng, freq='D')
other1 = pd.period_range('1/6/2000', freq='D', periods=5)
expected1 = pd.period_range('1/1/2000', freq='D', periods=5)
expected1 = rng1

rng2 = pd.period_range('1/1/2000', freq='D', periods=5)
rng2 = pd.PeriodIndex(period_rng, freq='D')
other2 = pd.period_range('1/4/2000', freq='D', periods=5)
expected2 = pd.period_range('1/1/2000', freq='D', periods=3)
expected2 = pd.PeriodIndex(['1/3/2000', '1/2/2000', '1/1/2000'],
freq='D')

rng3 = pd.period_range('1/1/2000', freq='D', periods=5)
rng3 = pd.PeriodIndex(period_rng, freq='D')
other3 = pd.PeriodIndex([], freq='D')
expected3 = pd.period_range('1/1/2000', freq='D', periods=5)
expected3 = rng3

rng4 = pd.period_range('2000-01-01 09:00', freq='H', periods=5)
period_rng = ['2000-01-01 10:00', '2000-01-01 09:00',
'2000-01-01 12:00', '2000-01-01 11:00',
'2000-01-01 13:00']
rng4 = pd.PeriodIndex(period_rng, freq='H')
other4 = pd.period_range('2000-01-02 09:00', freq='H', periods=5)
expected4 = rng4

rng5 = pd.PeriodIndex(['2000-01-01 09:01', '2000-01-01 09:03',
rng5 = pd.PeriodIndex(['2000-01-01 09:03', '2000-01-01 09:01',
'2000-01-01 09:05'], freq='T')
other5 = pd.PeriodIndex(
['2000-01-01 09:01', '2000-01-01 09:05'], freq='T')
expected5 = pd.PeriodIndex(['2000-01-01 09:03'], freq='T')

rng6 = pd.period_range('2000-01-01', freq='M', periods=7)
period_rng = ['2000-02-01', '2000-01-01', '2000-06-01',
'2000-07-01', '2000-05-01', '2000-03-01',
'2000-04-01']
rng6 = pd.PeriodIndex(period_rng, freq='M')
other6 = pd.period_range('2000-04-01', freq='M', periods=7)
expected6 = pd.period_range('2000-01-01', freq='M', periods=3)
expected6 = pd.PeriodIndex(['2000-02-01', '2000-01-01', '2000-03-01'],
freq='M')

rng7 = pd.period_range('2003-01-01', freq='A', periods=5)
period_rng = ['2003', '2007', '2006', '2005', '2004']
rng7 = pd.PeriodIndex(period_rng, freq='A')
other7 = pd.period_range('1998-01-01', freq='A', periods=8)
expected7 = pd.period_range('2006-01-01', freq='A', periods=2)
expected7 = pd.PeriodIndex(['2007', '2006'], freq='A')

for rng, other, expected in [(rng1, other1, expected1),
(rng2, other2, expected2),
Expand All @@ -243,5 +255,7 @@ def test_difference(self):
(rng5, other5, expected5),
(rng6, other6, expected6),
(rng7, other7, expected7), ]:
result_union = rng.difference(other)
result_union = rng.difference(other, sort)
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result_union, expected)
Loading

0 comments on commit f7446b5

Please sign in to comment.