Skip to content

Commit

Permalink
First step to add forward indexes.
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Aug 3, 2020
1 parent 36d7413 commit 1de2a95
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 9 deletions.
1 change: 0 additions & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,6 @@ def test_duplicated_duplicated_attributes(self):
op_type="ReduceSum", remaining_op_num=2)

def _check_initializer_num(self, graph_proto, num):
print(len(graph_proto.initializer))
return num == len(graph_proto.initializer)

def test_duplicated_duplicated_constant(self):
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MICROSOFT_DOMAIN = "com.microsoft"

# Default opset version for onnx domain
PREFERRED_OPSET = 8
PREFERRED_OPSET = 11

# Default opset for custom ops
TENSORFLOW_OPSET = helper.make_opsetid("ai.onnx.converters.tensorflow", 1)
Expand Down
70 changes: 64 additions & 6 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
self._nodes = []
self._nodes_by_name = {}
self._output_to_node_name = {}
self._input_to_node_name = {}
self.shapes = {}
self.graph_name = graph_name or "tf2onnx"
self._is_subgraph = is_subgraph
Expand All @@ -442,7 +443,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No

ops = [Node(node, self) for node in nodes]
self.reset_nodes(ops)

if not is_subgraph:
# add identity node after each output, in case it is renamed during conversion.
for o in self.outputs:
Expand Down Expand Up @@ -569,6 +570,11 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk

onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)

for name in onnx_node.input:
if name not in self._input_to_node_name:
self._input_to_node_name[name] = set()
self._input_to_node_name[name].add(onnx_node.name)

if op_type in ["If", "Loop", "Scan"]:
# we force the op containing inner graphs not skipped during conversion.
skip_conversion = False
Expand Down Expand Up @@ -606,6 +612,10 @@ def append_node(self, node):
self._output_to_node_name[name] = node.name
self.set_dtype(name, output_dtypes[i])
self.set_shape(name, output_shapes[i])
for name in node.input:
if name not in self._input_to_node_name:
self._input_to_node_name[name] = set()
self._input_to_node_name[name].add(node.name)

def remove_node(self, node_name):
"""Remove node in current graph."""
Expand All @@ -626,6 +636,13 @@ def remove_node(self, node_name):
if op_output in self._dtypes:
del self._dtypes[op_output]

for op_input in node.input:
if op_input not in self._input_to_node_name:
raise RuntimeError(
"Input %r of node %r not found." % (op_input, node.name))
if node.name in self._input_to_node_name[op_input]:
self._input_to_node_name[op_input].remove(node.name)

self._nodes.remove(node)
node.graph = None

Expand All @@ -649,16 +666,26 @@ def reset_nodes(self, ops):
self.contained_graphs = remained_sub_graphs
self._nodes_by_name = {op.name: op for op in ops}
self._output_to_node_name = {}
self._input_to_node_name = {}
for op in ops:
for op_output in op.output:
self._output_to_node_name[op_output] = op.name
for op_input in op.input:
if op_input not in self._input_to_node_name:
self._input_to_node_name[op_input] = set()
self._input_to_node_name[op_input].add(op.name)

for n in self._order_sensitive_inputs:
if n not in ops:
self._order_sensitive_inputs.remove(n)
for o in self.outputs:
if o not in self._output_to_node_name:
raise ValueError("graph output " + o + " not exist")
raise ValueError("graph output %r not exist" % o)
for i in self.inputs:
if i.name.startswith('Placeholder'):
continue
if i.name not in self._input_to_node_name:
raise ValueError("graph input %r not exist in graph." % i.name)

self._dtypes = remained_dtypes
self._output_shapes = remained_shapes
Expand Down Expand Up @@ -775,6 +802,14 @@ def get_node_by_output_in_current_graph(self, output):
ret = self._nodes_by_name.get(name)
return ret

def get_node_by_input_in_current_graph(self, input):
"""Get nodes by node input id."""
names = self._output_to_node_name.get(input)
ret = None
if name:
ret = [self._nodes_by_name.get(name) for name in names]
return ret

def get_node_by_name(self, name):
"""Get node by name."""
ret = self._nodes_by_name.get(name)
Expand All @@ -785,6 +820,10 @@ def set_node_by_name(self, node):
self._nodes_by_name[node.name] = node
for op_output in node.output:
self._output_to_node_name[op_output] = node.name
for name in node.input:
if name not in self._input_to_node_name:
self._input_to_node_name[name] = set()
self._input_to_node_name[name].add(node.name)

def change_node_name(self, node, new_name):
"""Remove node in current graph."""
Expand Down Expand Up @@ -1210,35 +1249,54 @@ def find_output_consumers(self, output_name):
nodes.extend(g.find_output_consumers(output_name))
return nodes

@staticmethod
def replace_all_inputs(ops, old_input, new_input):
def replace_all_inputs(self, ops, old_input, new_input):
"""Replace all inputs pointing to old_input with new_input."""
if old_input == new_input:
return
if new_input not in self._input_to_node_name:
self._input_to_node_name[new_input] = set()

to_ops = self._input_to_node_name.get(old_input, None)
if to_ops is None:
# This means old_input is a final output.
to_ops = set()

for node in ops:
if old_input in node.input and new_input in node.output:
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
self._input_to_node_name[new_input].add(node.name)

for i, input_name in enumerate(node.input):
if input_name == old_input:
node.input[i] = new_input
if node.name not in to_ops:
raise RuntimeError(
"Unable to replace %r by %r. Node %r is not using input %r." % (
old_input, new_input, node.name, old_input))


# modify references in sub graphs
body_graphs = node.get_body_graphs()
if body_graphs:
for g in body_graphs.values():
g.replace_all_inputs(g.get_nodes(), old_input, new_input)

@staticmethod
def replace_input(node, old_input, new_input):
def replace_input(self, node, old_input, new_input):
"""Replace node."""
assert isinstance(node, Node) and isinstance(old_input, six.text_type) and isinstance(new_input, six.text_type)
is_replaced = False
for i, input_name in enumerate(node.input):
if input_name == old_input:
node.input[i] = new_input
is_replaced = True

to_ops = self._input_to_node_name.get(old_input, None)
if to_ops is not None:
to_ops.remove(node.name)
if new_input not in self._input_to_node_name:
self._input_to_node_name[new_input] = set()
self._input_to_node_name[new_input].add(node.name)

return is_replaced

def _extract_sub_graph_nodes(self, dest_node, input_checker=None):
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def version_7(cls, ctx, node, **kwargs):
maximum_iterations_name = node.input[1]
maximum_iterations = node.inputs[1].get_tensor_value()
if maximum_iterations == -1:
maximum_iterations = sys.maxsize
maximum_iterations = np.iinfo(dtype_loop).max
consumers = ctx.find_output_consumers(maximum_iterations_name)
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
if len(external_consumers) == 0:
Expand Down

0 comments on commit 1de2a95

Please sign in to comment.