Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: EA._from_scalars #53089

Merged
merged 18 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
cast,
overload,
)
import warnings

import numpy as np

Expand All @@ -33,6 +34,7 @@
Substitution,
cache_readonly,
)
from pandas.util._exceptions import find_stack_level
from pandas.util._validators import (
validate_bool_kwarg,
validate_fillna_kwargs,
Expand Down Expand Up @@ -83,6 +85,7 @@
AstypeArg,
AxisInt,
Dtype,
DtypeObj,
FillnaOptions,
InterpolateOptions,
NumpySorter,
Expand Down Expand Up @@ -280,6 +283,38 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal
"""
raise AbstractMethodError(cls)

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking aloud. Once we allow EA authors to specify their "is_scalar" passing equivalent, should that be incorporated here somehow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Assuming at some point we add to the EADtype something like a _is_recognized_scalar method, the default implementation of that method might look like (assuming away NAs for the moment)

def _is_recognized_scalar(self, scalar) -> bool:
    cls = self.construct_array_class()
    try:
        cls._from_scalars([scalar], dtype=self)
    except TypeError:
        return False
    return True

We could alternately implement a default from_scalars in terms of _is_recognized_scalar. _find_compatible_dtype from #53106 could also be the primitive from which the other two are constructed (though i dont think it can be constructed from either).

"""
Strict analogue to _from_sequence, allowing only sequences of scalars
that should be specifically inferred to the given dtype.

Parameters
----------
scalars : sequence
dtype : ExtensionDtype

Raises
------
TypeError or ValueError

Notes
-----
This is called in a try/except block when casting the result of a
pointwise operation.
"""
try:
return cls._from_sequence(scalars, dtype=dtype, copy=False)
except (ValueError, TypeError):
raise
except Exception:
warnings.warn(
"_from_scalars should only raise ValueError or TypeError. "
"Consider overriding _from_scalars where appropriate.",
stacklevel=find_stack_level(),
)
raise

@classmethod
def _from_sequence_of_strings(
cls, strings, *, dtype: Dtype | None = None, copy: bool = False
Expand Down
17 changes: 17 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
AstypeArg,
AxisInt,
Dtype,
DtypeObj,
NpDtype,
Ordered,
Self,
Expand Down Expand Up @@ -509,6 +510,22 @@ def _from_sequence(
) -> Self:
return cls(scalars, dtype=dtype, copy=copy)

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
if dtype is None:
# The _from_scalars strictness doesn't make much sense in this case.
raise NotImplementedError

res = cls._from_sequence(scalars, dtype=dtype)

# if there are any non-category elements in scalars, these will be
# converted to NAs in res.
mask = isna(scalars)
if not (mask == res.isna()).all():
# Some non-category element in scalars got converted to NA in res.
raise ValueError
return res

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...
Expand Down
9 changes: 9 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@

from pandas._typing import (
DateTimeErrorChoices,
DtypeObj,
IntervalClosedType,
Self,
TimeAmbiguous,
Expand Down Expand Up @@ -263,6 +264,14 @@ def _scalar_type(self) -> type[Timestamp]:
_freq: BaseOffset | None = None
_default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]:
# TODO: require any NAs be valid-for-DTA
# TODO: if dtype is passed, check for tzawareness compat?
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

@classmethod
def _validate_dtype(cls, values, dtype):
# used in TimeLikeOps.__init__
Expand Down
9 changes: 9 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@
from pandas._typing import (
AxisInt,
Dtype,
DtypeObj,
NumpySorter,
NumpyValueArrayLike,
Scalar,
Self,
npt,
type_t,
)
Expand Down Expand Up @@ -228,6 +230,13 @@ def tolist(self):
return [x.tolist() for x in self]
return list(self.to_numpy())

@classmethod
def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
if lib.infer_dtype(scalars, skipna=True) != "string":
# TODO: require any NAs be valid-for-string
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)


# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
Expand Down
26 changes: 12 additions & 14 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,11 @@ def maybe_cast_pointwise_result(
"""

if isinstance(dtype, ExtensionDtype):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
# TODO: avoid this special-casing
# We have to special case categorical so as not to upcast
# things like counts back to categorical

cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)
cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)

elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)
Expand All @@ -495,11 +490,14 @@ def _maybe_cast_to_extension_array(
-------
ExtensionArray or obj
"""
from pandas.core.arrays.string_ import BaseStringArray
result: ArrayLike

# Everything can be converted to StringArrays, but we may not want to convert
if issubclass(cls, BaseStringArray) and lib.infer_dtype(obj) != "string":
return obj
if dtype is not None:
try:
result = cls._from_scalars(obj, dtype=dtype)
except (TypeError, ValueError):
return obj
return result

try:
result = cls._from_sequence(obj, dtype=dtype)
Expand Down
9 changes: 8 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
CategoricalDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import ABCDataFrame
Expand All @@ -102,6 +103,7 @@
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
from pandas.core.construction import (
extract_array,
sanitize_array,
Expand Down Expand Up @@ -3331,7 +3333,12 @@ def combine(

# try_float=False is to match agg_series
npvalues = lib.maybe_convert_objects(new_values, try_float=False)
res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False)
# same_dtype here is a kludge to avoid casting e.g. [True, False] to
# ["True", "False"]
same_dtype = isinstance(self.dtype, (StringDtype, CategoricalDtype))
res_values = maybe_cast_pointwise_result(
npvalues, self.dtype, same_dtype=same_dtype
)
return self._constructor(res_values, index=new_index, name=new_name, copy=False)

def combine_first(self, other) -> Series:
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/resample/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_resample_categorical_data_with_timedeltaindex():
index=pd.TimedeltaIndex([0, 10], unit="s", freq="10s"),
)
expected = expected.reindex(["Group_obj", "Group"], axis=1)
expected["Group"] = expected["Group_obj"]
expected["Group"] = expected["Group_obj"].astype("category")
tm.assert_frame_equal(result, expected)


Expand Down