From 922ce1930128ce7cd7a0f4e7d943c2bb10eaaa48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 4 Aug 2020 16:07:12 +0200 Subject: [PATCH] pep8 --- tf2onnx/graph.py | 29 +++++++++++++++-------------- tf2onnx/onnx_opset/controlflow.py | 1 - tf2onnx/onnx_opset/tensor.py | 4 ++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py index 430a4ff78..f2ceb4d6c 100644 --- a/tf2onnx/graph.py +++ b/tf2onnx/graph.py @@ -447,7 +447,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No ops = [Node(node, self) for node in nodes] self.reset_nodes(ops) - + if not is_subgraph: # add identity node after each output, in case it is renamed during conversion. for o in self.outputs: @@ -572,7 +572,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk n = self.get_node_by_output_in_current_graph(o) utils.make_sure(n is None, "output tensor named %s already exists in node: \n%s", o, n) - onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr) + onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr) for name in onnx_node.input: if name not in self._input_to_node_name: @@ -607,6 +607,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk return node def append_node(self, node): + "Add a node to the graph." output_shapes = node.output_shapes output_dtypes = node.output_dtypes node.graph = self @@ -815,9 +816,9 @@ def get_node_by_output_in_current_graph(self, output): ret = self._nodes_by_name.get(name) return ret - def get_node_by_input_in_current_graph(self, input): + def get_node_by_input_in_current_graph(self, input_name): """Get nodes by node input id.""" - names = self._output_to_node_name.get(input) + names = self._output_to_node_name.get(input_name) ret = None if name: ret = [self._nodes_by_name.get(name) for name in names] @@ -1198,15 +1199,15 @@ def remove_input(self, node, to_be_removed, i=None): if node.name in to_ops: to_ops.remove(node.name) del node.input[i] - return + return True - for i, name in enumerate(node.input): + for i2, name in enumerate(node.input): if name == to_be_removed: - if node.input[i] in self._input_to_node_name: - to_ops = self._input_to_node_name[node.input[i]] + if node.input[i2] in self._input_to_node_name: + to_ops = self._input_to_node_name[node.input[i2]] if node.name in to_ops: to_ops.remove(node.name) - del node.input[i] + del node.input[i2] break # don't remove output from parent since others might depend on it return True @@ -1281,7 +1282,7 @@ def replace_all_inputs(self, ops, old_input, new_input): return if new_input not in self._input_to_node_name: self._input_to_node_name[new_input] = set() - + to_ops = self._input_to_node_name.get(old_input, None) if to_ops is None: # This means old_input is a final output. @@ -1307,16 +1308,16 @@ def replace_input(self, node, old_input, new_input, i=None): assert isinstance(node, Node) and isinstance(old_input, six.text_type) and isinstance(new_input, six.text_type) is_replaced = False if i is None: - for i, input_name in enumerate(node.input): + for i2, input_name in enumerate(node.input): if input_name == old_input: - node.input[i] = new_input + node.input[i2] = new_input is_replaced = True elif node.input[i] == old_input: node.input[i] = new_input is_replaced = True else: raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name)) - + to_ops = self._input_to_node_name.get(old_input, None) if to_ops is not None: # That might be an issue if a node @@ -1338,7 +1339,7 @@ def replace_inputs(self, node, new_inputs): # To avoid issues when a node # takes twice the same entry. to_ops.remove(old_input) - + for input_name in new_inputs: assert isinstance(input_name, six.text_type) if input_name not in self._input_to_node_name: diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index a06258bc6..b66cd9d08 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -11,7 +11,6 @@ import copy import logging -import sys import numpy as np diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 8aca7d6ea..f68cfe871 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -111,8 +111,8 @@ class Identity: def version_1(cls, ctx, node, **kwargs): if node.inputs[0] is None: raise RuntimeError( - "Issue with node {}\nI={}\nI2={}\nO={}.".format( - node, node.input, node._input, node.output)) + "Issue with node {}\nI={}\nO={}.".format( + node, node.input, node.output)) if node.inputs[0].is_const(): # should not remove the identity node if it is output of the graph if node.output[0] in ctx.outputs: