Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: EADtype._find_compatible_dtype #53106

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
)
from decimal import Decimal
import re
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
)

import numpy as np

from pandas._libs import missing as libmissing
from pandas._libs.tslibs import (
Timedelta,
Timestamp,
Expand All @@ -23,6 +27,7 @@
StorageExtensionDtype,
register_extension_dtype,
)
from pandas.core.dtypes.cast import maybe_promote
from pandas.core.dtypes.dtypes import CategoricalDtypeType

if not pa_version_under7p0:
Expand Down Expand Up @@ -321,3 +326,27 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray):
array_class = self.construct_array_type()
arr = array.cast(self.pyarrow_dtype, safe=True)
return array_class(arr)

def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]:
mroeschke marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(item, pa.Scalar):
if not item.is_valid:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't NA always be able to be inserted into ArrowExtensionArray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure. pyarrow nulls are typed, so we could plausibly want to disallow e.g. <pyarrow.TimestampScalar: None> in a pyarrow integer dtype

# TODO: ask joris for help making these checks more robust
if item.type == self.pyarrow_dtype:
return self, item.as_py()
if item.type.to_pandas_dtype() == np.int64 and self.kind == "i":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed specifically?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just to get the tests copied from #52833 passing.

# FIXME: kludge
return self, item.as_py()

item = item.as_py()

elif item is None or item is libmissing.NA:
# TODO: np.nan? use is_valid_na_for_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since pyarrow supports nan vs NA, possibly we want to allow nan if pa.types.is_floating(self.pyarrow_dtype)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what to do here depends on making a decision about when/how to distinguish between np.nan and pd.NA (which i hope to finally nail down at the sprint). doing this The Right Way would involve something like implementing EA._is_valid_na_for_dtype

return self, item

dtype, item = maybe_promote(self.numpy_dtype, item)

if dtype == self.numpy_dtype:
return self, item

# TODO: implement from_numpy_dtype analogous to MaskedDtype.from_numpy_dtype
return np.dtype(object), item
20 changes: 20 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ExtensionDtype:
* _is_numeric
* _is_boolean
* _get_common_dtype
* _maybe_promote

The `na_value` class attribute can be used to set the default NA value
for this type. :attr:`numpy.nan` is used by default.
Expand Down Expand Up @@ -391,6 +392,25 @@ def _can_hold_na(self) -> bool:
"""
return True

def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]:
"""
Find the minimal dtype that we need to upcast to in order to hold
the given item.

This is used in when doing Series setitem-with-expansion on a Series
with our dtype.

Parameters
----------
item : object

Returns
-------
np.dtype or ExtensionDtype
object
"""
return np.dtype(object), item


class StorageExtensionDtype(ExtensionDtype):
"""ExtensionDtype that may be backed by more than one implementation."""
Expand Down
17 changes: 7 additions & 10 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
ensure_int16,
ensure_int32,
ensure_int64,
ensure_object,
ensure_str,
is_bool,
is_complex,
Expand Down Expand Up @@ -539,13 +538,13 @@ def ensure_dtype_can_hold_na(dtype: DtypeObj) -> DtypeObj:
}


def maybe_promote(dtype: np.dtype, fill_value=np.nan):
def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
"""
Find the minimal dtype that can hold both the given dtype and fill_value.

Parameters
----------
dtype : np.dtype
dtype : np.dtype or ExtensionDtype
fill_value : scalar, default np.nan

Returns
Expand Down Expand Up @@ -593,9 +592,13 @@ def _maybe_promote_cached(dtype, fill_value, fill_value_type):
return _maybe_promote(dtype, fill_value)


def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
def _maybe_promote(dtype: DtypeObj, fill_value=np.nan):
# The actual implementation of the function, use `maybe_promote` above for
# a cached version.

if not isinstance(dtype, np.dtype):
return dtype._maybe_promote(fill_value)

if not is_scalar(fill_value):
# with object dtype there is nothing to promote, and the user can
# pass pretty much any weird fill_value they like
Expand All @@ -611,12 +614,6 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
fv = na_value_for_dtype(dtype)
return dtype, fv

elif isinstance(dtype, CategoricalDtype):
if fill_value in dtype.categories or isna(fill_value):
return dtype, fill_value
else:
return object, ensure_object(fill_value)

elif isna(fill_value):
dtype = _dtype_obj
if fill_value is None:
Expand Down
21 changes: 21 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,15 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:

return find_common_type(non_cat_dtypes)

def _maybe_promote(self, item) -> tuple[DtypeObj, Any]:
from pandas.core.dtypes.missing import is_valid_na_for_dtype

if item in self.categories or is_valid_na_for_dtype(
item, self.categories.dtype
):
return self, item
return np.dtype(object), item


@register_extension_dtype
class DatetimeTZDtype(PandasExtensionDtype):
Expand Down Expand Up @@ -1500,3 +1509,15 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
return type(self).from_numpy_dtype(new_dtype)
except (KeyError, NotImplementedError):
return None

def _maybe_promote(self, item) -> tuple[DtypeObj, Any]:
from pandas.core.dtypes.cast import maybe_promote
from pandas.core.dtypes.missing import is_valid_na_for_dtype

if is_valid_na_for_dtype(item, self):
return self, item

dtype, item = maybe_promote(self.numpy_dtype, item)
if dtype.kind in "iufb":
return type(self).from_numpy_dtype(dtype), item
return dtype, item
27 changes: 5 additions & 22 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,7 @@
ABCDataFrame,
ABCSeries,
)
from pandas.core.dtypes.missing import (
infer_fill_value,
is_valid_na_for_dtype,
isna,
na_value_for_dtype,
)
from pandas.core.dtypes.missing import infer_fill_value

from pandas.core import algorithms as algos
import pandas.core.common as com
Expand Down Expand Up @@ -2091,26 +2086,14 @@ def _setitem_with_indexer_missing(self, indexer, value):
return self._setitem_with_indexer(new_indexer, value, "loc")

# this preserves dtype of the value and of the object
if not is_scalar(value):
new_dtype = None

elif is_valid_na_for_dtype(value, self.obj.dtype):
if not is_object_dtype(self.obj.dtype):
# Every NA value is suitable for object, no conversion needed
value = na_value_for_dtype(self.obj.dtype, compat=False)

new_dtype = maybe_promote(self.obj.dtype, value)[0]

elif isna(value):
new_dtype = None
new_dtype = None
if is_list_like(value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for ArrowDtype with pa.list_ type, we would want to treat value like a scalar e.g

ser = pd.Series([[1, 1]], dtype=pd.ArrowDtype(pa.list_(pa.int64())))
ser[0] = [1, 2]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah, getting rid of this is_list_like check causes us to incorrectly raise on numpy non-object cases when using a list value (for which we don't have any tests). Can fix that in this PR or separately, as it is a bit more invasive.

pass
elif not self.obj.empty and not is_object_dtype(self.obj.dtype):
# We should not cast, if we have object dtype because we can
# set timedeltas into object series
curr_dtype = self.obj.dtype
curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype)
new_dtype = maybe_promote(curr_dtype, value)[0]
else:
new_dtype = None
new_dtype, value = maybe_promote(curr_dtype, value)

new_values = Series([value], dtype=new_dtype)._values

Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2855,6 +2855,29 @@ def test_describe_timedelta_data(pa_type):
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"value, target_value, dtype",
[
(pa.scalar(4, type="int32"), 4, "int32[pyarrow]"),
(pa.scalar(4, type="int64"), 4, "int32[pyarrow]"),
# (pa.scalar(4.5, type="float64"), 4, "int32[pyarrow]"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here?

Also what happens with a int64 scalar and int32 dtype?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id want to follow the same logic we do for numpy dtypes, but was punting here in expectation of doing it in a follow-up (likely involving joris expressing an opinion)

(4, 4, "int32[pyarrow]"),
(pd.NA, None, "int32[pyarrow]"),
(None, None, "int32[pyarrow]"),
(pa.scalar(None, type="int32"), None, "int32[pyarrow]"),
(pa.scalar(None, type="int64"), None, "int32[pyarrow]"),
],
)
def test_series_setitem_with_enlargement(value, target_value, dtype):
# GH#52235
# similar to series/inedexing/test_setitem.py::test_setitem_keep_precision
# and test_setitem_enlarge_with_na, but for arrow dtypes
ser = pd.Series([1, 2, 3], dtype=dtype)
ser[3] = value
expected = pd.Series([1, 2, 3, target_value], dtype=dtype)
tm.assert_series_equal(ser, expected)


@pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES)
def test_describe_datetime_data(pa_type):
# GH53001
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/series/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,11 @@ def test_setitem_enlargement_object_none(self, nulls_fixture):
ser[3] = nulls_fixture
expected = Series(["a", "b", nulls_fixture], index=[0, 1, 3])
tm.assert_series_equal(ser, expected)
assert ser[3] is nulls_fixture
if isinstance(nulls_fixture, float):
# We retain the same type, but maybe not the same _object_
assert np.isnan(ser[3])
else:
assert ser[3] is nulls_fixture


def test_setitem_scalar_into_readonly_backing_data():
Expand Down