Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF (string): move ArrowStringArrayNumpySemantics methods to base class #59501

Merged
merged 8 commits into from
Aug 15, 2024
15 changes: 13 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,14 @@ def _cmp_method(self, other, op) -> ArrowExtensionArray:
raise NotImplementedError(
f"{op.__name__} not implemented for {type(other)}"
)
return ArrowExtensionArray(result)
result = ArrowExtensionArray(result)
if self.dtype.na_value is np.nan: # type: ignore[comparison-overlap]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semantically wouldn't this be better defined in ArrowStringArray? Since self.dtype must not be a ArrowDtype here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either way works; i preferred this since it was maximally high in the inheritance chain

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I would prefer this in ArrowStringArray since this logic is specific to that subclass.

# i.e. ArrowStringArray with Numpy Semantics
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
return result.to_numpy(np.bool_, na_value=False)
return result

def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:
pa_type = self._pa_array.type
Expand Down Expand Up @@ -1529,7 +1536,11 @@ def value_counts(self, dropna: bool = True) -> Series:

index = Index(type(self)(values))

return Series(counts, index=index, name="count", copy=False)
result = Series(counts, index=index, name="count", copy=False)
if self.dtype.na_value is np.nan: # type: ignore[comparison-overlap]
# i.e. ArrowStringArray with Numpy Semantics
return Series(counts.to_numpy(), index=index, name="count", copy=False)
return result

@classmethod
def _concat_same_type(cls, to_concat) -> Self:
Expand Down
100 changes: 33 additions & 67 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from functools import partial
import operator
import re
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -62,8 +60,6 @@

from pandas.core.dtypes.dtypes import ExtensionDtype

from pandas import Series


ArrowStringScalarOrNAT = Union[str, libmissing.NAType]

Expand Down Expand Up @@ -216,12 +212,17 @@ def dtype(self) -> StringDtype: # type: ignore[override]
return self._dtype

def insert(self, loc: int, item) -> ArrowStringArray:
if self.dtype.na_value is np.nan and item is np.nan:
item = libmissing.NA
if not isinstance(item, str) and item is not libmissing.NA:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

@classmethod
def _result_converter(cls, values, na=None):
def _result_converter(self, values, na=None):
if self.dtype.na_value is np.nan:
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
Expand Down Expand Up @@ -496,11 +497,30 @@ def _str_get_dummies(self, sep: str = "|"):
return dummies.astype(np.int64, copy=False), labels

def _convert_int_dtype(self, result):
if self.dtype.na_value is np.nan:
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
else:
result = result.to_numpy()
if result.dtype == np.int32:
result = result.astype(np.int64)
return result

return Int64Dtype().__from_arrow__(result)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
if self.dtype.na_value is np.nan and name in ["any", "all"]:
if not skipna:
nas = pc.is_null(self._pa_array)
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
else:
arr = pc.not_equal(self._pa_array, "")
return ArrowExtensionArray(arr)._reduce(
name, skipna=skipna, keepdims=keepdims, **kwargs
)

result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_dtype(result)
Expand Down Expand Up @@ -533,65 +553,11 @@ def _rank(


class ArrowStringArrayNumpySemantics(ArrowStringArray):
_storage = "pyarrow"
_na_value = np.nan

@classmethod
def _result_converter(cls, values, na=None):
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)

def __getattribute__(self, item):
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
# creates inheritance problems (Diamond inheritance)
if item in ArrowStringArrayMixin.__dict__ and item not in (
"_pa_array",
"__dict__",
):
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _convert_int_dtype(self, result):
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
else:
result = result.to_numpy()
if result.dtype == np.int32:
result = result.astype(np.int64)
return result

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True) -> Series:
from pandas import Series

result = super().value_counts(dropna)
return Series(
result._values.to_numpy(), index=result.index, name=result.name, copy=False
)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
if name in ["any", "all"]:
if not skipna:
nas = pc.is_null(self._pa_array)
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
else:
arr = pc.not_equal(self._pa_array, "")
return ArrowExtensionArray(arr)._reduce(
name, skipna=skipna, keepdims=keepdims, **kwargs
)
else:
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)

def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics:
if item is np.nan:
item = libmissing.NA
return super().insert(loc, item) # type: ignore[return-value]
_str_get = ArrowStringArrayMixin._str_get
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_pad = ArrowStringArrayMixin._str_pad
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace