Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support index is tensor or list for getitem in static mode #33000

Merged
merged 1 commit into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions python/paddle/fluid/tests/unittests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,75 @@ def _test_slice(self, place):
self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))

def test_slice(self):
place = fluid.CPUPlace()
self._test_slice(place)
def _test_slice_index_tensor(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [1, 0]
idx1 = [0, 1]
idx2 = [0, 0]
idx3 = [1, 1]

out0 = x[paddle.assign(np.array(idx0))]
out1 = x[paddle.assign(np.array(idx1))]
out2 = x[paddle.assign(np.array(idx2))]
out3 = x[paddle.assign(np.array(idx3))]

exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])

expected = [data[idx0], data[idx1], data[idx2], data[idx3]]

self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())

with self.assertRaises(IndexError):
one = paddle.ones(shape=[1])
res = x[one, [0, 0]]

def _test_slice_index_list(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [1, 0]
idx1 = [0, 1]
idx2 = [0, 0]
idx3 = [1, 1]

out0 = x[idx0]
out1 = x[idx1]
out2 = x[idx2]
out3 = x[idx3]

exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])

expected = [data[idx0], data[idx1], data[idx2], data[idx3]]

self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())

with self.assertRaises(IndexError):
res = x[[1, 0], [0, 0]]

with self.assertRaises(TypeError):
res = x[[1.2, 0]]

def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self._test_slice(core.CUDAPlace(0))
places.append(core.CUDAPlace(0))

for place in places:
self._test_slice(place)
self._test_slice_index_tensor(place)
self._test_slice_index_list(place)

def _tostring(self):
b = default_main_program().current_block()
Expand Down
27 changes: 26 additions & 1 deletion python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _getitem_impl_(var, item):
Returns:
Sliced variable
"""
from .framework import default_main_program
from .framework import default_main_program, Variable

if not isinstance(item, tuple):
item = (item, )
Expand Down Expand Up @@ -126,6 +126,31 @@ def _getitem_impl_(var, item):
start = 0 if start is None else start
end = MAX_INTEGER if end is None else end

elif isinstance(slice_item, list):
for i in slice_item:
if not isinstance(i, int):
raise TypeError("Only support int value in list")

if len(item) != 1:
raise IndexError(
"When index contains a list, its length must be 1, but received {}".
format(len(item)))

from .layers import assign
from ..tensor import index_select

idx = assign(np.array(slice_item))
return index_select(var, index=idx, axis=0)

elif isinstance(slice_item, Variable):
if len(item) != 1:
raise IndexError(
"When index contains a Tensor, its length must be 1, but received {}".
format(len(item)))

from ..tensor import index_select
return index_select(var, index=slice_item, axis=0)

else:
raise IndexError(
"Valid index accept int or slice or ellipsis, but received {}.".
Expand Down