Skip to content

Commit

Permalink
Fix .replace(Index, Index) raising a TypeError (#16513)
Browse files Browse the repository at this point in the history
Since `cudf.Index` is list-like, passing this to `.replace` should act like replacing a list of values with a corresponding list of values.

Discovered while working on rapidsai/cuml#6019

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #16513
  • Loading branch information
mroeschke committed Aug 15, 2024
1 parent 89863a3 commit 2bcb7ec
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
14 changes: 7 additions & 7 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6469,7 +6469,7 @@ def _get_replacement_values_for_columns(
to_replace_columns = {col: [to_replace] for col in columns_dtype_map}
values_columns = {col: [value] for col in columns_dtype_map}
elif cudf.api.types.is_list_like(to_replace) or isinstance(
to_replace, ColumnBase
to_replace, (ColumnBase, BaseIndex)
):
if is_scalar(value):
to_replace_columns = {col: to_replace for col in columns_dtype_map}
Expand All @@ -6483,7 +6483,9 @@ def _get_replacement_values_for_columns(
)
for col in columns_dtype_map
}
elif cudf.api.types.is_list_like(value):
elif cudf.api.types.is_list_like(
value
) or cudf.utils.dtypes.is_column_like(value):
if len(to_replace) != len(value):
raise ValueError(
f"Replacement lists must be "
Expand All @@ -6495,9 +6497,6 @@ def _get_replacement_values_for_columns(
col: to_replace for col in columns_dtype_map
}
values_columns = {col: value for col in columns_dtype_map}
elif cudf.utils.dtypes.is_column_like(value):
to_replace_columns = {col: to_replace for col in columns_dtype_map}
values_columns = {col: value for col in columns_dtype_map}
else:
raise TypeError(
"value argument must be scalar, list-like or Series"
Expand Down Expand Up @@ -6592,12 +6591,13 @@ def _get_replacement_values_for_columns(
return all_na_columns, to_replace_columns, values_columns


def _is_series(obj):
def _is_series(obj: Any) -> bool:
"""
Checks if the `obj` is of type `cudf.Series`
instead of checking for isinstance(obj, cudf.Series)
to avoid circular imports.
"""
return isinstance(obj, Frame) and obj.ndim == 1 and obj.index is not None
return isinstance(obj, IndexedFrame) and obj.ndim == 1


@_performance_tracking
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/tests/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,3 +1378,9 @@ def test_fillna_nan_and_null():
result = ser.fillna(2.2)
expected = cudf.Series([2.2, 2.2, 1.1])
assert_eq(result, expected)


def test_replace_with_index_objects():
result = cudf.Series([1, 2]).replace(cudf.Index([1]), cudf.Index([2]))
expected = pd.Series([1, 2]).replace(pd.Index([1]), pd.Index([2]))
assert_eq(result, expected)

0 comments on commit 2bcb7ec

Please sign in to comment.