Skip to content

Commit

Permalink
pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Aug 4, 2020
1 parent 57a3d6c commit 922ce19
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
29 changes: 15 additions & 14 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No

ops = [Node(node, self) for node in nodes]
self.reset_nodes(ops)

if not is_subgraph:
# add identity node after each output, in case it is renamed during conversion.
for o in self.outputs:
Expand Down Expand Up @@ -572,7 +572,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
n = self.get_node_by_output_in_current_graph(o)
utils.make_sure(n is None, "output tensor named %s already exists in node: \n%s", o, n)

onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)

for name in onnx_node.input:
if name not in self._input_to_node_name:
Expand Down Expand Up @@ -607,6 +607,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
return node

def append_node(self, node):
"Add a node to the graph."
output_shapes = node.output_shapes
output_dtypes = node.output_dtypes
node.graph = self
Expand Down Expand Up @@ -815,9 +816,9 @@ def get_node_by_output_in_current_graph(self, output):
ret = self._nodes_by_name.get(name)
return ret

def get_node_by_input_in_current_graph(self, input):
def get_node_by_input_in_current_graph(self, input_name):
"""Get nodes by node input id."""
names = self._output_to_node_name.get(input)
names = self._output_to_node_name.get(input_name)
ret = None
if name:
ret = [self._nodes_by_name.get(name) for name in names]
Expand Down Expand Up @@ -1198,15 +1199,15 @@ def remove_input(self, node, to_be_removed, i=None):
if node.name in to_ops:
to_ops.remove(node.name)
del node.input[i]
return
return True

for i, name in enumerate(node.input):
for i2, name in enumerate(node.input):
if name == to_be_removed:
if node.input[i] in self._input_to_node_name:
to_ops = self._input_to_node_name[node.input[i]]
if node.input[i2] in self._input_to_node_name:
to_ops = self._input_to_node_name[node.input[i2]]
if node.name in to_ops:
to_ops.remove(node.name)
del node.input[i]
del node.input[i2]
break
# don't remove output from parent since others might depend on it
return True
Expand Down Expand Up @@ -1281,7 +1282,7 @@ def replace_all_inputs(self, ops, old_input, new_input):
return
if new_input not in self._input_to_node_name:
self._input_to_node_name[new_input] = set()

to_ops = self._input_to_node_name.get(old_input, None)
if to_ops is None:
# This means old_input is a final output.
Expand All @@ -1307,16 +1308,16 @@ def replace_input(self, node, old_input, new_input, i=None):
assert isinstance(node, Node) and isinstance(old_input, six.text_type) and isinstance(new_input, six.text_type)
is_replaced = False
if i is None:
for i, input_name in enumerate(node.input):
for i2, input_name in enumerate(node.input):
if input_name == old_input:
node.input[i] = new_input
node.input[i2] = new_input
is_replaced = True
elif node.input[i] == old_input:
node.input[i] = new_input
is_replaced = True
else:
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name))

to_ops = self._input_to_node_name.get(old_input, None)
if to_ops is not None:
# That might be an issue if a node
Expand All @@ -1338,7 +1339,7 @@ def replace_inputs(self, node, new_inputs):
# To avoid issues when a node
# takes twice the same entry.
to_ops.remove(old_input)

for input_name in new_inputs:
assert isinstance(input_name, six.text_type)
if input_name not in self._input_to_node_name:
Expand Down
1 change: 0 additions & 1 deletion tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import copy
import logging
import sys

import numpy as np

Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ class Identity:
def version_1(cls, ctx, node, **kwargs):
if node.inputs[0] is None:
raise RuntimeError(
"Issue with node {}\nI={}\nI2={}\nO={}.".format(
node, node.input, node._input, node.output))
"Issue with node {}\nI={}\nO={}.".format(
node, node.input, node.output))
if node.inputs[0].is_const():
# should not remove the identity node if it is output of the graph
if node.output[0] in ctx.outputs:
Expand Down

0 comments on commit 922ce19

Please sign in to comment.