Skip to content

Commit

Permalink
ENH: RangeIndex._shallow_copy can return RangeIndex (pandas-dev#47557)
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored Jul 8, 2022
1 parent 8e6ca28 commit e915b0a
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 72 deletions.
4 changes: 2 additions & 2 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ Other enhancements
- Allow reading compressed SAS files with :func:`read_sas` (e.g., ``.sas7bdat.gz`` files)
- :meth:`DatetimeIndex.astype` now supports casting timezone-naive indexes to ``datetime64[s]``, ``datetime64[ms]``, and ``datetime64[us]``, and timezone-aware indexes to the corresponding ``datetime64[unit, tzname]`` dtypes (:issue:`47579`)
- :class:`Series` reducers (e.g. ``min``, ``max``, ``sum``, ``mean``) will now successfully operate when the dtype is numeric and ``numeric_only=True`` is provided; previously this would raise a ``NotImplementedError`` (:issue:`47500`)
-
- :meth:`RangeIndex.union` now can return a :class:`RangeIndex` instead of a :class:`Int64Index` if the resulting values are equally spaced (:issue:`47557`, :issue:`43885`)

.. ---------------------------------------------------------------------------
.. _whatsnew_150.notable_bug_fixes:
Expand Down Expand Up @@ -1009,7 +1009,7 @@ Reshaping
- Bug in :func:`concat` with identical key leads to error when indexing :class:`MultiIndex` (:issue:`46519`)
- Bug in :meth:`DataFrame.join` with a list when using suffixes to join DataFrames with duplicate column names (:issue:`46396`)
- Bug in :meth:`DataFrame.pivot_table` with ``sort=False`` results in sorted index (:issue:`17041`)
-
- Bug in :meth:`concat` when ``axis=1`` and ``sort=False`` where the resulting Index was a :class:`Int64Index` instead of a :class:`RangeIndex` (:issue:`46675`)

Sparse
^^^^^^
Expand Down
131 changes: 79 additions & 52 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
index as libindex,
lib,
)
from pandas._libs.algos import unique_deltas
from pandas._libs.lib import no_default
from pandas._typing import (
Dtype,
Expand Down Expand Up @@ -436,7 +437,15 @@ def _shallow_copy(self, values, name: Hashable = no_default):

if values.dtype.kind == "f":
return Float64Index(values, name=name)
return Int64Index._simple_new(values, name=name)
# GH 46675 & 43885: If values is equally spaced, return a
# more memory-compact RangeIndex instead of Int64Index
unique_diffs = unique_deltas(values)
if len(unique_diffs) == 1 and unique_diffs[0] != 0:
diff = unique_diffs[0]
new_range = range(values[0], values[-1] + diff, diff)
return type(self)._simple_new(new_range, name=name)
else:
return Int64Index._simple_new(values, name=name)

def _view(self: RangeIndex) -> RangeIndex:
result = type(self)._simple_new(self._range, name=self._name)
Expand Down Expand Up @@ -638,6 +647,17 @@ def _extended_gcd(self, a: int, b: int) -> tuple[int, int, int]:
old_t, t = t, old_t - quotient * t
return old_r, old_s, old_t

def _range_in_self(self, other: range) -> bool:
"""Check if other range is contained in self"""
# https://stackoverflow.com/a/32481015
if not other:
return True
if not self._range:
return False
if len(other) > 1 and other.step % self._range.step:
return False
return other.start in self._range and other[-1] in self._range

def _union(self, other: Index, sort):
"""
Form the union of two Index objects and sorts if possible
Expand All @@ -647,64 +667,71 @@ def _union(self, other: Index, sort):
other : Index or array-like
sort : False or None, default None
Whether to sort resulting index. ``sort=None`` returns a
monotonically increasing ``RangeIndex`` if possible or a sorted
``Int64Index`` if not. ``sort=False`` always returns an
unsorted ``Int64Index``
Whether to sort (monotonically increasing) the resulting index.
``sort=None`` returns a ``RangeIndex`` if possible or a sorted
``Int64Index`` if not.
``sort=False`` can return a ``RangeIndex`` if self is monotonically
increasing and other is fully contained in self. Otherwise, returns
an unsorted ``Int64Index``
.. versionadded:: 0.25.0
Returns
-------
union : Index
"""
if isinstance(other, RangeIndex) and sort is None:
start_s, step_s = self.start, self.step
end_s = self.start + self.step * (len(self) - 1)
start_o, step_o = other.start, other.step
end_o = other.start + other.step * (len(other) - 1)
if self.step < 0:
start_s, step_s, end_s = end_s, -step_s, start_s
if other.step < 0:
start_o, step_o, end_o = end_o, -step_o, start_o
if len(self) == 1 and len(other) == 1:
step_s = step_o = abs(self.start - other.start)
elif len(self) == 1:
step_s = step_o
elif len(other) == 1:
step_o = step_s
start_r = min(start_s, start_o)
end_r = max(end_s, end_o)
if step_o == step_s:
if (
(start_s - start_o) % step_s == 0
and (start_s - end_o) <= step_s
and (start_o - end_s) <= step_s
):
return type(self)(start_r, end_r + step_s, step_s)
if (
(step_s % 2 == 0)
and (abs(start_s - start_o) == step_s / 2)
and (abs(end_s - end_o) == step_s / 2)
):
# e.g. range(0, 10, 2) and range(1, 11, 2)
# but not range(0, 20, 4) and range(1, 21, 4) GH#44019
return type(self)(start_r, end_r + step_s / 2, step_s / 2)

elif step_o % step_s == 0:
if (
(start_o - start_s) % step_s == 0
and (start_o + step_s >= start_s)
and (end_o - step_s <= end_s)
):
return type(self)(start_r, end_r + step_s, step_s)
elif step_s % step_o == 0:
if (
(start_s - start_o) % step_o == 0
and (start_s + step_o >= start_o)
and (end_s - step_o <= end_o)
):
return type(self)(start_r, end_r + step_o, step_o)
if isinstance(other, RangeIndex):
if sort is None or (
sort is False and self.step > 0 and self._range_in_self(other._range)
):
# GH 47557: Can still return a RangeIndex
# if other range in self and sort=False
start_s, step_s = self.start, self.step
end_s = self.start + self.step * (len(self) - 1)
start_o, step_o = other.start, other.step
end_o = other.start + other.step * (len(other) - 1)
if self.step < 0:
start_s, step_s, end_s = end_s, -step_s, start_s
if other.step < 0:
start_o, step_o, end_o = end_o, -step_o, start_o
if len(self) == 1 and len(other) == 1:
step_s = step_o = abs(self.start - other.start)
elif len(self) == 1:
step_s = step_o
elif len(other) == 1:
step_o = step_s
start_r = min(start_s, start_o)
end_r = max(end_s, end_o)
if step_o == step_s:
if (
(start_s - start_o) % step_s == 0
and (start_s - end_o) <= step_s
and (start_o - end_s) <= step_s
):
return type(self)(start_r, end_r + step_s, step_s)
if (
(step_s % 2 == 0)
and (abs(start_s - start_o) == step_s / 2)
and (abs(end_s - end_o) == step_s / 2)
):
# e.g. range(0, 10, 2) and range(1, 11, 2)
# but not range(0, 20, 4) and range(1, 21, 4) GH#44019
return type(self)(start_r, end_r + step_s / 2, step_s / 2)

elif step_o % step_s == 0:
if (
(start_o - start_s) % step_s == 0
and (start_o + step_s >= start_s)
and (end_o - step_s <= end_s)
):
return type(self)(start_r, end_r + step_s, step_s)
elif step_s % step_o == 0:
if (
(start_s - start_o) % step_o == 0
and (start_s + step_o >= start_o)
and (end_s - step_o <= end_o)
):
return type(self)(start_r, end_r + step_o, step_o)

return super()._union(other, sort=sort)

Expand Down
38 changes: 20 additions & 18 deletions pandas/tests/indexes/ranges/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def test_union_noncomparable(self, sort):
expected = Index(np.concatenate((other, index)))
tm.assert_index_equal(result, expected)

@pytest.fixture(
params=[
@pytest.mark.parametrize(
"idx1, idx2, expected_sorted, expected_notsorted",
[
(
RangeIndex(0, 10, 1),
RangeIndex(0, 10, 1),
Expand All @@ -157,13 +158,13 @@ def test_union_noncomparable(self, sort):
RangeIndex(0, 10, 1),
RangeIndex(5, 20, 1),
RangeIndex(0, 20, 1),
Int64Index(range(20)),
RangeIndex(0, 20, 1),
),
(
RangeIndex(0, 10, 1),
RangeIndex(10, 20, 1),
RangeIndex(0, 20, 1),
Int64Index(range(20)),
RangeIndex(0, 20, 1),
),
(
RangeIndex(0, -10, -1),
Expand All @@ -175,7 +176,7 @@ def test_union_noncomparable(self, sort):
RangeIndex(0, -10, -1),
RangeIndex(-10, -20, -1),
RangeIndex(-19, 1, 1),
Int64Index(range(0, -20, -1)),
RangeIndex(0, -20, -1),
),
(
RangeIndex(0, 10, 2),
Expand Down Expand Up @@ -205,7 +206,7 @@ def test_union_noncomparable(self, sort):
RangeIndex(0, 100, 5),
RangeIndex(0, 100, 20),
RangeIndex(0, 100, 5),
Int64Index(range(0, 100, 5)),
RangeIndex(0, 100, 5),
),
(
RangeIndex(0, -100, -5),
Expand All @@ -230,7 +231,7 @@ def test_union_noncomparable(self, sort):
RangeIndex(0, 100, 2),
RangeIndex(100, 150, 200),
RangeIndex(0, 102, 2),
Int64Index(range(0, 102, 2)),
RangeIndex(0, 102, 2),
),
(
RangeIndex(0, -100, -2),
Expand All @@ -242,13 +243,13 @@ def test_union_noncomparable(self, sort):
RangeIndex(0, -100, -1),
RangeIndex(0, -50, -3),
RangeIndex(-99, 1, 1),
Int64Index(list(range(0, -100, -1))),
RangeIndex(0, -100, -1),
),
(
RangeIndex(0, 1, 1),
RangeIndex(5, 6, 10),
RangeIndex(0, 6, 5),
Int64Index([0, 5]),
RangeIndex(0, 10, 5),
),
(
RangeIndex(0, 10, 5),
Expand All @@ -274,16 +275,17 @@ def test_union_noncomparable(self, sort):
Int64Index([1, 5, 6]),
Int64Index([1, 5, 6]),
),
]
# GH 43885
(
RangeIndex(0, 10),
RangeIndex(0, 5),
RangeIndex(0, 10),
RangeIndex(0, 10),
),
],
ids=lambda x: repr(x) if isinstance(x, RangeIndex) else x,
)
def unions(self, request):
"""Inputs and expected outputs for RangeIndex.union tests"""
return request.param

def test_union_sorted(self, unions):

idx1, idx2, expected_sorted, expected_notsorted = unions

def test_union_sorted(self, idx1, idx2, expected_sorted, expected_notsorted):
res1 = idx1.union(idx2, sort=None)
tm.assert_index_equal(res1, expected_sorted, exact=True)

Expand Down
22 changes: 22 additions & 0 deletions pandas/tests/reshape/concat/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,25 @@ def test_concat_index_find_common(self, dtype):
[[0, 1, 1.0], [0, 1, np.nan]], columns=Index([1, 2, 3], dtype="Int32")
)
tm.assert_frame_equal(result, expected)

def test_concat_axis_1_sort_false_rangeindex(self):
# GH 46675
s1 = Series(["a", "b", "c"])
s2 = Series(["a", "b"])
s3 = Series(["a", "b", "c", "d"])
s4 = Series([], dtype=object)
result = concat(
[s1, s2, s3, s4], sort=False, join="outer", ignore_index=False, axis=1
)
expected = DataFrame(
[
["a"] * 3 + [np.nan],
["b"] * 3 + [np.nan],
["c", np.nan] * 2,
[np.nan] * 2 + ["d"] + [np.nan],
],
dtype=object,
)
tm.assert_frame_equal(
result, expected, check_index_type=True, check_column_type=True
)

0 comments on commit e915b0a

Please sign in to comment.