Skip to content

Commit

Permalink
String dtype: allow string dtype for non-raw apply with numba engine (#…
Browse files Browse the repository at this point in the history
…59854)

* String dtype: allow string dtype for non-raw apply with numba engine

* remove xfails

* clean-up
  • Loading branch information
jorisvandenbossche authored Sep 25, 2024
1 parent 7e5282f commit c8a6740
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 11 deletions.
3 changes: 2 additions & 1 deletion pandas/core/_numba/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
@contextmanager
def set_numba_data(index: Index):
numba_data = index._data
if numba_data.dtype == object:
if numba_data.dtype in (object, "string"):
numba_data = np.asarray(numba_data)
if not lib.is_string_array(numba_data):
raise ValueError(
"The numba engine only supports using string or numeric column names"
Expand Down
5 changes: 0 additions & 5 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,12 +1172,7 @@ def apply_with_numba(self) -> dict[int, Any]:
from pandas.core._numba.extensions import set_numba_data

index = self.obj.index
if index.dtype == "string":
index = index.astype(object)

columns = self.obj.columns
if columns.dtype == "string":
columns = columns.astype(object)

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def test_apply(float_frame, engine, request):
assert result.index is float_frame.index


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("raw", [True, False])
@pytest.mark.parametrize("nopython", [True, False])
Expand Down
4 changes: 0 additions & 4 deletions pandas/tests/apply/test_numba.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas.util._test_decorators as td

import pandas as pd
Expand All @@ -20,7 +18,6 @@ def apply_axis(request):
return request.param


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_numba_vs_python_noop(float_frame, apply_axis):
func = lambda x: x
result = float_frame.apply(func, engine="numba", axis=apply_axis)
Expand All @@ -43,7 +40,6 @@ def test_numba_vs_python_string_index():
)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_numba_vs_python_indexing():
frame = DataFrame(
{"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]},
Expand Down

0 comments on commit c8a6740

Please sign in to comment.