Skip to content

Commit

Permalink
[ONNX] Support optional outputs for ONNX nodes (#7818)
Browse files Browse the repository at this point in the history
* Support optional outputs for ONNX nodes

* add comments
  • Loading branch information
Matthew Brookhart authored Apr 12, 2021
1 parent ab0dc2e commit f38ae65
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3202,6 +3202,32 @@ def from_onnx(self, graph, opset, get_output_expr=False):
outputs_num = 1
else:
outputs_num = len(op)
if outputs_num > 1:
# ONNX supports optional outputs for some nodes.
# This block searches for missing outputs in the ONNX graph
# and removes any unneeded ops
valid_outputs = [False] * outputs_num
for i, output in enumerate(node_output):
if output != "":
valid_outputs[i] = True
# If we have outputs ONNX isn't expecting, we need to drop them
if not all(valid_outputs):
tup = op.astuple()
# TupleWrapper can also wrap ops with TupleType outputs
if isinstance(tup, _expr.Tuple):
# For tuples, we extract the fields instead of using GetTupleItem
outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid]
else:
# For call nodes, we need to GetTupleItem
outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid]
# Create the new op with valid outputs
if len(outputs) == 1:
op = outputs[0]
else:
op = _expr.TupleWrapper(outputs, len(outputs))
# Drop invalid outputs for the onnx node
outputs_num = len(outputs)
node_output = [output for output in node_output if output != ""]
assert (
len(node_output) == outputs_num
), "Number of output mismatch {} vs {} in {}.".format(
Expand Down
6 changes: 0 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4138,9 +4138,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
unsupported_onnx_tests = [
"test_basic_convinteger/",
"test_cast_DOUBLE_to_FLOAT16/",
"test_cast_FLOAT16_to_DOUBLE/",
"test_cast_FLOAT16_to_FLOAT/",
"test_cast_FLOAT_to_FLOAT16/",
"test_cast_FLOAT_to_STRING/",
"test_cast_STRING_to_FLOAT/",
"test_compress_0/",
Expand Down Expand Up @@ -4171,9 +4168,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_hardmax_one_hot/",
"test_isinf_negative/",
"test_isinf_positive/",
"test_lstm_defaults/",
"test_lstm_with_initial_bias/",
"test_lstm_with_peepholes/",
"test_matmulinteger/",
"test_maxpool_2d_dilations/",
"test_maxpool_2d_same_lower/",
Expand Down

0 comments on commit f38ae65

Please sign in to comment.