From 4dcc628b66aedbcf4d942cc47fe48f4c250f0825 Mon Sep 17 00:00:00 2001 From: TomWildenhain-Microsoft <67606533+TomWildenhain-Microsoft@users.noreply.github.com> Date: Tue, 27 Apr 2021 17:12:58 -0400 Subject: [PATCH] Fix bug that renamed subgraph i/o twice (#1478) * Fix bug that renamed subgraph i/o twice Signed-off-by: Tom Wildenhain * Don't rename tensors in subgraphs at all Signed-off-by: Tom Wildenhain --- tf2onnx/tfonnx.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 7ec35e6fd..ccea803db 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -511,8 +511,6 @@ def rename_tensors_in_nodes(onnx_nodes): for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] - f_inputs_names = rename_tensors_in_list(f_inputs_names) - f_output_names = rename_tensors_in_list(f_output_names) fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, @@ -524,12 +522,13 @@ def rename_tensors_in_nodes(onnx_nodes): check_io(input_names, output_names, output_shapes) - rename_tensors_in_nodes(onnx_nodes) - input_names = rename_tensors_in_list(input_names) - output_names = rename_tensors_in_list(output_names) - output_shapes = rename_tensors_in_dict(output_shapes) - dtypes = rename_tensors_in_dict(dtypes) - inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) + if not is_subgraph: + rename_tensors_in_nodes(onnx_nodes) + input_names = rename_tensors_in_list(input_names) + output_names = rename_tensors_in_list(output_names) + output_shapes = rename_tensors_in_dict(output_shapes) + dtypes = rename_tensors_in_dict(dtypes) + inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names, is_subgraph) g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt)