Skip to content

Commit

Permalink
ENH: NumericIndex for any numpy int/uint/float dtype (#41153)
Browse files Browse the repository at this point in the history
  • Loading branch information
topper-123 authored Aug 5, 2021
1 parent ae8321d commit fcadfb8
Show file tree
Hide file tree
Showing 16 changed files with 321 additions and 103 deletions.
3 changes: 3 additions & 0 deletions pandas/_libs/join.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ ctypedef fused join_t:
int16_t
int32_t
int64_t
uint8_t
uint16_t
uint32_t
uint64_t


Expand Down
3 changes: 2 additions & 1 deletion pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
use_numexpr,
with_csv_dialect,
)
from pandas.core.api import NumericIndex
from pandas.core.arrays import (
DatetimeArray,
PandasArray,
Expand Down Expand Up @@ -314,7 +315,7 @@ def makeNumericIndex(k=10, name=None, *, dtype):
else:
raise NotImplementedError(f"wrong dtype {dtype}")

return Index(values, dtype=dtype, name=name)
return NumericIndex(values, dtype=dtype, name=name)


def makeIntIndex(k=10, name=None):
Expand Down
15 changes: 14 additions & 1 deletion pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ def _create_mi_with_dt64tz_level():
"uint": tm.makeUIntIndex(100),
"range": tm.makeRangeIndex(100),
"float": tm.makeFloatIndex(100),
"num_int64": tm.makeNumericIndex(100, dtype="int64"),
"num_int32": tm.makeNumericIndex(100, dtype="int32"),
"num_int16": tm.makeNumericIndex(100, dtype="int16"),
"num_int8": tm.makeNumericIndex(100, dtype="int8"),
"num_uint64": tm.makeNumericIndex(100, dtype="uint64"),
"num_uint32": tm.makeNumericIndex(100, dtype="uint32"),
"num_uint16": tm.makeNumericIndex(100, dtype="uint16"),
"num_uint8": tm.makeNumericIndex(100, dtype="uint8"),
"num_float64": tm.makeNumericIndex(100, dtype="float64"),
"num_float32": tm.makeNumericIndex(100, dtype="float32"),
"bool": tm.makeBoolIndex(10),
"categorical": tm.makeCategoricalIndex(100),
"interval": tm.makeIntervalIndex(100),
Expand Down Expand Up @@ -511,7 +521,10 @@ def index_flat(request):
params=[
key
for key in indices_dict
if key not in ["int", "uint", "range", "empty", "repeats"]
if not (
key in ["int", "uint", "range", "empty", "repeats"]
or key.startswith("num_")
)
and not isinstance(indices_dict[key], MultiIndex)
]
)
Expand Down
1 change: 1 addition & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
Int64Index,
IntervalIndex,
MultiIndex,
NumericIndex,
PeriodIndex,
RangeIndex,
TimedeltaIndex,
Expand Down
1 change: 1 addition & 0 deletions pandas/core/dtypes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _check(cls, inst) -> bool:
"rangeindex",
"float64index",
"uint64index",
"numericindex",
"multiindex",
"datetimeindex",
"timedeltaindex",
Expand Down
17 changes: 17 additions & 0 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
is_interval_dtype,
is_iterator,
is_list_like,
is_numeric_dtype,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
Expand Down Expand Up @@ -360,6 +361,11 @@ def _outer_indexer(
_can_hold_na: bool = True
_can_hold_strings: bool = True

# Whether this index is a NumericIndex, but not a Int64Index, Float64Index,
# UInt64Index or RangeIndex. Needed for backwards compat. Remove this attribute and
# associated code in pandas 2.0.
_is_backward_compat_public_numeric_index: bool = False

_engine_type: type[libindex.IndexEngine] = libindex.ObjectEngine
# whether we support partial string indexing. Overridden
# in DatetimeIndex and PeriodIndex
Expand Down Expand Up @@ -437,6 +443,12 @@ def __new__(
return Index._simple_new(data, name=name)

# index-like
elif (
isinstance(data, Index)
and data._is_backward_compat_public_numeric_index
and dtype is None
):
return data._constructor(data, name=name, copy=copy)
elif isinstance(data, (np.ndarray, Index, ABCSeries)):

if isinstance(data, ABCMultiIndex):
Expand Down Expand Up @@ -5726,6 +5738,11 @@ def map(self, mapper, na_action=None):
# empty
attributes["dtype"] = self.dtype

if self._is_backward_compat_public_numeric_index and is_numeric_dtype(
new_values.dtype
):
return self._constructor(new_values, **attributes)

return Index(new_values, **attributes)

# TODO: De-duplicate with map, xref GH#32349
Expand Down
25 changes: 25 additions & 0 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pandas.core.dtypes.common import (
is_categorical_dtype,
is_scalar,
pandas_dtype,
)
from pandas.core.dtypes.missing import (
is_valid_na_for_dtype,
Expand Down Expand Up @@ -280,6 +281,30 @@ def _is_dtype_compat(self, other) -> Categorical:

return other

@doc(Index.astype)
def astype(self, dtype: Dtype, copy: bool = True) -> Index:
from pandas.core.api import NumericIndex

dtype = pandas_dtype(dtype)

categories = self.categories
# the super method always returns Int64Index, UInt64Index and Float64Index
# but if the categories are a NumericIndex with dtype float32, we want to
# return an index with the same dtype as self.categories.
if categories._is_backward_compat_public_numeric_index:
assert isinstance(categories, NumericIndex) # mypy complaint fix
try:
categories._validate_dtype(dtype)
except ValueError:
pass
else:
new_values = self._data.astype(dtype, copy=copy)
# pass copy=False because any copying has been done in the
# _data.astype call above
return categories._constructor(new_values, name=self.name, copy=False)

return super().astype(dtype, copy=copy)

def equals(self, other: object) -> bool:
"""
Determine if two CategoricalIndex objects contain the same elements.
Expand Down
51 changes: 42 additions & 9 deletions pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class NumericIndex(Index):
)
_is_numeric_dtype = True
_can_hold_strings = False
_is_backward_compat_public_numeric_index: bool = True

@cache_readonly
def _can_hold_na(self) -> bool:
Expand Down Expand Up @@ -165,7 +166,15 @@ def _ensure_array(cls, data, dtype, copy: bool):
dtype = cls._ensure_dtype(dtype)

if copy or not is_dtype_equal(data.dtype, dtype):
subarr = np.array(data, dtype=dtype, copy=copy)
# TODO: the try/except below is because it's difficult to predict the error
# and/or error message from different combinations of data and dtype.
# Efforts to avoid this try/except welcome.
# See https://github.com/pandas-dev/pandas/pull/41153#discussion_r676206222
try:
subarr = np.array(data, dtype=dtype, copy=copy)
cls._validate_dtype(subarr.dtype)
except (TypeError, ValueError):
raise ValueError(f"data is not compatible with {cls.__name__}")
cls._assert_safe_casting(data, subarr)
else:
subarr = data
Expand All @@ -189,12 +198,24 @@ def _validate_dtype(cls, dtype: Dtype | None) -> None:
)

@classmethod
def _ensure_dtype(
cls,
dtype: Dtype | None,
) -> np.dtype | None:
"""Ensure int64 dtype for Int64Index, etc. Assumed dtype is validated."""
return cls._default_dtype
def _ensure_dtype(cls, dtype: Dtype | None) -> np.dtype | None:
"""
Ensure int64 dtype for Int64Index etc. but allow int32 etc. for NumericIndex.
Assumes dtype has already been validated.
"""
if dtype is None:
return cls._default_dtype

dtype = pandas_dtype(dtype)
assert isinstance(dtype, np.dtype)

if cls._is_backward_compat_public_numeric_index:
# dtype for NumericIndex
return dtype
else:
# dtype for Int64Index, UInt64Index etc. Needed for backwards compat.
return cls._default_dtype

def __contains__(self, key) -> bool:
"""
Expand All @@ -214,8 +235,8 @@ def __contains__(self, key) -> bool:

@doc(Index.astype)
def astype(self, dtype, copy=True):
dtype = pandas_dtype(dtype)
if is_float_dtype(self.dtype):
dtype = pandas_dtype(dtype)
if needs_i8_conversion(dtype):
raise TypeError(
f"Cannot convert Float64Index to dtype {dtype}; integer "
Expand All @@ -225,7 +246,16 @@ def astype(self, dtype, copy=True):
# TODO(jreback); this can change once we have an EA Index type
# GH 13149
arr = astype_nansafe(self._values, dtype=dtype)
return Int64Index(arr, name=self.name)
if isinstance(self, Float64Index):
return Int64Index(arr, name=self.name)
else:
return NumericIndex(arr, name=self.name, dtype=dtype)
elif self._is_backward_compat_public_numeric_index:
# this block is needed so e.g. NumericIndex[int8].astype("int32") returns
# NumericIndex[int32] and not Int64Index with dtype int64.
# When Int64Index etc. are removed from the code base, removed this also.
if not is_extension_array_dtype(dtype) and is_numeric_dtype(dtype):
return self._constructor(self, dtype=dtype, copy=copy)

return super().astype(dtype, copy=copy)

Expand Down Expand Up @@ -335,6 +365,8 @@ class IntegerIndex(NumericIndex):
This is an abstract class for Int64Index, UInt64Index.
"""

_is_backward_compat_public_numeric_index: bool = False

@property
def asi8(self) -> np.ndarray:
# do not cache or you'll create a memory leak
Expand Down Expand Up @@ -399,3 +431,4 @@ class Float64Index(NumericIndex):
_engine_type = libindex.Float64Engine
_default_dtype = np.dtype(np.float64)
_dtype_validation_metadata = (is_float_dtype, "float")
_is_backward_compat_public_numeric_index: bool = False
1 change: 1 addition & 0 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class RangeIndex(NumericIndex):
_engine_type = libindex.Int64Engine
_dtype_validation_metadata = (is_signed_integer_dtype, "signed integer")
_range: range
_is_backward_compat_public_numeric_index: bool = False

# --------------------------------------------------------------------
# Constructors
Expand Down
9 changes: 8 additions & 1 deletion pandas/tests/base/test_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pandas as pd
import pandas._testing as tm
from pandas.core.api import NumericIndex
from pandas.tests.base.common import allow_na_ops


Expand All @@ -24,6 +25,9 @@ def test_unique(index_or_series_obj):
expected = pd.MultiIndex.from_tuples(unique_values)
expected.names = obj.names
tm.assert_index_equal(result, expected, exact=True)
elif isinstance(obj, pd.Index) and obj._is_backward_compat_public_numeric_index:
expected = NumericIndex(unique_values, dtype=obj.dtype)
tm.assert_index_equal(result, expected, exact=True)
elif isinstance(obj, pd.Index):
expected = pd.Index(unique_values, dtype=obj.dtype)
if is_datetime64tz_dtype(obj.dtype):
Expand Down Expand Up @@ -62,7 +66,10 @@ def test_unique_null(null_obj, index_or_series_obj):
unique_values_not_null = [val for val in unique_values_raw if not pd.isnull(val)]
unique_values = [null_obj] + unique_values_not_null

if isinstance(obj, pd.Index):
if isinstance(obj, pd.Index) and obj._is_backward_compat_public_numeric_index:
expected = NumericIndex(unique_values, dtype=obj.dtype)
tm.assert_index_equal(result, expected, exact=True)
elif isinstance(obj, pd.Index):
expected = pd.Index(unique_values, dtype=obj.dtype)
if is_datetime64tz_dtype(obj.dtype):
result = result.normalize()
Expand Down
Loading

0 comments on commit fcadfb8

Please sign in to comment.