Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the same syntax to replace an node input (2) + optimize replace_all_inputs #1060

Merged
merged 28 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
469b4c4
Use the same syntax to replace an node input (2) + optimisation
sdpython Aug 12, 2020
bae92a6
Update graph.py
sdpython Aug 12, 2020
74321f3
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Aug 12, 2020
d972ace
missing cast
sdpython Aug 12, 2020
d9ea76a
Merge branch 'cast3' of https://github.com/xadupre/tensorflow-onnx in…
sdpython Aug 12, 2020
7b244e0
remove unexisting nodes
sdpython Aug 12, 2020
f5c9406
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Aug 20, 2020
3db62e4
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Aug 24, 2020
03e2bb7
addresses comments
sdpython Aug 24, 2020
085d6fa
remove lint alerts
sdpython Aug 24, 2020
cd0e711
improves find_output_consumers
sdpython Aug 24, 2020
c491bc9
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Aug 25, 2020
b4a73be
a few updates
sdpython Aug 25, 2020
eb68bcd
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Aug 26, 2020
0adae73
Update const_fold_optimizer.py
sdpython Aug 26, 2020
71b0ddb
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Aug 28, 2020
874e5c5
make ops optional in replace_all_inputs
sdpython Aug 28, 2020
1250847
lint
sdpython Aug 28, 2020
c9aa7ab
rename into _output_to_consumers
sdpython Aug 28, 2020
d130591
remove keras_learning_phase, replace runtimeerror by make_sure
sdpython Aug 28, 2020
e1eccab
use is_graph_input
sdpython Aug 28, 2020
1daa780
remove one call to is_graph_input
sdpython Aug 28, 2020
0b469c0
remove unnecessary lines
sdpython Aug 28, 2020
6a54003
Removes specific case introduced for Const
sdpython Aug 28, 2020
bb4d9a6
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Sep 1, 2020
5115124
removed unnecessary lines
sdpython Sep 1, 2020
094ab85
lint
sdpython Sep 1, 2020
89ad32a
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into…
sdpython Sep 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_rewrite_subgraph(self):
op_name = utils.make_name("ReplacedOp")
out_name = utils.port_name(op_name)
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
g.replace_all_inputs(output_node.output[0], new_node.output[0]) # ops=ops
for n in set(match.get_nodes()):
g.remove_node(n.name)
g.topological_sort(ops)
Expand Down
148 changes: 123 additions & 25 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
self._nodes = []
self._nodes_by_name = {}
self._output_to_node_name = {}
self._output_to_consumers = {}
self._input_to_graph = {}
self.shapes = {}
self.graph_name = graph_name or "tf2onnx"
self._is_subgraph = is_subgraph
Expand Down Expand Up @@ -502,7 +504,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
body_graph.parent_graph = self
new_node.set_body_graph_as_attr(attr_name, body_graph)

self.replace_all_inputs(self.get_nodes(), o, new_output_name)
self.replace_all_inputs(o, new_output_name, ops=self.get_nodes())
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
self.copy_shape(new_output_name, o)
self.copy_dtype(new_output_name, o)
Expand Down Expand Up @@ -607,6 +609,9 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk

onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)

for name2 in onnx_node.input:
self._register_input_name(name2, onnx_node)

if op_type in ["If", "Loop", "Scan"]:
# we force the op containing inner graphs not skipped during conversion.
skip_conversion = False
Expand Down Expand Up @@ -635,6 +640,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
return node

def append_node(self, node):
"Add a node to the graph."
output_shapes = node.output_shapes
output_dtypes = node.output_dtypes
node.graph = self
Expand All @@ -644,6 +650,8 @@ def append_node(self, node):
self._output_to_node_name[name] = node.name
self.set_dtype(name, output_dtypes[i])
self.set_shape(name, output_shapes[i])
for name in node.input:
self._register_input_name(name, node)

def remove_node(self, node_name):
"""Remove node in current graph."""
Expand All @@ -664,6 +672,12 @@ def remove_node(self, node_name):
if op_output in self._dtypes:
del self._dtypes[op_output]

for op_input in node.input:
utils.make_sure(
op_input in self._output_to_consumers,
"Input %r of node %r not found.", op_input, node_name)
self._unregister_input_name(op_input, node)

self._nodes.remove(node)
node.graph = None

Expand All @@ -687,9 +701,13 @@ def reset_nodes(self, ops):
self.contained_graphs = remained_sub_graphs
self._nodes_by_name = {op.name: op for op in ops}
self._output_to_node_name = {}
self._output_to_consumers = {}
for op in ops:
for op_output in op.output:
self._output_to_node_name[op_output] = op.name
inps = op.input
for op_input in inps:
self._register_input_name(op_input, op)

for n in self._order_sensitive_inputs:
if n not in ops:
Expand Down Expand Up @@ -823,6 +841,8 @@ def set_node_by_name(self, node):
self._nodes_by_name[node.name] = node
for op_output in node.output:
self._output_to_node_name[op_output] = node.name
for name in node.input:
self._register_input_name(name, node)

def change_node_name(self, node, new_name):
"""Remove node in current graph."""
Expand All @@ -838,7 +858,7 @@ def change_node_name(self, node, new_name):
if k == old_output:
self.outputs[j] = new_output
break
self.replace_all_inputs(self.get_nodes(), old_output, new_output)
self.replace_all_inputs(old_output, new_output, ops=self.get_nodes())
return new_node

def add_graph_input(self, name, dtype=None, shape=None):
Expand Down Expand Up @@ -1164,13 +1184,12 @@ def dump_node_statistics(self):
op_cnt[n.type] += 1
body_graphs = n.get_body_graphs()
if body_graphs:
for _, b_g in body_graphs.items():
for b_g in body_graphs.values():
op_cnt += b_g.dump_node_statistics()

return op_cnt

@staticmethod
def remove_input(node, to_be_removed, input_index=None):
def remove_input(self, node, to_be_removed, input_index=None):
"""Remove input from Node.
Args:
node: the node we expect the input on
Expand All @@ -1182,15 +1201,24 @@ def remove_input(node, to_be_removed, input_index=None):
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type)
if input_index is not None:
assert node.input[input_index] == to_be_removed
if node.input[input_index] in self._output_to_consumers:
to_ops = self._output_to_consumers[node.input[input_index]]
if node.name in to_ops:
to_ops.remove(node.name)
del node.input[input_index]
return True
return

for i, name in enumerate(node.input):
if name == to_be_removed:
utils.make_sure(
node.input.count(node.input[i]) <= 1,
"Node %r takes multiple times the same input %r. This case is not handled.",
node.name, node.input[i])
self._unregister_input_name(node.input[i], node)
del node.input[i]
TomWildenhain-Microsoft marked this conversation as resolved.
Show resolved Hide resolved
break

# don't remove output from parent since others might depend on it
return True

def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, **kwargs):
"""Create and insert a new node into the graph.
Expand Down Expand Up @@ -1238,43 +1266,93 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
new_output = port_name(name)
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)

to_replace = [n for n in self.get_nodes() if n != new_node]
self.replace_all_inputs(to_replace, output_name, new_output)
to_replace = [self.get_node_by_name(n) for n in self._output_to_consumers[output_name]]
to_replace = [n for n in to_replace if n != new_node]
self.replace_all_inputs(output_name, new_output, ops=to_replace)
return new_node

def find_output_consumers(self, output_name):
"""Find all nodes consuming a given output."""
if output_name in self._output_to_consumers:
ops = self._output_to_consumers[output_name]
ops = [self.get_node_by_name(n) for n in ops]
else:
ops = [] # self.get_nodes()
nodes = []
for node in self.get_nodes():
for node in ops:
if node is None:
continue
if output_name in node.input:
nodes.append(node)

# find consumers in sub graphs
body_graphs = node.get_body_graphs()
if body_graphs:
for g in body_graphs.values():
nodes.extend(g.find_output_consumers(output_name))
# find consumers in sub graphs
if output_name in self._input_to_graph:
for g in self._input_to_graph[output_name].values():
nodes.extend(g.find_output_consumers(output_name))
return nodes

@staticmethod
def replace_all_inputs(ops, old_input, new_input):
"""Replace all inputs pointing to old_input with new_input."""
def _register_input_name(self, input_name, node, only_graph=False):
"Register node taking a specific input."
if not only_graph:
if input_name not in self._output_to_consumers:
self._output_to_consumers[input_name] = set()
self._output_to_consumers[input_name].add(node.name)
if self.parent_graph is not None:
if input_name not in self.parent_graph._input_to_graph:
self.parent_graph._input_to_graph[input_name] = {}
self.parent_graph._input_to_graph[input_name][id(self)] = self
self.parent_graph._register_input_name(input_name, node, only_graph=True)

def _unregister_input_name(self, input_name, node, only_graph=False):
"Unregister node taking a specific input."
node_name = node.name
if not only_graph:
if input_name in self._output_to_consumers[input_name]:
if node_name in self._output_to_consumers[input_name]:
self._output_to_consumers[input_name].remove(node_name)
if (self.parent_graph is not None and
input_name in self.parent_graph._input_to_graph and
id(self) in self.parent_graph._input_to_graph[input_name]):
TomWildenhain-Microsoft marked this conversation as resolved.
Show resolved Hide resolved
del self.parent_graph._input_to_graph[input_name][id(self)]
self.parent_graph._unregister_input_name(input_name, node, only_graph=True)

def replace_all_inputs(self, old_input, new_input, ops=None):
"""
Replace all inputs pointing to old_input with new_input.
*ops* is used if defined, otherwise `_output_to_consumers`
is used to determine the impacted nodes.
"""
if old_input == new_input:
return
if new_input not in self._output_to_consumers:
self._output_to_consumers[new_input] = set()

if ops is not None:
keep_ops = True
elif old_input in self._output_to_consumers:
ops = list(
filter(lambda a: a is not None,
map(self.get_node_by_name, self._output_to_consumers[old_input])))
keep_ops = False
else:
ops = []
keep_ops = False

for node in ops:
assert node is not None
if old_input in node.input and new_input in node.output:
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
self._register_input_name(new_input, node)

for i, input_name in enumerate(node.input):
if input_name == old_input:
node.input[i] = new_input
self.replace_input(node, node.input[i], new_input, i)

# modify references in sub graphs
body_graphs = node.get_body_graphs()
if body_graphs:
for g in body_graphs.values():
g.replace_all_inputs(g.get_nodes(), old_input, new_input)
# modify references in sub graphs
if old_input in self._input_to_graph:
for g in self._input_to_graph[old_input].values():
g.replace_all_inputs(old_input, new_input,
ops=g.get_nodes() if keep_ops else None)

def replace_input(self, node, old_input, new_input, input_index=None):
"""
Expand All @@ -1294,11 +1372,31 @@ def replace_input(self, node, old_input, new_input, input_index=None):
is_replaced = True
else:
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name))

to_ops = self._output_to_consumers.get(old_input, None)
if to_ops is not None:
if node.name in to_ops:
# A node may take twice the same entry.
to_ops.remove(node.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isn't technically correct if a node has two copies of the same input (node.input = [inp1, inp1]) and you replace just one copy (replace_input(node, inp1, inp2, 0) -> node.input = [inp2, inp1]. node is still a consumer of inp1. This bug should only occur if input_index is not None since otherwise we replace everything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why not use the _unregister helper function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I did not do it because sometimes replace_input is followed by a call to remove_input, sometimes not. So I did not do it. That means _input_to_node_name keeps[input_name] contains nodes not using the input anymore. This is not an issue as _input_to_node_name keeps[input_name] to retrieve all nodes using input_name. It was previously done with get_nodes(), even if the new set is bigger than necessary, it is still smaller than get_nodes(). I need to review all calls to replace_input and remove_input to be thorough to call _unregister.


self._register_input_name(new_input, node)
return is_replaced

def replace_inputs(self, node, new_inputs):
"""Replace node inputs."""
assert isinstance(node, Node) and isinstance(new_inputs, list)

for old_input in node.input:
to_ops = self._output_to_consumers.get(old_input, None)
if to_ops is not None and old_input in to_ops:
# To avoid issues when a node
# takes twice the same entry.
to_ops.remove(old_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use unregister here as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason as above, a change here means other changes to be consistent.


for input_name in new_inputs:
assert isinstance(input_name, six.text_type)
self._register_input_name(input_name, node)

node.input = new_inputs
return True

Expand Down Expand Up @@ -1374,7 +1472,7 @@ def delete_unused_nodes(self, outputs_name):
for node in related_nodes:
attr_body_graphs = node.get_body_graphs()
if attr_body_graphs:
for _, body_graph in attr_body_graphs.items():
for body_graph in attr_body_graphs.values():
body_graph.delete_unused_nodes(body_graph.outputs)
self.reset_nodes(related_nodes)

Expand Down
8 changes: 4 additions & 4 deletions tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ class TensorListStack:
def version_7(cls, ctx, node, **kwargs):
if node.inputs[0].is_while():
ctx.remove_node(node.name)
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], node.input[0])
ctx.replace_all_inputs(node.output[0], node.input[0]) # ops=ctx.get_nodes()


@tf_op(["While", "StatelessWhile"])
Expand Down Expand Up @@ -582,7 +582,7 @@ def version_7(cls, ctx, node, **kwargs):
for idx, n in reversed(to_remove):
ctx.remove_node(n.name)
# make the node output bad
ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC")
ctx.replace_all_inputs(n.output[0], "@@ALLOC") # ops=ctx.get_nodes()
del body.func_inputs[idx]
del cond_graph.func_inputs[idx]
del tf_while_inputs[idx]
Expand Down Expand Up @@ -618,7 +618,7 @@ def version_7(cls, ctx, node, **kwargs):

# shift output consumers
for k, v in output_map.items():
ctx.replace_all_inputs(ctx.get_nodes(), k, v)
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()

wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)
Expand Down Expand Up @@ -813,7 +813,7 @@ def prefix_graph(g, scope):
if old_output == oname:
g.outputs[i] = new_output
break
g.replace_all_inputs(ops, old_output, new_output)
g.replace_all_inputs(old_output, new_output, ops=ops)
to_remove.append(node)
for node in to_remove:
g.remove_node(node.name)
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,4 +695,4 @@ def atan2(y, x):
"Add", inputs=[atan_node.output[0], pi_part.output[0]],
op_name_scope=node.name + 'all',
shapes=[shape], dtypes=[onnx_dtype])
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def version_1(cls, ctx, node, **kwargs):
# if identity has a const as input, remove it
input_name = node.input[0]
output_name = node.output[0]
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
ctx.remove_node(node.name)


Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def version_1(cls, ctx, node, **kwargs):
downstream_nodes = ctx.find_output_consumers(node.output[0])
downstream_nodes.remove(output_shape)
downstream_nodes.remove(slice_node)
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
ctx.replace_all_inputs(node.output[0], slice_node.output[0], ops=downstream_nodes)

conv_dims_attr(node, "strides", spatial=spatial)
conv_dims_attr(node, "dilations", spatial=spatial)
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,4 @@ def version_10(cls, ctx, node, **kwargs):
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
op_name_scope=node.name, attr={"axis": axis},
shapes=[shape], dtypes=[dtype])
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def make_sigmoid(i, w, b):
h_node = ctx.make_node("Mul", [co_node.output[0], o])

def replace_output(old_output, new_output):
ctx.replace_all_inputs(ctx.get_nodes(), old_output, new_output)
ctx.replace_all_inputs(old_output, new_output) # ops=ctx.get_nodes()
ctx.copy_dtype(old_output, new_output)
ctx.copy_shape(old_output, new_output)

Expand Down
6 changes: 3 additions & 3 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def version_1(cls, ctx, node, **kwargs):
# if identity has a const as input, remove it
input_name = node.input[0]
output_name = node.output[0]
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
ctx.remove_node(node.name)


Expand All @@ -125,7 +125,7 @@ class IdentityN:
def version_1(cls, ctx, node, **kwargs):
ctx.remove_node(node.name)
for input_name, output_name in zip(node.input, node.output):
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()


@tf_op("Reshape")
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def version_1(cls, ctx, node, **kwargs):
# concat all unqueezes
concat = ctx.make_node("Concat", inputs, op_name_scope=node.name, attr={"axis": axis},
shapes=shapes, dtypes=dtypes)
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], concat.output[0])
ctx.replace_all_inputs(node.output[0], concat.output[0]) # ops=ctx.get_nodes()


@tf_op("Unpack")
Expand Down
Loading