Skip to content

Commit

Permalink
String dtype: map builtin str alias to StringDtype (pandas-dev#59685)
Browse files Browse the repository at this point in the history
* String dtype: map builtin str alias to StringDtype

* fix tests

* fix datetimelike astype and more tests

* remove xfails

* try fix typing

* fix copy_view tests

* fix remaining tests with infer_string enabled

* ignore typing issue for now

* move to common.py

* simplify Categorical._str_get_dummies

* small cleanup

* fix ensure_string_array to not modify extension arrays inplace

* fix ensure_string_array once more + fix is_extension_array_dtype for str

* still xfail TestArrowArray::test_astype_str when not using infer_string

* ensure maybe_convert_objects copies object dtype input array when inferring StringDtype

* update test_1d_object_array_does_not_copy test

* update constructor copy test + do not copy in maybe_convert_objects?

* skip str.get_dummies test for now

* use pandas_dtype() instead of registry.find

* fix corner cases for calling pandas_dtype

* add TODO comment in ensure_string_array
  • Loading branch information
jorisvandenbossche committed Oct 10, 2024
1 parent 4ff2c68 commit 2789338
Show file tree
Hide file tree
Showing 31 changed files with 183 additions and 112 deletions.
9 changes: 8 additions & 1 deletion pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,14 @@ cpdef ndarray[object] ensure_string_array(

if hasattr(arr, "to_numpy"):

if hasattr(arr, "dtype") and arr.dtype.kind in "mM":
if (
hasattr(arr, "dtype")
and arr.dtype.kind in "mM"
# TODO: we should add a custom ArrowExtensionArray.astype implementation
# that handles astype(str) specifically, avoiding ending up here and
# then we can remove the below check for `_pa_array` (for ArrowEA)
and not hasattr(arr, "_pa_array")
):
# dtype check to exclude DataFrame
# GH#41409 TODO: not a great place for this
out = arr.astype(str).astype(object)
Expand Down
2 changes: 1 addition & 1 deletion pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@

COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
if using_string_dtype():
STRING_DTYPES: list[Dtype] = [str, "U"]
STRING_DTYPES: list[Dtype] = ["U"]
else:
STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]
COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,7 +2691,9 @@ def _str_get_dummies(self, sep: str = "|"):
# sep may not be in categories. Just bail on this.
from pandas.core.arrays import NumpyExtensionArray

return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
return NumpyExtensionArray(self.to_numpy(str, na_value="NaN"))._str_get_dummies(
sep
)

# ------------------------------------------------------------------------
# GroupBy Methods
Expand Down
10 changes: 8 additions & 2 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,16 @@ def astype(self, dtype, copy: bool = True):

return self._box_values(self.asi8.ravel()).reshape(self.shape)

elif is_string_dtype(dtype):
if isinstance(dtype, ExtensionDtype):
arr_object = self._format_native_types(na_rep=dtype.na_value) # type: ignore[arg-type]
cls = dtype.construct_array_type()
return cls._from_sequence(arr_object, dtype=dtype, copy=False)
else:
return self._format_native_types()

elif isinstance(dtype, ExtensionDtype):
return super().astype(dtype, copy=copy)
elif is_string_dtype(dtype):
return self._format_native_types()
elif dtype.kind in "iu":
# we deliberately ignore int32 vs. int64 here.
# See https://github.com/pandas-dev/pandas/issues/24381 for more.
Expand Down
18 changes: 17 additions & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import numpy as np

from pandas._config import using_string_dtype

from pandas._libs import (
Interval,
Period,
Expand Down Expand Up @@ -1325,7 +1327,15 @@ def is_extension_array_dtype(arr_or_dtype) -> bool:
elif isinstance(dtype, np.dtype):
return False
else:
return registry.find(dtype) is not None
try:
with warnings.catch_warnings():
# pandas_dtype(..) can raise UserWarning for class input
warnings.simplefilter("ignore", UserWarning)
dtype = pandas_dtype(dtype)
except (TypeError, ValueError):
# np.dtype(..) can raise ValueError
return False
return isinstance(dtype, ExtensionDtype)


def is_ea_or_datetimelike_dtype(dtype: DtypeObj | None) -> bool:
Expand Down Expand Up @@ -1620,6 +1630,12 @@ def pandas_dtype(dtype) -> DtypeObj:
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
return dtype

# builtin aliases
if dtype is str and using_string_dtype():
from pandas.core.arrays.string_ import StringDtype

return StringDtype(na_value=np.nan)

# registered extension types
result = registry.find(dtype)
if result is not None:
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6415,7 +6415,11 @@ def _should_compare(self, other: Index) -> bool:
return False

dtype = _unpack_nested_dtype(other)
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
return (
self._is_comparable_dtype(dtype)
or is_object_dtype(dtype)
or is_string_dtype(dtype)
)

def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
"""
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_number,
is_object_dtype,
is_scalar,
is_string_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import (
Expand Down Expand Up @@ -699,7 +700,7 @@ def _get_indexer(
# left/right get_indexer, compare elementwise, equality -> match
indexer = self._get_indexer_unique_sides(target)

elif not is_object_dtype(target.dtype):
elif not (is_object_dtype(target.dtype) or is_string_dtype(target.dtype)):
# homogeneous scalar index: use IntervalTree
# we should always have self._should_partial_index(target) here
target = self._maybe_convert_i8(target)
Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/arrays/floating/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ def test_astype_str(using_infer_string):

if using_infer_string:
expected = pd.array(["0.1", "0.2", None], dtype=pd.StringDtype(na_value=np.nan))
tm.assert_extension_array_equal(a.astype("str"), expected)

# TODO(infer_string) this should also be a string array like above
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
tm.assert_numpy_array_equal(a.astype(str), expected)
tm.assert_extension_array_equal(a.astype(str), expected)
tm.assert_extension_array_equal(a.astype("str"), expected)
else:
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")

Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/arrays/integer/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,9 @@ def test_astype_str(using_infer_string):

if using_infer_string:
expected = pd.array(["1", "2", None], dtype=pd.StringDtype(na_value=np.nan))
tm.assert_extension_array_equal(a.astype("str"), expected)

# TODO(infer_string) this should also be a string array like above
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
tm.assert_numpy_array_equal(a.astype(str), expected)
tm.assert_extension_array_equal(a.astype(str), expected)
tm.assert_extension_array_equal(a.astype("str"), expected)
else:
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")

Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/arrays/sparse/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def test_astype_all(self, any_real_numpy_dtype):
),
(
SparseArray([0, 1, 10]),
str,
SparseArray(["0", "1", "10"], dtype=SparseDtype(str, "0")),
np.str_,
SparseArray(["0", "1", "10"], dtype=SparseDtype(np.str_, "0")),
),
(SparseArray(["10", "20"]), float, SparseArray([10.0, 20.0])),
(
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/sparse/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_construct_from_string_fill_value_raises(string):
[
(SparseDtype(int, 0), float, SparseDtype(float, 0.0)),
(SparseDtype(int, 1), float, SparseDtype(float, 1.0)),
(SparseDtype(int, 1), str, SparseDtype(object, "1")),
(SparseDtype(int, 1), np.str_, SparseDtype(object, "1")),
(SparseDtype(float, 1.5), int, SparseDtype(int, 1)),
],
)
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,11 +810,23 @@ def test_pandas_dtype_string_dtypes(string_storage):
"pyarrow" if HAS_PYARROW else "python", na_value=np.nan
)

with pd.option_context("future.infer_string", True):
# with the default string_storage setting
result = pandas_dtype(str)
assert result == pd.StringDtype(
"pyarrow" if HAS_PYARROW else "python", na_value=np.nan
)

with pd.option_context("future.infer_string", True):
with pd.option_context("string_storage", string_storage):
result = pandas_dtype("str")
assert result == pd.StringDtype(string_storage, na_value=np.nan)

with pd.option_context("future.infer_string", True):
with pd.option_context("string_storage", string_storage):
result = pandas_dtype(str)
assert result == pd.StringDtype(string_storage, na_value=np.nan)

with pd.option_context("future.infer_string", False):
with pd.option_context("string_storage", string_storage):
result = pandas_dtype("str")
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/base/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def test_tolist(self, data):
assert result == expected

def test_astype_str(self, data):
result = pd.Series(data[:5]).astype(str)
expected = pd.Series([str(x) for x in data[:5]], dtype=str)
result = pd.Series(data[:2]).astype(str)
expected = pd.Series([str(x) for x in data[:2]], dtype=str)
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize(
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/json/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,8 @@ def astype(self, dtype, copy=True):
return self.copy()
return self
elif isinstance(dtype, StringDtype):
value = self.astype(str) # numpy doesn't like nested dicts
arr_cls = dtype.construct_array_type()
return arr_cls._from_sequence(value, dtype=dtype, copy=False)
return arr_cls._from_sequence(self, dtype=dtype, copy=False)
elif not copy:
return np.asarray([dict(x) for x in self], dtype=dtype)
else:
Expand Down
29 changes: 5 additions & 24 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
pa_version_under13p0,
pa_version_under14p0,
)
import pandas.util._test_decorators as td

from pandas.core.dtypes.dtypes import (
ArrowDtype,
Expand Down Expand Up @@ -286,43 +285,25 @@ def test_map(self, data_missing, na_action):
expected = data_missing.to_numpy()
tm.assert_numpy_array_equal(result, expected)

def test_astype_str(self, data, request):
def test_astype_str(self, data, request, using_infer_string):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_binary(pa_dtype):
request.applymarker(
pytest.mark.xfail(
reason=f"For {pa_dtype} .astype(str) decodes.",
)
)
elif (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
) or pa.types.is_duration(pa_dtype):
elif not using_infer_string and (
(pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
or pa.types.is_duration(pa_dtype)
):
request.applymarker(
pytest.mark.xfail(
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
)
)
super().test_astype_str(data)

@pytest.mark.parametrize(
"nullable_string_dtype",
[
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
],
)
def test_astype_string(self, data, nullable_string_dtype, request):
pa_dtype = data.dtype.pyarrow_dtype
if (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
) or pa.types.is_duration(pa_dtype):
request.applymarker(
pytest.mark.xfail(
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
)
)
super().test_astype_string(data, nullable_string_dtype)

def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
Expand Down
17 changes: 9 additions & 8 deletions pandas/tests/frame/methods/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,21 @@ def test_astype_str(self):
"d": list(map(str, d._values)),
"e": list(map(str, e._values)),
},
dtype="object",
dtype="str",
)

tm.assert_frame_equal(result, expected)

def test_astype_str_float(self):
def test_astype_str_float(self, using_infer_string):
# see GH#11302
result = DataFrame([np.nan]).astype(str)
expected = DataFrame(["nan"], dtype="object")
expected = DataFrame([np.nan if using_infer_string else "nan"], dtype="str")

tm.assert_frame_equal(result, expected)
result = DataFrame([1.12345678901234567890]).astype(str)

val = "1.1234567890123457"
expected = DataFrame([val], dtype="object")
expected = DataFrame([val], dtype="str")
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("dtype_class", [dict, Series])
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_astype_duplicate_col_series_arg(self):
result = df.astype(dtypes)
expected = DataFrame(
{
0: Series(vals[:, 0].astype(str), dtype=object),
0: Series(vals[:, 0].astype(str), dtype="str"),
1: vals[:, 1],
2: pd.array(vals[:, 2], dtype="Float64"),
3: vals[:, 3],
Expand Down Expand Up @@ -666,25 +666,26 @@ def test_astype_dt64tz(self, timezone_frame):
# dt64tz->dt64 deprecated
timezone_frame.astype("datetime64[ns]")

def test_astype_dt64tz_to_str(self, timezone_frame):
def test_astype_dt64tz_to_str(self, timezone_frame, using_infer_string):
# str formatting
result = timezone_frame.astype(str)
na_value = np.nan if using_infer_string else "NaT"
expected = DataFrame(
[
[
"2013-01-01",
"2013-01-01 00:00:00-05:00",
"2013-01-01 00:00:00+01:00",
],
["2013-01-02", "NaT", "NaT"],
["2013-01-02", na_value, na_value],
[
"2013-01-03",
"2013-01-03 00:00:00-05:00",
"2013-01-03 00:00:00+01:00",
],
],
columns=timezone_frame.columns,
dtype="object",
dtype="str",
)
tm.assert_frame_equal(result, expected)

Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/frame/methods/test_select_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def test_select_dtypes_include_using_list_like(self, using_infer_string):
ei = df[["a"]]
tm.assert_frame_equal(ri, ei)

ri = df.select_dtypes(include=[str])
tm.assert_frame_equal(ri, ei)

def test_select_dtypes_exclude_using_list_like(self):
df = DataFrame(
{
Expand Down Expand Up @@ -358,7 +361,7 @@ def test_select_dtypes_datetime_with_tz(self):
@pytest.mark.parametrize("dtype", [str, "str", np.bytes_, "S1", np.str_, "U1"])
@pytest.mark.parametrize("arg", ["include", "exclude"])
def test_select_dtypes_str_raises(self, dtype, arg, using_infer_string):
if using_infer_string and dtype == "str":
if using_infer_string and (dtype == "str" or dtype is str):
# this is tested below
pytest.skip("Selecting string columns works with future strings")
df = DataFrame(
Expand Down
Loading

0 comments on commit 2789338

Please sign in to comment.