diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py index a5cdcd687..7694fef18 100644 --- a/tf2onnx/graph.py +++ b/tf2onnx/graph.py @@ -56,6 +56,11 @@ def input(self): @input.setter def input(self, val): + # The setter can catch that all inputs are change + # but it cannot catch that one input is changed. + # That's method replace_input and replace_inputs must + # be used to change inputs to let the graph instance + # update its internal indices. self._input = copy.deepcopy(val) @property @@ -1291,6 +1296,19 @@ def replace_all_inputs(self, ops, old_input, new_input): # This means old_input is a final output. to_ops = set() + # Verification that we can use the index to + # remove nodes. + for node in ops: + if old_input in node.input: + if old_input not in self._input_to_node_name: + raise RuntimeError( + "Input %r of node %r, old_input %r not in _input_to_node_name." % ( + old_input, node.name, old_input)) + if node.name not in self._input_to_node_name[old_input]: + raise RuntimeError( + "Input %r of node %r, node %r not in _input_to_node_name[%r]." % ( + old_input, node.name, node.name, old_input)) + for node in ops: if old_input in node.input and new_input in node.output: raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)