diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index f65cb94df293e..53a52a7d1c2a7 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -458,17 +458,26 @@ def maybe_cast_pointwise_result( """ if isinstance(dtype, ExtensionDtype): - if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)): + if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)): # TODO: avoid this special-casing # We have to special case categorical so as not to upcast # things like counts back to categorical - cls = dtype.construct_array_type() if same_dtype: result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) else: result = _maybe_cast_to_extension_array(cls, result) - + elif isinstance(dtype, ArrowDtype): + pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow") + if isinstance(pyarrow_type, ExtensionDtype): + cls = pyarrow_type.construct_array_type() + result = _maybe_cast_to_extension_array(cls, result) + else: + cls = dtype.construct_array_type() + if same_dtype: + result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) + else: + result = _maybe_cast_to_extension_array(cls, result) elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 3558377907931..6fcdee2d898be 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -5,6 +5,7 @@ import functools from functools import partial import re +import typing import numpy as np import pytest @@ -1603,3 +1604,25 @@ def test_agg_with_as_index_false_with_list(): columns=MultiIndex.from_tuples([("a1", ""), ("a2", ""), ("b", "sum")]), ) tm.assert_frame_equal(result, expected) + + +@pytest.mark.skipif( + not typing.TYPE_CHECKING, reason="TYPE_CHECKING must be True to import pyarrow" +) +def test_agg_arrow_type(): + df = DataFrame.from_dict( + { + "category": ["A"] * 10 + ["B"] * 10, + "bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5, + } + ) + df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]") + result = df.groupby("category").agg(lambda x: x.sum() / x.count()) + expected = DataFrame( + { + "bool_numpy": [0.5, 0.5], + "bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values, + }, + index=Index(["A", "B"], name="category"), + ) + tm.assert_frame_equal(result, expected)