Skip to content

Commit

Permalink
[Frontend][ONNX] Support ONNX Scan operator (apache#9438)
Browse files Browse the repository at this point in the history
* [Frontend][ONNX] Support ONNX Scan operator

* fix lint

* remove test_scan_sum in unsupported_onnx_tests

* support scan opset 8

* fix lint

* fix negative axes bug

* fix lint

Co-authored-by: Matthew Brookhart <mbrookhart@octoml.ai>
  • Loading branch information
2 people authored and yangulei committed Jan 11, 2022
1 parent 17c7a84 commit 7d1b5f5
Show file tree
Hide file tree
Showing 2 changed files with 387 additions and 2 deletions.
220 changes: 220 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3220,6 +3220,225 @@ def _impl_v1(cls, inputs, attr, params):
return ret


class Scan(OnnxOpConverter):
"""Operator converter for Scan"""

@classmethod
def _impl_v8(cls, inputs, attr, params):
new_inputs = inputs[1:]
batch_num = infer_shape(inputs[1])[0]
out = []
for i in range(batch_num):
v9_inputs = [
_op.take(new_inputs[j], _expr.const(i), axis=0) for j in range(len(new_inputs))
]
results = cls._impl_v9(v9_inputs, attr, params)
results = [_op.expand_dims(results[j], axis=0) for j in range(len(results))]
if i == 0:
out = results
else:
out = [_op.concatenate([out[j], results[j]], axis=0) for j in range(len(results))]

out = _expr.TupleWrapper(_expr.Tuple(out), len(out))
return out

@classmethod
def _impl_v9(cls, inputs, attr, params):
body = attr.get("body")
num_scan_inputs = attr.get("num_scan_inputs")
num_all_inputs = len(inputs)
num_state_inputs = len(body.input) - num_scan_inputs
num_state_outputs = num_state_inputs
num_all_outputs = len(body.output)
num_scan_outputs = num_all_outputs - num_state_outputs
scan_input_axes = attr.get("scan_input_axes", [0] * num_scan_inputs)
scan_input_directions = attr.get("scan_input_directions", [0] * num_scan_inputs)
scan_output_axes = list(attr.get("scan_output_axes", [0] * num_scan_outputs))
scan_output_directions = attr.get("scan_output_directions", [0] * num_scan_outputs)
# loop count are the same for all scan inputs, so get loop count by first input scan
# strided_slice not support dynamic axes, so assume input shape are static
max_loop_count = infer_shape(inputs[num_state_inputs])[scan_input_axes[0]]

# Create a copy of the body function to prevent the original
# from being modified.
body = copy.copy(attr["body"])

# Loop inputs will be packed as
# [iter_count, loop_deps, scan_outputs]
def cond_fn(*loop_inputs):
i = loop_inputs[0]
return _op.less(i, relay.const(max_loop_count, "int32"))

# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(
graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()

# Create a list of variables for each value updated in the loop.
def get_var(name, val, scan=False):
checked_type = infer_type(val)
if hasattr(checked_type, "type_annotation"):
checked_type = checked_type.type_annotation
if hasattr(checked_type, "checked_type"):
checked_type = checked_type.checked_type
shape = get_const_tuple(checked_type.shape)
actual_shape = []
for dim in shape:
if isinstance(dim, int) and dim == 0:
actual_shape.append(_ty.Any())
else:
actual_shape.append(dim)
if scan:
return _expr.var(name, shape=[_ty.Any()] + actual_shape, dtype=checked_type.dtype)

return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype)

# Construct variables and initial empty tensors for any scan outputs.
# To do this, we'll figure out the output shapes of the body subgraph by importing
# it and doing type inference.
scan_output_vars = []
scan_output_init = []
if num_scan_outputs > 0:
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(
body, graph_scope.opset, get_output_expr=True
)
loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output))

for i in range(num_scan_outputs):
name, _, _, _ = get_info(body.output[i + num_state_outputs])
output_node = infer_type(loop_outputs[i + num_state_outputs])
shape = list(get_const_tuple(output_node.checked_type.shape))
if scan_output_axes[i] < 0:
scan_output_axes[i] = len(shape) + scan_output_axes[i] + 1
shape.insert(scan_output_axes[i], max_loop_count)
dtype = output_node.checked_type.dtype
scan_output_vars.append(_expr.var(name, shape=shape, dtype=dtype))
scan_output_init.append(_op.zeros(shape, dtype))

# loop vars = [iter_count, scan_state, scan_out]
loop_vars = [
_expr.var("iter", shape=(), dtype="int32"), # iteration count
]
loop_vars += [
get_var(body.input[i].name, v) for i, v in enumerate(inputs) if i < num_state_inputs
]
loop_vars += scan_output_vars
body_input_var_names = ["iter"] + [body.input[i].name for i in range(len(body.input))]

# # Now we can remove loop iter variables from our inner loop's inputs.
# # This is kind of a hack since we have graph inputs that we don't
# # want to treat as actual inputs.
while len(body.input) != 0:
body.input.pop(0)

# Define the loop body, in this function we need to unpack loop inputs,
# convert the loop subgraph, and pack outputs for the next iteration.
def body_fn(*loop_inputs):
# Unpack inputs
loop_count = loop_inputs[0]
state_vars = list(loop_inputs[1 : 1 + num_state_inputs])
scan_vars = list(loop_inputs[1 + num_state_inputs :])
# body take scan graph scan inputs as original input
input_scan_exprs = []
for i in range(num_state_inputs, num_all_inputs):
if scan_input_directions[i - num_state_inputs] != 0:
input_scan_exprs.append(
relay.take(
inputs[i],
relay.const(max_loop_count - 1, "int32") - loop_count,
axis=scan_input_axes[i - num_state_inputs],
)
)
else:
input_scan_exprs.append(
relay.take(
inputs[i],
loop_count,
axis=scan_input_axes[i - num_state_inputs],
)
)

# Prepare body inputs by adding them to node dictionary.
body_inputs = [loop_count] + state_vars + input_scan_exprs
for i, inp in enumerate(body_inputs):
subgraph_scope._nodes[body_input_var_names[i]] = inp

# Get the output of the current loop using the updated inputs.
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(
body, graph_scope.opset, get_output_expr=True
)
# Unpack the body outputs and prepare variables for next iteration.
new_state_vars = [loop_outputs[i] for i in range(num_state_outputs)]
new_scan_vars = [loop_outputs[i] for i in range(num_state_outputs, num_all_outputs)]

# Add new scan outputs to tracking
combined_scan_outputs = []
for i in range(num_scan_outputs):
if scan_output_directions[i] == 0:
# append new scan output
combined_scan = _op.concatenate(
[scan_vars[i], _op.expand_dims(new_scan_vars[i], axis=scan_output_axes[i])],
axis=scan_output_axes[i],
)
# pop head scan output
combined_scan = _op.strided_slice(
combined_scan,
begin=[1],
end=[max_loop_count + 1],
strides=[1],
axes=[scan_output_axes[i]],
)
else:
# prepend new scan output
combined_scan = _op.concatenate(
[_op.expand_dims(new_scan_vars[i], axis=scan_output_axes[i]), scan_vars[i]],
axis=scan_output_axes[i],
)
# pop tail scan output
combined_scan = _op.strided_slice(
combined_scan,
begin=[0],
end=[max_loop_count],
strides=[1],
axes=[scan_output_axes[i]],
)
combined_scan_outputs.append(combined_scan)

incr = _expr.const(1, dtype="int32")
loop_count = loop_count + incr

# Pack loop outputs for next iteration
# [iter_count, state_var, scan_var]
return [loop_count] + new_state_vars + combined_scan_outputs

# Create the loop function.
loop = fold_constant(_loops.while_loop(cond_fn, loop_vars, body_fn))

# Now need to run initial values through the graph.
init_count = _expr.const(0, dtype="int32")

input_states = [inputs[i] for i in range(num_state_inputs)]
loop_vals = loop(init_count, *input_states, *scan_output_init)

outputs = _expr.TupleWrapper(
_expr.Tuple([_expr.TupleGetItem(loop_vals, i + 1) for i in range(num_all_outputs)]),
num_all_outputs,
)

# Update outer graph with constants found in the subgraph.
free_vars = analysis.free_vars(loop)
graph_scope._params.update(subgraph_scope._params)
graph_scope._nodes.update(subgraph_scope._nodes)
for var in free_vars:
graph_scope._nodes.update({var.name_hint: var})
return outputs


class NonMaxSuppression(OnnxOpConverter):
"""Operator converter for NonMaxSuppression."""

Expand Down Expand Up @@ -4537,6 +4756,7 @@ def _get_convert_map(opset):
"Adagrad": Adagrad.get_converter(opset),
"Adam": Adam.get_converter(opset),
"Momentum": Momentum.get_converter(opset),
"Scan": Scan.get_converter(opset),
}


Expand Down
Loading

0 comments on commit 7d1b5f5

Please sign in to comment.