Skip to content

Commit

Permalink
BUG: Aggregation on arrow array return same type.
Browse files Browse the repository at this point in the history
Signed-off-by: Liang Yan <ckgppl_yan@sina.cn>
  • Loading branch information
liang3zy22 committed Jun 21, 2023
1 parent f0d3301 commit d979c10
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
15 changes: 12 additions & 3 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,26 @@ def maybe_cast_pointwise_result(
"""

if isinstance(dtype, ExtensionDtype):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)):
# TODO: avoid this special-casing
# We have to special case categorical so as not to upcast
# things like counts back to categorical

cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)

elif isinstance(dtype, ArrowDtype):
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
if isinstance(pyarrow_type, ExtensionDtype):
cls = pyarrow_type.construct_array_type()
result = _maybe_cast_to_extension_array(cls, result)
else:
cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)

Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
from functools import partial
import re
import typing

import numpy as np
import pytest
Expand Down Expand Up @@ -1603,3 +1604,25 @@ def test_agg_with_as_index_false_with_list():
columns=MultiIndex.from_tuples([("a1", ""), ("a2", ""), ("b", "sum")]),
)
tm.assert_frame_equal(result, expected)


@pytest.mark.skipif(
not typing.TYPE_CHECKING, reason="TYPE_CHECKING must be True to import pyarrow"
)
def test_agg_arrow_type():
df = DataFrame.from_dict(
{
"category": ["A"] * 10 + ["B"] * 10,
"bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5,
}
)
df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]")
result = df.groupby("category").agg(lambda x: x.sum() / x.count())
expected = DataFrame(
{
"bool_numpy": [0.5, 0.5],
"bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values,
},
index=Index(["A", "B"], name="category"),
)
tm.assert_frame_equal(result, expected)

0 comments on commit d979c10

Please sign in to comment.