From 3ba1a4bc3e86b54cac47a1c46177fe0de4d6b41e Mon Sep 17 00:00:00 2001 From: Liang Yan Date: Fri, 7 Jul 2023 16:08:09 +0800 Subject: [PATCH] Fix convert_dtype issue. Signed-off-by: Liang Yan --- pandas/core/dtypes/cast.py | 20 ++++++++----------- .../tests/groupby/aggregate/test_aggregate.py | 6 ++---- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index e53dda15f9954..a8ecd2d98f1eb 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -461,27 +461,23 @@ def maybe_cast_pointwise_result( """ if isinstance(dtype, ExtensionDtype): - if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)): + if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)): # TODO: avoid this special-casing # We have to special case categorical so as not to upcast # things like counts back to categorical - - cls = dtype.construct_array_type() - if same_dtype: - result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) - else: - result = _maybe_cast_to_extension_array(cls, result) - elif isinstance(dtype, ArrowDtype): - pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow") - if isinstance(pyarrow_type, ExtensionDtype): - cls = pyarrow_type.construct_array_type() - result = _maybe_cast_to_extension_array(cls, result) + if isinstance(dtype, ArrowDtype): + pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow") else: + pyarrow_type = np.dtype("object") + if not isinstance(pyarrow_type, ExtensionDtype): cls = dtype.construct_array_type() if same_dtype: result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) else: result = _maybe_cast_to_extension_array(cls, result) + else: + cls = pyarrow_type.construct_array_type() + result = _maybe_cast_to_extension_array(cls, result) elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 6836a43569f77..c96411e8fdb93 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -5,7 +5,6 @@ import functools from functools import partial import re -import typing import numpy as np import pytest @@ -25,6 +24,7 @@ ) import pandas._testing as tm from pandas.core.groupby.grouper import Grouping +from pandas.tests.arrays.string_.test_string_arrow import skip_if_no_pyarrow def test_groupby_agg_no_extra_calls(): @@ -1610,9 +1610,7 @@ def test_agg_with_as_index_false_with_list(): tm.assert_frame_equal(result, expected) -# @pytest.mark.skipif( -# not typing.TYPE_CHECKING, reason="let pyarrow to be imported in dtypes.py" -# ) +@skip_if_no_pyarrow def test_agg_arrow_type(): df = DataFrame.from_dict( {