Skip to content

Commit

Permalink
BUG: Aggregation on arrow array return same type.
Browse files Browse the repository at this point in the history
Signed-off-by: Liang Yan <ckgppl_yan@sina.cn>
  • Loading branch information
liang3zy22 committed Jun 20, 2023
1 parent d06f2d3 commit 73ded8b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
8 changes: 7 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
if TYPE_CHECKING:
from pandas.core.generic import NDFrame

from pandas.arrays import ArrowExtensionArray


def check_result_array(obj, dtype):
# Our operation is supposed to be an aggregation/reduction. If
Expand Down Expand Up @@ -837,7 +839,11 @@ def agg_series(
# test_groupby_empty_with_category gets here with self.ngroups == 0
# and len(obj) > 0

if len(obj) > 0 and not isinstance(obj._values, np.ndarray):
if (
len(obj) > 0
and not isinstance(obj._values, np.ndarray)
and not isinstance(obj._values, ArrowExtensionArray)
):
# we can preserve a little bit more aggressively with EA dtype
# because maybe_cast_pointwise_result will do a try/except
# with _from_sequence. NB we are assuming here that _from_sequence
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
from functools import partial
import re
import typing

import numpy as np
import pytest
Expand Down Expand Up @@ -1603,3 +1604,23 @@ def test_agg_with_as_index_false_with_list():
columns=MultiIndex.from_tuples([("a1", ""), ("a2", ""), ("b", "sum")]),
)
tm.assert_frame_equal(result, expected)


@pytest.mark.skipif(
typing.TYPE_CHECKING, reason="TYPE_CHECKING must be True to import pyarrow"
)
def test_agg_arrow_type():
df = DataFrame.from_dict(
{
"category": ["A"] * 10 + ["B"] * 10,
"bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5,
}
)
df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]")
result = df.groupby("category").agg(lambda x: x.sum() / x.count())
expected = DataFrame(
[[0.5, 0.5], [0.5, 0.5]],
columns=["bool_numpy", "bool_arrow"],
index=Index(["A", "B"], name="category"),
)
tm.assert_frame_equal(result, expected)

0 comments on commit 73ded8b

Please sign in to comment.