Skip to content

Commit

Permalink
add IndexError to solve timeout problem in static-mode
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Jul 11, 2023
1 parent f1aba67 commit d6f9a2c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
22 changes: 22 additions & 0 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,10 +929,32 @@ def parse_index(x, indices):
indices = replace_ellipsis(x, indices)
indices, none_axes = replace_none(indices)

is_tensor_array = (
hasattr(x, "desc")
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
)

estimated_dim = 0
for dim, slice_item in enumerate(indices):
start, end, step = None, None, None
if is_integer_or_scalar_tensor(slice_item):
if (
not is_tensor_array
and isinstance(slice_item, int)
and x.shape[dim] is not None
and x.shape[dim] >= 0
and slice_item >= x.shape[dim]
):
# For python, if users write a, b = var, the __getitem__
# method will iterate through 0, 1, 2 ... until __getitem__
# throws an IndexError, then stop. The var[0], var[1] will
# be given to a, b respectively. If more values are given,
# the unpack size would cause error.
# We raises IndexError here to support grammar like `a, b = var`
raise IndexError(
"slice_item %d at dim %d should be >= 0 and < x.shape[%d]: %d"
% (slice_item, dim, dim, x.shape[dim])
)
# not calculate result to reduce call times for slice OP.
decrease_axes.append(dim)
start = slice_item
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_while_loop_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,9 @@ def body(z, i):
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=[-1, 5], dtype='int32')
z = paddle.tensor.fill_constant([1], 'int32', 0)
z = paddle.tensor.fill_constant([], 'int32', 0)
x_shape = paddle.shape(x)
i = paddle.tensor.fill_constant([1], 'int32', 0)
i = paddle.tensor.fill_constant([], 'int32', 0)
z, _ = paddle.static.nn.while_loop(cond, body, [z, i])

place = (
Expand Down

0 comments on commit d6f9a2c

Please sign in to comment.