From 7d1b5f540dcefc038854885cad972c843d115a83 Mon Sep 17 00:00:00 2001 From: shengxinhu <69130386+shengxinhu@users.noreply.github.com> Date: Thu, 9 Dec 2021 16:46:22 +0800 Subject: [PATCH] [Frontend][ONNX] Support ONNX Scan operator (#9438) * [Frontend][ONNX] Support ONNX Scan operator * fix lint * remove test_scan_sum in unsupported_onnx_tests * support scan opset 8 * fix lint * fix negative axes bug * fix lint Co-authored-by: Matthew Brookhart --- python/tvm/relay/frontend/onnx.py | 220 +++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 169 +++++++++++++++- 2 files changed, 387 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 856fb37d5c5ff..79fe0d734b6d4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3220,6 +3220,225 @@ def _impl_v1(cls, inputs, attr, params): return ret +class Scan(OnnxOpConverter): + """Operator converter for Scan""" + + @classmethod + def _impl_v8(cls, inputs, attr, params): + new_inputs = inputs[1:] + batch_num = infer_shape(inputs[1])[0] + out = [] + for i in range(batch_num): + v9_inputs = [ + _op.take(new_inputs[j], _expr.const(i), axis=0) for j in range(len(new_inputs)) + ] + results = cls._impl_v9(v9_inputs, attr, params) + results = [_op.expand_dims(results[j], axis=0) for j in range(len(results))] + if i == 0: + out = results + else: + out = [_op.concatenate([out[j], results[j]], axis=0) for j in range(len(results))] + + out = _expr.TupleWrapper(_expr.Tuple(out), len(out)) + return out + + @classmethod + def _impl_v9(cls, inputs, attr, params): + body = attr.get("body") + num_scan_inputs = attr.get("num_scan_inputs") + num_all_inputs = len(inputs) + num_state_inputs = len(body.input) - num_scan_inputs + num_state_outputs = num_state_inputs + num_all_outputs = len(body.output) + num_scan_outputs = num_all_outputs - num_state_outputs + scan_input_axes = attr.get("scan_input_axes", [0] * num_scan_inputs) + scan_input_directions = attr.get("scan_input_directions", [0] * num_scan_inputs) + scan_output_axes = list(attr.get("scan_output_axes", [0] * num_scan_outputs)) + scan_output_directions = attr.get("scan_output_directions", [0] * num_scan_outputs) + # loop count are the same for all scan inputs, so get loop count by first input scan + # strided_slice not support dynamic axes, so assume input shape are static + max_loop_count = infer_shape(inputs[num_state_inputs])[scan_input_axes[0]] + + # Create a copy of the body function to prevent the original + # from being modified. + body = copy.copy(attr["body"]) + + # Loop inputs will be packed as + # [iter_count, loop_deps, scan_outputs] + def cond_fn(*loop_inputs): + i = loop_inputs[0] + return _op.less(i, relay.const(max_loop_count, "int32")) + + # Get the current graph proto and create a clone for the subgraph + graph_scope = GraphProto.current + subgraph_scope = GraphProto( + graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params + ) + # Load nodes from outer graph into inner graph. + subgraph_scope._nodes = graph_scope._nodes.copy() + + # Create a list of variables for each value updated in the loop. + def get_var(name, val, scan=False): + checked_type = infer_type(val) + if hasattr(checked_type, "type_annotation"): + checked_type = checked_type.type_annotation + if hasattr(checked_type, "checked_type"): + checked_type = checked_type.checked_type + shape = get_const_tuple(checked_type.shape) + actual_shape = [] + for dim in shape: + if isinstance(dim, int) and dim == 0: + actual_shape.append(_ty.Any()) + else: + actual_shape.append(dim) + if scan: + return _expr.var(name, shape=[_ty.Any()] + actual_shape, dtype=checked_type.dtype) + + return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) + + # Construct variables and initial empty tensors for any scan outputs. + # To do this, we'll figure out the output shapes of the body subgraph by importing + # it and doing type inference. + scan_output_vars = [] + scan_output_init = [] + if num_scan_outputs > 0: + with subgraph_scope: + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) + loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output)) + + for i in range(num_scan_outputs): + name, _, _, _ = get_info(body.output[i + num_state_outputs]) + output_node = infer_type(loop_outputs[i + num_state_outputs]) + shape = list(get_const_tuple(output_node.checked_type.shape)) + if scan_output_axes[i] < 0: + scan_output_axes[i] = len(shape) + scan_output_axes[i] + 1 + shape.insert(scan_output_axes[i], max_loop_count) + dtype = output_node.checked_type.dtype + scan_output_vars.append(_expr.var(name, shape=shape, dtype=dtype)) + scan_output_init.append(_op.zeros(shape, dtype)) + + # loop vars = [iter_count, scan_state, scan_out] + loop_vars = [ + _expr.var("iter", shape=(), dtype="int32"), # iteration count + ] + loop_vars += [ + get_var(body.input[i].name, v) for i, v in enumerate(inputs) if i < num_state_inputs + ] + loop_vars += scan_output_vars + body_input_var_names = ["iter"] + [body.input[i].name for i in range(len(body.input))] + + # # Now we can remove loop iter variables from our inner loop's inputs. + # # This is kind of a hack since we have graph inputs that we don't + # # want to treat as actual inputs. + while len(body.input) != 0: + body.input.pop(0) + + # Define the loop body, in this function we need to unpack loop inputs, + # convert the loop subgraph, and pack outputs for the next iteration. + def body_fn(*loop_inputs): + # Unpack inputs + loop_count = loop_inputs[0] + state_vars = list(loop_inputs[1 : 1 + num_state_inputs]) + scan_vars = list(loop_inputs[1 + num_state_inputs :]) + # body take scan graph scan inputs as original input + input_scan_exprs = [] + for i in range(num_state_inputs, num_all_inputs): + if scan_input_directions[i - num_state_inputs] != 0: + input_scan_exprs.append( + relay.take( + inputs[i], + relay.const(max_loop_count - 1, "int32") - loop_count, + axis=scan_input_axes[i - num_state_inputs], + ) + ) + else: + input_scan_exprs.append( + relay.take( + inputs[i], + loop_count, + axis=scan_input_axes[i - num_state_inputs], + ) + ) + + # Prepare body inputs by adding them to node dictionary. + body_inputs = [loop_count] + state_vars + input_scan_exprs + for i, inp in enumerate(body_inputs): + subgraph_scope._nodes[body_input_var_names[i]] = inp + + # Get the output of the current loop using the updated inputs. + with subgraph_scope: + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) + # Unpack the body outputs and prepare variables for next iteration. + new_state_vars = [loop_outputs[i] for i in range(num_state_outputs)] + new_scan_vars = [loop_outputs[i] for i in range(num_state_outputs, num_all_outputs)] + + # Add new scan outputs to tracking + combined_scan_outputs = [] + for i in range(num_scan_outputs): + if scan_output_directions[i] == 0: + # append new scan output + combined_scan = _op.concatenate( + [scan_vars[i], _op.expand_dims(new_scan_vars[i], axis=scan_output_axes[i])], + axis=scan_output_axes[i], + ) + # pop head scan output + combined_scan = _op.strided_slice( + combined_scan, + begin=[1], + end=[max_loop_count + 1], + strides=[1], + axes=[scan_output_axes[i]], + ) + else: + # prepend new scan output + combined_scan = _op.concatenate( + [_op.expand_dims(new_scan_vars[i], axis=scan_output_axes[i]), scan_vars[i]], + axis=scan_output_axes[i], + ) + # pop tail scan output + combined_scan = _op.strided_slice( + combined_scan, + begin=[0], + end=[max_loop_count], + strides=[1], + axes=[scan_output_axes[i]], + ) + combined_scan_outputs.append(combined_scan) + + incr = _expr.const(1, dtype="int32") + loop_count = loop_count + incr + + # Pack loop outputs for next iteration + # [iter_count, state_var, scan_var] + return [loop_count] + new_state_vars + combined_scan_outputs + + # Create the loop function. + loop = fold_constant(_loops.while_loop(cond_fn, loop_vars, body_fn)) + + # Now need to run initial values through the graph. + init_count = _expr.const(0, dtype="int32") + + input_states = [inputs[i] for i in range(num_state_inputs)] + loop_vals = loop(init_count, *input_states, *scan_output_init) + + outputs = _expr.TupleWrapper( + _expr.Tuple([_expr.TupleGetItem(loop_vals, i + 1) for i in range(num_all_outputs)]), + num_all_outputs, + ) + + # Update outer graph with constants found in the subgraph. + free_vars = analysis.free_vars(loop) + graph_scope._params.update(subgraph_scope._params) + graph_scope._nodes.update(subgraph_scope._nodes) + for var in free_vars: + graph_scope._nodes.update({var.name_hint: var}) + return outputs + + class NonMaxSuppression(OnnxOpConverter): """Operator converter for NonMaxSuppression.""" @@ -4537,6 +4756,7 @@ def _get_convert_map(opset): "Adagrad": Adagrad.get_converter(opset), "Adam": Adam.get_converter(opset), "Momentum": Momentum.get_converter(opset), + "Scan": Scan.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 3e74a7ebcd048..701906e4be40d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5049,8 +5049,6 @@ def verify_eyelike(indata): "test_reduce_sum_negative_axes_keepdims_random", "test_rnn_seq_length", "test_round", - "test_scan9_sum", - "test_scan_sum", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_shape_end_1", @@ -6033,6 +6031,172 @@ def repeat(N, D): ) +@tvm.testing.parametrize_targets +def test_scan(target, dev): + def verify_scan( + input_shapes, + output_shapes, + num_scan_inputs, + scan_input_axes, + scan_input_directions, + scan_output_axes, + scan_output_directions, + opset, + ): + import copy + + body_input_shapes = copy.deepcopy(input_shapes) + num_state_inputs = len(input_shapes) - num_scan_inputs + + if opset == 8: + for i in range(len(input_shapes)): + body_input_shapes[i].pop(0) + for i in range(num_state_inputs, len(input_shapes)): + body_input_shapes[i].pop(0) + else: + for i in range(num_state_inputs, len(input_shapes)): + body_input_shapes[i].pop(scan_input_axes[i - num_state_inputs]) + + initial0 = onnx.helper.make_tensor_value_info( + "initial0", onnx.TensorProto.FLOAT, body_input_shapes[0] + ) + initial1 = onnx.helper.make_tensor_value_info( + "initial1", onnx.TensorProto.FLOAT, body_input_shapes[1] + ) + input0 = onnx.helper.make_tensor_value_info( + "input0", onnx.TensorProto.FLOAT, body_input_shapes[2] + ) + input1 = onnx.helper.make_tensor_value_info( + "input1", onnx.TensorProto.FLOAT, body_input_shapes[3] + ) + input2 = onnx.helper.make_tensor_value_info( + "input2", onnx.TensorProto.FLOAT, body_input_shapes[4] + ) + state0 = onnx.helper.make_tensor_value_info( + "state0", onnx.TensorProto.FLOAT, body_input_shapes[0] + ) + scan_out0 = onnx.helper.make_tensor_value_info( + "scan_out0", onnx.TensorProto.FLOAT, body_input_shapes[0] + ) + matmul_out = onnx.helper.make_tensor_value_info( + "matmul_out", onnx.TensorProto.FLOAT, body_input_shapes[1] + ) + state1 = onnx.helper.make_tensor_value_info( + "state1", onnx.TensorProto.FLOAT, body_input_shapes[1] + ) + scan_out1 = onnx.helper.make_tensor_value_info( + "scan_out1", onnx.TensorProto.FLOAT, body_input_shapes[1] + ) + add_node = onnx.helper.make_node( + "Add", + inputs=["initial0", "input0"], + outputs=["state0"], + ) + id_node_0 = onnx.helper.make_node( + "Identity", + inputs=["state0"], + outputs=["scan_out0"], + ) + matmul_node = onnx.helper.make_node( + "MatMul", + inputs=["input1", "input2"], + outputs=["matmul_out"], + ) + sub_node = onnx.helper.make_node( + "Sub", + inputs=["initial1", "matmul_out"], + outputs=["state1"], + ) + id_node_1 = onnx.helper.make_node( + "Identity", + inputs=["state1"], + outputs=["scan_out1"], + ) + scan_body = onnx.helper.make_graph( + [add_node, id_node_0, matmul_node, sub_node, id_node_1], + "scan_body", + [initial0, initial1, input0, input1, input2], + [state0, state1, scan_out0, scan_out1], + ) + # create scan op node + scan_node = None + if opset == 8: + scan_node = onnx.helper.make_node( + "Scan", + inputs=["", "init0", "init1", "in0", "in1", "in2"], + outputs=["s0", "s1", "scan0", "scan1"], + num_scan_inputs=num_scan_inputs, + body=scan_body, + ) + else: + scan_node = onnx.helper.make_node( + "Scan", + inputs=["init0", "init1", "in0", "in1", "in2"], + outputs=["s0", "s1", "scan0", "scan1"], + num_scan_inputs=num_scan_inputs, + body=scan_body, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + ) + input_info = [ + helper.make_tensor_value_info("init0", TensorProto.FLOAT, input_shapes[0]), + helper.make_tensor_value_info("init1", TensorProto.FLOAT, input_shapes[1]), + helper.make_tensor_value_info("in0", TensorProto.FLOAT, input_shapes[2]), + helper.make_tensor_value_info("in1", TensorProto.FLOAT, input_shapes[3]), + helper.make_tensor_value_info("in2", TensorProto.FLOAT, input_shapes[4]), + ] + out_info = [ + helper.make_tensor_value_info("s0", TensorProto.FLOAT, output_shapes[0]), + helper.make_tensor_value_info("s1", TensorProto.FLOAT, output_shapes[1]), + helper.make_tensor_value_info("scan0", TensorProto.FLOAT, output_shapes[2]), + helper.make_tensor_value_info("scan1", TensorProto.FLOAT, output_shapes[3]), + ] + graph = helper.make_graph( + nodes=[scan_node], + name="scan_test", + inputs=input_info, + outputs=out_info, + ) + model = onnx.helper.make_model(graph, producer_name="scan-test") + init0 = np.random.uniform(low=0, high=255, size=input_shapes[0]).astype(np.float32) + init1 = np.random.uniform(low=0, high=255, size=input_shapes[1]).astype(np.float32) + in0 = np.random.uniform(low=0, high=255, size=input_shapes[2]).astype(np.float32) + in1 = np.random.uniform(low=0, high=255, size=input_shapes[3]).astype(np.float32) + in2 = np.random.uniform(low=0, high=255, size=input_shapes[4]).astype(np.float32) + input_values = [init0, init1, in0, in1, in2] + + verify_with_ort_with_inputs( + model, + input_values, + target=target, + dev=dev, + opt_level=2, + use_vm=True, + opset=opset, + ) + + # opset 8 + input_shapes = [[2, 6, 7, 8], [2, 3, 3], [2, 5, 6, 7, 8], [2, 5, 3, 4], [2, 5, 4, 3]] + output_shapes = [[2, 6, 7, 8], [2, 3, 3], [2, 5, 6, 7, 8], [2, 5, 3, 3]] + # input_shapes, output_shapes, num_scan_inputs, scan_input_axes, scan_input_directions, + # scan_output_axes, scan_output_directions, opset + verify_scan(input_shapes, output_shapes, 3, [0] * 3, [0] * 3, [0] * 2, [0] * 2, 8) + # opset 9 + input_shapes = [[6, 7, 8], [3, 3], [5, 6, 7, 8], [5, 3, 4], [5, 4, 3]] + output_shapes = [[6, 7, 8], [3, 3], [5, 6, 7, 8], [5, 3, 3]] + verify_scan(input_shapes, output_shapes, 3, [0] * 3, [0] * 3, [0] * 2, [0] * 2, 9) + + input_shapes = [[6, 7, 8], [3, 3], [5, 6, 7, 8], [3, 4, 5], [4, 5, 3]] + output_shapes = [[6, 7, 8], [3, 3], [6, 5, 7, 8], [3, 5, 3]] + verify_scan(input_shapes, output_shapes, 3, [0, 2, 1], [1] * 3, [1] * 2, [1] * 2, 9) + # Negative axes + input_shapes = [[6, 7, 8], [3, 3], [5, 6, 7, 8], [3, 4, 5], [4, 5, 3]] + output_shapes = [[6, 7, 8], [3, 3], [6, 5, 7, 8], [3, 5, 3]] + verify_scan(input_shapes, output_shapes, 3, [-4, -1, -2], [1] * 3, [-3, -2], [1] * 2, 9) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -6124,6 +6288,7 @@ def repeat(N, D): test_convinteger() test_batch_matmul() test_global_lppool() + test_scan() test_random_uniform_like() test_random_normal() test_random_normal_like()