Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Aug 12, 2020
1 parent 09119b9 commit ffac86a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,6 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
new_output = port_name(name)
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)

# to_replace = [n for n in self.get_nodes() if n != new_node]
to_replace = [self.get_node_by_name(n) for n in self._input_to_node_name[output_name]]
to_replace = [n for n in to_replace if n != new_node]
self.replace_all_inputs(to_replace, output_name, new_output)
Expand Down Expand Up @@ -1324,7 +1323,9 @@ def replace_all_inputs(self, ops, old_input, new_input):
if ops is not None:
keep_ops = True
elif old_input in self._input_to_node_name:
ops = [self.get_node_by_name(n) for n in self._input_to_node_name[old_input]]
ops = list(
filter(lambda a: a is not None,
map(self.get_node_by_name, self._input_to_node_name[old_input])))
keep_ops = False
else:
ops = []
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,8 @@ def version_7(cls, ctx, node, **kwargs):
del output_names[idx]
del body.outputs[idx]

# remove tensor array that are passed in to the loop
removed_scan_outputs = {}
# remove tensor array that are passed in to the loop
for idx, n in reversed(to_remove):
ctx.remove_node(n.name)
# make the node output bad
Expand Down

0 comments on commit ffac86a

Please sign in to comment.