Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Arrow backed string array - implement factorize() method without casting to objects #38007

Merged
merged 15 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion asv_bench/benchmarks/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pandas._libs import lib

import pandas as pd
from pandas.core.arrays.string_arrow import ArrowStringDtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do this in a try/except? (we need to be able to still run the benchmarks with slightly older pandas version that might not have this import available)


from .pandas_vb_common import tm

Expand Down Expand Up @@ -43,23 +44,34 @@ class Factorize:
"datetime64[ns, tz]",
"Int64",
"boolean",
"string_arrow",
],
]
param_names = ["unique", "sort", "dtype"]

def setup(self, unique, sort, dtype):
N = 10 ** 5
string_index = tm.makeStringIndex(N)
try:
string_arrow = pd.array(string_index, dtype=ArrowStringDtype())
except ImportError:
string_arrow = None

if dtype == "string_arrow" and not string_arrow:
raise NotImplementedError

data = {
"int": pd.Int64Index(np.arange(N)),
"uint": pd.UInt64Index(np.arange(N)),
"float": pd.Float64Index(np.random.randn(N)),
"string": tm.makeStringIndex(N),
"string": string_index,
"datetime64[ns]": pd.date_range("2011-01-01", freq="H", periods=N),
"datetime64[ns, tz]": pd.date_range(
"2011-01-01", freq="H", periods=N, tz="Asia/Tokyo"
),
"Int64": pd.array(np.arange(N), dtype="Int64"),
"boolean": pd.array(np.random.randint(0, 2, N), dtype="boolean"),
"string_arrow": string_arrow,
}[dtype]
if not unique:
data = data.repeat(5)
Expand Down
21 changes: 18 additions & 3 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
Optional,
Sequence,
Tuple,
Type,
Union,
)
Expand All @@ -20,6 +21,7 @@
Dtype,
NpDtype,
)
from pandas.util._decorators import doc
from pandas.util._validators import validate_fillna_kwargs

from pandas.core.dtypes.base import ExtensionDtype
Expand Down Expand Up @@ -273,9 +275,22 @@ def __len__(self) -> int:
"""
return len(self._data)

@classmethod
def _from_factorized(cls, values, original):
return cls._from_sequence(values)
@doc(ExtensionArray.factorize)
def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, ExtensionArray]:
encoded = self._data.dictionary_encode()
indices = pa.chunked_array(
[c.indices for c in encoded.chunks], type=encoded.type.index_type
).to_pandas()
if indices.dtype.kind == "f":
indices[np.isnan(indices)] = na_sentinel
indices = indices.astype(np.int64, copy=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering, is the int64 needed here? (pyarrow will typically use int32 as default I think)

I suppose that we always return int64 from factorize for the indices. Short-term, casting to int64 might be best then (to ensure nothing else breaks because of not doing that), but long term we should maybe check if internally we require int64 or would be fine with int32 as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering, is the int64 needed here? (pyarrow will typically use int32 as default I think)

refactor in 0023f08 partially to address comments

but yes, we seem to be getting an int32 from pyarrow

also we could maybe work with numpy arrays here directly for the indices instead of pandas Series?


if encoded.num_chunks:
uniques = type(self)(encoded.chunk(0).dictionary)
else:
uniques = type(self)(pa.array([], type=encoded.type.value_type))

return indices.values, uniques

@classmethod
def _concat_same_type(cls, to_concat) -> ArrowStringArray:
Expand Down
48 changes: 38 additions & 10 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@
from pandas.tests.extension import base


def split_array(arr):
if not isinstance(arr.dtype, ArrowStringDtype):
pytest.skip("chunked array n/a")

def _split_array(arr):
import pyarrow as pa

arrow_array = arr._data
split = len(arrow_array) // 2
arrow_array = pa.chunked_array(
[*arrow_array[:split].chunks, *arrow_array[split:].chunks]
)
assert arrow_array.num_chunks == 2
return type(arr)(arrow_array)

return _split_array(arr)


@pytest.fixture(params=[True, False])
def chunked(request):
return request.param


@pytest.fixture(
params=[
StringDtype,
Expand All @@ -39,28 +62,32 @@ def dtype(request):


@pytest.fixture
def data(dtype):
def data(dtype, chunked):
strings = np.random.choice(list(string.ascii_letters), size=100)
while strings[0] == strings[1]:
strings = np.random.choice(list(string.ascii_letters), size=100)

return dtype.construct_array_type()._from_sequence(strings)
arr = dtype.construct_array_type()._from_sequence(strings)
return split_array(arr) if chunked else arr


@pytest.fixture
def data_missing(dtype):
def data_missing(dtype, chunked):
"""Length 2 array with [NA, Valid]"""
return dtype.construct_array_type()._from_sequence([pd.NA, "A"])
arr = dtype.construct_array_type()._from_sequence([pd.NA, "A"])
return split_array(arr) if chunked else arr


@pytest.fixture
def data_for_sorting(dtype):
return dtype.construct_array_type()._from_sequence(["B", "C", "A"])
def data_for_sorting(dtype, chunked):
arr = dtype.construct_array_type()._from_sequence(["B", "C", "A"])
return split_array(arr) if chunked else arr


@pytest.fixture
def data_missing_for_sorting(dtype):
return dtype.construct_array_type()._from_sequence(["B", pd.NA, "A"])
def data_missing_for_sorting(dtype, chunked):
arr = dtype.construct_array_type()._from_sequence(["B", pd.NA, "A"])
return split_array(arr) if chunked else arr


@pytest.fixture
Expand All @@ -69,10 +96,11 @@ def na_value():


@pytest.fixture
def data_for_grouping(dtype):
return dtype.construct_array_type()._from_sequence(
def data_for_grouping(dtype, chunked):
arr = dtype.construct_array_type()._from_sequence(
["B", "B", pd.NA, pd.NA, "A", "A", "B", "C"]
)
return split_array(arr) if chunked else arr


class TestDtype(base.BaseDtypeTests):
Expand Down