Skip to content

Commit

Permalink
Use NaN as na_value for new pyarrow_numpy StringDtype (pandas-dev#54585)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored and mroeschke committed Sep 11, 2023
1 parent a47a4a8 commit 8e3963b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 24 deletions.
10 changes: 7 additions & 3 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,14 @@ class StringDtype(StorageExtensionDtype):
# base class "StorageExtensionDtype") with class variable
name: ClassVar[str] = "string" # type: ignore[misc]

#: StringDtype().na_value uses pandas.NA
#: StringDtype().na_value uses pandas.NA except the implementation that
# follows NumPy semantics, which uses nan.
@property
def na_value(self) -> libmissing.NAType:
return libmissing.NA
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
if self.storage == "pyarrow_numpy":
return np.nan
else:
return libmissing.NA

_metadata = ("storage",)

Expand Down
41 changes: 26 additions & 15 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
)


def na_val(dtype):
if dtype.storage == "pyarrow_numpy":
return np.nan
else:
return pd.NA


@pytest.fixture
def dtype(string_storage):
"""Fixture giving StringDtype from parametrized 'string_storage'"""
Expand All @@ -31,26 +38,34 @@ def cls(dtype):

def test_repr(dtype):
df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)})
expected = " A\n0 a\n1 <NA>\n2 b"
if dtype.storage == "pyarrow_numpy":
expected = " A\n0 a\n1 NaN\n2 b"
else:
expected = " A\n0 a\n1 <NA>\n2 b"
assert repr(df) == expected

expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
if dtype.storage == "pyarrow_numpy":
expected = "0 a\n1 NaN\n2 b\nName: A, dtype: string"
else:
expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
assert repr(df.A) == expected

if dtype.storage == "pyarrow":
arr_name = "ArrowStringArray"
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
elif dtype.storage == "pyarrow_numpy":
arr_name = "ArrowStringArrayNumpySemantics"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string"
else:
arr_name = "StringArray"
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
assert repr(df.A.array) == expected


def test_none_to_nan(cls):
a = cls._from_sequence(["a", None, "b"])
assert a[1] is not None
assert a[1] is pd.NA
assert a[1] is na_val(a.dtype)


def test_setitem_validates(cls):
Expand Down Expand Up @@ -213,13 +228,9 @@ def test_comparison_methods_scalar(comparison_op, dtype):
other = "a"
result = getattr(a, op_name)(other)
if dtype.storage == "pyarrow_numpy":
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
expected = (
pd.array(expected, dtype="boolean")
.to_numpy(na_value=False)
.astype(np.bool_)
)
tm.assert_numpy_array_equal(result, expected)
expected = np.array([getattr(item, op_name)(other) for item in a])
expected[1] = False
tm.assert_numpy_array_equal(result, expected.astype(np.bool_))
else:
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
Expand Down Expand Up @@ -415,7 +426,7 @@ def test_min_max(method, skipna, dtype, request):
expected = "a" if method == "min" else "c"
assert result == expected
else:
assert result is pd.NA
assert result is na_val(arr.dtype)


@pytest.mark.parametrize("method", ["min", "max"])
Expand Down Expand Up @@ -483,7 +494,7 @@ def test_arrow_roundtrip(dtype, string_storage2):
expected = df.astype(f"string[{string_storage2}]")
tm.assert_frame_equal(result, expected)
# ensure the missing value is represented by NA and not np.nan or None
assert result.loc[2, "a"] is pd.NA
assert result.loc[2, "a"] is na_val(result["a"].dtype)


def test_arrow_load_from_zero_chunks(dtype, string_storage2):
Expand Down Expand Up @@ -581,7 +592,7 @@ def test_astype_from_float_dtype(float_dtype, dtype):
def test_to_numpy_returns_pdna_default(dtype):
arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
result = np.array(arr)
expected = np.array(["a", pd.NA, "b"], dtype=object)
expected = np.array(["a", na_val(dtype), "b"], dtype=object)
tm.assert_numpy_array_equal(result, expected)


Expand Down Expand Up @@ -621,7 +632,7 @@ def test_setitem_scalar_with_mask_validation(dtype):
mask = np.array([False, True, False])

ser[mask] = None
assert ser.array[1] is pd.NA
assert ser.array[1] is na_val(ser.dtype)

# for other non-string we should also raise an error
ser = pd.Series(["a", "b", "c"], dtype=dtype)
Expand Down
9 changes: 6 additions & 3 deletions pandas/tests/strings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Needed for new arrow string dtype
import numpy as np

import pandas as pd

Expand All @@ -7,6 +7,9 @@

def _convert_na_value(ser, expected):
if ser.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
if ser.dtype.storage == "pyarrow_numpy":
expected = expected.fillna(np.nan)
else:
# GH#18463
expected = expected.fillna(pd.NA)
return expected
9 changes: 6 additions & 3 deletions pandas/tests/strings/test_split_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
Series,
_testing as tm,
)
from pandas.tests.strings import _convert_na_value
from pandas.tests.strings import (
_convert_na_value,
object_pyarrow_numpy,
)


@pytest.mark.parametrize("method", ["split", "rsplit"])
Expand Down Expand Up @@ -113,8 +116,8 @@ def test_split_object_mixed(expand, method):
def test_split_n(any_string_dtype, method, n):
s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype)
expected = Series([["a", "b"], pd.NA, ["b", "c"]])

result = getattr(s.str, method)(" ", n=n)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -381,7 +384,7 @@ def test_split_nan_expand(any_string_dtype):
# check that these are actually np.nan/pd.NA and not None
# TODO see GH 18463
# tm.assert_frame_equal does not differentiate
if any_string_dtype == "object":
if any_string_dtype in object_pyarrow_numpy:
assert all(np.isnan(x) for x in result.iloc[1])
else:
assert all(x is pd.NA for x in result.iloc[1])
Expand Down

0 comments on commit 8e3963b

Please sign in to comment.