-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use the same syntax to replace an node input (2) + optimize replace_all_inputs #1060
Conversation
This pull request introduces 2 alerts when merging bae92a6 into 0110037 - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging 74321f3 into d3d301a - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging 7b244e0 into d3d301a - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging f5c9406 into ee2b202 - view on LGTM.com new alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking forward to faster conversion speeds! Hope this feedback helps!
tf2onnx/graph.py
Outdated
# modify references in sub graphs | ||
if old_input in self._input_to_graph: | ||
for _, g in self._input_to_graph[old_input].items(): | ||
g.replace_all_inputs(g.get_nodes() if keep_ops else None, old_input, new_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need g.get_nodes()? Even if ops was passed in, shouldn't the _input_to_graph in g be valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the case when ops is specified: the user bypasses _input_to_node_name, I assumed he would bypass it in all subgraphs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, makes sense. Hopefully ops won't be used very often anyway.
tf2onnx/graph.py
Outdated
ops = self._input_to_node_name[output_name] | ||
ops = [self.get_node_by_name(n) for n in ops] | ||
else: | ||
ops = self.get_nodes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would this case trigger? Shouldn't the index always be up to date? If output_name is not in self._input_to_node_name, ops should be [], I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tf2onnx/graph.py
Outdated
nodes.extend(g.find_output_consumers(output_name)) | ||
# find consumers in sub graphs | ||
if output_name in self._input_to_graph: | ||
for _, g in self._input_to_graph[output_name].items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for _, g in self._input_to_graph[output_name].items(): | |
for g in self._input_to_graph[output_name].values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
if to_ops is not None: | ||
if node.name in to_ops: | ||
# A node may take twice the same entry. | ||
to_ops.remove(node.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this isn't technically correct if a node has two copies of the same input (node.input = [inp1, inp1]) and you replace just one copy (replace_input(node, inp1, inp2, 0) -> node.input = [inp2, inp1]. node is still a consumer of inp1. This bug should only occur if input_index is not None since otherwise we replace everything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why not use the _unregister helper function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. I did not do it because sometimes replace_input is followed by a call to remove_input, sometimes not. So I did not do it. That means _input_to_node_name keeps[input_name]
contains nodes not using the input anymore. This is not an issue as _input_to_node_name keeps[input_name]
to retrieve all nodes using input_name
. It was previously done with get_nodes(), even if the new set is bigger than necessary, it is still smaller than get_nodes(). I need to review all calls to replace_input and remove_input to be thorough to call _unregister.
if to_ops is not None and old_input in to_ops: | ||
# To avoid issues when a node | ||
# takes twice the same entry. | ||
to_ops.remove(old_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use unregister here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same reason as above, a change here means other changes to be consistent.
This pull request introduces 2 alerts when merging 03e2bb7 into 1a35937 - view on LGTM.com new alerts:
|
tf2onnx/graph.py
Outdated
"Unregister node taking a specific input." | ||
node_name = node.name | ||
if not only_graph: | ||
if input_name in self._input_to_node_name[input_name]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if input_name in self._input_to_node_name[input_name]: | |
if input_name in self._input_to_node_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I'm worrried that the unit test pass then. This line was probably never used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this passes because there is no test to make sure the index doesn't contain extraneous entries. As I think you mentioned before, if the index has extra entries it isn't actually an issue as far as correctness is concerned and it only causes a slight performance decrease. Still, we may want to consider adding a validate_indices function that runs during unit tests between, say, optimization passes and once after the rewriters/handlers finish. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm mostly concerned that later someone will forget to update the indices and introduce a subtle bug that is hard to catch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind going further but I'll have to look at every call to replace_all_inputs and possibly change them. This is a huge refactoring. The PR will grow. One big PR or two smaller ones but not that small either, that is the question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can have a check_graph() method on the graph class that we can call from some places (ie. in ut there would be a good place to call this between conversion and optimizer) that validates all is in order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same PR or another one?
This pull request introduces 1 alert when merging ee8df44 into 6ec695b - view on LGTM.com new alerts:
|
tf2onnx/graph.py
Outdated
if op.type == 'Placeholder': | ||
inps = [op.name] | ||
elif op.type == 'Const': | ||
inps = [op.name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should Placeholder or Const ops be considered consumers of themselves?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed placeholder.
tf2onnx/graph.py
Outdated
|
||
for n in self._order_sensitive_inputs: | ||
if n not in ops: | ||
self._order_sensitive_inputs.remove(n) | ||
for o in self.outputs: | ||
if o not in self._output_to_node_name: | ||
raise ValueError("graph output " + o + " not exist") | ||
for i in self.inputs: | ||
if i.name.startswith('Placeholder'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should never look at a name since they can be given by the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
tf2onnx/graph.py
Outdated
|
||
for i, name in enumerate(node.input): | ||
if name == to_be_removed: | ||
if node.input.count(node.input[i]) > 1: | ||
raise RuntimeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't Runtime error - we use make_sure()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -14,3 +14,14 @@ steps: | |||
condition: succeededOrFailed() | |||
env: | |||
CI_ONNX_OPSET: '${{ onnx_opset }}' | |||
|
|||
- bash: | | |||
export TF2ONNX_TEST_BACKEND=$CI_ONNX_BACKEND |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to run this in the master CI pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
@@ -0,0 +1,103 @@ | |||
# coding: utf-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want to limit top level directories - please move to tools/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -90,7 +90,7 @@ def _replace_node_with_const(node, graph, vals): | |||
const_node = graph.make_const(utils.make_name("const_fold_opt"), val) | |||
graph.set_dtype(const_node.output[0], utils.map_numpy_to_onnx_dtype(val.dtype)) | |||
graph.set_shape(const_node.output[0], val.shape) | |||
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0]) | |||
graph.replace_all_inputs(None, old_input, const_node.output[0]) # graph.get_nodes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd either keep graph.get_nodes() (which I prefer because since it makes it very clear which nodes are used) or remove the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually - this PR would be much smaller if we keep graph.get_nodes()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make ops an optional arg that is None by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 - that would be perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
tf2onnx/graph.py
Outdated
if i.name.startswith('keras_learning_phase'): | ||
continue | ||
if i.name not in self._input_to_node_name: | ||
raise ValueError("graph input %r not exist in graph." % i.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would leave the check for self._output_to_node_name
but delete lines 687 to 691. Isn't everything in self.inputs graph input? if i.is_graph_input():
will always be true so this loop doesn't do anything, unless I'm missing something.
tf2onnx/graph.py
Outdated
for i in self.inputs: | ||
if i.is_graph_input(): | ||
continue | ||
if i.name not in self._output_to_consumers: | ||
raise ValueError("graph input %r not exist in graph." % i.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in self.inputs: | |
if i.is_graph_input(): | |
continue | |
if i.name not in self._output_to_consumers: | |
raise ValueError("graph input %r not exist in graph." % i.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I did not check inputs was a property and was returning inputs verifying this condition. I removed the lines.
tf2onnx/graph.py
Outdated
raise RuntimeError( | ||
"Input %r of node %r not found." % (op_input, node_name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a make_sure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hesitated because there was a mixed use of make_sure and RuntimeError in graph.py. I changed this one.
This pull request introduces 1 alert and fixes 1 when merging 5115124 into 0b15fe1 - view on LGTM.com new alerts:
fixed alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Measurss for EfficientNetB2:
Before:
After: