Skip to content

Commit

Permalink
fix logic
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanfz98 committed Jan 16, 2022
1 parent 98216ab commit 962fa36
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
16 changes: 11 additions & 5 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,10 +926,11 @@ def autopad(
return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type)


def ensure_scalar_shape(x):
def ensure_scalar_shape(x, force_assert=True):
"""
Assume that `x` is a tensor with one element (regardless of tensor rank).
Return a version of that tensor with rank 0.
Return a version of that tensor with rank 0. If force_assert=True, throw an exception
when test fails, otherwise return x itself.
"""
x_shape = infer_shape(x)
x_rank = len(x_shape)
Expand All @@ -938,9 +939,14 @@ def ensure_scalar_shape(x):
return x

num_elem = np.prod(x_shape)
assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar form.".format(x_shape)

return _op.squeeze(x)
if num_elem == 1:
return _op.squeeze(x)
else:
if force_assert:
assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar form.".format(x_shape)
else:
return x



def try_resolve_var_to_const(x, graph_params):
Expand Down
5 changes: 1 addition & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3806,10 +3806,7 @@ def _impl_v10(cls, inputs, attr, params):
# requirements.
def try_resolve_to_const(x, dtype_override=None, allow1D=False):
x2 = try_resolve_var_to_const(x, params)
if allow1D:
x3 = x2
else:
x3 = ensure_scalar_shape(x2)
x3 = ensure_scalar_shape(x2, force_assert=allow1D)
x_dtype = infer_type(x).checked_type.dtype
if (dtype_override is not None) and (dtype_override != x_dtype):
x4 = _op.cast(x3, dtype_override)
Expand Down

0 comments on commit 962fa36

Please sign in to comment.