Skip to content

Commit

Permalink
fix UT when np >= 1.23 (#51466)
Browse files Browse the repository at this point in the history
* fix UT when np >= 1.24

* optimize decription of this change
  • Loading branch information
zoooo0820 authored Mar 14, 2023
1 parent 9cd99f7 commit 9658f49
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
17 changes: 15 additions & 2 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,22 @@ def _test_list_index(self):
array = np.arange(120).reshape([6, 5, 4])
x = paddle.to_tensor(array)
py_idx = [[0, 2, 0, 1, 3], [0, 0, 1, 2, 0]]

# note(chenjianye):
# Non-tuple sequence for multidimensional indexing is supported in numpy < 1.23.
# For List case, the outermost `[]` will be treated as tuple `()` in version less than 1.23,
# which is used to wrap index elements for multiple axes.
# And from 1.23, this will be treat as a whole and only works on one axis.
#
# e.g. x[[[0],[1]]] == x[([0],[1])] == x[[0],[1]] (in version < 1.23)
# x[[[0],[1]]] == x[array([[0],[1]])] (in version >= 1.23)
#
# Here, we just modify the code to remove the impact of numpy version changes,
# changing x[[[0],[1]]] to x[tuple([[0],[1]])] == x[([0],[1])] == x[[0],[1]].
# Whether the paddle behavior in this case will change is still up for debate.
idx = [paddle.to_tensor(py_idx[0]), paddle.to_tensor(py_idx[1])]
np.testing.assert_array_equal(x[idx].numpy(), array[py_idx])
np.testing.assert_array_equal(x[py_idx].numpy(), array[py_idx])
np.testing.assert_array_equal(x[idx].numpy(), array[tuple(py_idx)])
np.testing.assert_array_equal(x[py_idx].numpy(), array[tuple(py_idx)])
# case2:
tensor_x = paddle.to_tensor(
np.zeros(12).reshape(2, 6).astype(np.float32)
Expand Down
23 changes: 20 additions & 3 deletions python/paddle/fluid/tests/unittests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,19 @@ def test_slice(self):


class TestListIndex(unittest.TestCase):
# note(chenjianye):
# Non-tuple sequence for multidimensional indexing is supported in numpy < 1.23.
# For List case, the outermost `[]` will be treated as tuple `()` in version less than 1.23,
# which is used to wrap index elements for multiple axes.
# And from 1.23, this will be treat as a whole and only works on one axis.
#
# e.g. x[[[0],[1]]] == x[([0],[1])] == x[[0],[1]] (in version < 1.23)
# x[[[0],[1]]] == x[array([[0],[1]])] (in version >= 1.23)
#
# Here, we just modify the code to remove the impact of numpy version changes,
# changing x[[[0],[1]]] to x[tuple([[0],[1]])] == x[([0],[1])] == x[[0],[1]].
# Whether the paddle behavior in this case will change is still up for debate.

def setUp(self):
np.random.seed(2022)

Expand Down Expand Up @@ -637,7 +650,7 @@ def test_static_graph_list_index(self):
exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]

getitem_np = array[index_mod]
getitem_np = array[tuple(index_mod)]
getitem_pp = exe.run(
prog, feed={x.name: array}, fetch_list=fetch_list
)
Expand All @@ -659,7 +672,7 @@ def test_dygraph_list_index(self):
pt = paddle.to_tensor(array)
index_mod = (index % (array.shape[-1])).tolist()
try:
getitem_np = array[index_mod]
getitem_np = array[tuple(index_mod)]

except:
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -844,8 +857,12 @@ def run_setitem_list_index(self, array, index, value_np):
exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]
array2 = array.copy()

try:
index = (
tuple(index)
if isinstance(index, list) and isinstance(index[0], list)
else index
)
array2[index] = value_np
except:
with self.assertRaises(ValueError):
Expand Down

0 comments on commit 9658f49

Please sign in to comment.