Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF (string): move ArrowStringArrayNumpySemantics methods to base class #59501

Merged
merged 8 commits into from
Aug 15, 2024
109 changes: 48 additions & 61 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from functools import partial
import operator
import re
from typing import (
Expand Down Expand Up @@ -216,12 +215,17 @@ def dtype(self) -> StringDtype: # type: ignore[override]
return self._dtype

def insert(self, loc: int, item) -> ArrowStringArray:
if self.dtype.na_value is np.nan and item is np.nan:
item = libmissing.NA
if not isinstance(item, str) and item is not libmissing.NA:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

@classmethod
def _result_converter(cls, values, na=None):
def _result_converter(self, values, na=None):
if self.dtype.na_value is np.nan:
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
Expand Down Expand Up @@ -496,11 +500,30 @@ def _str_get_dummies(self, sep: str = "|"):
return dummies.astype(np.int64, copy=False), labels

def _convert_int_dtype(self, result):
if self.dtype.na_value is np.nan:
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
else:
result = result.to_numpy()
if result.dtype == np.int32:
result = result.astype(np.int64)
return result

return Int64Dtype().__from_arrow__(result)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
if self.dtype.na_value is np.nan and name in ["any", "all"]:
if not skipna:
nas = pc.is_null(self._pa_array)
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
else:
arr = pc.not_equal(self._pa_array, "")
return ArrowExtensionArray(arr)._reduce(
name, skipna=skipna, keepdims=keepdims, **kwargs
)

result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_dtype(result)
Expand Down Expand Up @@ -531,67 +554,31 @@ def _rank(
)
)


class ArrowStringArrayNumpySemantics(ArrowStringArray):
_storage = "pyarrow"
_na_value = np.nan

@classmethod
def _result_converter(cls, values, na=None):
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)

def __getattribute__(self, item):
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
# creates inheritance problems (Diamond inheritance)
if item in ArrowStringArrayMixin.__dict__ and item not in (
"_pa_array",
"__dict__",
):
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _convert_int_dtype(self, result):
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
else:
result = result.to_numpy()
if result.dtype == np.int32:
result = result.astype(np.int64)
def value_counts(self, dropna: bool = True) -> Series:
result = super().value_counts(dropna=dropna)
if self.dtype.na_value is np.nan:
res_values = result._values.to_numpy()
return result._constructor(
res_values, index=result.index, name=result.name, copy=False
)
return result

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True) -> Series:
from pandas import Series

result = super().value_counts(dropna)
return Series(
result._values.to_numpy(), index=result.index, name=result.name, copy=False
)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
if name in ["any", "all"]:
if not skipna:
nas = pc.is_null(self._pa_array)
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
if self.dtype.na_value is np.nan:
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
arr = pc.not_equal(self._pa_array, "")
return ArrowExtensionArray(arr)._reduce(
name, skipna=skipna, keepdims=keepdims, **kwargs
)
else:
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
return result.to_numpy(np.bool_, na_value=False)
return result

def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics:
if item is np.nan:
item = libmissing.NA
return super().insert(loc, item) # type: ignore[return-value]

class ArrowStringArrayNumpySemantics(ArrowStringArray):
_na_value = np.nan
_str_get = ArrowStringArrayMixin._str_get
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_pad = ArrowStringArrayMixin._str_pad
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace