Skip to content

Commit

Permalink
Support py-boolean in indexing (#58856)
Browse files Browse the repository at this point in the history
* fix single Py_Bool indexing and add unittest case

* fix ci error

* complex32 -> complex128
  • Loading branch information
zoooo0820 authored Nov 15, 2023
1 parent 6947474 commit 0b4fd0f
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 35 deletions.
4 changes: 3 additions & 1 deletion python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,9 @@ def contain_tensor_or_list(item):
item = (item,)

for slice_item in item:
if isinstance(slice_item, (list, np.ndarray, Variable, range)):
if isinstance(
slice_item, (list, np.ndarray, Variable, range, bool)
):
return True

return False
Expand Down
20 changes: 11 additions & 9 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def replace_none(item):
def is_integer_or_scalar_tensor(ele):
from .framework import Variable

if isinstance(ele, int):
if type(ele) is int:
return True
elif isinstance(ele, Variable):
# NOTE(zoooo0820): For compatibility, if FLAGS_set_to_1d is set to True,
Expand Down Expand Up @@ -693,7 +693,8 @@ def parse_index(x, indices):
)

estimated_dim = 0
for dim, slice_item in enumerate(indices):
dim = 0
for i, slice_item in enumerate(indices):
start, end, step = None, None, None
if is_integer_or_scalar_tensor(slice_item):
if (
Expand All @@ -718,21 +719,22 @@ def parse_index(x, indices):
start = slice_item
step = 1
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
dim += 1
elif isinstance(slice_item, bool):
# single bool is advanced-indexing
none_axes.append(dim)
estimated_dim += 1
advanced_index[estimated_dim] = (
estimated_dim,
paddle.to_tensor(slice_item),
paddle.to_tensor([slice_item]),
)
has_advanced_index = True
estimated_dim += 1
elif isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
estimated_dim += 1

dim += 1
if start is None and end is None and step is None:
continue

Expand Down Expand Up @@ -760,7 +762,7 @@ def parse_index(x, indices):

has_advanced_index = True
estimated_dim += 1

dim += 1
elif isinstance(slice_item, paddle.base.Variable):
# In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
if (
Expand All @@ -780,7 +782,7 @@ def parse_index(x, indices):
advanced_index[estimated_dim] = (estimated_dim, slice_item)
has_advanced_index = True
estimated_dim += 1

dim += 1
elif isinstance(slice_item, paddle.pir.OpResult):
# In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
if slice_item.dtype == paddle.pir.core.DataType.BOOL:
Expand All @@ -797,7 +799,7 @@ def parse_index(x, indices):
advanced_index[estimated_dim] = (estimated_dim, slice_item)
has_advanced_index = True
estimated_dim += 1

dim += 1
else:
raise IndexError(
"Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format(
Expand All @@ -808,7 +810,7 @@ def parse_index(x, indices):
starts.append(start)
ends.append(end)
steps.append(step)
axes.append(dim)
axes.append(dim - 1)
use_strided_slice = (
True
if (
Expand Down
99 changes: 84 additions & 15 deletions test/indexing/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_combined_index_1(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[[0, 1], :, [1, 2]]
Expand All @@ -52,7 +52,7 @@ def test_combined_index_2(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

x = paddle.to_tensor(np_data, dtype=self.dtype)
Expand All @@ -70,7 +70,7 @@ def test_combined_index_3(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4]
Expand All @@ -89,7 +89,7 @@ def test_combined_index_4(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4]
Expand All @@ -107,7 +107,7 @@ def test_combined_index_5(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[::2, [1, 0], [2, 3], 0:4:2]
Expand All @@ -125,7 +125,7 @@ def test_combined_index_6(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]]
Expand All @@ -143,7 +143,7 @@ def test_combined_index_7(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]]
Expand All @@ -161,7 +161,7 @@ def test_combined_index_8(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[
Expand All @@ -181,7 +181,7 @@ def test_combined_index_9(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]]
Expand All @@ -199,7 +199,7 @@ def test_combined_index_10(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[:, [True, False, True, False], 4]
Expand All @@ -220,7 +220,7 @@ def test_combined_index_11(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[:, [False, False, False, False], 4]
Expand All @@ -240,7 +240,7 @@ def test_index_has_range(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[:, range(3), 4]
Expand All @@ -261,7 +261,7 @@ def test_indexing_with_bool_list1(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[[True, False, True], [False, False, False, True]]
Expand All @@ -282,7 +282,7 @@ def test_indexing_with_bool_list2(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_indexing_is_multi_dim_list(self):

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex32':
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[np.array([[2, 3, 4], [1, 2, 5]])]
Expand All @@ -326,6 +326,45 @@ def test_indexing_is_multi_dim_list(self):
np.testing.assert_allclose(y.numpy(), np_res)
np.testing.assert_allclose(y.numpy(), y_index_tensor.numpy())

def test_indexing_is_boolean_true(self):
# indexing is boolean, should improve rank of tensor and then treat it as advanced indexing.
np_data = (
np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3)).astype(self.ndtype)
)

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[True]

x = paddle.to_tensor(np_data, dtype=self.dtype)
y = x[True]

if self.dtype == 'bfloat16':
y = paddle.cast(y, dtype='float32')

np.testing.assert_allclose(y.numpy(), np_res)

def test_indexing_is_boolean_false(self):
# indexing is boolean, should improve rank of tensor and then treat it as advanced indexing.
np_data = (
np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3)).astype(self.ndtype)
)

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[1, False, 0]

x = paddle.to_tensor(np_data, dtype=self.dtype)
y = x[1, False, 0]

np.testing.assert_allclose(y.numpy(), np_res)


@unittest.skipIf(
not core.is_compiled_with_cuda()
Expand Down Expand Up @@ -1002,6 +1041,36 @@ def test_indexing_is_multi_dim_list(self):
np.testing.assert_allclose(res[0], np_res)
np.testing.assert_allclose(res[1], np_res)

def test_indexing_is_boolean_true(self):
# indexing is boolean, should improve rank of tensor and then treat it as advanced indexing.
np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
np_res = np_data[True]

with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(x, True)

res = self.exe.run(fetch_list=[y.name])

np.testing.assert_allclose(res[0], np_res)

def test_indexing_is_boolean_false(self):
# indexing is boolean, should improve rank of tensor and then treat it as advanced indexing.
np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
np_res = np_data[1, False, 0]

with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(x, (1, False, 0))

res = self.exe.run(fetch_list=[y.name])

np.testing.assert_allclose(res[0], np_res)


class TestGetitemBasicIndexOutputView(unittest.TestCase):
def setUp(self):
Expand Down
Loading

0 comments on commit 0b4fd0f

Please sign in to comment.