Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NaN as na_value for new pyarrow_numpy StringDtype #54585

Merged
merged 23 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,14 @@ class StringDtype(StorageExtensionDtype):
# base class "StorageExtensionDtype") with class variable
name: ClassVar[str] = "string" # type: ignore[misc]

#: StringDtype().na_value uses pandas.NA
#: StringDtype().na_value uses pandas.NA except the implementation that
# follows NumPy semantics, which uses nan.
@property
def na_value(self) -> libmissing.NAType:
return libmissing.NA
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
if self.storage == "pyarrow_numpy":
return np.nan
else:
return libmissing.NA

_metadata = ("storage",)

Expand Down
41 changes: 26 additions & 15 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
)


def na_val(dtype):
if dtype.storage == "pyarrow_numpy":
return np.nan
else:
return pd.NA


@pytest.fixture
def dtype(string_storage):
"""Fixture giving StringDtype from parametrized 'string_storage'"""
Expand All @@ -31,26 +38,34 @@ def cls(dtype):

def test_repr(dtype):
df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)})
expected = " A\n0 a\n1 <NA>\n2 b"
if dtype.storage == "pyarrow_numpy":
expected = " A\n0 a\n1 NaN\n2 b"
else:
expected = " A\n0 a\n1 <NA>\n2 b"
assert repr(df) == expected

expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
if dtype.storage == "pyarrow_numpy":
expected = "0 a\n1 NaN\n2 b\nName: A, dtype: string"
else:
expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
assert repr(df.A) == expected

if dtype.storage == "pyarrow":
arr_name = "ArrowStringArray"
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
elif dtype.storage == "pyarrow_numpy":
arr_name = "ArrowStringArrayNumpySemantics"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string"
else:
arr_name = "StringArray"
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
assert repr(df.A.array) == expected


def test_none_to_nan(cls):
a = cls._from_sequence(["a", None, "b"])
assert a[1] is not None
assert a[1] is pd.NA
assert a[1] is na_val(a.dtype)


def test_setitem_validates(cls):
Expand Down Expand Up @@ -213,13 +228,9 @@ def test_comparison_methods_scalar(comparison_op, dtype):
other = "a"
result = getattr(a, op_name)(other)
if dtype.storage == "pyarrow_numpy":
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
expected = (
pd.array(expected, dtype="boolean")
.to_numpy(na_value=False)
.astype(np.bool_)
)
tm.assert_numpy_array_equal(result, expected)
expected = np.array([getattr(item, op_name)(other) for item in a])
expected[1] = False
tm.assert_numpy_array_equal(result, expected.astype(np.bool_))
else:
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
Expand Down Expand Up @@ -415,7 +426,7 @@ def test_min_max(method, skipna, dtype, request):
expected = "a" if method == "min" else "c"
assert result == expected
else:
assert result is pd.NA
assert result is na_val(arr.dtype)


@pytest.mark.parametrize("method", ["min", "max"])
Expand Down Expand Up @@ -483,7 +494,7 @@ def test_arrow_roundtrip(dtype, string_storage2):
expected = df.astype(f"string[{string_storage2}]")
tm.assert_frame_equal(result, expected)
# ensure the missing value is represented by NA and not np.nan or None
assert result.loc[2, "a"] is pd.NA
assert result.loc[2, "a"] is na_val(result["a"].dtype)


def test_arrow_load_from_zero_chunks(dtype, string_storage2):
Expand Down Expand Up @@ -581,7 +592,7 @@ def test_astype_from_float_dtype(float_dtype, dtype):
def test_to_numpy_returns_pdna_default(dtype):
arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
result = np.array(arr)
expected = np.array(["a", pd.NA, "b"], dtype=object)
expected = np.array(["a", na_val(dtype), "b"], dtype=object)
tm.assert_numpy_array_equal(result, expected)


Expand Down Expand Up @@ -621,7 +632,7 @@ def test_setitem_scalar_with_mask_validation(dtype):
mask = np.array([False, True, False])

ser[mask] = None
assert ser.array[1] is pd.NA
assert ser.array[1] is na_val(ser.dtype)

# for other non-string we should also raise an error
ser = pd.Series(["a", "b", "c"], dtype=dtype)
Expand Down
9 changes: 6 additions & 3 deletions pandas/tests/strings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Needed for new arrow string dtype
import numpy as np

import pandas as pd

Expand All @@ -7,6 +7,9 @@

def _convert_na_value(ser, expected):
if ser.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
if ser.dtype.storage == "pyarrow_numpy":
expected = expected.fillna(np.nan)
else:
# GH#18463
expected = expected.fillna(pd.NA)
return expected
9 changes: 6 additions & 3 deletions pandas/tests/strings/test_split_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
Series,
_testing as tm,
)
from pandas.tests.strings import _convert_na_value
from pandas.tests.strings import (
_convert_na_value,
object_pyarrow_numpy,
)


@pytest.mark.parametrize("method", ["split", "rsplit"])
Expand Down Expand Up @@ -113,8 +116,8 @@ def test_split_object_mixed(expand, method):
def test_split_n(any_string_dtype, method, n):
s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype)
expected = Series([["a", "b"], pd.NA, ["b", "c"]])

result = getattr(s.str, method)(" ", n=n)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -381,7 +384,7 @@ def test_split_nan_expand(any_string_dtype):
# check that these are actually np.nan/pd.NA and not None
# TODO see GH 18463
# tm.assert_frame_equal does not differentiate
if any_string_dtype == "object":
if any_string_dtype in object_pyarrow_numpy:
assert all(np.isnan(x) for x in result.iloc[1])
else:
assert all(x is pd.NA for x in result.iloc[1])
Expand Down