From 746e5eee860b6e143c33c9b985e095dac2e42677 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Mon, 16 Oct 2023 11:50:28 -0700 Subject: [PATCH] ENH: EA._from_scalars (#53089) * ENH: BaseStringArray._from_scalars * WIP: EA._from_scalars * ENH: implement EA._from_scalars * Fix StringDtype/CategoricalDtype combine * mypy fixup --- pandas/core/arrays/base.py | 33 +++++++++++++++++++++++++ pandas/core/arrays/categorical.py | 17 +++++++++++++ pandas/core/arrays/datetimes.py | 9 +++++++ pandas/core/arrays/string_.py | 8 ++++++ pandas/core/dtypes/cast.py | 26 +++++++++---------- pandas/core/series.py | 13 ++++++++-- pandas/tests/resample/test_timedelta.py | 2 +- 7 files changed, 91 insertions(+), 17 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index e7b7ecba60e0b..05e6fc09a5ef6 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -86,6 +86,7 @@ AstypeArg, AxisInt, Dtype, + DtypeObj, FillnaOptions, InterpolateOptions, NumpySorter, @@ -293,6 +294,38 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal """ raise AbstractMethodError(cls) + @classmethod + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: + """ + Strict analogue to _from_sequence, allowing only sequences of scalars + that should be specifically inferred to the given dtype. + + Parameters + ---------- + scalars : sequence + dtype : ExtensionDtype + + Raises + ------ + TypeError or ValueError + + Notes + ----- + This is called in a try/except block when casting the result of a + pointwise operation. + """ + try: + return cls._from_sequence(scalars, dtype=dtype, copy=False) + except (ValueError, TypeError): + raise + except Exception: + warnings.warn( + "_from_scalars should only raise ValueError or TypeError. " + "Consider overriding _from_scalars where appropriate.", + stacklevel=find_stack_level(), + ) + raise + @classmethod def _from_sequence_of_strings( cls, strings, *, dtype: Dtype | None = None, copy: bool = False diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index e19635af6ab0b..154d78a8cd85a 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -101,6 +101,7 @@ AstypeArg, AxisInt, Dtype, + DtypeObj, NpDtype, Ordered, Self, @@ -509,6 +510,22 @@ def _from_sequence( ) -> Self: return cls(scalars, dtype=dtype, copy=copy) + @classmethod + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: + if dtype is None: + # The _from_scalars strictness doesn't make much sense in this case. + raise NotImplementedError + + res = cls._from_sequence(scalars, dtype=dtype) + + # if there are any non-category elements in scalars, these will be + # converted to NAs in res. + mask = isna(scalars) + if not (mask == res.isna()).all(): + # Some non-category element in scalars got converted to NA in res. + raise ValueError + return res + @overload def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray: ... diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index a2742aed31e4c..b94cbe9c3fc60 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -77,6 +77,7 @@ from pandas._typing import ( DateTimeErrorChoices, + DtypeObj, IntervalClosedType, Self, TimeAmbiguous, @@ -266,6 +267,14 @@ def _scalar_type(self) -> type[Timestamp]: _freq: BaseOffset | None = None _default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__ + @classmethod + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: + if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]: + # TODO: require any NAs be valid-for-DTA + # TODO: if dtype is passed, check for tzawareness compat? + raise ValueError + return cls._from_sequence(scalars, dtype=dtype) + @classmethod def _validate_dtype(cls, values, dtype): # used in TimeLikeOps.__init__ diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 693ebad0ca16f..ce0f0eec7f37c 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -56,6 +56,7 @@ from pandas._typing import ( AxisInt, Dtype, + DtypeObj, NumpySorter, NumpyValueArrayLike, Scalar, @@ -253,6 +254,13 @@ def tolist(self): return [x.tolist() for x in self] return list(self.to_numpy()) + @classmethod + def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self: + if lib.infer_dtype(scalars, skipna=True) != "string": + # TODO: require any NAs be valid-for-string + raise ValueError + return cls._from_sequence(scalars, dtype=dtype) + # error: Definition of "_concat_same_type" in base class "NDArrayBacked" is # incompatible with definition in base class "ExtensionArray" diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 3208a742738a3..28e83fd70519c 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -464,16 +464,11 @@ def maybe_cast_pointwise_result( """ if isinstance(dtype, ExtensionDtype): - if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)): - # TODO: avoid this special-casing - # We have to special case categorical so as not to upcast - # things like counts back to categorical - - cls = dtype.construct_array_type() - if same_dtype: - result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) - else: - result = _maybe_cast_to_extension_array(cls, result) + cls = dtype.construct_array_type() + if same_dtype: + result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) + else: + result = _maybe_cast_to_extension_array(cls, result) elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) @@ -498,11 +493,14 @@ def _maybe_cast_to_extension_array( ------- ExtensionArray or obj """ - from pandas.core.arrays.string_ import BaseStringArray + result: ArrayLike - # Everything can be converted to StringArrays, but we may not want to convert - if issubclass(cls, BaseStringArray) and lib.infer_dtype(obj) != "string": - return obj + if dtype is not None: + try: + result = cls._from_scalars(obj, dtype=dtype) + except (TypeError, ValueError): + return obj + return result try: result = cls._from_sequence(obj, dtype=dtype) diff --git a/pandas/core/series.py b/pandas/core/series.py index fdd03debf6de4..b02cee3c1fcf3 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -75,7 +75,10 @@ pandas_dtype, validate_all_hashable, ) -from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + ExtensionDtype, +) from pandas.core.dtypes.generic import ABCDataFrame from pandas.core.dtypes.inference import is_hashable from pandas.core.dtypes.missing import ( @@ -100,6 +103,7 @@ from pandas.core.arrays.arrow import StructAccessor from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.arrays.sparse import SparseAccessor +from pandas.core.arrays.string_ import StringDtype from pandas.core.construction import ( extract_array, sanitize_array, @@ -3377,7 +3381,12 @@ def combine( # try_float=False is to match agg_series npvalues = lib.maybe_convert_objects(new_values, try_float=False) - res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False) + # same_dtype here is a kludge to avoid casting e.g. [True, False] to + # ["True", "False"] + same_dtype = isinstance(self.dtype, (StringDtype, CategoricalDtype)) + res_values = maybe_cast_pointwise_result( + npvalues, self.dtype, same_dtype=same_dtype + ) return self._constructor(res_values, index=new_index, name=new_name, copy=False) def combine_first(self, other) -> Series: diff --git a/pandas/tests/resample/test_timedelta.py b/pandas/tests/resample/test_timedelta.py index 606403ba56494..79f9492062063 100644 --- a/pandas/tests/resample/test_timedelta.py +++ b/pandas/tests/resample/test_timedelta.py @@ -103,7 +103,7 @@ def test_resample_categorical_data_with_timedeltaindex(): index=pd.TimedeltaIndex([0, 10], unit="s", freq="10s"), ) expected = expected.reindex(["Group_obj", "Group"], axis=1) - expected["Group"] = expected["Group_obj"] + expected["Group"] = expected["Group_obj"].astype("category") tm.assert_frame_equal(result, expected)