Skip to content

Commit

Permalink
update algo.take to solve #59177 (#59181)
Browse files Browse the repository at this point in the history
* update algo.take to solve #59177

* forgot to update TestExtensionTake::test_take_coerces_list

* fixing pandas/tests/dtypes/test_generic.py::TestABCClasses::test_abc_hierarchy

* ABCExtensionArray set formatting

---------

Co-authored-by: Laurent Mutricy <laurent.mutricy@ekium.eu>
  • Loading branch information
mutricyl and Laurent Mutricy authored Jul 6, 2024
1 parent f6d06b8 commit 236d89b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 7 additions & 3 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
ABCExtensionArray,
ABCIndex,
ABCMultiIndex,
ABCNumpyExtensionArray,
ABCSeries,
ABCTimedeltaArray,
)
Expand Down Expand Up @@ -1161,11 +1162,14 @@ def take(
... )
array([ 10, 10, -10])
"""
if not isinstance(arr, (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries)):
if not isinstance(
arr,
(np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries, ABCNumpyExtensionArray),
):
# GH#52981
raise TypeError(
"pd.api.extensions.take requires a numpy.ndarray, "
f"ExtensionArray, Index, or Series, got {type(arr).__name__}."
"pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, "
f"Index, Series, or NumpyExtensionArray got {type(arr).__name__}."
)

indices = ensure_platform_int(indices)
Expand Down
10 changes: 9 additions & 1 deletion pandas/tests/test_take.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pandas._libs import iNaT

from pandas import array
import pandas._testing as tm
import pandas.core.algorithms as algos

Expand Down Expand Up @@ -303,7 +304,14 @@ def test_take_coerces_list(self):
arr = [1, 2, 3]
msg = (
"pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, "
"Index, or Series, got list"
"Index, Series, or NumpyExtensionArray got list"
)
with pytest.raises(TypeError, match=msg):
algos.take(arr, [0, 0])

def test_take_NumpyExtensionArray(self):
# GH#59177
arr = array([1 + 1j, 2, 3]) # NumpyEADtype('complex128') (NumpyExtensionArray)
assert algos.take(arr, [2]) == 2
arr = array([1, 2, 3]) # Int64Dtype() (ExtensionArray)
assert algos.take(arr, [2]) == 2

0 comments on commit 236d89b

Please sign in to comment.