Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: EADtype._find_compatible_dtype #53106

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ExtensionDtype:
* _is_numeric
* _is_boolean
* _get_common_dtype
* _find_compatible_dtype

The `na_value` class attribute can be used to set the default NA value
for this type. :attr:`numpy.nan` is used by default.
Expand Down Expand Up @@ -418,6 +419,25 @@ def index_class(self) -> type_t[Index]:

return Index

def _find_compatible_dtype(self, item: Any) -> tuple[DtypeObj, Any]:
"""
Find the minimal dtype that we need to upcast to in order to hold
the given item.

This is used in when doing Series setitem-with-expansion on a Series
with our dtype.

Parameters
----------
item : object

Returns
-------
np.dtype or ExtensionDtype
object
"""
return np.dtype(object), item


class StorageExtensionDtype(ExtensionDtype):
"""ExtensionDtype that may be backed by more than one implementation."""
Expand Down
17 changes: 7 additions & 10 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
ensure_int16,
ensure_int32,
ensure_int64,
ensure_object,
ensure_str,
is_bool,
is_complex,
Expand Down Expand Up @@ -547,13 +546,13 @@ def ensure_dtype_can_hold_na(dtype: DtypeObj) -> DtypeObj:
}


def maybe_promote(dtype: np.dtype, fill_value=np.nan):
def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
"""
Find the minimal dtype that can hold both the given dtype and fill_value.

Parameters
----------
dtype : np.dtype
dtype : np.dtype or ExtensionDtype
fill_value : scalar, default np.nan

Returns
Expand Down Expand Up @@ -610,9 +609,13 @@ def _maybe_promote_cached(dtype, fill_value, fill_value_type):
return _maybe_promote(dtype, fill_value)


def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
def _maybe_promote(dtype: DtypeObj, fill_value=np.nan):
# The actual implementation of the function, use `maybe_promote` above for
# a cached version.

if not isinstance(dtype, np.dtype):
return dtype._find_compatible_dtype(fill_value)

if not is_scalar(fill_value):
# with object dtype there is nothing to promote, and the user can
# pass pretty much any weird fill_value they like
Expand All @@ -628,12 +631,6 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
fv = na_value_for_dtype(dtype)
return dtype, fv

elif isinstance(dtype, CategoricalDtype):
if fill_value in dtype.categories or isna(fill_value):
return dtype, fill_value
else:
return object, ensure_object(fill_value)

elif isna(fill_value):
dtype = _dtype_obj
if fill_value is None:
Expand Down
47 changes: 47 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,15 @@ def index_class(self) -> type_t[CategoricalIndex]:

return CategoricalIndex

def _find_compatible_dtype(self, item) -> tuple[DtypeObj, Any]:
from pandas.core.dtypes.missing import is_valid_na_for_dtype

if item in self.categories or is_valid_na_for_dtype(
item, self.categories.dtype
):
return self, item
return np.dtype(object), item


@register_extension_dtype
class DatetimeTZDtype(PandasExtensionDtype):
Expand Down Expand Up @@ -1590,6 +1599,18 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
except (KeyError, NotImplementedError):
return None

def _find_compatible_dtype(self, item) -> tuple[DtypeObj, Any]:
from pandas.core.dtypes.cast import maybe_promote
from pandas.core.dtypes.missing import is_valid_na_for_dtype

if is_valid_na_for_dtype(item, self):
return self, item

dtype, item = maybe_promote(self.numpy_dtype, item)
if dtype.kind in "iufb":
return type(self).from_numpy_dtype(dtype), item
return dtype, item


@register_extension_dtype
class SparseDtype(ExtensionDtype):
Expand Down Expand Up @@ -2326,3 +2347,29 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray):
array_class = self.construct_array_type()
arr = array.cast(self.pyarrow_dtype, safe=True)
return array_class(arr)

def _find_compatible_dtype(self, item: Any) -> tuple[DtypeObj, Any]:
if isinstance(item, pa.Scalar):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if item is null? The pyarrow null

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC pyarrow nulls are now typed. id prefer to be strict about making these match, but dont care that much. am hoping @jorisvandenbossche will weigh in

if not item.is_valid:
# TODO: ask joris for help making these checks more robust
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jorisvandenbossche any thoughts here? (not time-sensitive)

if item.type == self.pyarrow_dtype:
return self, item.as_py()
if item.type.to_pandas_dtype() == np.int64 and self.kind == "i":
# FIXME: kludge
return self, item.as_py()

item = item.as_py()

elif item is None or item is libmissing.NA:
# TODO: np.nan? use is_valid_na_for_dtype
return self, item

from pandas.core.dtypes.cast import maybe_promote

dtype, item = maybe_promote(self.numpy_dtype, item)

if dtype == self.numpy_dtype:
return self, item

# TODO: implement from_numpy_dtype analogous to MaskedDtype.from_numpy_dtype
return np.dtype(object), item
27 changes: 5 additions & 22 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@
ABCDataFrame,
ABCSeries,
)
from pandas.core.dtypes.missing import (
infer_fill_value,
is_valid_na_for_dtype,
isna,
na_value_for_dtype,
)
from pandas.core.dtypes.missing import infer_fill_value

from pandas.core import algorithms as algos
import pandas.core.common as com
Expand Down Expand Up @@ -2165,26 +2160,14 @@ def _setitem_with_indexer_missing(self, indexer, value):
return self._setitem_with_indexer(new_indexer, value, "loc")

# this preserves dtype of the value and of the object
if not is_scalar(value):
new_dtype = None

elif is_valid_na_for_dtype(value, self.obj.dtype):
if not is_object_dtype(self.obj.dtype):
# Every NA value is suitable for object, no conversion needed
value = na_value_for_dtype(self.obj.dtype, compat=False)

new_dtype = maybe_promote(self.obj.dtype, value)[0]

elif isna(value):
new_dtype = None
new_dtype = None
if is_list_like(value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for ArrowDtype with pa.list_ type, we would want to treat value like a scalar e.g

ser = pd.Series([[1, 1]], dtype=pd.ArrowDtype(pa.list_(pa.int64())))
ser[0] = [1, 2]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah, getting rid of this is_list_like check causes us to incorrectly raise on numpy non-object cases when using a list value (for which we don't have any tests). Can fix that in this PR or separately, as it is a bit more invasive.

pass
elif not self.obj.empty and not is_object_dtype(self.obj.dtype):
# We should not cast, if we have object dtype because we can
# set timedeltas into object series
curr_dtype = self.obj.dtype
curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype)
new_dtype = maybe_promote(curr_dtype, value)[0]
else:
new_dtype = None
new_dtype, value = maybe_promote(curr_dtype, value)

new_values = Series([value], dtype=new_dtype)._values

Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ def construct_array_type(cls) -> type_t[DecimalArray]:
def _is_numeric(self) -> bool:
return True

def _find_compatible_dtype(self, item):
if isinstance(item, decimal.Decimal):
# TODO: need to watch out for precision?
# TODO: allow non-decimal numeric?
# TODO: allow np.nan, pd.NA?
return (self, item)
return (np.dtype(object), item)


class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
__array_priority__ = 1000
Expand Down
9 changes: 9 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,12 @@ def test_array_copy_on_write(using_copy_on_write):
{"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype=DecimalDtype()
)
tm.assert_equal(df2.values, expected.values)


def test_setitem_with_expansion():
# GH#32346 dont upcast to object
arr = DecimalArray(make_data())
ser = pd.Series(arr[:3])

ser[3] = ser[0]
assert ser.dtype == arr.dtype
23 changes: 23 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2754,6 +2754,29 @@ def test_describe_timedelta_data(pa_type):
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"value, target_value, dtype",
[
(pa.scalar(4, type="int32"), 4, "int32[pyarrow]"),
(pa.scalar(4, type="int64"), 4, "int32[pyarrow]"),
# (pa.scalar(4.5, type="float64"), 4, "int32[pyarrow]"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here?

Also what happens with a int64 scalar and int32 dtype?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id want to follow the same logic we do for numpy dtypes, but was punting here in expectation of doing it in a follow-up (likely involving joris expressing an opinion)

(4, 4, "int32[pyarrow]"),
(pd.NA, None, "int32[pyarrow]"),
(None, None, "int32[pyarrow]"),
(pa.scalar(None, type="int32"), None, "int32[pyarrow]"),
(pa.scalar(None, type="int64"), None, "int32[pyarrow]"),
],
)
def test_series_setitem_with_enlargement(value, target_value, dtype):
# GH#52235
# similar to series/inedexing/test_setitem.py::test_setitem_keep_precision
# and test_setitem_enlarge_with_na, but for arrow dtypes
ser = pd.Series([1, 2, 3], dtype=dtype)
ser[3] = value
expected = pd.Series([1, 2, 3, target_value], dtype=dtype)
tm.assert_series_equal(ser, expected)


@pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES)
def test_describe_datetime_data(pa_type):
# GH53001
Expand Down
22 changes: 22 additions & 0 deletions pandas/tests/indexing/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,6 +1966,28 @@ def test_loc_drops_level(self):


class TestLocSetitemWithExpansion:
def test_series_loc_setitem_with_expansion_categorical(self):
# NA value that can't be held in integer categories
ser = Series([1, 2, 3], dtype="category")
ser.loc[3] = pd.NaT
assert ser.dtype == object

def test_series_loc_setitem_with_expansion_interval(self):
idx = pd.interval_range(1, 3)
ser2 = Series(idx)
ser2.loc[2] = np.nan
assert ser2.dtype == "interval[float64, right]"

ser2.loc[3] = pd.NaT
assert ser2.dtype == object

def test_series_loc_setitem_with_expansion_list_object(self):
ser3 = Series(range(3))
ser3.loc[3] = []
assert ser3.dtype == object
item = ser3.loc[3]
assert isinstance(item, list) and len(item) == 0

@pytest.mark.slow
def test_loc_setitem_with_expansion_large_dataframe(self):
# GH#10692
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/series/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,11 @@ def test_setitem_enlargement_object_none(self, nulls_fixture):
ser[3] = nulls_fixture
expected = Series(["a", "b", nulls_fixture], index=[0, 1, 3])
tm.assert_series_equal(ser, expected)
assert ser[3] is nulls_fixture
if isinstance(nulls_fixture, float):
# We retain the same type, but maybe not the same _object_
assert np.isnan(ser[3])
else:
assert ser[3] is nulls_fixture


def test_setitem_scalar_into_readonly_backing_data():
Expand Down