Skip to content

Commit

Permalink
ENH: Add sort parameter to RangeIndex.union (pandas-dev#24471)
Browse files Browse the repository at this point in the history
  • Loading branch information
reidy-p committed Mar 18, 2019
1 parent e8d951d commit 05d1667
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,7 @@ def union(self, other, sort=None):
else:
rvals = other._values

if self.is_monotonic and other.is_monotonic:
if self.is_monotonic and other.is_monotonic and sort is None:
try:
result = self._outer_indexer(lvals, rvals)[0]
except TypeError:
Expand Down
8 changes: 4 additions & 4 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _extended_gcd(self, a, b):
old_t, t = t, old_t - quotient * t
return old_r, old_s, old_t

def union(self, other):
def union(self, other, sort=None):
"""
Form the union of two Index objects and sorts if possible
Expand All @@ -477,9 +477,9 @@ def union(self, other):
"""
self._assert_can_do_setop(other)
if len(other) == 0 or self.equals(other) or len(self) == 0:
return super(RangeIndex, self).union(other)
return super(RangeIndex, self).union(other, sort=sort)

if isinstance(other, RangeIndex):
if isinstance(other, RangeIndex) and sort is None:
start_s, step_s = self._start, self._step
end_s = self._start + self._step * (len(self) - 1)
start_o, step_o = other._start, other._step
Expand Down Expand Up @@ -516,7 +516,7 @@ def union(self, other):
(end_s - step_o <= end_o)):
return RangeIndex(start_r, end_r + step_o, step_o)

return self._int64index.union(other)
return self._int64index.union(other, sort=sort)

@Appender(_index_shared_docs['join'])
def join(self, other, how='left', level=None, return_indexers=False,
Expand Down
12 changes: 10 additions & 2 deletions pandas/tests/indexes/datetimes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def test_union_bug_1730(self, sort):
rng_b = date_range('1/1/2012', periods=4, freq='4H')

result = rng_a.union(rng_b, sort=sort)
exp = DatetimeIndex(sorted(set(list(rng_a)) | set(list(rng_b))))
exp = list(rng_a) + list(rng_b[1:])
if sort is None:
exp = DatetimeIndex(sorted(exp))
else:
exp = DatetimeIndex(exp)
tm.assert_index_equal(result, exp)

@pytest.mark.parametrize("sort", [None, False])
Expand All @@ -112,7 +116,11 @@ def test_union_bug_4564(self, sort):
right = left + DateOffset(minutes=15)

result = left.union(right, sort=sort)
exp = DatetimeIndex(sorted(set(list(left)) | set(list(right))))
exp = list(left) + list(right)
if sort is None:
exp = DatetimeIndex(sorted(exp))
else:
exp = DatetimeIndex(exp)
tm.assert_index_equal(result, exp)

@pytest.mark.parametrize("sort", [None, False])
Expand Down
11 changes: 9 additions & 2 deletions pandas/tests/indexes/period/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def test_union(self, sort):
# union
other1 = pd.period_range('1/1/2000', freq='D', periods=5)
rng1 = pd.period_range('1/6/2000', freq='D', periods=5)
expected1 = pd.period_range('1/1/2000', freq='D', periods=10)
expected1 = pd.PeriodIndex(['2000-01-06', '2000-01-07',
'2000-01-08', '2000-01-09',
'2000-01-10', '2000-01-01',
'2000-01-02', '2000-01-03',
'2000-01-04', '2000-01-05'],
freq='D')

rng2 = pd.period_range('1/1/2000', freq='D', periods=5)
other2 = pd.period_range('1/4/2000', freq='D', periods=5)
Expand Down Expand Up @@ -77,7 +82,9 @@ def test_union(self, sort):

rng7 = pd.period_range('2003-01-01', freq='A', periods=5)
other7 = pd.period_range('1998-01-01', freq='A', periods=8)
expected7 = pd.period_range('1998-01-01', freq='A', periods=10)
expected7 = pd.PeriodIndex(['2003', '2004', '2005', '2006', '2007',
'1998', '1999', '2000', '2001', '2002'],
freq='A')

rng8 = pd.PeriodIndex(['1/3/2000', '1/2/2000', '1/1/2000',
'1/5/2000', '1/4/2000'], freq='D')
Expand Down
98 changes: 75 additions & 23 deletions pandas/tests/indexes/test_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,36 +586,88 @@ def test_union_noncomparable(self):
def test_union(self):
RI = RangeIndex
I64 = Int64Index
cases = [(RI(0, 10, 1), RI(0, 10, 1), RI(0, 10, 1)),
(RI(0, 10, 1), RI(5, 20, 1), RI(0, 20, 1)),
(RI(0, 10, 1), RI(10, 20, 1), RI(0, 20, 1)),
(RI(0, -10, -1), RI(0, -10, -1), RI(0, -10, -1)),
(RI(0, -10, -1), RI(-10, -20, -1), RI(-19, 1, 1)),
(RI(0, 10, 2), RI(1, 10, 2), RI(0, 10, 1)),
(RI(0, 11, 2), RI(1, 12, 2), RI(0, 12, 1)),
(RI(0, 21, 4), RI(-2, 24, 4), RI(-2, 24, 2)),
(RI(0, -20, -2), RI(-1, -21, -2), RI(-19, 1, 1)),
(RI(0, 100, 5), RI(0, 100, 20), RI(0, 100, 5)),
(RI(0, -100, -5), RI(5, -100, -20), RI(-95, 10, 5)),
(RI(0, -11, -1), RI(1, -12, -4), RI(-11, 2, 1)),
(RI(0), RI(0), RI(0)),
(RI(0, -10, -2), RI(0), RI(0, -10, -2)),
(RI(0, 100, 2), RI(100, 150, 200), RI(0, 102, 2)),
(RI(0, -100, -2), RI(-100, 50, 102), RI(-100, 4, 2)),
(RI(0, -100, -1), RI(0, -50, -3), RI(-99, 1, 1)),
(RI(0, 1, 1), RI(5, 6, 10), RI(0, 6, 5)),
(RI(0, 10, 5), RI(-5, -6, -20), RI(-5, 10, 5)),
(RI(0, 3, 1), RI(4, 5, 1), I64([0, 1, 2, 4])),
(RI(0, 10, 1), I64([]), RI(0, 10, 1)),
(RI(0), I64([1, 5, 6]), I64([1, 5, 6]))]
for idx1, idx2, expected in cases:

inputs = [(RI(0, 10, 1), RI(0, 10, 1)),
(RI(0, 10, 1), RI(5, 20, 1)),
(RI(0, 10, 1), RI(10, 20, 1)),
(RI(0, -10, -1), RI(0, -10, -1)),
(RI(0, -10, -1), RI(-10, -20, -1)),
(RI(0, 10, 2), RI(1, 10, 2)),
(RI(0, 11, 2), RI(1, 12, 2)),
(RI(0, 21, 4), RI(-2, 24, 4)),
(RI(0, -20, -2), RI(-1, -21, -2)),
(RI(0, 100, 5), RI(0, 100, 20)),
(RI(0, -100, -5), RI(5, -100, -20)),
(RI(0, -11, -1), RI(1, -12, -4)),
(RI(0), RI(0)),
(RI(0, -10, -2), RI(0)),
(RI(0, 100, 2), RI(100, 150, 200)),
(RI(0, -100, -2), RI(-100, 50, 102)),
(RI(0, -100, -1), RI(0, -50, -3)),
(RI(0, 1, 1), RI(5, 6, 10)),
(RI(0, 10, 5), RI(-5, -6, -20)),
(RI(0, 3, 1), RI(4, 5, 1)),
(RI(0, 10, 1), I64([])),
(RI(0), I64([1, 5, 6]))]

expected_sorted = [RI(0, 10, 1),
RI(0, 20, 1),
RI(0, 20, 1),
RI(0, -10, -1),
RI(-19, 1, 1),
RI(0, 10, 1),
RI(0, 12, 1),
RI(-2, 24, 2),
RI(-19, 1, 1),
RI(0, 100, 5),
RI(-95, 10, 5),
RI(-11, 2, 1),
RI(0),
RI(0, -10, -2),
RI(0, 102, 2),
RI(-100, 4, 2),
RI(-99, 1, 1),
RI(0, 6, 5),
RI(-5, 10, 5),
I64([0, 1, 2, 4]),
RI(0, 10, 1),
I64([1, 5, 6])]

for ((idx1, idx2), expected) in zip(inputs, expected_sorted):
res1 = idx1.union(idx2)
res2 = idx2.union(idx1)
res3 = idx1._int64index.union(idx2)
tm.assert_index_equal(res1, expected, exact=True)
tm.assert_index_equal(res2, expected, exact=True)
tm.assert_index_equal(res3, expected)

expected_notsorted = [RI(0, 10, 1),
I64(range(20)),
I64(range(20)),
RI(0, -10, -1),
I64(range(0, -20, -1)),
I64(list(range(0, 10, 2)) + list(range(1, 10, 2))),
I64(list(range(0, 11, 2)) + list(range(1, 12, 2))),
I64(list(range(0, 21, 4)) + list(range(-2, 24, 4))),
I64(list(range(0, -20, -2)) + list(range(-1, -21, -2))),
I64(range(0, 100, 5)),
I64(list(range(0, -100, -5)) + [5]),
I64(list(range(0, -11, -1)) + [1, -11]),
RI(0),
RI(0, -10, -2),
I64(range(0, 102, 2)),
I64(list(range(0, -100, -2)) + [-100, 2]),
I64(list(range(0, -100, -1))),
I64([0, 5]),
I64([0, 5, -5]),
I64([0, 1, 2, 4]),
RI(0, 10, 1),
I64([1, 5, 6])]

for ((idx1, idx2), expected) in zip(inputs, expected_notsorted):
res1 = idx1.union(idx2, sort=False)
tm.assert_index_equal(res1, expected, exact=True)

def test_nbytes(self):

# memory savings vs int index
Expand Down

0 comments on commit 05d1667

Please sign in to comment.