From f38ae654c741c02a1b59b0d32e4fddafa3655ea2 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Sun, 11 Apr 2021 21:58:43 -0600 Subject: [PATCH] [ONNX] Support optional outputs for ONNX nodes (#7818) * Support optional outputs for ONNX nodes * add comments --- python/tvm/relay/frontend/onnx.py | 26 ++++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 6 ----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 09525a64ac05..85fe01905b6e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3202,6 +3202,32 @@ def from_onnx(self, graph, opset, get_output_expr=False): outputs_num = 1 else: outputs_num = len(op) + if outputs_num > 1: + # ONNX supports optional outputs for some nodes. + # This block searches for missing outputs in the ONNX graph + # and removes any unneeded ops + valid_outputs = [False] * outputs_num + for i, output in enumerate(node_output): + if output != "": + valid_outputs[i] = True + # If we have outputs ONNX isn't expecting, we need to drop them + if not all(valid_outputs): + tup = op.astuple() + # TupleWrapper can also wrap ops with TupleType outputs + if isinstance(tup, _expr.Tuple): + # For tuples, we extract the fields instead of using GetTupleItem + outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid] + else: + # For call nodes, we need to GetTupleItem + outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid] + # Create the new op with valid outputs + if len(outputs) == 1: + op = outputs[0] + else: + op = _expr.TupleWrapper(outputs, len(outputs)) + # Drop invalid outputs for the onnx node + outputs_num = len(outputs) + node_output = [output for output in node_output if output != ""] assert ( len(node_output) == outputs_num ), "Number of output mismatch {} vs {} in {}.".format( diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a491ed130418..8a63bac33923 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4138,9 +4138,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): unsupported_onnx_tests = [ "test_basic_convinteger/", "test_cast_DOUBLE_to_FLOAT16/", - "test_cast_FLOAT16_to_DOUBLE/", - "test_cast_FLOAT16_to_FLOAT/", - "test_cast_FLOAT_to_FLOAT16/", "test_cast_FLOAT_to_STRING/", "test_cast_STRING_to_FLOAT/", "test_compress_0/", @@ -4171,9 +4168,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_hardmax_one_hot/", "test_isinf_negative/", "test_isinf_positive/", - "test_lstm_defaults/", - "test_lstm_with_initial_bias/", - "test_lstm_with_peepholes/", "test_matmulinteger/", "test_maxpool_2d_dilations/", "test_maxpool_2d_same_lower/",