Skip to content

Commit

Permalink
fix reshape shape bug (#58865)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer authored Nov 12, 2023
1 parent e5f007f commit 4ff3c4d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3865,7 +3865,7 @@ def get_attr_shape(list_shape):
)
if isinstance(shape, (list, tuple)):
if paddle.utils._contain_var(shape):
new_shape = paddle.utils._convert_to_tensor_list(shape)
new_shape = paddle.utils.get_int_tensor_list(shape)
else:
new_shape = get_attr_shape(shape)
out = _C_ops.reshape(x, new_shape)
Expand Down
12 changes: 12 additions & 0 deletions test/legacy_test/test_reshape_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,18 @@ def test_static(self):
self.assertEqual(result[3].shape, (1,))


class TestReshapePirOpResultListShape(unittest.TestCase):
def test_opresult_list_shape(self):
with paddle.pir_utils.IrGuard():
x = paddle.static.data(
'x',
[3],
)
shape = [1, paddle.full([], 3)]
out = paddle.reshape(x, shape)
np.testing.assert_array_equal(tuple(out.shape), (-1, -1))


if __name__ == "__main__":
paddle.enable_static()
unittest.main()

0 comments on commit 4ff3c4d

Please sign in to comment.