Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the same syntax to replace an node input (2) + optimize replace_all_inputs #1060

Merged
merged 28 commits into from
Sep 4, 2020

Conversation

xadupre
Copy link
Collaborator

@xadupre xadupre commented Aug 12, 2020

Measurss for EfficientNetB2:

Before:

21.492 profile  profile_conversion_time.py:58
`- 21.492 convert  profile_conversion_time.py:53
   `- 21.492 spy_convert  profile_conversion_time.py:30
      |- 20.699 spy_convert_in  profile_conversion_time.py:35
      |  `- 20.698 process_tf_graph  tf2onnx\tfonnx.py:335
      |     |- 10.522 tensorflow_onnx_mapping  tf2onnx\tfonnx.py:211
      |     |  |- 3.105 version_1  tf2onnx\onnx_opset\nn.py:329
      |     |  |  `- 3.077 conv_convert_inputs  tf2onnx\onnx_opset\nn.py:64
      |     |  |     |- 1.601 insert_new_node_on_output  tf2onnx\graph.py:1191
      |     |  |     |  `- 1.522 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |     |     |- 0.824 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |     |     |  |- 0.483 get_body_graphs  tf2onnx\graph.py:324
      |     |  |     |     |  |  `- 0.268 _graph_check  tf2onnx\graph.py:386
      |     |  |     |     |  `- 0.255 [self]  
      |     |  |     |     `- 0.411 get_body_graphs  tf2onnx\graph.py:324
      |     |  |     |        `- 0.233 _graph_check  tf2onnx\graph.py:386
      |     |  |     `- 1.234 find_output_consumers  tf2onnx\graph.py:1214
      |     |  |        |- 0.672 find_output_consumers  tf2onnx\graph.py:1214
      |     |  |        |  `- 0.486 get_body_graphs  tf2onnx\graph.py:324
      |     |  |        |     `- 0.269 _graph_check  tf2onnx\graph.py:386
      |     |  |        `- 0.396 get_body_graphs  tf2onnx\graph.py:324
      |     |  |           `- 0.226 _graph_check  tf2onnx\graph.py:386
      |     |  |- 2.523 version_1  tf2onnx\onnx_opset\tensor.py:124
      |     |  |  `- 2.518 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |     |- 1.469 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |     |  |- 0.714 get_body_graphs  tf2onnx\graph.py:324
      |     |  |     |  |  |- 0.393 _graph_check  tf2onnx\graph.py:386
      |     |  |     |  |  `- 0.242 [self]  
      |     |  |     |  `- 0.633 [self]  
      |     |  |     |- 0.610 get_body_graphs  tf2onnx\graph.py:324
      |     |  |     |  `- 0.343 _graph_check  tf2onnx\graph.py:386
      |     |  |     `- 0.320 [self]  
      |     |  |- 1.162 version_1  tf2onnx\onnx_opset\tensor.py:109
      |     |  |  `- 1.152 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |     |- 0.672 get_body_graphs  tf2onnx\graph.py:324
      |     |  |     |  `- 0.383 _graph_check  tf2onnx\graph.py:386
      |     |  |     `- 0.326 [self]  
      |     |  |- 1.106 version_1  tf2onnx\onnx_opset\tensor.py:644
      |     |  |  |- 0.780 insert_new_node_on_output  tf2onnx\graph.py:1191
      |     |  |  |  `- 0.736 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |  |     `- 0.394 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |  |        `- 0.228 get_body_graphs  tf2onnx\graph.py:324
      |     |  |  `- 0.306 find_output_consumers  tf2onnx\graph.py:1214
      |     |  |- 0.980 version_1  tf2onnx\onnx_opset\controlflow.py:420
      |     |  |  `- 0.938 wire_if_branch  tf2onnx\onnx_opset\controlflow.py:725
      |     |  |     `- 0.900 prefix_graph  tf2onnx\onnx_opset\controlflow.py:795
      |     |  |        |- 0.396 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |        |  `- 0.242 get_body_graphs  tf2onnx\graph.py:324
      |     |  |        `- 0.361 make_node  tf2onnx\graph.py:540
      |     |  |- 0.743 version_1  tf2onnx\onnx_opset\nn.py:475
      |     |  |  `- 0.736 conv_convert_inputs  tf2onnx\onnx_opset\nn.py:64
      |     |  |     |- 0.388 insert_new_node_on_output  tf2onnx\graph.py:1191
      |     |  |     |  `- 0.366 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  |     `- 0.311 find_output_consumers  tf2onnx\graph.py:1214
      |     |  |- 0.405 version_1  tf2onnx\onnx_opset\tensor.py:1028
      |     |  |  `- 0.381 replace_all_inputs  tf2onnx\graph.py:1228
      |     |  `- 0.388 version_1  tf2onnx\onnx_opset\tensor.py:1174
      |     |     `- 0.386 insert_new_node_on_output  tf2onnx\graph.py:1191
      |     |        `- 0.365 replace_all_inputs  tf2onnx\graph.py:1228
      |     |- 3.876 process_tf_graph  tf2onnx\tfonnx.py:335
      |     |  |- 1.707 run_rewriters  tf2onnx\tfonnx.py:303
      |     |  |  `- 0.323 reset_nodes  tf2onnx\graph.py:639
      |     |  |- 0.802 tensorflow_onnx_mapping  tf2onnx\tfonnx.py:211
      |     |  |  |- 0.289 error  logging\__init__.py:1402
      |     |  |  |     [116 frames hidden]  logging, absl, colorama, encodings, n...
      |     |  |  `- 0.252 version_6  tf2onnx\onnx_opset\nn.py:686
      |     |  `- 0.663 tensorflow_to_onnx  tf2onnx\tf_utils.py:220
      |     |     `- 0.661 tflist_to_onnx  tf2onnx\tf_utils.py:128
      |     |- 3.604 resolve_functions  tf2onnx\tf_loader.py:417
      |     |  |- 2.019 function_def_to_graph  tensorflow\python\framework\function_def_to_graph.py:33
      |     |  |     [564 frames hidden]  tensorflow, contextlib, threading, <s...
      |     |  `- 1.333 tflist_to_onnx  tf2onnx\tf_utils.py:128
      |     |     |- 0.232 [self]  
      |     |     `- 0.229 node_def  tensorflow\python\framework\ops.py:2412
      |     |           [16 frames hidden]  tensorflow, contextlib
      |     |- 0.837 update_proto  tf2onnx\graph.py:751
      |     |  `- 0.833 update_proto  tf2onnx\graph.py:336
      |     |     `- 0.676 make_graph  tf2onnx\graph.py:971
      |     |        `- 0.250 delete_unused_nodes  tf2onnx\graph.py:1334
      |     |- 0.696 run_rewriters  tf2onnx\tfonnx.py:303
      |     |  `- 0.266 reset_nodes  tf2onnx\graph.py:639
      |     |- 0.678 tensorflow_to_onnx  tf2onnx\tf_utils.py:220
      |     |  `- 0.678 tflist_to_onnx  tf2onnx\tf_utils.py:128
      |     `- 0.264 __init__  tf2onnx\graph.py:415
      |        `- 0.244 <listcomp>  tf2onnx\graph.py:450
      |           `- 0.242 __init__  tf2onnx\graph.py:35
      |              `- 0.236 [self]  
      `- 0.793 new_func  tensorflow\python\util\deprecation.py:473
            [274 frames hidden]  tensorflow, contextlib, inspect

After:

15.057 profile  profile_conversion_time.py:58
`- 15.057 convert  profile_conversion_time.py:53
   `- 15.057 spy_convert  profile_conversion_time.py:30
      |- 14.183 spy_convert_in  profile_conversion_time.py:35
      |  `- 14.181 process_tf_graph  tf2onnx\tfonnx.py:335
      |     |- 5.155 process_tf_graph  tf2onnx\tfonnx.py:335
      |     |  |- 2.801 run_rewriters  tf2onnx\tfonnx.py:303
      |     |  |  |- 0.826 reset_nodes  tf2onnx\graph.py:653
      |     |  |  |  `- 0.359 [self]  
      |     |  |  |- 0.337 rewrite_dropout  tf2onnx\rewriter\dropout_rewriter.py:19
      |     |  |  |  `- 0.286 match_ops  tf2onnx\graph_matcher.py:244
      |     |  |  |     `- 0.283 match_op  tf2onnx\graph_matcher.py:227
      |     |  |  |        `- 0.234 [self]  
      |     |  |  |- 0.279 rewrite_single_direction_gru  tf2onnx\rewriter\rnn.py:36
      |     |  |  |  `- 0.267 run  tf2onnx\rewriter\gru_rewriter.py:33
      |     |  |  |     `- 0.265 run  tf2onnx\rewriter\unit_rnn_rewriter_base.py:61
      |     |  |  |        `- 0.265 run_internal  tf2onnx\rewriter\loop_rewriter_base.py:195
      |     |  |  |           `- 0.259 delete_unused_nodes  tf2onnx\graph.py:1444
      |     |  |  |              `- 0.200 extract_sub_graph_nodes  tf2onnx\graph.py:1416
      |     |  |  |                 `- 0.176 _extract_sub_graph_nodes  tf2onnx\graph.py:1384
      |     |  |  |- 0.251 rewrite_single_direction_lstm  tf2onnx\rewriter\rnn.py:27
      |     |  |  |  `- 0.234 run  tf2onnx\rewriter\lstm_rewriter.py:33
      |     |  |  |     `- 0.232 run  tf2onnx\rewriter\unit_rnn_rewriter_base.py:61
      |     |  |  |        `- 0.232 run_internal  tf2onnx\rewriter\loop_rewriter_base.py:195
      |     |  |  |           `- 0.226 delete_unused_nodes  tf2onnx\graph.py:1444
      |     |  |  |              `- 0.166 extract_sub_graph_nodes  tf2onnx\graph.py:1416
      |     |  |  |- 0.226 rewrite_custom_rnn_cell  tf2onnx\rewriter\rnn.py:45
      |     |  |  |  `- 0.216 run  tf2onnx\rewriter\custom_rnn_rewriter.py:41
      |     |  |  |     `- 0.213 run_internal  tf2onnx\rewriter\loop_rewriter_base.py:195
      |     |  |  |        `- 0.208 delete_unused_nodes  tf2onnx\graph.py:1444
      |     |  |  |- 0.211 rewrite_generic_loop  tf2onnx\rewriter\rnn.py:49
      |     |  |  |  `- 0.202 run  tf2onnx\rewriter\loop_rewriter.py:33
      |     |  |  |     `- 0.201 run_internal  tf2onnx\rewriter\loop_rewriter_base.py:195
      |     |  |  |        `- 0.193 delete_unused_nodes  tf2onnx\graph.py:1444
      |     |  |  `- 0.160 rewrite_cond  tf2onnx\rewriter\cond_rewriter.py:320
      |     |  |     `- 0.157 rewrite  tf2onnx\rewriter\cond_rewriter.py:54
      |     |  |        `- 0.155 run  tf2onnx\rewriter\cond_rewriter.py:58
      |     |  |- 0.791 tensorflow_onnx_mapping  tf2onnx\tfonnx.py:211
      |     |  |  |- 0.285 error  logging\__init__.py:1402
      |     |  |  |     [116 frames hidden]  logging, absl, colorama, encodings, n...
      |     |  |  `- 0.246 version_6  tf2onnx\onnx_opset\nn.py:686
      |     |  |     `- 0.237 conv_convert_inputs  tf2onnx\onnx_opset\nn.py:64
      |     |  |- 0.736 tensorflow_to_onnx  tf2onnx\tf_utils.py:220
      |     |  |  `- 0.734 tflist_to_onnx  tf2onnx\tf_utils.py:128
      |     |  |     `- 0.162 node_def  tensorflow\python\framework\ops.py:2412
      |     |  |           [16 frames hidden]  tensorflow, contextlib
      |     |  |- 0.275 topological_sort  tf2onnx\tfonnx.py:291
      |     |  |  `- 0.270 topological_sort  tf2onnx\graph.py:941
      |     |  |- 0.185 delete_unused_nodes  tf2onnx\graph.py:1444
      |     |  `- 0.157 update_proto  tf2onnx\graph.py:781
      |     |     `- 0.154 update_proto  tf2onnx\graph.py:336
      |     |- 4.000 resolve_functions  tf2onnx\tf_loader.py:417
      |     |  |- 2.239 function_def_to_graph  tensorflow\python\framework\function_def_to_graph.py:33
      |     |  |     [552 frames hidden]  tensorflow, contextlib, threading, <s...
      |     |  `- 1.492 tflist_to_onnx  tf2onnx\tf_utils.py:128
      |     |     |- 0.255 node_def  tensorflow\python\framework\ops.py:2412
      |     |     |     [16 frames hidden]  tensorflow, contextlib
      |     |     |- 0.254 [self]  
      |     |     |- 0.247 get_tf_node_attr  tf2onnx\tf_utils.py:119
      |     |     |  `- 0.228 get_attr  tensorflow\python\framework\ops.py:2497
      |     |     |        [20 frames hidden]  tensorflow, contextlib
      |     |     |- 0.226 make_node  onnx\helper.py:20
      |     |     |     [33 frames hidden]  onnx, abc, typing
      |     |     `- 0.209 tf_to_onnx_tensor  tf2onnx\tf_utils.py:52
      |     |- 1.820 tensorflow_onnx_mapping  tf2onnx\tfonnx.py:211
      |     |  |- 1.208 version_1  tf2onnx\onnx_opset\controlflow.py:420
      |     |  |  `- 1.162 wire_if_branch  tf2onnx\onnx_opset\controlflow.py:725
      |     |  |     `- 1.118 prefix_graph  tf2onnx\onnx_opset\controlflow.py:795
      |     |  |        |- 0.488 replace_all_inputs  tf2onnx\graph.py:1301
      |     |  |        |  |- 0.274 _register_input_name  tf2onnx\graph.py:1276
      |     |  |        |  |  `- 0.205 [self]  
      |     |  |        |  `- 0.152 [self]  
      |     |  |        `- 0.437 make_node  tf2onnx\graph.py:542
      |     |  `- 0.329 version_1  tf2onnx\onnx_opset\nn.py:329
      |     |     `- 0.299 conv_convert_inputs  tf2onnx\onnx_opset\nn.py:64
      |     |- 1.382 run_rewriters  tf2onnx\tfonnx.py:303
      |     |  `- 0.760 reset_nodes  tf2onnx\graph.py:653
      |     |     `- 0.277 [self]  
      |     |- 0.744 tensorflow_to_onnx  tf2onnx\tf_utils.py:220
      |     |  `- 0.744 tflist_to_onnx  tf2onnx\tf_utils.py:128
      |     |     |- 0.157 get_tf_node_attr  tf2onnx\tf_utils.py:119
      |     |     `- 0.154 tf_to_onnx_tensor  tf2onnx\tf_utils.py:52
      |     |- 0.711 update_proto  tf2onnx\graph.py:781
      |     |  `- 0.709 update_proto  tf2onnx\graph.py:336
      |     |     `- 0.567 make_graph  tf2onnx\graph.py:1003
      |     |        `- 0.205 delete_unused_nodes  tf2onnx\graph.py:1444
      |     `- 0.163 topological_sort  tf2onnx\tfonnx.py:291
      |        `- 0.162 topological_sort  tf2onnx\graph.py:941
      `- 0.874 new_func  tensorflow\python\util\deprecation.py:473
            [271 frames hidden]  tensorflow, contextlib, inspect

@lgtm-com
Copy link

lgtm-com bot commented Aug 12, 2020

This pull request introduces 2 alerts when merging bae92a6 into 0110037 - view on LGTM.com

new alerts:

  • 2 for Unused local variable

@lgtm-com
Copy link

lgtm-com bot commented Aug 12, 2020

This pull request introduces 2 alerts when merging 74321f3 into d3d301a - view on LGTM.com

new alerts:

  • 2 for Unused local variable

@lgtm-com
Copy link

lgtm-com bot commented Aug 12, 2020

This pull request introduces 2 alerts when merging 7b244e0 into d3d301a - view on LGTM.com

new alerts:

  • 2 for Unused local variable

@lgtm-com
Copy link

lgtm-com bot commented Aug 20, 2020

This pull request introduces 2 alerts when merging f5c9406 into ee2b202 - view on LGTM.com

new alerts:

  • 2 for Unused local variable

Copy link
Contributor

@TomWildenhain-Microsoft TomWildenhain-Microsoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to faster conversion speeds! Hope this feedback helps!

tf2onnx/graph.py Outdated Show resolved Hide resolved
tf2onnx/graph.py Outdated Show resolved Hide resolved
tf2onnx/graph.py Outdated
# modify references in sub graphs
if old_input in self._input_to_graph:
for _, g in self._input_to_graph[old_input].items():
g.replace_all_inputs(g.get_nodes() if keep_ops else None, old_input, new_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need g.get_nodes()? Even if ops was passed in, shouldn't the _input_to_graph in g be valid?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the case when ops is specified: the user bypasses _input_to_node_name, I assumed he would bypass it in all subgraphs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense. Hopefully ops won't be used very often anyway.

tf2onnx/graph.py Outdated Show resolved Hide resolved
tf2onnx/graph.py Show resolved Hide resolved
tf2onnx/graph.py Show resolved Hide resolved
tf2onnx/graph.py Outdated
ops = self._input_to_node_name[output_name]
ops = [self.get_node_by_name(n) for n in ops]
else:
ops = self.get_nodes()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would this case trigger? Shouldn't the index always be up to date? If output_name is not in self._input_to_node_name, ops should be [], I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

tf2onnx/graph.py Outdated
nodes.extend(g.find_output_consumers(output_name))
# find consumers in sub graphs
if output_name in self._input_to_graph:
for _, g in self._input_to_graph[output_name].items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for _, g in self._input_to_graph[output_name].items():
for g in self._input_to_graph[output_name].values():

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

if to_ops is not None:
if node.name in to_ops:
# A node may take twice the same entry.
to_ops.remove(node.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isn't technically correct if a node has two copies of the same input (node.input = [inp1, inp1]) and you replace just one copy (replace_input(node, inp1, inp2, 0) -> node.input = [inp2, inp1]. node is still a consumer of inp1. This bug should only occur if input_index is not None since otherwise we replace everything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why not use the _unregister helper function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I did not do it because sometimes replace_input is followed by a call to remove_input, sometimes not. So I did not do it. That means _input_to_node_name keeps[input_name] contains nodes not using the input anymore. This is not an issue as _input_to_node_name keeps[input_name] to retrieve all nodes using input_name. It was previously done with get_nodes(), even if the new set is bigger than necessary, it is still smaller than get_nodes(). I need to review all calls to replace_input and remove_input to be thorough to call _unregister.

if to_ops is not None and old_input in to_ops:
# To avoid issues when a node
# takes twice the same entry.
to_ops.remove(old_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use unregister here as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason as above, a change here means other changes to be consistent.

@lgtm-com
Copy link

lgtm-com bot commented Aug 24, 2020

This pull request introduces 2 alerts when merging 03e2bb7 into 1a35937 - view on LGTM.com

new alerts:

  • 2 for Unused local variable

tf2onnx/graph.py Show resolved Hide resolved
tf2onnx/graph.py Outdated
"Unregister node taking a specific input."
node_name = node.name
if not only_graph:
if input_name in self._input_to_node_name[input_name]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if input_name in self._input_to_node_name[input_name]:
if input_name in self._input_to_node_name:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I'm worrried that the unit test pass then. This line was probably never used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this passes because there is no test to make sure the index doesn't contain extraneous entries. As I think you mentioned before, if the index has extra entries it isn't actually an issue as far as correctness is concerned and it only causes a slight performance decrease. Still, we may want to consider adding a validate_indices function that runs during unit tests between, say, optimization passes and once after the rewriters/handlers finish. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mostly concerned that later someone will forget to update the indices and introduce a subtle bug that is hard to catch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind going further but I'll have to look at every call to replace_all_inputs and possibly change them. This is a huge refactoring. The PR will grow. One big PR or two smaller ones but not that small either, that is the question.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can have a check_graph() method on the graph class that we can call from some places (ie. in ut there would be a good place to call this between conversion and optimizer) that validates all is in order?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same PR or another one?

@lgtm-com
Copy link

lgtm-com bot commented Aug 27, 2020

This pull request introduces 1 alert when merging ee8df44 into 6ec695b - view on LGTM.com

new alerts:

  • 1 for Module is imported with 'import' and 'import from'

tf2onnx/graph.py Outdated
Comment on lines 677 to 680
if op.type == 'Placeholder':
inps = [op.name]
elif op.type == 'Const':
inps = [op.name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should Placeholder or Const ops be considered consumers of themselves?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed placeholder.

tf2onnx/graph.py Outdated

for n in self._order_sensitive_inputs:
if n not in ops:
self._order_sensitive_inputs.remove(n)
for o in self.outputs:
if o not in self._output_to_node_name:
raise ValueError("graph output " + o + " not exist")
for i in self.inputs:
if i.name.startswith('Placeholder'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never look at a name since they can be given by the user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

tf2onnx/graph.py Outdated

for i, name in enumerate(node.input):
if name == to_be_removed:
if node.input.count(node.input[i]) > 1:
raise RuntimeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't Runtime error - we use make_sure()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -14,3 +14,14 @@ steps:
condition: succeededOrFailed()
env:
CI_ONNX_OPSET: '${{ onnx_opset }}'

- bash: |
export TF2ONNX_TEST_BACKEND=$CI_ONNX_BACKEND
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to run this in the master CI pipeline

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -0,0 +1,103 @@
# coding: utf-8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to limit top level directories - please move to tools/

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -90,7 +90,7 @@ def _replace_node_with_const(node, graph, vals):
const_node = graph.make_const(utils.make_name("const_fold_opt"), val)
graph.set_dtype(const_node.output[0], utils.map_numpy_to_onnx_dtype(val.dtype))
graph.set_shape(const_node.output[0], val.shape)
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0])
graph.replace_all_inputs(None, old_input, const_node.output[0]) # graph.get_nodes()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd either keep graph.get_nodes() (which I prefer because since it makes it very clear which nodes are used) or remove the comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually - this PR would be much smaller if we keep graph.get_nodes()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make ops an optional arg that is None by default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - that would be perfect!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

tf2onnx/graph.py Outdated
if i.name.startswith('keras_learning_phase'):
continue
if i.name not in self._input_to_node_name:
raise ValueError("graph input %r not exist in graph." % i.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave the check for self._output_to_node_name but delete lines 687 to 691. Isn't everything in self.inputs graph input? if i.is_graph_input(): will always be true so this loop doesn't do anything, unless I'm missing something.

tf2onnx/graph.py Outdated
Comment on lines 687 to 691
for i in self.inputs:
if i.is_graph_input():
continue
if i.name not in self._output_to_consumers:
raise ValueError("graph input %r not exist in graph." % i.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in self.inputs:
if i.is_graph_input():
continue
if i.name not in self._output_to_consumers:
raise ValueError("graph input %r not exist in graph." % i.name)

Copy link
Collaborator Author

@xadupre xadupre Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I did not check inputs was a property and was returning inputs verifying this condition. I removed the lines.

tf2onnx/graph.py Outdated
Comment on lines 646 to 647
raise RuntimeError(
"Input %r of node %r not found." % (op_input, node_name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a make_sure?

Copy link
Collaborator Author

@xadupre xadupre Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hesitated because there was a mixed use of make_sure and RuntimeError in graph.py. I changed this one.

@lgtm-com
Copy link

lgtm-com bot commented Sep 1, 2020

This pull request introduces 1 alert and fixes 1 when merging 5115124 into 0b15fe1 - view on LGTM.com

new alerts:

  • 1 for Syntax error

fixed alerts:

  • 1 for Except block handles 'BaseException'

Copy link
Contributor

@TomWildenhain-Microsoft TomWildenhain-Microsoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@xadupre xadupre merged commit 19d3f97 into onnx:master Sep 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants