From 206bafa111cada236a2ee19109e87b451fd20e53 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 28 Aug 2023 18:45:20 +0000 Subject: [PATCH 01/10] Add initial spec for openai whisper decode op --- .../core/graph/contrib_ops/contrib_defs.cc | 111 +++++++----------- onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + 2 files changed, 42 insertions(+), 71 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..44345ab05fffa 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1061,77 +1061,46 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, updateOutputShape(ctx, 0, {N, C, H_out, W_out}); })); -ONNX_MS_OPERATOR_SET_SCHEMA( - UnfoldTensor, 1, - OpSchema() - .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. " - "Step between two slices is given by step. " - "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in " - "the returned tensor will be (sizedim - size) / step + 1. " - "An additional dimension of size size is appended in the returned tensor.") - .Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1)) - .Attr("size", "specify the size", AttributeProto::INT) - .Attr("step", "specify the step.", AttributeProto::INT, static_cast(1)) - .Input(0, "input", "input tensor", "T") - .Output(0, "output", "Output tensor.", "T") - .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - - if (!hasInputShape(ctx, 0)) return; - auto& input_shape = getInputShape(ctx, 0); - const int rank = input_shape.dim_size(); - int64_t dim = getAttribute(ctx, "dim", -1); - dim = HandleNegativeAxis(dim, rank); - if (!input_shape.dim(static_cast(dim)).has_dim_value()) { - return; - } - int64_t dim_size = input_shape.dim(static_cast(dim)).dim_value(); - - const int64_t step = getAttribute(ctx, "step", -1); - if (step <= 0) { - fail_shape_inference("size attribute in UnfoldTensor must greater than 0.") - } - int64_t size = -1; - auto size_proto = ctx.getAttribute("size"); - if (!(size_proto)) { - fail_shape_inference("size attribute in UnfoldTensor not specified!") - } - size = size_proto->i(); - if (size > dim_size || size <= 0) { - fail_shape_inference("size attribute in UnfoldTensor not positive and less than the dim size!") - } - - ONNX_NAMESPACE::TensorShapeProto output_shape; - for (int d = 0; d < rank; d++) { - if (d == dim) { - output_shape.add_dim()->set_dim_value((dim_size - size) / step + 1); - } else { - *output_shape.add_dim() = input_shape.dim(d); - } - } - output_shape.add_dim()->set_dim_value(size); - updateOutputShape(ctx, 0, output_shape); - })); - -ONNX_MS_OPERATOR_SET_SCHEMA( - DynamicTimeWarping, 1, - OpSchema() - .SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point" - "(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return " - "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" - "have the lowest cost among all paths from (0, 0) to (M-1, N-1).") - .Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F") - .Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I") - .TypeConstraint("F", {"tensor(float)"}, "Constrain to float tensors.") - .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); - ONNX_NAMESPACE::TensorShapeProto resultShape; - resultShape.add_dim()->set_dim_value(2); - resultShape.add_dim(); - updateOutputShape(ctx, 0, resultShape); - })); +ONNX_MS_OPERATOR_SET_SCHEMA(WhisperDecode, 1, + OpSchema() + .SetDoc("Whisper decode for speech processing. Supports whisper beam search and greedy search.") + .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) + .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) + .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) + .Attr("search_type", "search type: 0 for GreedySearch; 1 for BeamSearch", AttributeProto::INT, static_cast(0)) + .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) + .Attr("vocab_size", + "Size of the vocabulary. " + "If not provided, it will be inferred from the decoder subgraph's output shape", + AttributeProto::INT, static_cast(-1)) + .Attr("no_speech_token", + "The token in whisper model that mark all sequence empty. With this model, whisper could output no_speech_prob after Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Input(0, "logits", "Per-token logits of the probability distribution at the current step. Shape is (batch_size, vocab_size)", "F") + .Input(1, "tokens", "All tokens in the context so far, including the prefix and sot_sequence tokens. Shape is (batch_size, sequence_length)", "F") + .Input(2, "temperature", "", "F") + .Input(3, "best_of", "The number of independent sample trajectories, if t > 0. Shape is (1)", OpSchema::Optional) + .Input(4, "beam_size", "The number of beams in beam search, if t == 0. Shape is (1)", OpSchema::Optional) + .Input(5, "patience", "The patience in beam search. Shape is (1)", OpSchema::Optional) + .Input(6, "length_penalty", "alpha in Google NMT, or None for length norm, when ranking generations" + "to select which to return among the beams or best-of-N samples. Shape is (1)", "T", OpSchema::Optional) + .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) + .Input(8, "logits_filters", "Logits filters to be applied based on types of search used. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) + .Output(0, "sequences", "The tokens, appended with the selected next token. Shape is (batch_size, current_sequence_length + 1)", "F") + .Output(1, "sequences_scores", "Sequence of Tensors containing candidate token sequences, for each audio input. Shape is ()", "T", OpSchema::Optional) + .Output(2, "non_speech_probs", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." + "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." + "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", "T", OpSchema::Optional) + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") + .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + #BeamSearchShapeInference(ctx); + })); ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, OpSchema() diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 5eef1b33a24dd..0e98e05f0d07b 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -110,6 +110,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Trilu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, UnfoldTensor); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicTimeWarping); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Unique); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WhisperDecode); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention); @@ -219,6 +220,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); From d35a7bf7330eaf418b6da6bbc2111547ef8f493c Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 25 Sep 2023 16:41:44 +0000 Subject: [PATCH 02/10] Add openai fusion files --- .../core/graph/contrib_ops/contrib_defs.cc | 41 -- onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 - .../transformers/fusion_bart_attention_openai | 416 ++++++++++++++++++ .../models/whisper/convert_to_onnx.py | 4 +- .../models/whisper/whisper_helper.py | 5 + .../tools/transformers/onnx_model_bart.py | 1 + 6 files changed, 424 insertions(+), 45 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_bart_attention_openai diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 44345ab05fffa..dea668f6d1ace 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1061,47 +1061,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, updateOutputShape(ctx, 0, {N, C, H_out, W_out}); })); -ONNX_MS_OPERATOR_SET_SCHEMA(WhisperDecode, 1, - OpSchema() - .SetDoc("Whisper decode for speech processing. Supports whisper beam search and greedy search.") - .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) - .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) - .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) - .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) - .Attr("search_type", "search type: 0 for GreedySearch; 1 for BeamSearch", AttributeProto::INT, static_cast(0)) - .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) - .Attr("vocab_size", - "Size of the vocabulary. " - "If not provided, it will be inferred from the decoder subgraph's output shape", - AttributeProto::INT, static_cast(-1)) - .Attr("no_speech_token", - "The token in whisper model that mark all sequence empty. With this model, whisper could output no_speech_prob after Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) - .Input(0, "logits", "Per-token logits of the probability distribution at the current step. Shape is (batch_size, vocab_size)", "F") - .Input(1, "tokens", "All tokens in the context so far, including the prefix and sot_sequence tokens. Shape is (batch_size, sequence_length)", "F") - .Input(2, "temperature", "", "F") - .Input(3, "best_of", "The number of independent sample trajectories, if t > 0. Shape is (1)", OpSchema::Optional) - .Input(4, "beam_size", "The number of beams in beam search, if t == 0. Shape is (1)", OpSchema::Optional) - .Input(5, "patience", "The patience in beam search. Shape is (1)", OpSchema::Optional) - .Input(6, "length_penalty", "alpha in Google NMT, or None for length norm, when ranking generations" - "to select which to return among the beams or best-of-N samples. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(8, "logits_filters", "Logits filters to be applied based on types of search used. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) - .Output(0, "sequences", "The tokens, appended with the selected next token. Shape is (batch_size, current_sequence_length + 1)", "F") - .Output(1, "sequences_scores", "Sequence of Tensors containing candidate token sequences, for each audio input. Shape is ()", "T", OpSchema::Optional) - .Output(2, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", "T", OpSchema::Optional) - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") - .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") - .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") - .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - #BeamSearchShapeInference(ctx); - })); - ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, OpSchema() .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 0e98e05f0d07b..5eef1b33a24dd 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -110,7 +110,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Trilu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, UnfoldTensor); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicTimeWarping); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Unique); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WhisperDecode); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention); @@ -220,7 +219,6 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); - fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai new file mode 100644 index 0000000000000..7bb4cde91f210 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai @@ -0,0 +1,416 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from fusion_attention import AttentionMask, FusionAttention +from onnx import TensorProto, helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionBartAttentionOpenai(FusionAttention): + """ + Fuse Bart Attention subgraph into one Attention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def check_runtime_shape_path( + self, + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ): + concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) + if concat_qkv_2_path is None: + return False + concat_qkv_2 = concat_qkv_2_path[0] + + reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None: + return False + + _, gather_1, shape_1 = reshape_qkv_2_path_1 + _, gather_2, shape_2 = reshape_qkv_2_path_2 + + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0]) + reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0]) + if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None: + return False + if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name: + return False + + reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None: + return False + + mul_q = reshape_q_2_path[-1] + mul_k = reshape_k_2_path[-1] + mul_v = reshape_v_2_path[-1] + + gather_1_out = gather_1.output[0] + if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: + return False + + return True + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0, 0, 0], + ) + if qkv_nodes is not None: + ( + add_out, + matmul_out, + reshape_qkv_2, + transpose_qkv, + reshape_qkv_1, + matmul_qkv, + ) = qkv_nodes + else: + return + + other_inputs = [] + for input in normalize_node.input: + if input not in output_name_to_node: + continue + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + return + root_input = other_inputs[0] + + # Sometimes the input name to the attention MatMul nodes does not match the input name to the end + # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul + # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are + # children nodes for each of its output names. + """ + root_input + +---------------------------------------------------+ + | | + | | + SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization + """ + skip_layernorm = output_name_to_node[root_input] + # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose + # child is the LayerNormalization node. + if skip_layernorm.op_type == "Add": + skip_layernorm = self.model.get_children(skip_layernorm)[0] + for output in skip_layernorm.output: + if not output: + continue + children = input_name_to_nodes[output] + children_types = [child.op_type for child in children] + if children_types.count("MatMul") >= 1: + root_input = output + break + + graph_input_names = set([node.name for node in self.model.graph().input]) + graph_output_names = set([node.name for node in self.model.graph().output]) + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, None], + ) + v_nodes_with_past_self_attn = self.model.match_parent_path( + # Decoder attention with past value concatenated before MatMul + matmul_qkv, + ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 1, 0, 0, None], + ) + v_nodes_with_past_cross_attn = self.model.match_parent_path( + # Decoder attention with past value directly used in MatMul + matmul_qkv, + ["Reshape"], + [1], + ) + past_v, present_v = "", "" + reshape_v_2, add_v = None, None + if v_nodes is not None: + (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + present_v = transpose_v.output[0] + elif v_nodes_with_past_self_attn is not None: + (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn + v_nodes = v_nodes_with_past_self_attn + past_v = concat_v.input[0] + present_v = concat_v.output[0] + elif ( + v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names + ): + v_nodes = v_nodes_with_past_cross_attn + past_v = v_nodes[-1].input[0] + present_v = v_nodes[-1].output[0] + if present_v not in graph_output_names: + identity_node_v = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) + ) + present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" + else: + logger.debug("fuse_attention: failed to match v path") + return + past_v = past_v if past_v in graph_input_names else "" + present_v = present_v if present_v in graph_output_names else "" + + qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + qk_nodes_2 = self.model.match_parent_path( + matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] + ) + if qk_nodes_1 is not None: + _, matmul_qk = qk_nodes_1 + qk_nodes = qk_nodes_1 + elif qk_nodes_2 is not None: + _, _, add_qk, _, matmul_qk = qk_nodes_2 + qk_nodes = qk_nodes_2 + else: + return + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], + [0, 0, 0, 0, 0, 1], + ) + if q_nodes is not None: + reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes + else: + return + + k_nodes_with_bias = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, 0, 1], + ) + k_nodes_no_bias = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path( + # Decoder attention with past key concatenated before MatMul + matmul_qk, + ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path( + # Decoder attention with past key directly used in MatMul + matmul_qk, + ["Transpose", "Reshape"], + [1, 0], + ) + past_k, present_k = "", "" + reshape_k_2, reshape_k_1, matmul_k = None, None, None + if k_nodes_with_bias is not None: + _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias + k_nodes = k_nodes_with_bias + elif k_nodes_no_bias is not None: + _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias + k_nodes = k_nodes_no_bias + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + present_k = transpose_k_1.output[0] + elif k_nodes_no_bias_with_past_self_attn is not None: + _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn + k_nodes = k_nodes_no_bias_with_past_self_attn + past_k = concat_k.input[0] + present_k = concat_k.output[0] + elif ( + k_nodes_no_bias_with_past_cross_attn is not None + and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names + ): + k_nodes = k_nodes_no_bias_with_past_cross_attn + past_k = k_nodes[-1].input[0] + present_k = k_nodes[-1].output[0] + if present_k not in graph_output_names: + identity_node_k = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) + ) + present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" + else: + return + past_k = past_k if past_k in graph_input_names else "" + present_k = present_k if present_k in graph_output_names else "" + + if k_nodes in (k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + # Create empty Add node for attention graph + bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] + empty_bias_name = "empty_bias" + empty_tensor = self.model.get_initializer(empty_bias_name) + if empty_tensor is None: + empty_tensor = helper.make_tensor(empty_bias_name, TensorProto.FLOAT, [bias_dim], [0.0] * bias_dim) + self.model.add_initializer(empty_tensor, self.this_graph_name) + + add_name = self.model.create_node_name("Add") + add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) + + if not past_k and not self.check_runtime_shape_path( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ): + return + + three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals() + one_root_input = ( + not three_root_inputs + and matmul_k.input[0] == root_input + and matmul_q.input[0] == root_input + and matmul_v.input[0] == root_input + ) + two_root_inputs = ( + not three_root_inputs + and matmul_q.input[0] == root_input + and matmul_k.input[0] == matmul_v.input[0] + and matmul_k.input[0] != matmul_q.input[0] + ) + + # There are 5 types of attention: + # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1 + # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2 + # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value + # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 + # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 + encoder_attention = one_root_input and qk_nodes == qk_nodes_1 + decoder_attention = one_root_input and qk_nodes == qk_nodes_2 + decoder_attention_with_past = encoder_attention and past_k and past_v + decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 + decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 + + # For decoder_attention, the attention mask needs to be included in the attention node + mask_index = None + if decoder_attention: + mask_nodes_bart = self.model.match_parent_path( + add_qk, + ["Where"], + [1], + ) + mask_nodes_whisper = self.model.match_parent_path( + add_qk, + ["Expand", "Unsqueeze", "Unsqueeze", "Where"], + [1, 0, 0, 0], + ) + if mask_nodes_whisper is not None: + mask_index = mask_nodes_whisper[0].output[-1] + elif mask_nodes_bart is not None: + mask_index = mask_nodes_bart[0].output[-1] + + if ( + encoder_attention + or decoder_attention + or decoder_attention_with_past + or decoder_cross_attention + or decoder_cross_attention_with_past + ): + attention_last_node = reshape_qkv_2 + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1) + + if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: + logger.debug("fuse_attention: failed to detect num_heads or hidden_size") + return + + new_node = None + if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + # Note: Decoder attention with past key and past value is fused as multihead attention + # rather than attention because multihead attention supports separate past key and past + # value whereas attention supports concatenated past key and past value. + new_node = ( + self.create_multihead_attention_node( + matmul_q, + matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, + matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, + add_q, + add_k if decoder_cross_attention or decoder_attention_with_past else None, + add_v if decoder_cross_attention or decoder_attention_with_past else None, + num_heads, + hidden_size, + attention_last_node.output[0], + past_k=past_k if decoder_attention_with_past else "", + past_v=past_v if decoder_attention_with_past else "", + present_k=present_k, + present_v=present_v, + packed_qkv=decoder_attention_with_past, + ) + if self.use_multi_head_attention + else None + ) + else: + # Temporarily set multihead attention flag to false + use_multi_head_attention_ground_truth = self.use_multi_head_attention + self.use_multi_head_attention = False + new_node = self.create_attention_node( + None, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + num_heads, + hidden_size, + root_input, + attention_last_node.output[0], + add_qk_str=mask_index if decoder_attention else None, + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + ) + self.use_multi_head_attention = use_multi_head_attention_ground_truth + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + + # When using multihead attention, keep MatMul nodes in original graph + if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + if q_nodes[-1].op_type == "MatMul": + q_nodes.pop() + if k_nodes[-1].op_type == "MatMul": + k_nodes.pop() + if v_nodes[-1].op_type == "MatMul": + v_nodes.pop() + if self.disable_multi_head_attention_bias and ( + decoder_cross_attention or decoder_cross_attention_with_past + ): + if q_nodes[-1].op_type == "Add": + q_nodes.pop() + if k_nodes[-1].op_type == "Add": + k_nodes.pop() + if v_nodes[-1].op_type == "Add": + v_nodes.pop() + + self.nodes_to_remove.extend(q_nodes) + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index bb697fe1e1506..9a49b146009e8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -176,7 +176,7 @@ def parse_arguments(argv=None): action="store_true", help="Produce beam search model with chained encdecinit and decoder.", ) - parser.set_defaults(chain_model=True) + parser.set_defaults(chain_model=False) parser.add_argument( "--use_whisper_beamsearch", @@ -333,7 +333,7 @@ def export_onnx_models( models = WhisperHelper.load_model( model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path ) - config = models["decoder"].config + config = models["encoder_decoder_init"].config if (not use_external_data_format) and (config.num_hidden_layers > 24): logger.info("Try use_external_data_format when model size > 2GB") diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e2dc79ca247ce..d68384fafa6d8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,6 +6,7 @@ import logging import os +import io import sys from pathlib import Path from typing import Dict, Tuple, Union @@ -19,6 +20,10 @@ from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper +from whisper.model import Whisper, ModelDimensions +from whisper import _MODELS, _ALIGNMENT_HEADS +from whisper import _download + from onnxruntime import InferenceSession sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 61a786d7af60b..d78b734a8545d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -124,6 +124,7 @@ class BartOnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) + print("reach") self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.bart_reshape_fusion_preprocess = FusionBartReshape(self) From 1d2b214ee10bfab70709f1c82001258ca8d10253 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 9 Oct 2023 17:45:42 +0000 Subject: [PATCH 03/10] Modify decoding Logic --- ...n_bart_attention_openai => fusion_bart_attention_openai.py} | 0 .../tools/transformers/models/whisper/whisper_decoder.py | 3 +-- onnxruntime/python/tools/transformers/onnx_model_bart.py | 3 ++- 3 files changed, 3 insertions(+), 3 deletions(-) rename onnxruntime/python/tools/transformers/{fusion_bart_attention_openai => fusion_bart_attention_openai.py} (100%) diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py similarity index 100% rename from onnxruntime/python/tools/transformers/fusion_bart_attention_openai rename to onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 0d69960a095ac..27fa139b0f295 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -164,8 +164,7 @@ def create_dummy( self_attention_past_shape = [ batch_size, num_attention_heads, - past_decode_sequence_length, - head_size, + past_decode_sequence_length * head_size, ] cross_attention_past_shape = [ batch_size, diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index d78b734a8545d..08ddc352415cb 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -7,6 +7,7 @@ from fusion_attention import AttentionMask from fusion_bart_attention import FusionBartAttention +from fusion_bart_attention_openai import FusionBartAttentionOpenai from fusion_options import FusionOptions from fusion_reshape import FusionReshape from onnx import numpy_helper @@ -125,7 +126,7 @@ def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) print("reach") - self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + self.attention_fusion = FusionBartAttentionOpenai(self, self.hidden_size, self.num_heads, self.attention_mask) self.bart_reshape_fusion_preprocess = FusionBartReshape(self) def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): From 1ea7bc74ad529ebf062f59484bdaadedfd28a467 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 26 Oct 2023 11:03:25 +0000 Subject: [PATCH 04/10] Add optimizations to account for 3d to 4d past --- .../fusion_bart_attention_openai.py | 148 ++++++++++++++---- .../models/whisper/whisper_decoder.py | 3 +- .../models/whisper/whisper_helper.py | 142 ++++++++--------- .../tools/transformers/onnx_model_bart.py | 11 +- .../tools/transformers/onnx_model_bert.py | 3 +- 5 files changed, 198 insertions(+), 109 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py index 7bb4cde91f210..ad07f444f551b 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class FusionBartAttentionOpenai(FusionAttention): +class FusionBartAttention(FusionAttention): """ Fuse Bart Attention subgraph into one Attention node. """ @@ -75,18 +75,19 @@ def check_runtime_shape_path( def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + print("\n") qkv_nodes = self.model.match_parent_path( normalize_node, - ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], - [1, 1, 0, 0, 0, 0], + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], ) if qkv_nodes is not None: + print("Reached qkv level") ( add_out, matmul_out, - reshape_qkv_2, - transpose_qkv, reshape_qkv_1, + transpose_qkv, matmul_qkv, ) = qkv_nodes else: @@ -128,13 +129,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input = output break - graph_input_names = set([node.name for node in self.model.graph().input]) - graph_output_names = set([node.name for node in self.model.graph().output]) + graph_input_names = set([node.name for node in self.model.graph().input]) + graph_output_names = set([node.name for node in self.model.graph().output]) v_nodes = self.model.match_parent_path( matmul_qkv, - ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0, None], + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, None], ) v_nodes_with_past_self_attn = self.model.match_parent_path( # Decoder attention with past value concatenated before MatMul @@ -145,15 +146,47 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): v_nodes_with_past_cross_attn = self.model.match_parent_path( # Decoder attention with past value directly used in MatMul matmul_qkv, - ["Reshape"], - [1], + ["Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0], ) past_v, present_v = "", "" reshape_v_2, add_v = None, None if v_nodes is not None: - (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + print("reach v path") + (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - present_v = transpose_v.output[0] + #present_v = add_v.output[0] + + add_v_children = self.model.get_children(add_v) + for child in add_v_children: + if child.op_type == "Reshape": + #if child.output[0] in graph_output_names: + #present_v = child.output[0] + reshape_v_children = self.model.get_children(child) + for reshape_child in reshape_v_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_v = reshape_child.output[0] + if child.op_type == "Concat": + concat_v_children = self.model.get_children(child) + for concat_child in concat_v_children: + if concat_child.op_type == "Reshape": + reshape_v_children = self.model.get_children(concat_child) + for reshape_child in reshape_v_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_v = reshape_child.output[0] + print("reach v path with past self attn") + concat_v_parents = self.model.get_parents(child) + for concat_parent in concat_v_parents: + if concat_parent.op_type == "Reshape": + reshape_v_parents = self.model.get_parents(concat_parent) + for reshape_parent in reshape_v_parents: + if reshape_parent.op_type == "Transpose": + if reshape_parent.input[0] in graph_input_names: + past_v = reshape_parent.input[0] + print("reach v path with past self attn") + elif v_nodes_with_past_self_attn is not None: (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn v_nodes = v_nodes_with_past_self_attn @@ -162,6 +195,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif ( v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names ): + print("reach v path with past cross attn") v_nodes = v_nodes_with_past_cross_attn past_v = v_nodes[-1].input[0] present_v = v_nodes[-1].output[0] @@ -178,31 +212,34 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) qk_nodes_2 = self.model.match_parent_path( - matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] + matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0] ) if qk_nodes_1 is not None: + print("reach qk type 1") _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: - _, _, add_qk, _, matmul_qk = qk_nodes_2 + print("reach qk type 2") + _, add_qk, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 else: return q_nodes = self.model.match_parent_path( matmul_qk, - ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], - [0, 0, 0, 0, 0, 1], + ["Mul", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], ) if q_nodes is not None: - reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes + print("reach q path") + mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes else: return k_nodes_with_bias = self.model.match_parent_path( matmul_qk, - ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0, 0, 1], + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], ) k_nodes_no_bias = self.model.match_parent_path( matmul_qk, @@ -218,14 +255,53 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path( # Decoder attention with past key directly used in MatMul matmul_qk, - ["Transpose", "Reshape"], - [1, 0], + ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], ) past_k, present_k = "", "" reshape_k_2, reshape_k_1, matmul_k = None, None, None if k_nodes_with_bias is not None: - _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias + print("reach k path") + mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias k_nodes = k_nodes_with_bias + present_k = matmul_k.output[0] + mat_k_out_tmp = matmul_k.output[0] + "_temp" + #matmul_k.output[0] = matmul_k.output[0] + "_temp" + + matmul_k_children = self.model.get_children(matmul_k) + for child in matmul_k_children: + if child.op_type == "Reshape": + #if child.output[0] in graph_output_names: + # present_k = child.output[0] + reshape_k_children = self.model.get_children(child) + for reshape_child in reshape_k_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_k = reshape_child.output[0] + if child.op_type == "Concat": + concat_k_children = self.model.get_children(child) + for concat_child in concat_k_children: + if concat_child.op_type == "Reshape": + reshape_k_children = self.model.get_children(concat_child) + for reshape_child in reshape_k_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_k = reshape_child.output[0] + print("reach k path with past self attn") + concat_k_parents = self.model.get_parents(child) + for concat_parent in concat_k_parents: + if concat_parent.op_type == "Reshape": + reshape_k_parents = self.model.get_parents(concat_parent) + for reshape_parent in reshape_k_parents: + if reshape_parent.op_type == "Transpose": + if reshape_parent.input[0] in graph_input_names: + past_k = reshape_parent.input[0] + print("reach v path with past self attn") + print("reach k path with past self attn") + #else: + # matmul_k.output[0] = mat_k_out_tmp + + elif k_nodes_no_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias k_nodes = k_nodes_no_bias @@ -240,6 +316,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes_no_bias_with_past_cross_attn is not None and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names ): + print("reach k path with past cross attn") k_nodes = k_nodes_no_bias_with_past_cross_attn past_k = k_nodes[-1].input[0] present_k = k_nodes[-1].output[0] @@ -253,7 +330,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" - if k_nodes in (k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + if k_nodes in (k_nodes_with_bias, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): # Create empty Add node for attention graph bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] empty_bias_name = "empty_bias" @@ -265,6 +342,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) + ''' if not past_k and not self.check_runtime_shape_path( reshape_qkv_2, reshape_qkv_1, @@ -274,6 +352,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input, ): return + ''' three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals() one_root_input = ( @@ -282,12 +361,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input ) + if one_root_input: print("one root input") two_root_inputs = ( not three_root_inputs and matmul_q.input[0] == root_input and matmul_k.input[0] == matmul_v.input[0] and matmul_k.input[0] != matmul_q.input[0] ) + if two_root_inputs: print("two root inputs") # There are 5 types of attention: # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1 @@ -297,9 +378,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 encoder_attention = one_root_input and qk_nodes == qk_nodes_1 decoder_attention = one_root_input and qk_nodes == qk_nodes_2 - decoder_attention_with_past = encoder_attention and past_k and past_v + decoder_attention_with_past = decoder_attention and past_k and past_v decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 + print(three_root_inputs) + print( + encoder_attention, + decoder_attention, + decoder_attention_with_past, + decoder_cross_attention, + decoder_cross_attention_with_past, + ) # For decoder_attention, the attention mask needs to be included in the attention node mask_index = None @@ -311,11 +400,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) mask_nodes_whisper = self.model.match_parent_path( add_qk, - ["Expand", "Unsqueeze", "Unsqueeze", "Where"], - [1, 0, 0, 0], + ["Slice", "Slice", "Unsqueeze", "Gather"], + [1, 0, 2, 0], ) if mask_nodes_whisper is not None: - mask_index = mask_nodes_whisper[0].output[-1] + print("reach qk add") + #mask_index = mask_nodes_whisper[0].output[-1] elif mask_nodes_bart is not None: mask_index = mask_nodes_bart[0].output[-1] @@ -326,7 +416,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): or decoder_cross_attention or decoder_cross_attention_with_past ): - attention_last_node = reshape_qkv_2 + attention_last_node = reshape_qkv_1 num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1) if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 27fa139b0f295..0d69960a095ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -164,7 +164,8 @@ def create_dummy( self_attention_past_shape = [ batch_size, num_attention_heads, - past_decode_sequence_length * head_size, + past_decode_sequence_length, + head_size, ] cross_attention_past_shape = [ batch_size, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index d68384fafa6d8..e9ef82fc1fad8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -349,82 +349,76 @@ def verify_onnx( from datasets import load_dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 - length_penalty, repetition_penalty = 1.0, 1.0 - inputs = { - "input_features": input_features.to(device), - "max_length": max_length, - "min_length": min_length, - "num_beams": num_beams, - "num_return_sequences": num_return_sequences, - "length_penalty": length_penalty, - "repetition_penalty": repetition_penalty, - "early_stopping": True, - "use_cache": True, - } - pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() - - del inputs["early_stopping"] - del inputs["use_cache"] - ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) - ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) - ort_to_np = { - "tensor(float)": np.float32, - "tensor(float16)": np.float16, - "tensor(int64)": np.int64, - "tensor(int32)": np.int32, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - } + for d in ds: + input_features = processor([d["audio"]["array"]], return_tensors="pt").input_features + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + length_penalty, repetition_penalty = 1.0, 1.0 + inputs = { + "input_features": input_features.to(device), + "max_length": max_length, + "min_length": min_length, + "num_beams": num_beams, + "num_return_sequences": num_return_sequences, + "length_penalty": length_penalty, + "repetition_penalty": repetition_penalty, + "early_stopping": True, + "use_cache": True, + } + pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + + del inputs["early_stopping"] + del inputs["use_cache"] + ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) + ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) + ort_to_np = { + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(int64)": np.int64, + "tensor(int32)": np.int32, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + } - use_extra_decoding_ids = "extra_decoding_ids" in ort_names - for name, dtype in zip(ort_names, ort_dtypes): - if name == "input_features": - inputs[name] = inputs[name].detach().cpu().numpy() - elif name == "vocab_mask": - inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype]) - elif name == "prefix_vocab_mask": - inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) - elif name == "decoder_input_ids": - raw_input_ids = ( - [[config.decoder_start_token_id]] - if use_extra_decoding_ids - else [[config.decoder_start_token_id, 50259, 50359, 50363]] + for name, dtype in zip(ort_names, ort_dtypes): + if name == "input_features": + inputs[name] = inputs[name].detach().cpu().numpy() + elif name == "vocab_mask": + inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype]) + elif name == "prefix_vocab_mask": + inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) + elif name == "decoder_input_ids": + inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) + elif name == "logits_processor": + inputs[name] = np.array([1], dtype=ort_to_np[dtype]) + else: + inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) + ort_outputs = ort_session.run(None, inputs)[0][0] + + if pt_outputs.shape != ort_outputs.shape: + logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + + #diff = pt_outputs - ort_outputs + #max_diff = max(diff.min(), diff.max(), key=abs) + #print(max_diff) + + if max_diff == 0: + # For ONNX Runtime INT8 model + pt_expected_transcription = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." ) - inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) - elif name == "logits_processor": - inputs[name] = np.array([1], dtype=ort_to_np[dtype]) - elif name == "cross_qk_layer_head": - inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) - elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) - else: - inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) - ort_outputs = ort_session.run(None, inputs)[0][0] - - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") - - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) - - if max_diff > 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) + print(pt_transcription) + ort_expected_transcription = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + print(ort_transcription) - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + parity = ( + pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] + ) + if parity: + max_diff = 0 return max_diff diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 08ddc352415cb..24083112eb0ac 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -6,8 +6,8 @@ from typing import Optional from fusion_attention import AttentionMask -from fusion_bart_attention import FusionBartAttention -from fusion_bart_attention_openai import FusionBartAttentionOpenai +#from fusion_bart_attention import FusionBartAttention +from fusion_bart_attention_openai import FusionBartAttention from fusion_options import FusionOptions from fusion_reshape import FusionReshape from onnx import numpy_helper @@ -125,8 +125,7 @@ class BartOnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) - print("reach") - self.attention_fusion = FusionBartAttentionOpenai(self, self.hidden_size, self.num_heads, self.attention_mask) + self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.bart_reshape_fusion_preprocess = FusionBartReshape(self) def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): @@ -137,7 +136,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo super().optimize(options, add_dynamic_axes) def fuse_attention(self): + #import onnx + #onnx.save_model(self.model, "intermediate.onnx") self.attention_fusion.apply() + import onnx + onnx.save_model(self.model, "intermediate2.onnx") def preprocess(self): self.adjust_reshape_and_expand() diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 431e64509e3cc..f85df2d8a737b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -8,7 +8,8 @@ from convert_to_packing_mode import PackingMode from fusion_attention import AttentionMask, FusionAttention -from fusion_bart_attention import FusionBartAttention +#from fusion_bart_attention import FusionBartAttention +from fusion_bart_attention_openai import FusionBartAttention from fusion_biasgelu import FusionBiasGelu from fusion_embedlayer import FusionEmbedLayerNormalization from fusion_fastgelu import FusionFastGelu From 27b5cfc941b62c5c12897416d20cda85af4a5236 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 26 Oct 2023 12:30:11 +0000 Subject: [PATCH 05/10] minor code cleanup --- .../fusion_bart_attention_openai.py | 28 +--- .../tools/transformers/fusion_options.py | 1 + .../models/whisper/convert_to_onnx.py | 5 +- .../whisper/whisper_encoder_decoder_init.py | 1 + .../models/whisper/whisper_helper.py | 137 +++++++++--------- .../tools/transformers/onnx_model_bart.py | 18 ++- .../tools/transformers/onnx_model_bert.py | 6 +- .../python/tools/transformers/optimizer.py | 5 +- 8 files changed, 92 insertions(+), 109 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py index ad07f444f551b..e738de0b6e9e3 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class FusionBartAttention(FusionAttention): +class FusionBartAttentionOpenai(FusionAttention): """ Fuse Bart Attention subgraph into one Attention node. """ @@ -75,14 +75,12 @@ def check_runtime_shape_path( def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # SkipLayerNormalization has two inputs, and one of them is the root input for attention. - print("\n") qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 1, 0, 0, 0], ) if qkv_nodes is not None: - print("Reached qkv level") ( add_out, matmul_out, @@ -152,7 +150,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past_v, present_v = "", "" reshape_v_2, add_v = None, None if v_nodes is not None: - print("reach v path") (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) #present_v = add_v.output[0] @@ -176,7 +173,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if reshape_child.op_type == "Transpose": if reshape_child.output[0] in graph_output_names: present_v = reshape_child.output[0] - print("reach v path with past self attn") concat_v_parents = self.model.get_parents(child) for concat_parent in concat_v_parents: if concat_parent.op_type == "Reshape": @@ -185,7 +181,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if reshape_parent.op_type == "Transpose": if reshape_parent.input[0] in graph_input_names: past_v = reshape_parent.input[0] - print("reach v path with past self attn") elif v_nodes_with_past_self_attn is not None: (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn @@ -195,7 +190,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif ( v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names ): - print("reach v path with past cross attn") v_nodes = v_nodes_with_past_cross_attn past_v = v_nodes[-1].input[0] present_v = v_nodes[-1].output[0] @@ -215,11 +209,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0] ) if qk_nodes_1 is not None: - print("reach qk type 1") _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: - print("reach qk type 2") _, add_qk, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 else: @@ -231,7 +223,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [0, 0, 0, 0, 1], ) if q_nodes is not None: - print("reach q path") mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes else: return @@ -261,7 +252,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past_k, present_k = "", "" reshape_k_2, reshape_k_1, matmul_k = None, None, None if k_nodes_with_bias is not None: - print("reach k path") mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias k_nodes = k_nodes_with_bias present_k = matmul_k.output[0] @@ -287,7 +277,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if reshape_child.op_type == "Transpose": if reshape_child.output[0] in graph_output_names: present_k = reshape_child.output[0] - print("reach k path with past self attn") concat_k_parents = self.model.get_parents(child) for concat_parent in concat_k_parents: if concat_parent.op_type == "Reshape": @@ -296,8 +285,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if reshape_parent.op_type == "Transpose": if reshape_parent.input[0] in graph_input_names: past_k = reshape_parent.input[0] - print("reach v path with past self attn") - print("reach k path with past self attn") #else: # matmul_k.output[0] = mat_k_out_tmp @@ -316,7 +303,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes_no_bias_with_past_cross_attn is not None and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names ): - print("reach k path with past cross attn") k_nodes = k_nodes_no_bias_with_past_cross_attn past_k = k_nodes[-1].input[0] present_k = k_nodes[-1].output[0] @@ -361,14 +347,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input ) - if one_root_input: print("one root input") two_root_inputs = ( not three_root_inputs and matmul_q.input[0] == root_input and matmul_k.input[0] == matmul_v.input[0] and matmul_k.input[0] != matmul_q.input[0] ) - if two_root_inputs: print("two root inputs") # There are 5 types of attention: # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1 @@ -381,14 +365,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): decoder_attention_with_past = decoder_attention and past_k and past_v decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 - print(three_root_inputs) - print( - encoder_attention, - decoder_attention, - decoder_attention_with_past, - decoder_cross_attention, - decoder_cross_attention_with_past, - ) # For decoder_attention, the attention mask needs to be included in the attention node mask_index = None @@ -404,7 +380,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [1, 0, 2, 0], ) if mask_nodes_whisper is not None: - print("reach qk add") + pass #mask_index = mask_nodes_whisper[0].output[-1] elif mask_nodes_bart is not None: mask_index = mask_nodes_bart[0].output[-1] diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 4c43e4487bfb1..76d603cb0e56b 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -59,6 +59,7 @@ def __init__(self, model_type): if model_type == "clip": self.enable_embed_layer_norm = False + self.model_impl = "hf" # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 9a49b146009e8..36e6e2a845e75 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -176,7 +176,7 @@ def parse_arguments(argv=None): action="store_true", help="Produce beam search model with chained encdecinit and decoder.", ) - parser.set_defaults(chain_model=False) + parser.set_defaults(chain_model=True) parser.add_argument( "--use_whisper_beamsearch", @@ -333,7 +333,7 @@ def export_onnx_models( models = WhisperHelper.load_model( model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path ) - config = models["encoder_decoder_init"].config + config = models["decoder"].config if (not use_external_data_format) and (config.num_hidden_layers > 24): logger.info("Try use_external_data_format when model size > 2GB") @@ -390,6 +390,7 @@ def export_onnx_models( auto_mixed_precision=not disable_auto_mixed_precision, use_gpu=use_gpu, provider=provider, + model_impl=model_impl, ) onnx_path = output_path diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 351173f525727..250cca2ed8ebd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -8,6 +8,7 @@ import logging import os import tempfile +import copy from pathlib import Path from typing import List, Optional diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e9ef82fc1fad8..af56ac49a06bd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -292,12 +292,14 @@ def optimize_onnx( auto_mixed_precision: bool = True, use_gpu: bool = False, provider: str = "cpu", + model_impl: str = "hf", ): """Optimize ONNX model with an option to convert it to use mixed precision.""" from fusion_options import FusionOptions optimization_options = FusionOptions("bart") + optimization_options.model_impl = model_impl optimization_options.use_multi_head_attention = True optimization_options.disable_multi_head_attention_bias = provider == "rocm" @@ -349,76 +351,71 @@ def verify_onnx( from datasets import load_dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - for d in ds: - input_features = processor([d["audio"]["array"]], return_tensors="pt").input_features - - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 - length_penalty, repetition_penalty = 1.0, 1.0 - inputs = { - "input_features": input_features.to(device), - "max_length": max_length, - "min_length": min_length, - "num_beams": num_beams, - "num_return_sequences": num_return_sequences, - "length_penalty": length_penalty, - "repetition_penalty": repetition_penalty, - "early_stopping": True, - "use_cache": True, - } - pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() - - del inputs["early_stopping"] - del inputs["use_cache"] - ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) - ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) - ort_to_np = { - "tensor(float)": np.float32, - "tensor(float16)": np.float16, - "tensor(int64)": np.int64, - "tensor(int32)": np.int32, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - } + input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + length_penalty, repetition_penalty = 1.0, 1.0 + inputs = { + "input_features": input_features.to(device), + "max_length": max_length, + "min_length": min_length, + "num_beams": num_beams, + "num_return_sequences": num_return_sequences, + "length_penalty": length_penalty, + "repetition_penalty": repetition_penalty, + "early_stopping": True, + "use_cache": True, + } + pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + + del inputs["early_stopping"] + del inputs["use_cache"] + ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) + ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) + ort_to_np = { + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(int64)": np.int64, + "tensor(int32)": np.int32, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + } + + for name, dtype in zip(ort_names, ort_dtypes): + if name == "input_features": + inputs[name] = inputs[name].detach().cpu().numpy() + elif name == "vocab_mask": + inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype]) + elif name == "prefix_vocab_mask": + inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) + elif name == "decoder_input_ids": + inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) + elif name == "logits_processor": + inputs[name] = np.array([1], dtype=ort_to_np[dtype]) + else: + inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) + ort_outputs = ort_session.run(None, inputs)[0][0] - for name, dtype in zip(ort_names, ort_dtypes): - if name == "input_features": - inputs[name] = inputs[name].detach().cpu().numpy() - elif name == "vocab_mask": - inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype]) - elif name == "prefix_vocab_mask": - inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) - elif name == "decoder_input_ids": - inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) - elif name == "logits_processor": - inputs[name] = np.array([1], dtype=ort_to_np[dtype]) - else: - inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) - ort_outputs = ort_session.run(None, inputs)[0][0] - - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") - - #diff = pt_outputs - ort_outputs - #max_diff = max(diff.min(), diff.max(), key=abs) - #print(max_diff) - - if max_diff == 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - print(pt_transcription) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) - print(ort_transcription) - - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + if pt_outputs.shape != ort_outputs.shape: + logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + diff = pt_outputs - ort_outputs + max_diff = max(diff.min(), diff.max(), key=abs) + + if max_diff == 0: + # For ONNX Runtime INT8 model + pt_expected_transcription = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) + ort_expected_transcription = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + + parity = ( + pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] + ) + if parity: + max_diff = 0 return max_diff diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 24083112eb0ac..1ef6a4329cb28 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -6,8 +6,8 @@ from typing import Optional from fusion_attention import AttentionMask -#from fusion_bart_attention import FusionBartAttention -from fusion_bart_attention_openai import FusionBartAttention +from fusion_bart_attention import FusionBartAttention +from fusion_bart_attention_openai import FusionBartAttentionOpenai from fusion_options import FusionOptions from fusion_reshape import FusionReshape from onnx import numpy_helper @@ -125,7 +125,15 @@ class BartOnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) - self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + if model_impl == "openai": + self.attention_fusion = FusionBartAttentionOpenai( + self, + self.hidden_size, + self.num_heads, + self.attention_mask + ) + else: + self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.bart_reshape_fusion_preprocess = FusionBartReshape(self) def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): @@ -136,11 +144,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo super().optimize(options, add_dynamic_axes) def fuse_attention(self): - #import onnx - #onnx.save_model(self.model, "intermediate.onnx") self.attention_fusion.apply() - import onnx - onnx.save_model(self.model, "intermediate2.onnx") def preprocess(self): self.adjust_reshape_and_expand() diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index f85df2d8a737b..6f061052716e7 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -8,8 +8,8 @@ from convert_to_packing_mode import PackingMode from fusion_attention import AttentionMask, FusionAttention -#from fusion_bart_attention import FusionBartAttention -from fusion_bart_attention_openai import FusionBartAttention +from fusion_bart_attention import FusionBartAttention +from fusion_bart_attention_openai import FusionBartAttentionOpenai from fusion_biasgelu import FusionBiasGelu from fusion_embedlayer import FusionEmbedLayerNormalization from fusion_fastgelu import FusionFastGelu @@ -350,7 +350,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) - if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): + if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention) and not isinstance(self.attention_fusion, FusionBartAttentionOpenai): self.attention_fusion = FusionAttention( self, self.hidden_size, diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ce0be6b3449ed..584452bdf004a 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -226,7 +226,10 @@ def optimize_by_fusion( if optimization_options is None: optimization_options = FusionOptions(model_type) - optimizer = optimizer_class(model, num_heads, hidden_size) + if optimization_options.model_impl == "openai": + optimizer = optimizer_class(model, num_heads, hidden_size, model_impl=optimization_options.model_impl) + else: + optimizer = optimizer_class(model, num_heads, hidden_size) optimizer.optimize(optimization_options) From 4b2374459c27a46878d5cb8aaa21162e1ad3d544 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Tue, 7 Nov 2023 23:32:20 +0000 Subject: [PATCH 06/10] Modify method of model name input --- .../python/tools/transformers/models/whisper/whisper_helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index af56ac49a06bd..123686e3fbc29 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -352,6 +352,7 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features + prompt_ids_list = [config.decoder_start_token_id, 50259, 50359, 50363] batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 length_penalty, repetition_penalty = 1.0, 1.0 @@ -389,7 +390,7 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) + inputs[name] = np.array([prompt_ids_list], dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) else: From 2781dbeb97d1705d8ecdf686b3c8f34a0e59dbf5 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 22 Nov 2023 17:58:50 +0000 Subject: [PATCH 07/10] Revert rebase confilcts --- .../core/graph/contrib_ops/contrib_defs.cc | 72 +++++++++++++++++++ .../models/whisper/whisper_helper.py | 12 +++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index dea668f6d1ace..27c968a59eb91 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1061,6 +1061,78 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, updateOutputShape(ctx, 0, {N, C, H_out, W_out}); })); +ONNX_MS_OPERATOR_SET_SCHEMA( + UnfoldTensor, 1, + OpSchema() + .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. " + "Step between two slices is given by step. " + "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in " + "the returned tensor will be (sizedim - size) / step + 1. " + "An additional dimension of size size is appended in the returned tensor.") + .Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1)) + .Attr("size", "specify the size", AttributeProto::INT) + .Attr("step", "specify the step.", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "input tensor", "T") + .Output(0, "output", "Output tensor.", "T") + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (!hasInputShape(ctx, 0)) return; + auto& input_shape = getInputShape(ctx, 0); + const int rank = input_shape.dim_size(); + int64_t dim = getAttribute(ctx, "dim", -1); + dim = HandleNegativeAxis(dim, rank); + if (!input_shape.dim(static_cast(dim)).has_dim_value()) { + return; + } + int64_t dim_size = input_shape.dim(static_cast(dim)).dim_value(); + + const int64_t step = getAttribute(ctx, "step", -1); + if (step <= 0) { + fail_shape_inference("size attribute in UnfoldTensor must greater than 0.") + } + int64_t size = -1; + auto size_proto = ctx.getAttribute("size"); + if (!(size_proto)) { + fail_shape_inference("size attribute in UnfoldTensor not specified!") + } + size = size_proto->i(); + if (size > dim_size || size <= 0) { + fail_shape_inference("size attribute in UnfoldTensor not positive and less than the dim size!") + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (int d = 0; d < rank; d++) { + if (d == dim) { + output_shape.add_dim()->set_dim_value((dim_size - size) / step + 1); + } else { + *output_shape.add_dim() = input_shape.dim(d); + } + } + output_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, output_shape); + })); + +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicTimeWarping, 1, + OpSchema() + .SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point" + "(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return " + "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" + "have the lowest cost among all paths from (0, 0) to (M-1, N-1).") + .Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F") + .Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I") + .TypeConstraint("F", {"tensor(float)"}, "Constrain to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); + ONNX_NAMESPACE::TensorShapeProto resultShape; + resultShape.add_dim()->set_dim_value(2); + resultShape.add_dim(); + updateOutputShape(ctx, 0, resultShape); + })); + ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, OpSchema() .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 123686e3fbc29..30d4edcc4a476 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -382,6 +382,7 @@ def verify_onnx( "tensor(uint8)": np.uint8, } + use_extra_decoding_ids = "extra_decoding_ids" in ort_names for name, dtype in zip(ort_names, ort_dtypes): if name == "input_features": inputs[name] = inputs[name].detach().cpu().numpy() @@ -390,9 +391,18 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - inputs[name] = np.array([prompt_ids_list], dtype=ort_to_np[dtype]) + raw_input_ids = ( + [[config.decoder_start_token_id]] + if use_extra_decoding_ids + else [[config.decoder_start_token_id, 50259, 50359, 50363]] + ) + inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) + elif name == "cross_qk_layer_head": + inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) + elif name == "extra_decoding_ids": + inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] From f5ff690ae503d3c7dd3e669684cc8fa841593205 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 14 Dec 2023 00:04:17 +0000 Subject: [PATCH 08/10] add new parameters for batch decoding --- .../cpu/transformers/beam_search.cc | 2 +- .../transformers/beam_search_impl_whisper.h | 8 +++++ .../transformers/beam_search_parameters.cc | 36 +++++++++++++++++++ .../transformers/generation_device_helper.cc | 32 +++++++++++++++-- .../transformers/generation_device_helper.h | 12 +++++-- .../cpu/transformers/generation_shared.h | 4 +++ .../transformers/subgraph_whisper_encoder.cc | 14 ++++++-- .../transformers/subgraph_whisper_encoder.h | 2 ++ .../cuda/transformers/beam_search.cc | 2 ++ .../core/graph/contrib_ops/contrib_defs.cc | 2 ++ .../fusion_bart_attention_openai.py | 10 ++++-- .../models/whisper/whisper_chain.py | 19 +++++++++- .../whisper/whisper_encoder_decoder_init.py | 30 +++++++++++++--- .../models/whisper/whisper_helper.py | 18 ++++++++-- .../models/whisper/whisper_openai_helper.py | 4 ++- 15 files changed, 175 insertions(+), 20 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 93cda00e5a3c3..58e842e72d83b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -169,7 +169,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, ORT_RETURN_IF_ERROR(whisper_encoder_subgraph_->Setup(session_state, subgraph_session_state)); encoder_feeds_fetches_manager_ = whisper_encoder_subgraph_->GetFeedsFetchesManager(); - ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs != 2, + ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs < 2, "Encoder subgraph shall have 2 inputs (encoder_input_ids, decoder_input_ids)"); } else if (attribute_name == "decoder") { ORT_ENFORCE(whisper_decoder_subgraph_ == nullptr, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 72e6d3930a548..e77f0a8d572f5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -153,6 +153,12 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe const OrtValue* encoder_input_ids_value = this->context_.GetInputOrtValue(0); const Tensor& encoder_input_ids = encoder_input_ids_value->Get(); + const OrtValue* left_pad_mask_value = this->context_.GetInputOrtValue(15); + const Tensor& left_pad_mask = left_pad_mask_value->Get(); + + const OrtValue* position_ids_value = this->context_.GetInputOrtValue(16); + const Tensor& position_ids = position_ids_value->Get(); + BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, this->IsCuda(), @@ -166,6 +172,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds( encoder_input_ids, initial_decoder_input_ids_value, + left_pad_mask, + position_ids, parameters->decoder_start_token_id, this->implicit_inputs_, encoder_feeds, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bb6885c3216bc..d0cb77c2f5600 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -64,6 +64,40 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { } } + left_pad_mask = gsl::span(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper && left_pad_mask_input_id > 0) { + const Tensor* left_pad_mask_tensor = context->Input(left_pad_mask_input_id); + if (left_pad_mask_tensor != nullptr) { + const auto& left_pad_mask_tensor_dims = left_pad_mask_tensor->Shape().GetDims(); + ORT_ENFORCE(left_pad_mask_tensor_dims.size() == 4, + "left_pad_mask_tensor shall have 4 dimensions. Got ", + left_pad_mask_tensor_dims.size()); + ORT_ENFORCE(left_pad_mask_tensor_dims[0] == batch_size, + "left_pad_mask_tensor first dim not same as batch_size. Got ", + left_pad_mask_tensor_dims[0], ", expecting ", batch_size); + if (left_pad_mask_tensor->Shape().Size() > 0) { + left_pad_mask = gsl::span(left_pad_mask_tensor->Data(), (size_t)left_pad_mask_tensor->Shape().Size()); + } + } + } + + position_ids = gsl::span(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper && position_ids_input_id > 0) { + const Tensor* position_ids_tensor = context->Input(position_ids_input_id); + if (position_ids_tensor != nullptr) { + const auto& position_ids_tensor_dims = position_ids_tensor->Shape().GetDims(); + ORT_ENFORCE(position_ids_tensor_dims.size() == 2, + "position_ids_tensor shall have 2 dimensions. Got ", + position_ids_tensor_dims.size()); + ORT_ENFORCE(position_ids_tensor_dims[0] == batch_size, + "position_ids_tensor first dim not same as batch_size. Got ", + position_ids_tensor_dims[0], ", expecting ", batch_size); + if (position_ids_tensor->Shape().Size() > 0) { + position_ids = gsl::span(position_ids_tensor->Data(), (size_t)position_ids_tensor->Shape().Size()); + } + } + } + if (this->model_type == IGenerationParameters::kModelTypeGpt) { sequence_length = static_cast(dims[1]); } else if (this->model_type == IGenerationParameters::kModelTypeWhisper) { @@ -156,6 +190,8 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; + left_pad_mask_input_id = 14; + position_ids_input_id = 15; cross_qk_output_id = 3; no_speech_probs_output_id = 4; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 927d3a58e5a6f..557b78ed6b8a2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -877,10 +877,14 @@ template Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_features, - OrtValue& decoder_input_ids) { + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids) { const TensorShape& input_features_shape = original_encoder_input_features->Shape(); ORT_ENFORCE(input_features_shape.NumDimensions() == 3); const int64_t& batch_size = input_features_shape[0]; @@ -898,6 +902,20 @@ Status CreateWhisperEncoderInputs( allocator->Info(), encoder_input_features); + const TensorShape& left_pad_mask_shape = original_left_pad_mask->Shape(); + Tensor::InitOrtValue(DataTypeImpl::GetType(), + left_pad_mask_shape, + const_cast(original_left_pad_mask)->MutableData(), + allocator->Info(), + left_pad_mask); + + const TensorShape& position_ids_shape = original_position_ids->Shape(); + Tensor::InitOrtValue(DataTypeImpl::GetType(), + position_ids_shape, + const_cast(original_position_ids)->MutableData(), + allocator->Info(), + position_ids); + // decoder_input_ids is optional. if (original_decoder_input_ids_value == nullptr) { // Filled decoder_input_ids with start token ID @@ -1071,18 +1089,26 @@ template Status ExpandBuffer( template Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_features, - OrtValue& decoder_input_ids); + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids); template Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_features, - OrtValue& decoder_input_ids); + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids); Status UpdateDecoderCrossQK( [[maybe_unused]] int iteration_number, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 6dfdc6b027671..3923f624f1aea 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -190,10 +190,14 @@ using UpdateDecoderFeedsFunc = std::function; + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids)>; template using ExpandBufferFunc = std::function Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_ids, - OrtValue& decoder_input_ids); + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids); // --------------------------------------------------------------- // Utility Functions diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cb62e2f7bf4da..2a79e9f3488ec 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,11 +183,15 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; + gsl::span left_pad_mask; + gsl::span position_ids; int32_t no_speech_token = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; int extra_decoding_ids_input_id = -1; + int left_pad_mask_input_id = -1; + int position_ids_input_id = -1; int cross_qk_output_id = -1; int no_speech_probs_output_id = -1; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc index 8480edc405e53..f6bde4123e709 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc @@ -40,7 +40,7 @@ namespace transformers { Status WhisperEncoderSubgraph::Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) { - ORT_RETURN_IF(num_subgraph_inputs != 2, "expect 2 inputs, got:", num_subgraph_inputs); + ORT_RETURN_IF(num_subgraph_inputs < 2, "expect 2 inputs, got:", num_subgraph_inputs); ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs); ORT_RETURN_IF((static_cast(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0, @@ -95,6 +95,8 @@ Status WhisperEncoderSubgraph::Validate(const std::vector& subgr Status WhisperEncoderSubgraph::CreateInitialFeeds( const Tensor& original_encoder_input_ids, const OrtValue* original_decoder_input_ids_value, + const Tensor& original_left_pad_mask, + const Tensor& original_position_ids, int start_token_id, const std::vector& implicit_inputs, std::vector& feeds, @@ -117,12 +119,18 @@ Status WhisperEncoderSubgraph::CreateInitialFeeds( ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr"); OrtValue encoder_input_ids; + OrtValue left_pad_mask; + OrtValue position_ids; ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids, original_decoder_input_ids_value, + &original_left_pad_mask, + &original_position_ids, start_token_id, cpu_allocator, encoder_input_ids, - decoder_input_ids)); + decoder_input_ids, + left_pad_mask, + position_ids)); const IExecutionProvider* provider = GetProvider(); AllocatorPtr default_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeDefault)); @@ -130,7 +138,7 @@ Status WhisperEncoderSubgraph::CreateInitialFeeds( const OrtMemoryInfo& location = default_allocator->Info(); ORT_RETURN_IF_ERROR(add_to_feeds_func( ort_stream, - {encoder_input_ids, decoder_input_ids}, + {encoder_input_ids, decoder_input_ids, left_pad_mask, position_ids}, feeds, buffer, default_allocator, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h index 10e7f43b0ea83..e50e1527a1881 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h @@ -22,6 +22,8 @@ class WhisperEncoderSubgraph : public T5EncoderSubgraph { Status CreateInitialFeeds( const Tensor& encoder_input_ids, const OrtValue* original_decoder_input_ids_value, + const Tensor& left_pad_mask, + const Tensor& position_ids, int start_token_id, const std::vector& implicit_inputs, std::vector& feeds, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 08cbb145a6f65..95e4728750702 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -50,6 +50,8 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 15) // 'left_pad_mask' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 16) // 'position_ids' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..d9dddace070d4 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1232,6 +1232,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) + .Input(15, "left_pad_mask", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "T", OpSchema::Optional) + .Input(16, "position_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py index e738de0b6e9e3..c2bd74ff33c0e 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py @@ -206,13 +206,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) qk_nodes_2 = self.model.match_parent_path( - matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0] + matmul_qkv, ["Softmax", "Add", "Add", "MatMul"], [0, 0, 0, 0] ) if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: - _, add_qk, matmul_qk = qk_nodes_2 + _, add_left_pad_mask, add_qk, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 else: return @@ -385,6 +385,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif mask_nodes_bart is not None: mask_index = mask_nodes_bart[0].output[-1] + left_pad_mask_index = None + if decoder_attention: + left_pad_mask_index = add_left_pad_mask.input[1] + if ( encoder_attention or decoder_attention @@ -440,7 +444,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): hidden_size, root_input, attention_last_node.output[0], - add_qk_str=mask_index if decoder_attention else None, + add_qk_str=left_pad_mask_index if decoder_attention else None, past_k=past_k, past_v=past_v, present_k=present_k, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index a74666b7af297..40f932a4c5c43 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -49,6 +49,10 @@ def chain_model(args): "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "", + "", + "left_pad_mask" if args.use_forced_decoder_ids else "", + "position_ids" if args.use_forced_decoder_ids else "", ] beam_outputs = ["sequences"] @@ -58,13 +62,15 @@ def chain_model(args): beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 + #assert len(beam_inputs) == 1 + ''' beam_inputs.extend( [ "cross_qk_layer_head" if args.collect_cross_qk else "", "extra_decoding_ids" if args.extra_decoding_ids else "", ] ) + ''' if args.collect_cross_qk: while len(beam_outputs) < 3: beam_outputs.extend([""]) @@ -169,6 +175,17 @@ def chain_model(args): ) graph_inputs.append(decoder_input_ids) + left_pad_mask = helper.make_tensor_value_info( + "left_pad_mask", TensorProto.INT32, ["batch_size", 1, "initial_sequence_length", "initial_sequence_length",] + ) + graph_inputs.append(left_pad_mask) + + position_ids = helper.make_tensor_value_info( + "position_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + graph_inputs.append(position_ids) + + if args.use_logits_processor: logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) graph_inputs.append(logits_processor) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 250cca2ed8ebd..68685f1a8b8a6 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -52,12 +52,14 @@ def forward( self, encoder_input_ids: torch.Tensor, decoder_input_ids: torch.Tensor = None, + left_pad_mask: torch.Tensor = None, + position_ids: torch.Tensor = None, ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) - decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states, left_pad_mask=left_pad_mask, position_ids=position_ids) return decinit_out, encoder_hidden_states, present else: decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) @@ -67,9 +69,11 @@ def forward( class WhisperEncoderDecoderInitInputs: - def __init__(self, encoder_input_ids, decoder_input_ids=None): + def __init__(self, encoder_input_ids, decoder_input_ids=None, left_pad_mask=None, position_ids=None): self.encoder_input_ids: torch.LongTensor = encoder_input_ids self.decoder_input_ids: torch.LongTensor = decoder_input_ids + self.left_pad_mask: torch.LongTensor = left_pad_mask + self.position_ids: torch.LongTensor = position_ids @staticmethod def create_dummy( @@ -90,13 +94,17 @@ def create_dummy( if use_decoder_input_ids: dtype = torch.int32 if use_int32_inputs else torch.int64 decoder_input_ids = torch.ones((batch_size, 2), dtype=dtype, device=device) * config.decoder_start_token_id + left_pad_mask = torch.zeros((batch_size, 1, 2, 2), dtype=dtype, device=device) + position_ids = torch.zeros((batch_size, 2), dtype=dtype, device=device) - return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids) + return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids, left_pad_mask, position_ids) def to_list(self) -> List: input_list = [self.encoder_input_ids] if self.decoder_input_ids is not None: input_list.append(self.decoder_input_ids) + input_list.append(self.left_pad_mask) + input_list.append(self.position_ids) return input_list @@ -134,7 +142,7 @@ def export_onnx( # TODO : Investigate whether copy of model if needed cloned_model = copy.deepcopy(model).to(device) - out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) + out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids, inputs.left_pad_mask, inputs.position_ids) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) @@ -179,6 +187,18 @@ def export_onnx( 1: "decode_sequence_length", } + input_names.append("left_pad_mask") + dynamic_axes["left_pad_mask"] = { + 0: "batch_size", + 2: "decode_sequence_length", + 3: "decode_sequence_length", + } + input_names.append("position_ids") + dynamic_axes["position_ids"] = { + 0: "batch_size", + 1: "decode_sequence_length", + } + for name in present_names: if "cross" in name: dynamic_axes[name] = { @@ -244,6 +264,8 @@ def onnxruntime_inference(ort_session, inputs: WhisperEncoderDecoderInitInputs): } if inputs.decoder_input_ids is not None: ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()) + ort_inputs["left_pad_mask"] = numpy.ascontiguousarray(inputs.left_pad_mask.cpu().numpy()) + ort_inputs["position_ids"] = numpy.ascontiguousarray(inputs.position_ids.cpu().numpy()) ort_outputs = ort_session.run(None, ort_inputs) return ort_outputs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 30d4edcc4a476..9a0fa6ddd50c4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -352,7 +352,6 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - prompt_ids_list = [config.decoder_start_token_id, 50259, 50359, 50363] batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 length_penalty, repetition_penalty = 1.0, 1.0 @@ -382,8 +381,15 @@ def verify_onnx( "tensor(uint8)": np.uint8, } + # Generate prompts + prompt_text = "Christians" + prompt_ids = processor.get_prompt_ids(prompt_text) + #print(prompt_ids) + #print(processor.decode(pt_model.generate(**inputs, prompt_ids=prompt_ids))) + use_extra_decoding_ids = "extra_decoding_ids" in ort_names for name, dtype in zip(ort_names, ort_dtypes): + print(name, dtype) if name == "input_features": inputs[name] = inputs[name].detach().cpu().numpy() elif name == "vocab_mask": @@ -403,6 +409,10 @@ def verify_onnx( inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "left_pad_mask": + inputs[name] = np.zeros((batch_size, 1, 4, 4), dtype=ort_to_np[dtype]) + elif name == "position_ids": + inputs[name] = np.zeros((batch_size, 4), dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] @@ -413,16 +423,20 @@ def verify_onnx( diff = pt_outputs - ort_outputs max_diff = max(diff.min(), diff.max(), key=abs) - if max_diff == 0: + if True: # For ONNX Runtime INT8 model pt_expected_transcription = ( " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." ) + print(pt_outputs) + print(ort_outputs) pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) + print(pt_transcription) ort_expected_transcription = ( " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." ) ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + print(ort_transcription) parity = ( pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index 941f61cf7cc29..805bbbb6d9c7f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -30,6 +30,8 @@ def forward( tokens, audio_features, past=None, + left_pad_mask=None, + position_ids=None, ): # Create a kv_cache for past_values past_kv_cache = dict() @@ -47,7 +49,7 @@ def forward( if not self.kv_cache: self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() - logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) + logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache, left_pad_mask=left_pad_mask, position_ids=position_ids) # Add concat node for past values if past is not None: From 12f0ff119c9aef6662db6fab5f4cfa138cdf06d0 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Tue, 19 Dec 2023 04:27:49 +0000 Subject: [PATCH 09/10] Correct datatype for position_ids --- .../cpu/transformers/beam_search_parameters.cc | 4 ++-- .../cpu/transformers/generation_device_helper.cc | 4 ++-- .../contrib_ops/cpu/transformers/generation_shared.h | 2 +- .../tools/transformers/fusion_bart_attention_openai.py | 8 +++++++- .../tools/transformers/models/whisper/whisper_chain.py | 2 +- .../tools/transformers/models/whisper/whisper_helper.py | 2 +- 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index d0cb77c2f5600..075d39d45bff1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -64,7 +64,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { } } - left_pad_mask = gsl::span(); + left_pad_mask = gsl::span(); if (this->model_type == IGenerationParameters::kModelTypeWhisper && left_pad_mask_input_id > 0) { const Tensor* left_pad_mask_tensor = context->Input(left_pad_mask_input_id); if (left_pad_mask_tensor != nullptr) { @@ -76,7 +76,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { "left_pad_mask_tensor first dim not same as batch_size. Got ", left_pad_mask_tensor_dims[0], ", expecting ", batch_size); if (left_pad_mask_tensor->Shape().Size() > 0) { - left_pad_mask = gsl::span(left_pad_mask_tensor->Data(), (size_t)left_pad_mask_tensor->Shape().Size()); + left_pad_mask = gsl::span(left_pad_mask_tensor->Data(), (size_t)left_pad_mask_tensor->Shape().Size()); } } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 557b78ed6b8a2..d554bb6345131 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -910,9 +910,9 @@ Status CreateWhisperEncoderInputs( left_pad_mask); const TensorShape& position_ids_shape = original_position_ids->Shape(); - Tensor::InitOrtValue(DataTypeImpl::GetType(), + Tensor::InitOrtValue(element_type, position_ids_shape, - const_cast(original_position_ids)->MutableData(), + const_cast(original_position_ids)->MutableData(), allocator->Info(), position_ids); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 2a79e9f3488ec..e1b8ecacc07b9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,7 +183,7 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - gsl::span left_pad_mask; + gsl::span left_pad_mask; gsl::span position_ids; int32_t no_speech_token = -1; void* no_speech_probs = nullptr; diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py index c2bd74ff33c0e..0926a5f19c5ac 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py @@ -208,12 +208,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes_2 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Add", "Add", "MatMul"], [0, 0, 0, 0] ) + qk_nodes_3 = self.model.match_parent_path( + matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0] + ) if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: _, add_left_pad_mask, add_qk, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 + elif qk_nodes_3 is not None: + _, add_qk, matmul_qk = qk_nodes_3 + qk_nodes = qk_nodes_3 else: return @@ -362,7 +368,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 encoder_attention = one_root_input and qk_nodes == qk_nodes_1 decoder_attention = one_root_input and qk_nodes == qk_nodes_2 - decoder_attention_with_past = decoder_attention and past_k and past_v + decoder_attention_with_past = one_root_input and qk_nodes == qk_nodes_3 and past_k and past_v decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 40f932a4c5c43..75dd1b1f35355 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -176,7 +176,7 @@ def chain_model(args): graph_inputs.append(decoder_input_ids) left_pad_mask = helper.make_tensor_value_info( - "left_pad_mask", TensorProto.INT32, ["batch_size", 1, "initial_sequence_length", "initial_sequence_length",] + "left_pad_mask", TensorProto.FLOAT, ["batch_size", 1, "initial_sequence_length", "initial_sequence_length",] ) graph_inputs.append(left_pad_mask) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 9a0fa6ddd50c4..18ec5c76ad88b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -412,7 +412,7 @@ def verify_onnx( elif name == "left_pad_mask": inputs[name] = np.zeros((batch_size, 1, 4, 4), dtype=ort_to_np[dtype]) elif name == "position_ids": - inputs[name] = np.zeros((batch_size, 4), dtype=ort_to_np[dtype]) + inputs[name] = np.repeat(np.array([[0, 1, 2, 3]], dtype=ort_to_np[dtype]), batch_size, 0) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] From d4e84abd5f679c561c4c3f56d583cf73eb4f732d Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Tue, 20 Feb 2024 15:25:40 +0000 Subject: [PATCH 10/10] fix linting issues --- .../transformers/beam_search_parameters.cc | 4 +-- .../core/graph/contrib_ops/contrib_defs.cc | 4 +-- .../fusion_bart_attention_openai.py | 31 ++++++++----------- .../models/whisper/whisper_chain.py | 16 +++++++--- .../whisper/whisper_encoder_decoder_init.py | 9 ++++-- .../models/whisper/whisper_helper.py | 27 +++++----------- .../models/whisper/whisper_openai_helper.py | 4 ++- .../tools/transformers/onnx_model_bart.py | 5 +-- .../tools/transformers/onnx_model_bert.py | 6 +++- 9 files changed, 50 insertions(+), 56 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 075d39d45bff1..c2c99c67ed8a1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -190,8 +190,8 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; - left_pad_mask_input_id = 14; - position_ids_input_id = 15; + left_pad_mask_input_id = 15; + position_ids_input_id = 16; cross_qk_output_id = 3; no_speech_probs_output_id = 4; } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index d9dddace070d4..49a62b739697c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1232,8 +1232,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) - .Input(15, "left_pad_mask", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "T", OpSchema::Optional) - .Input(16, "position_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) + .Input(15, "left_pad_mask", "The mask is added to qk node in the qkv attention function. Shape is (batch_size, initial_sequence_length)", "T", OpSchema::Optional) + .Input(16, "position_ids", "Used to select indices of positional embeddings. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py index 0926a5f19c5ac..3e4aa7add68e8 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py @@ -127,8 +127,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input = output break - graph_input_names = set([node.name for node in self.model.graph().input]) - graph_output_names = set([node.name for node in self.model.graph().output]) + graph_input_names = set([node.name for node in self.model.graph().input]) + graph_output_names = set([node.name for node in self.model.graph().output]) v_nodes = self.model.match_parent_path( matmul_qkv, @@ -152,13 +152,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if v_nodes is not None: (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - #present_v = add_v.output[0] + # present_v = add_v.output[0] add_v_children = self.model.get_children(add_v) for child in add_v_children: if child.op_type == "Reshape": - #if child.output[0] in graph_output_names: - #present_v = child.output[0] + # if child.output[0] in graph_output_names: + # present_v = child.output[0] reshape_v_children = self.model.get_children(child) for reshape_child in reshape_v_children: if reshape_child.op_type == "Transpose": @@ -205,12 +205,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): present_v = present_v if present_v in graph_output_names else "" qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) - qk_nodes_2 = self.model.match_parent_path( - matmul_qkv, ["Softmax", "Add", "Add", "MatMul"], [0, 0, 0, 0] - ) - qk_nodes_3 = self.model.match_parent_path( - matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0] - ) + qk_nodes_2 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Add", "MatMul"], [0, 0, 0, 0]) + qk_nodes_3 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 @@ -262,12 +258,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes = k_nodes_with_bias present_k = matmul_k.output[0] mat_k_out_tmp = matmul_k.output[0] + "_temp" - #matmul_k.output[0] = matmul_k.output[0] + "_temp" + # matmul_k.output[0] = matmul_k.output[0] + "_temp" matmul_k_children = self.model.get_children(matmul_k) for child in matmul_k_children: if child.op_type == "Reshape": - #if child.output[0] in graph_output_names: + # if child.output[0] in graph_output_names: # present_k = child.output[0] reshape_k_children = self.model.get_children(child) for reshape_child in reshape_k_children: @@ -291,10 +287,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if reshape_parent.op_type == "Transpose": if reshape_parent.input[0] in graph_input_names: past_k = reshape_parent.input[0] - #else: + # else: # matmul_k.output[0] = mat_k_out_tmp - elif k_nodes_no_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias k_nodes = k_nodes_no_bias @@ -334,7 +329,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) - ''' + """ if not past_k and not self.check_runtime_shape_path( reshape_qkv_2, reshape_qkv_1, @@ -344,7 +339,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input, ): return - ''' + """ three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals() one_root_input = ( @@ -387,7 +382,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) if mask_nodes_whisper is not None: pass - #mask_index = mask_nodes_whisper[0].output[-1] + # mask_index = mask_nodes_whisper[0].output[-1] elif mask_nodes_bart is not None: mask_index = mask_nodes_bart[0].output[-1] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 75dd1b1f35355..0d51623bd176f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -62,15 +62,15 @@ def chain_model(args): beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") if args.use_whisper_beamsearch: - #assert len(beam_inputs) == 1 - ''' + # assert len(beam_inputs) == 1 + """ beam_inputs.extend( [ "cross_qk_layer_head" if args.collect_cross_qk else "", "extra_decoding_ids" if args.extra_decoding_ids else "", ] ) - ''' + """ if args.collect_cross_qk: while len(beam_outputs) < 3: beam_outputs.extend([""]) @@ -176,7 +176,14 @@ def chain_model(args): graph_inputs.append(decoder_input_ids) left_pad_mask = helper.make_tensor_value_info( - "left_pad_mask", TensorProto.FLOAT, ["batch_size", 1, "initial_sequence_length", "initial_sequence_length",] + "left_pad_mask", + TensorProto.FLOAT, + [ + "batch_size", + 1, + "initial_sequence_length", + "initial_sequence_length", + ], ) graph_inputs.append(left_pad_mask) @@ -185,7 +192,6 @@ def chain_model(args): ) graph_inputs.append(position_ids) - if args.use_logits_processor: logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) graph_inputs.append(logits_processor) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 68685f1a8b8a6..11dce01c24e49 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -8,7 +8,6 @@ import logging import os import tempfile -import copy from pathlib import Path from typing import List, Optional @@ -59,7 +58,9 @@ def forward( # Decoder out: (logits, past_key_values, encoder_hidden_state) if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) - decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states, left_pad_mask=left_pad_mask, position_ids=position_ids) + decinit_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, encoder_hidden_states, left_pad_mask=left_pad_mask, position_ids=position_ids + ) return decinit_out, encoder_hidden_states, present else: decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) @@ -142,7 +143,9 @@ def export_onnx( # TODO : Investigate whether copy of model if needed cloned_model = copy.deepcopy(model).to(device) - out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids, inputs.left_pad_mask, inputs.position_ids) + out = cloned_model( + inputs.encoder_input_ids, inputs.decoder_input_ids, inputs.left_pad_mask, inputs.position_ids + ) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 18ec5c76ad88b..79d717c17b67c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,13 +6,15 @@ import logging import os -import io import sys from pathlib import Path from typing import Dict, Tuple, Union import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -20,16 +22,9 @@ from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper -from whisper.model import Whisper, ModelDimensions -from whisper import _MODELS, _ALIGNMENT_HEADS -from whisper import _download - from onnxruntime import InferenceSession sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff -from onnx_model import OnnxModel -from optimizer import optimize_model logger = logging.getLogger(__name__) @@ -348,8 +343,6 @@ def verify_onnx( logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") os.system(install_cmd) - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features @@ -382,14 +375,11 @@ def verify_onnx( } # Generate prompts - prompt_text = "Christians" - prompt_ids = processor.get_prompt_ids(prompt_text) - #print(prompt_ids) - #print(processor.decode(pt_model.generate(**inputs, prompt_ids=prompt_ids))) + # prompt_text = "" + # prompt_ids = processor.get_prompt_ids(prompt_text) use_extra_decoding_ids = "extra_decoding_ids" in ort_names for name, dtype in zip(ort_names, ort_dtypes): - print(name, dtype) if name == "input_features": inputs[name] = inputs[name].detach().cpu().numpy() elif name == "vocab_mask": @@ -423,20 +413,17 @@ def verify_onnx( diff = pt_outputs - ort_outputs max_diff = max(diff.min(), diff.max(), key=abs) - if True: + if max_diff > 0: # For ONNX Runtime INT8 model pt_expected_transcription = ( " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." ) - print(pt_outputs) - print(ort_outputs) pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - print(pt_transcription) + ort_expected_transcription = ( " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." ) ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) - print(ort_transcription) parity = ( pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index 805bbbb6d9c7f..13cdffaa836c3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -49,7 +49,9 @@ def forward( if not self.kv_cache: self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() - logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache, left_pad_mask=left_pad_mask, position_ids=position_ids) + logits = self.whisper_decoder( + tokens, audio_features, kv_cache=past_kv_cache, left_pad_mask=left_pad_mask, position_ids=position_ids + ) # Add concat node for past values if past is not None: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 1ef6a4329cb28..de0a418ae4daa 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -127,10 +127,7 @@ def __init__(self, model, num_heads, hidden_size, model_impl="hf"): self.attention_mask = AttentionMask(self) if model_impl == "openai": self.attention_fusion = FusionBartAttentionOpenai( - self, - self.hidden_size, - self.num_heads, - self.attention_mask + self, self.hidden_size, self.num_heads, self.attention_mask ) else: self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 6f061052716e7..2c77aeb784ec7 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -350,7 +350,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) - if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention) and not isinstance(self.attention_fusion, FusionBartAttentionOpenai): + if ( + options.use_multi_head_attention + and not isinstance(self.attention_fusion, FusionBartAttention) + and not isinstance(self.attention_fusion, FusionBartAttentionOpenai) + ): self.attention_fusion = FusionAttention( self, self.hidden_size,