Skip to content

Commit

Permalink
String dtype: propagate NaNs as False in predicate methods (eg .str.s…
Browse files Browse the repository at this point in the history
…tartswith) (pandas-dev#59616)
  • Loading branch information
jorisvandenbossche authored Oct 10, 2024
1 parent 8303af3 commit 88554d0
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 146 deletions.
44 changes: 27 additions & 17 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

import numpy as np

from pandas._libs import lib
from pandas.compat import (
pa_version_under10p1,
pa_version_under11p0,
pa_version_under13p0,
pa_version_under17p0,
)

from pandas.core.dtypes.missing import isna

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc
Expand All @@ -38,7 +37,7 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
# Convert a bool-dtype result to the appropriate result type
raise NotImplementedError

Expand Down Expand Up @@ -212,7 +211,9 @@ def _str_removesuffix(self, suffix: str):
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_startswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
Expand All @@ -225,11 +226,11 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="startswith")

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_endswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
Expand All @@ -242,9 +243,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="endswith")

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
Expand Down Expand Up @@ -283,7 +282,12 @@ def _str_isupper(self):
return self._convert_bool_result(result)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
regex: bool = True,
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")
Expand All @@ -293,19 +297,25 @@ def _str_contains(
else:
pa_contains = pc.match_substring
result = pa_contains(self._pa_array, pat, ignore_case=not case)
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="contains")

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat: str,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,9 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
for chunk in self._pa_array.iterchunks()
]

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return type(self)(result)

def _convert_int_result(self, result):
Expand Down
20 changes: 16 additions & 4 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,16 +2679,28 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
# ------------------------------------------------------------------------
# String methods interface
def _str_map(
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
self, f, na_value=lib.no_default, dtype=np.dtype("object"), convert: bool = True
):
# Optimization to apply the callable `f` to the categories once
# and rebuild the result by `take`ing from the result with the codes.
# Returns the same type as the object-dtype implementation though.
from pandas.core.arrays import NumpyExtensionArray

categories = self.categories
codes = self.codes
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
if categories.dtype == "string":
result = categories.array._str_map(f, na_value, dtype) # type: ignore[attr-defined]
if (
categories.dtype.na_value is np.nan # type: ignore[union-attr]
and is_bool_dtype(dtype)
and (na_value is lib.no_default or isna(na_value))
):
# NaN propagates as False for functions with boolean return type
na_value = False
else:
from pandas.core.arrays import NumpyExtensionArray

result = NumpyExtensionArray(categories.to_numpy())._str_map(
f, na_value, dtype
)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
Expand Down
31 changes: 20 additions & 11 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,11 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
return cls._from_sequence(scalars, dtype=dtype)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
self,
f,
na_value=lib.no_default,
dtype: Dtype | None = None,
convert: bool = True,
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(f, na_value=na_value, dtype=dtype)
Expand All @@ -390,7 +394,7 @@ def _str_map(

if dtype is None:
dtype = self.dtype
if na_value is None:
if na_value is lib.no_default:
na_value = self.dtype.na_value

mask = isna(self)
Expand Down Expand Up @@ -459,11 +463,17 @@ def _str_map_str_or_object(
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
def _str_map_nan_semantics(
self, f, na_value=lib.no_default, dtype: Dtype | None = None
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value
if na_value is lib.no_default:
if is_bool_dtype(dtype):
# NaN propagates as False
na_value = False
else:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
Expand All @@ -474,7 +484,8 @@ def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True
# NaN propagates as False
na_value = False

result = lib.map_infer_mask(
arr,
Expand All @@ -484,15 +495,13 @@ def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if na_value_is_na and is_integer_dtype(dtype) and mask.any():
# TODO: we could alternatively do this check before map_infer_mask
# and adjust the dtype/na_value we pass there. Which is more
# performant?
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result = result.astype("float64")
result[mask] = np.nan

return result

else:
Expand Down
40 changes: 26 additions & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,27 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

def _convert_bool_result(self, values):
def _convert_bool_result(self, values, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(na) and not isinstance(na, bool):
# GH#59561
warnings.warn(
f"Allowing a non-bool 'na' in obj.str.{method_name} is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

if self.dtype.na_value is np.nan:
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
if na is lib.no_default or isna(na):
# NaN propagates as False
values = values.fill_null(False)
else:
values = values.fill_null(na)
return values.to_numpy()
else:
if na is not lib.no_default and not isna(na): # pyright: ignore [reportGeneralTypeIssues]
values = values.fill_null(na)
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
Expand Down Expand Up @@ -306,22 +324,16 @@ def astype(self, dtype, copy: bool = True):
_str_slice = ArrowStringArrayMixin._str_slice

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na=lib.no_default,
regex: bool = True,
):
if flags:
return super()._str_contains(pat, case, flags, na, regex)

if not isna(na):
if not isinstance(na, bool):
# GH#59561
warnings.warn(
"Allowing a non-bool 'na' in obj.str.contains is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)

def _str_replace(
Expand Down
40 changes: 25 additions & 15 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,12 @@ def join(self, sep: str):

@forbid_nonstring_types(["bytes"])
def contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na=lib.no_default,
regex: bool = True,
):
r"""
Test if pattern or regex is contained within a string of a Series or Index.
Expand All @@ -1243,8 +1248,9 @@ def contains(
Flags to pass through to the re module, e.g. re.IGNORECASE.
na : scalar, optional
Fill value for missing values. The default depends on dtype of the
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
``pandas.NA`` is used.
array. For object-dtype, ``numpy.nan`` is used. For the nullable
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
``False`` is used.
regex : bool, default True
If True, assumes the pat is a regular expression.
Expand Down Expand Up @@ -1362,7 +1368,7 @@ def contains(
return self._wrap_result(result, fill_value=na, returns_string=False)

@forbid_nonstring_types(["bytes"])
def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
def match(self, pat: str, case: bool = True, flags: int = 0, na=lib.no_default):
"""
Determine if each string starts with a match of a regular expression.
Expand All @@ -1376,8 +1382,9 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
Regex module flags, e.g. re.IGNORECASE.
na : scalar, optional
Fill value for missing values. The default depends on dtype of the
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
``pandas.NA`` is used.
array. For object-dtype, ``numpy.nan`` is used. For the nullable
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
``False`` is used.
Returns
-------
Expand Down Expand Up @@ -1406,7 +1413,7 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
return self._wrap_result(result, fill_value=na, returns_string=False)

@forbid_nonstring_types(["bytes"])
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=None):
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=lib.no_default):
"""
Determine if each string entirely matches a regular expression.
Expand All @@ -1420,8 +1427,9 @@ def fullmatch(self, pat, case: bool = True, flags: int = 0, na=None):
Regex module flags, e.g. re.IGNORECASE.
na : scalar, optional
Fill value for missing values. The default depends on dtype of the
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
``pandas.NA`` is used.
array. For object-dtype, ``numpy.nan`` is used. For the nullable
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
``False`` is used.
Returns
-------
Expand Down Expand Up @@ -2612,7 +2620,7 @@ def count(self, pat, flags: int = 0):

@forbid_nonstring_types(["bytes"])
def startswith(
self, pat: str | tuple[str, ...], na: Scalar | None = None
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
) -> Series | Index:
"""
Test if the start of each string element matches a pattern.
Expand All @@ -2624,10 +2632,11 @@ def startswith(
pat : str or tuple[str, ...]
Character sequence or tuple of strings. Regular expressions are not
accepted.
na : object, default NaN
na : scalar, optional
Object shown if element tested is not a string. The default depends
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
For ``StringDtype``, ``pandas.NA`` is used.
For the nullable ``StringDtype``, ``pandas.NA`` is used.
For the ``"str"`` dtype, ``False`` is used.
Returns
-------
Expand Down Expand Up @@ -2682,7 +2691,7 @@ def startswith(

@forbid_nonstring_types(["bytes"])
def endswith(
self, pat: str | tuple[str, ...], na: Scalar | None = None
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
) -> Series | Index:
"""
Test if the end of each string element matches a pattern.
Expand All @@ -2694,10 +2703,11 @@ def endswith(
pat : str or tuple[str, ...]
Character sequence or tuple of strings. Regular expressions are not
accepted.
na : object, default NaN
na : scalar, optional
Object shown if element tested is not a string. The default depends
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
For ``StringDtype``, ``pandas.NA`` is used.
For the nullable ``StringDtype``, ``pandas.NA`` is used.
For the ``"str"`` dtype, ``False`` is used.
Returns
-------
Expand Down
Loading

0 comments on commit 88554d0

Please sign in to comment.