Skip to content

Commit

Permalink
Fix convert_dtype issue.
Browse files Browse the repository at this point in the history
Signed-off-by: Liang Yan <ckgppl_yan@sina.cn>
  • Loading branch information
liang3zy22 committed Jul 19, 2023
1 parent 0a18797 commit 3ba1a4b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
20 changes: 8 additions & 12 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,27 +461,23 @@ def maybe_cast_pointwise_result(
"""

if isinstance(dtype, ExtensionDtype):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
# TODO: avoid this special-casing
# We have to special case categorical so as not to upcast
# things like counts back to categorical

cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)
elif isinstance(dtype, ArrowDtype):
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
if isinstance(pyarrow_type, ExtensionDtype):
cls = pyarrow_type.construct_array_type()
result = _maybe_cast_to_extension_array(cls, result)
if isinstance(dtype, ArrowDtype):
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
else:
pyarrow_type = np.dtype("object")
if not isinstance(pyarrow_type, ExtensionDtype):
cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)
else:
cls = pyarrow_type.construct_array_type()
result = _maybe_cast_to_extension_array(cls, result)
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)

Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import functools
from functools import partial
import re
import typing

import numpy as np
import pytest
Expand All @@ -25,6 +24,7 @@
)
import pandas._testing as tm
from pandas.core.groupby.grouper import Grouping
from pandas.tests.arrays.string_.test_string_arrow import skip_if_no_pyarrow


def test_groupby_agg_no_extra_calls():
Expand Down Expand Up @@ -1610,9 +1610,7 @@ def test_agg_with_as_index_false_with_list():
tm.assert_frame_equal(result, expected)


# @pytest.mark.skipif(
# not typing.TYPE_CHECKING, reason="let pyarrow to be imported in dtypes.py"
# )
@skip_if_no_pyarrow
def test_agg_arrow_type():
df = DataFrame.from_dict(
{
Expand Down

0 comments on commit 3ba1a4b

Please sign in to comment.