diff --git a/asv_bench/benchmarks/algorithms.py b/asv_bench/benchmarks/algorithms.py index 823daa2e31529..aecc609df574e 100644 --- a/asv_bench/benchmarks/algorithms.py +++ b/asv_bench/benchmarks/algorithms.py @@ -28,23 +28,36 @@ class Factorize: "datetime64[ns, tz]", "Int64", "boolean", + "string_arrow", ], ] param_names = ["unique", "sort", "dtype"] def setup(self, unique, sort, dtype): N = 10 ** 5 + string_index = tm.makeStringIndex(N) + try: + from pandas.core.arrays.string_arrow import ArrowStringDtype + + string_arrow = pd.array(string_index, dtype=ArrowStringDtype()) + except ImportError: + string_arrow = None + + if dtype == "string_arrow" and not string_arrow: + raise NotImplementedError + data = { "int": pd.Int64Index(np.arange(N)), "uint": pd.UInt64Index(np.arange(N)), "float": pd.Float64Index(np.random.randn(N)), - "string": tm.makeStringIndex(N), + "string": string_index, "datetime64[ns]": pd.date_range("2011-01-01", freq="H", periods=N), "datetime64[ns, tz]": pd.date_range( "2011-01-01", freq="H", periods=N, tz="Asia/Tokyo" ), "Int64": pd.array(np.arange(N), dtype="Int64"), "boolean": pd.array(np.random.randint(0, 2, N), dtype="boolean"), + "string_arrow": string_arrow, }[dtype] if not unique: data = data.repeat(5) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 8441b324515f3..26fe6338118b6 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -6,6 +6,7 @@ Any, Optional, Sequence, + Tuple, Type, Union, ) @@ -20,6 +21,7 @@ Dtype, NpDtype, ) +from pandas.util._decorators import doc from pandas.util._validators import validate_fillna_kwargs from pandas.core.dtypes.base import ExtensionDtype @@ -273,9 +275,22 @@ def __len__(self) -> int: """ return len(self._data) - @classmethod - def _from_factorized(cls, values, original): - return cls._from_sequence(values) + @doc(ExtensionArray.factorize) + def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, ExtensionArray]: + encoded = self._data.dictionary_encode() + indices = pa.chunked_array( + [c.indices for c in encoded.chunks], type=encoded.type.index_type + ).to_pandas() + if indices.dtype.kind == "f": + indices[np.isnan(indices)] = na_sentinel + indices = indices.astype(np.int64, copy=False) + + if encoded.num_chunks: + uniques = type(self)(encoded.chunk(0).dictionary) + else: + uniques = type(self)(pa.array([], type=encoded.type.value_type)) + + return indices.values, uniques @classmethod def _concat_same_type(cls, to_concat) -> ArrowStringArray: diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index d0a3ef17afdbc..49aee76e10f6a 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -26,6 +26,29 @@ from pandas.tests.extension import base +def split_array(arr): + if not isinstance(arr.dtype, ArrowStringDtype): + pytest.skip("chunked array n/a") + + def _split_array(arr): + import pyarrow as pa + + arrow_array = arr._data + split = len(arrow_array) // 2 + arrow_array = pa.chunked_array( + [*arrow_array[:split].chunks, *arrow_array[split:].chunks] + ) + assert arrow_array.num_chunks == 2 + return type(arr)(arrow_array) + + return _split_array(arr) + + +@pytest.fixture(params=[True, False]) +def chunked(request): + return request.param + + @pytest.fixture( params=[ StringDtype, @@ -39,28 +62,32 @@ def dtype(request): @pytest.fixture -def data(dtype): +def data(dtype, chunked): strings = np.random.choice(list(string.ascii_letters), size=100) while strings[0] == strings[1]: strings = np.random.choice(list(string.ascii_letters), size=100) - return dtype.construct_array_type()._from_sequence(strings) + arr = dtype.construct_array_type()._from_sequence(strings) + return split_array(arr) if chunked else arr @pytest.fixture -def data_missing(dtype): +def data_missing(dtype, chunked): """Length 2 array with [NA, Valid]""" - return dtype.construct_array_type()._from_sequence([pd.NA, "A"]) + arr = dtype.construct_array_type()._from_sequence([pd.NA, "A"]) + return split_array(arr) if chunked else arr @pytest.fixture -def data_for_sorting(dtype): - return dtype.construct_array_type()._from_sequence(["B", "C", "A"]) +def data_for_sorting(dtype, chunked): + arr = dtype.construct_array_type()._from_sequence(["B", "C", "A"]) + return split_array(arr) if chunked else arr @pytest.fixture -def data_missing_for_sorting(dtype): - return dtype.construct_array_type()._from_sequence(["B", pd.NA, "A"]) +def data_missing_for_sorting(dtype, chunked): + arr = dtype.construct_array_type()._from_sequence(["B", pd.NA, "A"]) + return split_array(arr) if chunked else arr @pytest.fixture @@ -69,10 +96,11 @@ def na_value(): @pytest.fixture -def data_for_grouping(dtype): - return dtype.construct_array_type()._from_sequence( +def data_for_grouping(dtype, chunked): + arr = dtype.construct_array_type()._from_sequence( ["B", "B", pd.NA, pd.NA, "A", "A", "B", "C"] ) + return split_array(arr) if chunked else arr class TestDtype(base.BaseDtypeTests):