-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use the same syntax to replace an node input (2) + optimize replace_all_inputs #1060
Changes from 24 commits
469b4c4
bae92a6
74321f3
d972ace
d9ea76a
7b244e0
f5c9406
3db62e4
03e2bb7
085d6fa
cd0e711
c491bc9
b4a73be
eb68bcd
0adae73
71b0ddb
874e5c5
1250847
c9aa7ab
d130591
e1eccab
1daa780
0b469c0
6a54003
bb4d9a6
5115124
094ab85
89ad32a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -425,6 +425,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No | |||||||||||
self._nodes = [] | ||||||||||||
self._nodes_by_name = {} | ||||||||||||
self._output_to_node_name = {} | ||||||||||||
self._output_to_consumers = {} | ||||||||||||
self._input_to_graph = {} | ||||||||||||
self.shapes = {} | ||||||||||||
self.graph_name = graph_name or "tf2onnx" | ||||||||||||
self._is_subgraph = is_subgraph | ||||||||||||
|
@@ -471,7 +473,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No | |||||||||||
body_graph.parent_graph = self | ||||||||||||
new_node.set_body_graph_as_attr(attr_name, body_graph) | ||||||||||||
|
||||||||||||
self.replace_all_inputs(self.get_nodes(), o, new_output_name) | ||||||||||||
self.replace_all_inputs(o, new_output_name, ops=self.get_nodes()) | ||||||||||||
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs") | ||||||||||||
self.copy_shape(new_output_name, o) | ||||||||||||
self.copy_dtype(new_output_name, o) | ||||||||||||
|
@@ -576,6 +578,9 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk | |||||||||||
|
||||||||||||
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr) | ||||||||||||
|
||||||||||||
for name2 in onnx_node.input: | ||||||||||||
self._register_input_name(name2, onnx_node) | ||||||||||||
|
||||||||||||
if op_type in ["If", "Loop", "Scan"]: | ||||||||||||
# we force the op containing inner graphs not skipped during conversion. | ||||||||||||
skip_conversion = False | ||||||||||||
|
@@ -604,6 +609,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk | |||||||||||
return node | ||||||||||||
|
||||||||||||
def append_node(self, node): | ||||||||||||
"Add a node to the graph." | ||||||||||||
output_shapes = node.output_shapes | ||||||||||||
output_dtypes = node.output_dtypes | ||||||||||||
node.graph = self | ||||||||||||
|
@@ -613,6 +619,8 @@ def append_node(self, node): | |||||||||||
self._output_to_node_name[name] = node.name | ||||||||||||
self.set_dtype(name, output_dtypes[i]) | ||||||||||||
self.set_shape(name, output_shapes[i]) | ||||||||||||
for name in node.input: | ||||||||||||
self._register_input_name(name, node) | ||||||||||||
|
||||||||||||
def remove_node(self, node_name): | ||||||||||||
"""Remove node in current graph.""" | ||||||||||||
|
@@ -633,6 +641,12 @@ def remove_node(self, node_name): | |||||||||||
if op_output in self._dtypes: | ||||||||||||
del self._dtypes[op_output] | ||||||||||||
|
||||||||||||
for op_input in node.input: | ||||||||||||
if op_input not in self._output_to_consumers: | ||||||||||||
raise RuntimeError( | ||||||||||||
"Input %r of node %r not found." % (op_input, node_name)) | ||||||||||||
self._unregister_input_name(op_input, node) | ||||||||||||
|
||||||||||||
self._nodes.remove(node) | ||||||||||||
node.graph = None | ||||||||||||
|
||||||||||||
|
@@ -656,16 +670,25 @@ def reset_nodes(self, ops): | |||||||||||
self.contained_graphs = remained_sub_graphs | ||||||||||||
self._nodes_by_name = {op.name: op for op in ops} | ||||||||||||
self._output_to_node_name = {} | ||||||||||||
self._output_to_consumers = {} | ||||||||||||
for op in ops: | ||||||||||||
for op_output in op.output: | ||||||||||||
self._output_to_node_name[op_output] = op.name | ||||||||||||
inps = op.input | ||||||||||||
for op_input in inps: | ||||||||||||
self._register_input_name(op_input, op) | ||||||||||||
|
||||||||||||
for n in self._order_sensitive_inputs: | ||||||||||||
if n not in ops: | ||||||||||||
self._order_sensitive_inputs.remove(n) | ||||||||||||
for o in self.outputs: | ||||||||||||
if o not in self._output_to_node_name: | ||||||||||||
raise ValueError("graph output " + o + " not exist") | ||||||||||||
for i in self.inputs: | ||||||||||||
if i.is_graph_input(): | ||||||||||||
continue | ||||||||||||
if i.name not in self._output_to_consumers: | ||||||||||||
raise ValueError("graph input %r not exist in graph." % i.name) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just means the input isn't used in the graph. Do we really need to throw an error? Why does keras_learning_phase get an exemption? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can remove those lines or leave them, if the exception is raised, it is probably an error made by the user. Maybe he would want to know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is probably best to remove them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I remove the lines I added or the test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would leave the check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. I did not check inputs was a property and was returning inputs verifying this condition. I removed the lines. |
||||||||||||
|
||||||||||||
self._dtypes = remained_dtypes | ||||||||||||
self._output_shapes = remained_shapes | ||||||||||||
|
@@ -792,6 +815,8 @@ def set_node_by_name(self, node): | |||||||||||
self._nodes_by_name[node.name] = node | ||||||||||||
for op_output in node.output: | ||||||||||||
self._output_to_node_name[op_output] = node.name | ||||||||||||
for name in node.input: | ||||||||||||
self._register_input_name(name, node) | ||||||||||||
|
||||||||||||
def change_node_name(self, node, new_name): | ||||||||||||
"""Remove node in current graph.""" | ||||||||||||
|
@@ -807,7 +832,7 @@ def change_node_name(self, node, new_name): | |||||||||||
if k == old_output: | ||||||||||||
self.outputs[j] = new_output | ||||||||||||
break | ||||||||||||
self.replace_all_inputs(self.get_nodes(), old_output, new_output) | ||||||||||||
self.replace_all_inputs(old_output, new_output, ops=self.get_nodes()) | ||||||||||||
return new_node | ||||||||||||
|
||||||||||||
def add_graph_input(self, name, dtype=None, shape=None): | ||||||||||||
|
@@ -1133,13 +1158,12 @@ def dump_node_statistics(self): | |||||||||||
op_cnt[n.type] += 1 | ||||||||||||
body_graphs = n.get_body_graphs() | ||||||||||||
if body_graphs: | ||||||||||||
for _, b_g in body_graphs.items(): | ||||||||||||
for b_g in body_graphs.values(): | ||||||||||||
op_cnt += b_g.dump_node_statistics() | ||||||||||||
|
||||||||||||
return op_cnt | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
def remove_input(node, to_be_removed, input_index=None): | ||||||||||||
def remove_input(self, node, to_be_removed, input_index=None): | ||||||||||||
"""Remove input from Node. | ||||||||||||
Args: | ||||||||||||
node: the node we expect the input on | ||||||||||||
|
@@ -1151,15 +1175,24 @@ def remove_input(node, to_be_removed, input_index=None): | |||||||||||
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type) | ||||||||||||
if input_index is not None: | ||||||||||||
assert node.input[input_index] == to_be_removed | ||||||||||||
if node.input[input_index] in self._output_to_consumers: | ||||||||||||
to_ops = self._output_to_consumers[node.input[input_index]] | ||||||||||||
if node.name in to_ops: | ||||||||||||
to_ops.remove(node.name) | ||||||||||||
del node.input[input_index] | ||||||||||||
return True | ||||||||||||
return | ||||||||||||
|
||||||||||||
for i, name in enumerate(node.input): | ||||||||||||
if name == to_be_removed: | ||||||||||||
utils.make_sure( | ||||||||||||
node.input.count(node.input[i]) <= 1, | ||||||||||||
"Node %r takes multiple times the same input %r. This case is not handled.", | ||||||||||||
node.name, node.input[i]) | ||||||||||||
self._unregister_input_name(node.input[i], node) | ||||||||||||
del node.input[i] | ||||||||||||
TomWildenhain-Microsoft marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
break | ||||||||||||
|
||||||||||||
# don't remove output from parent since others might depend on it | ||||||||||||
return True | ||||||||||||
|
||||||||||||
def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, **kwargs): | ||||||||||||
"""Create and insert a new node into the graph. | ||||||||||||
|
@@ -1207,43 +1240,93 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k | |||||||||||
new_output = port_name(name) | ||||||||||||
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain) | ||||||||||||
|
||||||||||||
to_replace = [n for n in self.get_nodes() if n != new_node] | ||||||||||||
self.replace_all_inputs(to_replace, output_name, new_output) | ||||||||||||
to_replace = [self.get_node_by_name(n) for n in self._output_to_consumers[output_name]] | ||||||||||||
to_replace = [n for n in to_replace if n != new_node] | ||||||||||||
self.replace_all_inputs(output_name, new_output, ops=to_replace) | ||||||||||||
return new_node | ||||||||||||
|
||||||||||||
def find_output_consumers(self, output_name): | ||||||||||||
"""Find all nodes consuming a given output.""" | ||||||||||||
if output_name in self._output_to_consumers: | ||||||||||||
ops = self._output_to_consumers[output_name] | ||||||||||||
ops = [self.get_node_by_name(n) for n in ops] | ||||||||||||
else: | ||||||||||||
ops = [] # self.get_nodes() | ||||||||||||
nodes = [] | ||||||||||||
for node in self.get_nodes(): | ||||||||||||
for node in ops: | ||||||||||||
if node is None: | ||||||||||||
continue | ||||||||||||
if output_name in node.input: | ||||||||||||
nodes.append(node) | ||||||||||||
|
||||||||||||
# find consumers in sub graphs | ||||||||||||
body_graphs = node.get_body_graphs() | ||||||||||||
if body_graphs: | ||||||||||||
for g in body_graphs.values(): | ||||||||||||
nodes.extend(g.find_output_consumers(output_name)) | ||||||||||||
# find consumers in sub graphs | ||||||||||||
if output_name in self._input_to_graph: | ||||||||||||
for g in self._input_to_graph[output_name].values(): | ||||||||||||
nodes.extend(g.find_output_consumers(output_name)) | ||||||||||||
return nodes | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
def replace_all_inputs(ops, old_input, new_input): | ||||||||||||
"""Replace all inputs pointing to old_input with new_input.""" | ||||||||||||
def _register_input_name(self, input_name, node, only_graph=False): | ||||||||||||
"Register node taking a specific input." | ||||||||||||
if not only_graph: | ||||||||||||
if input_name not in self._output_to_consumers: | ||||||||||||
self._output_to_consumers[input_name] = set() | ||||||||||||
self._output_to_consumers[input_name].add(node.name) | ||||||||||||
if self.parent_graph is not None: | ||||||||||||
if input_name not in self.parent_graph._input_to_graph: | ||||||||||||
self.parent_graph._input_to_graph[input_name] = {} | ||||||||||||
self.parent_graph._input_to_graph[input_name][id(self)] = self | ||||||||||||
self.parent_graph._register_input_name(input_name, node, only_graph=True) | ||||||||||||
|
||||||||||||
def _unregister_input_name(self, input_name, node, only_graph=False): | ||||||||||||
"Unregister node taking a specific input." | ||||||||||||
node_name = node.name | ||||||||||||
if not only_graph: | ||||||||||||
if input_name in self._output_to_consumers[input_name]: | ||||||||||||
if node_name in self._output_to_consumers[input_name]: | ||||||||||||
self._output_to_consumers[input_name].remove(node_name) | ||||||||||||
if (self.parent_graph is not None and | ||||||||||||
input_name in self.parent_graph._input_to_graph and | ||||||||||||
id(self) in self.parent_graph._input_to_graph[input_name]): | ||||||||||||
TomWildenhain-Microsoft marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
del self.parent_graph._input_to_graph[input_name][id(self)] | ||||||||||||
self.parent_graph._unregister_input_name(input_name, node, only_graph=True) | ||||||||||||
|
||||||||||||
def replace_all_inputs(self, old_input, new_input, ops=None): | ||||||||||||
""" | ||||||||||||
Replace all inputs pointing to old_input with new_input. | ||||||||||||
*ops* is used if defined, otherwise `_output_to_consumers` | ||||||||||||
is used to determine the impacted nodes. | ||||||||||||
""" | ||||||||||||
if old_input == new_input: | ||||||||||||
return | ||||||||||||
if new_input not in self._output_to_consumers: | ||||||||||||
self._output_to_consumers[new_input] = set() | ||||||||||||
|
||||||||||||
if ops is not None: | ||||||||||||
keep_ops = True | ||||||||||||
elif old_input in self._output_to_consumers: | ||||||||||||
ops = list( | ||||||||||||
filter(lambda a: a is not None, | ||||||||||||
map(self.get_node_by_name, self._output_to_consumers[old_input]))) | ||||||||||||
keep_ops = False | ||||||||||||
else: | ||||||||||||
ops = [] | ||||||||||||
keep_ops = False | ||||||||||||
|
||||||||||||
for node in ops: | ||||||||||||
assert node is not None | ||||||||||||
if old_input in node.input and new_input in node.output: | ||||||||||||
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name) | ||||||||||||
self._register_input_name(new_input, node) | ||||||||||||
|
||||||||||||
for i, input_name in enumerate(node.input): | ||||||||||||
if input_name == old_input: | ||||||||||||
node.input[i] = new_input | ||||||||||||
self.replace_input(node, node.input[i], new_input, i) | ||||||||||||
|
||||||||||||
# modify references in sub graphs | ||||||||||||
body_graphs = node.get_body_graphs() | ||||||||||||
if body_graphs: | ||||||||||||
for g in body_graphs.values(): | ||||||||||||
g.replace_all_inputs(g.get_nodes(), old_input, new_input) | ||||||||||||
# modify references in sub graphs | ||||||||||||
if old_input in self._input_to_graph: | ||||||||||||
for g in self._input_to_graph[old_input].values(): | ||||||||||||
g.replace_all_inputs(old_input, new_input, | ||||||||||||
ops=g.get_nodes() if keep_ops else None) | ||||||||||||
|
||||||||||||
def replace_input(self, node, old_input, new_input, input_index=None): | ||||||||||||
""" | ||||||||||||
|
@@ -1263,11 +1346,31 @@ def replace_input(self, node, old_input, new_input, input_index=None): | |||||||||||
is_replaced = True | ||||||||||||
else: | ||||||||||||
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name)) | ||||||||||||
|
||||||||||||
to_ops = self._output_to_consumers.get(old_input, None) | ||||||||||||
if to_ops is not None: | ||||||||||||
if node.name in to_ops: | ||||||||||||
# A node may take twice the same entry. | ||||||||||||
to_ops.remove(node.name) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this isn't technically correct if a node has two copies of the same input (node.input = [inp1, inp1]) and you replace just one copy (replace_input(node, inp1, inp2, 0) -> node.input = [inp2, inp1]. node is still a consumer of inp1. This bug should only occur if input_index is not None since otherwise we replace everything. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also why not use the _unregister helper function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. I did not do it because sometimes replace_input is followed by a call to remove_input, sometimes not. So I did not do it. That means |
||||||||||||
|
||||||||||||
self._register_input_name(new_input, node) | ||||||||||||
return is_replaced | ||||||||||||
|
||||||||||||
def replace_inputs(self, node, new_inputs): | ||||||||||||
"""Replace node inputs.""" | ||||||||||||
assert isinstance(node, Node) and isinstance(new_inputs, list) | ||||||||||||
|
||||||||||||
for old_input in node.input: | ||||||||||||
to_ops = self._output_to_consumers.get(old_input, None) | ||||||||||||
if to_ops is not None and old_input in to_ops: | ||||||||||||
# To avoid issues when a node | ||||||||||||
# takes twice the same entry. | ||||||||||||
to_ops.remove(old_input) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you use unregister here as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reason as above, a change here means other changes to be consistent. |
||||||||||||
|
||||||||||||
for input_name in new_inputs: | ||||||||||||
assert isinstance(input_name, six.text_type) | ||||||||||||
self._register_input_name(input_name, node) | ||||||||||||
|
||||||||||||
node.input = new_inputs | ||||||||||||
return True | ||||||||||||
|
||||||||||||
|
@@ -1343,7 +1446,7 @@ def delete_unused_nodes(self, outputs_name): | |||||||||||
for node in related_nodes: | ||||||||||||
attr_body_graphs = node.get_body_graphs() | ||||||||||||
if attr_body_graphs: | ||||||||||||
for _, body_graph in attr_body_graphs.items(): | ||||||||||||
for body_graph in attr_body_graphs.values(): | ||||||||||||
body_graph.delete_unused_nodes(body_graph.outputs) | ||||||||||||
self.reset_nodes(related_nodes) | ||||||||||||
|
||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a make_sure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hesitated because there was a mixed use of make_sure and RuntimeError in graph.py. I changed this one.