Skip to content

Commit

Permalink
[Bugfix] reshape with zero input tensor (#35642)
Browse files Browse the repository at this point in the history
* reshape support zero-input

* add unitest

* revise error message
  • Loading branch information
JZ-LIANG authored Sep 13, 2021
1 parent ecfe837 commit cabc5f3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
17 changes: 17 additions & 0 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class ReshapeOp : public framework::OperatorWithKernel {
framework::make_ddim(shape), i, shape[i]));
}

// NOTE all non-zero values will be converted to True (include negative
// value)
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] =
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
Expand Down Expand Up @@ -222,6 +224,21 @@ class ReshapeOp : public framework::OperatorWithKernel {
in_dims, in_size, framework::make_ddim(shape), capacity));
}
}

// support reshape with zero-input(input tensor with product(shape) == 0)
// by now we require that if the input tensor is zero shape, the target
// shape of output must be zero
if (in_size == 0) {
PADDLE_ENFORCE_EQ(
capacity, in_size,
platform::errors::InvalidArgument(
"The 'shape' in ReshapeOp is invalid. "
"The input tensor X's shape = [%s], X's capacity = %d."
"But the target shape of Out is [%s], the "
"capacity of 'Out' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity));
}

return framework::make_ddim(output_shape);
}

Expand Down
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_reshape_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,5 +464,19 @@ def executed_api(self):
self.reshape = paddle.reshape_


class TestReshapeZeroTensor(unittest.TestCase):
def test_reshape_zero_tensor_success(self):
zero_tensor = paddle.zeros([0, 2, 3])
# since we use "0" as the dimension copy semantically in reshape,
# we need to copy the 0 dim in the src tensor in order to make a successful zero tensor reshape
zero_tensor = zero_tensor.reshape([0, 6])
self.assertTrue(list(zero_tensor.shape) == [0, 6])

def test_reshape_zero_tensor_error(self):
zero_tensor = paddle.zeros([0, 2, 3])
with self.assertRaises(ValueError):
zero_tensor.reshape([2, 3])


if __name__ == "__main__":
unittest.main()

0 comments on commit cabc5f3

Please sign in to comment.