Skip to content

Commit

Permalink
Optimization for tflite loops
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed Jan 22, 2021
1 parent 1d07510 commit 748f491
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,9 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
self.graph_name = graph_name or "tf2onnx"
self._is_subgraph = is_subgraph
self.ta_reads = []
# A list of index, output tuples of potential scan outputs in this graph
# Used by the tflite while loop handler
self.scan_outputs = []
self.func_inputs = []

self._target = set(target)
Expand Down
46 changes: 44 additions & 2 deletions tf2onnx/tflite_handlers/tfl_controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tf2onnx.handler import tfl_op
from tf2onnx import utils
from tf2onnx.tf_loader import find_function
from tf2onnx.graph_builder import GraphBuilder
from tf2onnx.onnx_opset.controlflow import parameter_binding, inline_subgraph


Expand Down Expand Up @@ -40,6 +41,19 @@ def version_7(cls, ctx, node, **kwargs):
cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding)

# Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter.
# They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph.
scan_outputs = sorted(body.scan_outputs, reverse=True)
def input_is_unused(g, index):
return len(g.find_output_consumers(g.func_inputs[index])) == 0
scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)]

for idx, _ in scan_outputs:
del tfl_while_inputs[idx]
output_shapes.append(output_shapes.pop(idx))
output_dtypes.append(output_dtypes.pop(idx))
output_names.append(output_names.pop(idx))

max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max))

loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs,
Expand All @@ -52,15 +66,21 @@ def version_7(cls, ctx, node, **kwargs):
for k, v in output_map.items():
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()

body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph)
body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs)

for i in range(len(scan_outputs)):
squeeze_node = GraphBuilder(body).make_squeeze(
{'data': body.outputs[-1-i], "axes": [0]}, return_node=True)
body.outputs[-1-i] = squeeze_node.output[0]

loop_node.set_body_graph_as_attr("body", body)

def wire_tfl_while_body(g, loop_node_inputs, output_shapes,
output_dtypes, cond_graph):
output_dtypes, cond_graph, scan_outputs):
"""Wire subgraph graph into main."""

g = copy.deepcopy(g)
graph_inputs = g.func_inputs.copy()

# onnx will pass in cond as argument
iter_node = g.make_node("Placeholder", [], name=utils.make_name("iteration_num"),
Expand All @@ -69,6 +89,28 @@ def wire_tfl_while_body(g, loop_node_inputs, output_shapes,
output_count=1, dtypes=[TensorProto.BOOL], shapes=[[]])
cond_binding = parameter_binding(cond_graph, g.outputs)

to_remove = set()
for idx, scan_output in scan_outputs:
inp = g.get_node_by_output(graph_inputs[idx])

# Remove consumers of scan input
stack = [inp]
while stack:
node = stack.pop()
if node not in to_remove:
to_remove.add(node)
for out in node.output:
stack += g.find_output_consumers(out)

# Remove scan input from cond graph
cond_binding = {k: "@@ALLOC" if v == g.outputs[idx] else v for k, v in cond_binding.items()}
del g.func_inputs[idx]
del g.outputs[idx]
g.outputs.append(scan_output)

for node in to_remove:
g.remove_node(node.name)

# in onnx the body inputs are: index, cond, [loop_vars]
g.func_inputs = [iter_node.output[0], cond_node.output[0]] + g.func_inputs
# tell graph lib to keep inputs in order
Expand Down
7 changes: 7 additions & 0 deletions tf2onnx/tflite_rewriters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

"""tf2onnx.tflite_rewriters module"""

from . import (
tfl_scan_output_rewriter
)
156 changes: 156 additions & 0 deletions tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0


"""
tf2onnx.tflite_rewriters.tfl_scan_output_rewriter - Identify a common slice/concat pattern in tflite subgraphs
Effectively replace A = A[:i] + [B] + A[i+1:] with A[i] = B
"""
import numpy as np

from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher


# pylint: disable=missing-docstring

def rewrite_slice_concat_to_scatter(g, ops):
pattern0 = \
OpTypePattern('TFL_CONCATENATION', name='concat', inputs=[
OpTypePattern('TFL_SLICE', name='begin_slice'),
OpTypePattern('*', name='middle'),
OpTypePattern('TFL_SLICE', name='end_slice')
])

matcher = GraphMatcher(pattern0, allow_reorder=False)
match_results = list(matcher.match_ops(ops))
if match_results:
for match in match_results:
concat = match.get_op("concat")
begin_slice = match.get_op("begin_slice")
middle = match.get_op("middle")
end_slice = match.get_op("end_slice")
middle_shape = g.get_shape(middle.output[0])

# Both slices must be slicing the same tensor
if begin_slice.input[0] != end_slice.input[0]:
continue
original_tensor = begin_slice.input[0]
if concat.get_attr_int("axis") != 0:
continue
# The inserted slice must have length 1 (to be a single index)
if middle_shape is None or len(middle_shape) == 0 or middle_shape[0] != 1:
continue
rank = len(middle_shape)
scan_output = middle.output[0]
if not begin_slice.inputs[1].is_const() or not end_slice.inputs[2].is_const():
continue
# The first slice must start from the beginning (0) for all dims
if not all(v == 0 for v in begin_slice.inputs[1].get_tensor_value()):
continue
# The second slice must slice to the end (-1) for all dims
if not all(v == -1 for v in end_slice.inputs[2].get_tensor_value()):
continue
# The other slice dims are assembled by concatenation if rank > 1
if rank > 1:
begin_concat = begin_slice.inputs[2]
end_concat = end_slice.inputs[1]
if not begin_concat.type == "TFL_CONCATENATION":
continue
if not end_concat.type == "TFL_CONCATENATION":
continue
# Except for dim 0, slice from beginning to end
if not all(get_uniform_const_val(inp) == -1 for inp in begin_concat.inputs[1:]):
continue
if not all(get_uniform_const_val(inp) == 0 for inp in end_concat.inputs[1:]):
continue
begin_idx = begin_concat.inputs[0]
end_idx = end_concat.inputs[0]
else:
begin_idx = begin_slice.inputs[2]
end_idx = end_slice.inputs[1]
# For dim 0, slice to i for first part and from i+1 for second
if not node_is_one_plus_node(begin_idx, end_idx):
continue
out1, _ = get_out_and_offset(begin_idx)
graph_inps = [n.output[0] for n in g.inputs]
# To be a scan output, i must be a graph input
if out1 not in graph_inps:
continue
# The array being sliced must be a graph input
if original_tensor not in graph_inps:
continue
# The input/output index of i
idx = graph_inps.index(out1)
# The input/output index of the array
scan_output_idx = graph_inps.index(original_tensor)
# For a scan output, i must be assigned to i+1 with each iteration
if not node_is_one_plus_node(g.get_node_by_output(out1), g.get_node_by_output(g.outputs[idx])):
continue
if len(g.find_output_consumers(concat.output[0])) > 1:
continue

if g.opset < 10 and len(g.find_output_consumers(concat.output[0])) <= 1:
# If opset is < 10, conversion of the subgraph will fail unless we remove the slice nodes
# We add a tmp node to replace them.
shape = g.get_shape(concat.output[0])
dtype = g.get_dtype(concat.output[0])
tmp_node = g.make_node("TMP_SCAN_OUTPUT", [original_tensor, scan_output],
shapes=[shape], dtypes=[dtype])
g.replace_all_inputs(concat.output[0], tmp_node.output[0])

to_remove = []
out = g.outputs[scan_output_idx]
node = g.get_node_by_output(out)
to_remove.append(node)

while len(node.input) > 0 and node != concat:
out = node.input[0]
node = g.get_node_by_output(out)
to_remove.append(node)

to_remove += [begin_slice, end_slice, concat]

out = original_tensor
node = g.get_node_by_output(out)
to_remove.append(node)

while len(node.input) > 0:
out = node.input[0]
node = g.get_node_by_output(out)
to_remove.append(node)

if not g.is_safe_to_remove_nodes(to_remove):
continue

g.scan_outputs.append((scan_output_idx, scan_output))
return ops

def get_uniform_const_val(n):
if not n.is_const():
return None
v = n.get_tensor_value(as_list=False).flatten()
if len(v) == 0:
return None
if np.all(v == v[0]):
return v[0]
return None

def get_out_and_offset(n):
if n.type in ['TFL_RESHAPE', 'TFL_IDENTITY', 'Identity']:
return get_out_and_offset(n.inputs[0])
if n.type == 'TFL_ADD':
v1 = get_uniform_const_val(n.inputs[0])
v2 = get_uniform_const_val(n.inputs[1])
if v1 is not None and v2 is not None:
return '', v1 + v2
if v1 is not None:
inp2, o2 = get_out_and_offset(n.inputs[1])
return inp2, v1 + o2
if v2 is not None:
inp1, o1 = get_out_and_offset(n.inputs[0])
return inp1, v2 + o1
return n.output[0], 0

def node_is_one_plus_node(node, one_plus_node):
n1, o1 = get_out_and_offset(node)
n2, o2 = get_out_and_offset(one_plus_node)
return n1 == n2 and o1 + 1 == o2
2 changes: 1 addition & 1 deletion tf2onnx/tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist()
buf = model.Buffers(tensor.Buffer())
dtypes[name] = map_tflite_dtype_to_onnx(tensor.Type())
if not buf.DataIsNone():
if not buf.DataIsNone() and tensor.Buffer() > 0:
# For const values we use TF to decode the binary data from the buffer
t = tensor_pb2.TensorProto()
t.tensor_content = buf.DataAsNumpy().tobytes()
Expand Down

0 comments on commit 748f491

Please sign in to comment.