diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 93cda00e5a3c3..58e842e72d83b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -169,7 +169,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, ORT_RETURN_IF_ERROR(whisper_encoder_subgraph_->Setup(session_state, subgraph_session_state)); encoder_feeds_fetches_manager_ = whisper_encoder_subgraph_->GetFeedsFetchesManager(); - ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs != 2, + ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs < 2, "Encoder subgraph shall have 2 inputs (encoder_input_ids, decoder_input_ids)"); } else if (attribute_name == "decoder") { ORT_ENFORCE(whisper_decoder_subgraph_ == nullptr, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 72e6d3930a548..e77f0a8d572f5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -153,6 +153,12 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe const OrtValue* encoder_input_ids_value = this->context_.GetInputOrtValue(0); const Tensor& encoder_input_ids = encoder_input_ids_value->Get<Tensor>(); + const OrtValue* left_pad_mask_value = this->context_.GetInputOrtValue(15); + const Tensor& left_pad_mask = left_pad_mask_value->Get<Tensor>(); + + const OrtValue* position_ids_value = this->context_.GetInputOrtValue(16); + const Tensor& position_ids = position_ids_value->Get<Tensor>(); + BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, this->IsCuda(), @@ -166,6 +172,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds( encoder_input_ids, initial_decoder_input_ids_value, + left_pad_mask, + position_ids, parameters->decoder_start_token_id, this->implicit_inputs_, encoder_feeds, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bb6885c3216bc..c2c99c67ed8a1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -64,6 +64,40 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { } } + left_pad_mask = gsl::span<float>(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper && left_pad_mask_input_id > 0) { + const Tensor* left_pad_mask_tensor = context->Input<Tensor>(left_pad_mask_input_id); + if (left_pad_mask_tensor != nullptr) { + const auto& left_pad_mask_tensor_dims = left_pad_mask_tensor->Shape().GetDims(); + ORT_ENFORCE(left_pad_mask_tensor_dims.size() == 4, + "left_pad_mask_tensor shall have 4 dimensions. Got ", + left_pad_mask_tensor_dims.size()); + ORT_ENFORCE(left_pad_mask_tensor_dims[0] == batch_size, + "left_pad_mask_tensor first dim not same as batch_size. Got ", + left_pad_mask_tensor_dims[0], ", expecting ", batch_size); + if (left_pad_mask_tensor->Shape().Size() > 0) { + left_pad_mask = gsl::span<const float>(left_pad_mask_tensor->Data<float>(), (size_t)left_pad_mask_tensor->Shape().Size()); + } + } + } + + position_ids = gsl::span<int32_t>(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper && position_ids_input_id > 0) { + const Tensor* position_ids_tensor = context->Input<Tensor>(position_ids_input_id); + if (position_ids_tensor != nullptr) { + const auto& position_ids_tensor_dims = position_ids_tensor->Shape().GetDims(); + ORT_ENFORCE(position_ids_tensor_dims.size() == 2, + "position_ids_tensor shall have 2 dimensions. Got ", + position_ids_tensor_dims.size()); + ORT_ENFORCE(position_ids_tensor_dims[0] == batch_size, + "position_ids_tensor first dim not same as batch_size. Got ", + position_ids_tensor_dims[0], ", expecting ", batch_size); + if (position_ids_tensor->Shape().Size() > 0) { + position_ids = gsl::span<const int32_t>(position_ids_tensor->Data<int32_t>(), (size_t)position_ids_tensor->Shape().Size()); + } + } + } + if (this->model_type == IGenerationParameters::kModelTypeGpt) { sequence_length = static_cast<int>(dims[1]); } else if (this->model_type == IGenerationParameters::kModelTypeWhisper) { @@ -156,6 +190,8 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) no_speech_token = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; + left_pad_mask_input_id = 15; + position_ids_input_id = 16; cross_qk_output_id = 3; no_speech_probs_output_id = 4; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 927d3a58e5a6f..d554bb6345131 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -877,10 +877,14 @@ template <typename T> Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_features, - OrtValue& decoder_input_ids) { + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids) { const TensorShape& input_features_shape = original_encoder_input_features->Shape(); ORT_ENFORCE(input_features_shape.NumDimensions() == 3); const int64_t& batch_size = input_features_shape[0]; @@ -898,6 +902,20 @@ Status CreateWhisperEncoderInputs( allocator->Info(), encoder_input_features); + const TensorShape& left_pad_mask_shape = original_left_pad_mask->Shape(); + Tensor::InitOrtValue(DataTypeImpl::GetType<T>(), + left_pad_mask_shape, + const_cast<Tensor*>(original_left_pad_mask)->MutableData<T>(), + allocator->Info(), + left_pad_mask); + + const TensorShape& position_ids_shape = original_position_ids->Shape(); + Tensor::InitOrtValue(element_type, + position_ids_shape, + const_cast<Tensor*>(original_position_ids)->MutableData<int32_t>(), + allocator->Info(), + position_ids); + // decoder_input_ids is optional. if (original_decoder_input_ids_value == nullptr) { // Filled decoder_input_ids with start token ID @@ -1071,18 +1089,26 @@ template Status ExpandBuffer<MLFloat16>( template Status CreateWhisperEncoderInputs<float>( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_features, - OrtValue& decoder_input_ids); + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids); template Status CreateWhisperEncoderInputs<MLFloat16>( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_features, - OrtValue& decoder_input_ids); + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids); Status UpdateDecoderCrossQK( [[maybe_unused]] int iteration_number, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 6dfdc6b027671..3923f624f1aea 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -190,10 +190,14 @@ using UpdateDecoderFeedsFunc = std::function<Status( using CreateWhisperEncoderInputsFunc = std::function<Status( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_ids, - OrtValue& decoder_input_ids)>; + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids)>; template <typename T> using ExpandBufferFunc = std::function<Status( @@ -376,10 +380,14 @@ template <typename T> Status CreateWhisperEncoderInputs( const Tensor* original_encoder_input_features, const OrtValue* original_decoder_input_ids_value, + const Tensor* original_left_pad_mask, + const Tensor* original_position_ids, int start_token_id, AllocatorPtr allocator, OrtValue& encoder_input_ids, - OrtValue& decoder_input_ids); + OrtValue& decoder_input_ids, + OrtValue& left_pad_mask, + OrtValue& position_ids); // --------------------------------------------------------------- // Utility Functions diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cb62e2f7bf4da..e1b8ecacc07b9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,11 +183,15 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span<const int32_t> extra_decoding_ids; + gsl::span<const float> left_pad_mask; + gsl::span<const int32_t> position_ids; int32_t no_speech_token = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; int extra_decoding_ids_input_id = -1; + int left_pad_mask_input_id = -1; + int position_ids_input_id = -1; int cross_qk_output_id = -1; int no_speech_probs_output_id = -1; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc index 8480edc405e53..f6bde4123e709 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc @@ -40,7 +40,7 @@ namespace transformers { Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs, const std::vector<const NodeArg*>& subgraph_outputs) { - ORT_RETURN_IF(num_subgraph_inputs != 2, "expect 2 inputs, got:", num_subgraph_inputs); + ORT_RETURN_IF(num_subgraph_inputs < 2, "expect 2 inputs, got:", num_subgraph_inputs); ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs); ORT_RETURN_IF((static_cast<int>(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0, @@ -95,6 +95,8 @@ Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr Status WhisperEncoderSubgraph::CreateInitialFeeds( const Tensor& original_encoder_input_ids, const OrtValue* original_decoder_input_ids_value, + const Tensor& original_left_pad_mask, + const Tensor& original_position_ids, int start_token_id, const std::vector<const OrtValue*>& implicit_inputs, std::vector<OrtValue>& feeds, @@ -117,12 +119,18 @@ Status WhisperEncoderSubgraph::CreateInitialFeeds( ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr"); OrtValue encoder_input_ids; + OrtValue left_pad_mask; + OrtValue position_ids; ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids, original_decoder_input_ids_value, + &original_left_pad_mask, + &original_position_ids, start_token_id, cpu_allocator, encoder_input_ids, - decoder_input_ids)); + decoder_input_ids, + left_pad_mask, + position_ids)); const IExecutionProvider* provider = GetProvider(); AllocatorPtr default_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeDefault)); @@ -130,7 +138,7 @@ Status WhisperEncoderSubgraph::CreateInitialFeeds( const OrtMemoryInfo& location = default_allocator->Info(); ORT_RETURN_IF_ERROR(add_to_feeds_func( ort_stream, - {encoder_input_ids, decoder_input_ids}, + {encoder_input_ids, decoder_input_ids, left_pad_mask, position_ids}, feeds, buffer, default_allocator, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h index 10e7f43b0ea83..e50e1527a1881 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h @@ -22,6 +22,8 @@ class WhisperEncoderSubgraph : public T5EncoderSubgraph { Status CreateInitialFeeds( const Tensor& encoder_input_ids, const OrtValue* original_decoder_input_ids_value, + const Tensor& left_pad_mask, + const Tensor& position_ids, int start_token_id, const std::vector<const OrtValue*>& implicit_inputs, std::vector<OrtValue>& feeds, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 08cbb145a6f65..95e4728750702 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -50,6 +50,8 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 15) // 'left_pad_mask' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 16) // 'position_ids' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..49a62b739697c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1232,6 +1232,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) + .Input(15, "left_pad_mask", "The mask is added to qk node in the qkv attention function. Shape is (batch_size, initial_sequence_length)", "T", OpSchema::Optional) + .Input(16, "position_ids", "Used to select indices of positional embeddings. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py new file mode 100644 index 0000000000000..3e4aa7add68e8 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py @@ -0,0 +1,487 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from fusion_attention import AttentionMask, FusionAttention +from onnx import TensorProto, helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionBartAttentionOpenai(FusionAttention): + """ + Fuse Bart Attention subgraph into one Attention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def check_runtime_shape_path( + self, + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ): + concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) + if concat_qkv_2_path is None: + return False + concat_qkv_2 = concat_qkv_2_path[0] + + reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None: + return False + + _, gather_1, shape_1 = reshape_qkv_2_path_1 + _, gather_2, shape_2 = reshape_qkv_2_path_2 + + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0]) + reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0]) + if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None: + return False + if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name: + return False + + reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None: + return False + + mul_q = reshape_q_2_path[-1] + mul_k = reshape_k_2_path[-1] + mul_v = reshape_v_2_path[-1] + + gather_1_out = gather_1.output[0] + if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: + return False + + return True + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) + if qkv_nodes is not None: + ( + add_out, + matmul_out, + reshape_qkv_1, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + else: + return + + other_inputs = [] + for input in normalize_node.input: + if input not in output_name_to_node: + continue + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + return + root_input = other_inputs[0] + + # Sometimes the input name to the attention MatMul nodes does not match the input name to the end + # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul + # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are + # children nodes for each of its output names. + """ + root_input + +---------------------------------------------------+ + | | + | | + SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization + """ + skip_layernorm = output_name_to_node[root_input] + # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose + # child is the LayerNormalization node. + if skip_layernorm.op_type == "Add": + skip_layernorm = self.model.get_children(skip_layernorm)[0] + for output in skip_layernorm.output: + if not output: + continue + children = input_name_to_nodes[output] + children_types = [child.op_type for child in children] + if children_types.count("MatMul") >= 1: + root_input = output + break + + graph_input_names = set([node.name for node in self.model.graph().input]) + graph_output_names = set([node.name for node in self.model.graph().output]) + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, None], + ) + v_nodes_with_past_self_attn = self.model.match_parent_path( + # Decoder attention with past value concatenated before MatMul + matmul_qkv, + ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 1, 0, 0, None], + ) + v_nodes_with_past_cross_attn = self.model.match_parent_path( + # Decoder attention with past value directly used in MatMul + matmul_qkv, + ["Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0], + ) + past_v, present_v = "", "" + reshape_v_2, add_v = None, None + if v_nodes is not None: + (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + # present_v = add_v.output[0] + + add_v_children = self.model.get_children(add_v) + for child in add_v_children: + if child.op_type == "Reshape": + # if child.output[0] in graph_output_names: + # present_v = child.output[0] + reshape_v_children = self.model.get_children(child) + for reshape_child in reshape_v_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_v = reshape_child.output[0] + if child.op_type == "Concat": + concat_v_children = self.model.get_children(child) + for concat_child in concat_v_children: + if concat_child.op_type == "Reshape": + reshape_v_children = self.model.get_children(concat_child) + for reshape_child in reshape_v_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_v = reshape_child.output[0] + concat_v_parents = self.model.get_parents(child) + for concat_parent in concat_v_parents: + if concat_parent.op_type == "Reshape": + reshape_v_parents = self.model.get_parents(concat_parent) + for reshape_parent in reshape_v_parents: + if reshape_parent.op_type == "Transpose": + if reshape_parent.input[0] in graph_input_names: + past_v = reshape_parent.input[0] + + elif v_nodes_with_past_self_attn is not None: + (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn + v_nodes = v_nodes_with_past_self_attn + past_v = concat_v.input[0] + present_v = concat_v.output[0] + elif ( + v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names + ): + v_nodes = v_nodes_with_past_cross_attn + past_v = v_nodes[-1].input[0] + present_v = v_nodes[-1].output[0] + if present_v not in graph_output_names: + identity_node_v = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) + ) + present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" + else: + logger.debug("fuse_attention: failed to match v path") + return + past_v = past_v if past_v in graph_input_names else "" + present_v = present_v if present_v in graph_output_names else "" + + qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + qk_nodes_2 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Add", "MatMul"], [0, 0, 0, 0]) + qk_nodes_3 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + if qk_nodes_1 is not None: + _, matmul_qk = qk_nodes_1 + qk_nodes = qk_nodes_1 + elif qk_nodes_2 is not None: + _, add_left_pad_mask, add_qk, matmul_qk = qk_nodes_2 + qk_nodes = qk_nodes_2 + elif qk_nodes_3 is not None: + _, add_qk, matmul_qk = qk_nodes_3 + qk_nodes = qk_nodes_3 + else: + return + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + if q_nodes is not None: + mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes + else: + return + + k_nodes_with_bias = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], + ) + k_nodes_no_bias = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path( + # Decoder attention with past key concatenated before MatMul + matmul_qk, + ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path( + # Decoder attention with past key directly used in MatMul + matmul_qk, + ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], + ) + past_k, present_k = "", "" + reshape_k_2, reshape_k_1, matmul_k = None, None, None + if k_nodes_with_bias is not None: + mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias + k_nodes = k_nodes_with_bias + present_k = matmul_k.output[0] + mat_k_out_tmp = matmul_k.output[0] + "_temp" + # matmul_k.output[0] = matmul_k.output[0] + "_temp" + + matmul_k_children = self.model.get_children(matmul_k) + for child in matmul_k_children: + if child.op_type == "Reshape": + # if child.output[0] in graph_output_names: + # present_k = child.output[0] + reshape_k_children = self.model.get_children(child) + for reshape_child in reshape_k_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_k = reshape_child.output[0] + if child.op_type == "Concat": + concat_k_children = self.model.get_children(child) + for concat_child in concat_k_children: + if concat_child.op_type == "Reshape": + reshape_k_children = self.model.get_children(concat_child) + for reshape_child in reshape_k_children: + if reshape_child.op_type == "Transpose": + if reshape_child.output[0] in graph_output_names: + present_k = reshape_child.output[0] + concat_k_parents = self.model.get_parents(child) + for concat_parent in concat_k_parents: + if concat_parent.op_type == "Reshape": + reshape_k_parents = self.model.get_parents(concat_parent) + for reshape_parent in reshape_k_parents: + if reshape_parent.op_type == "Transpose": + if reshape_parent.input[0] in graph_input_names: + past_k = reshape_parent.input[0] + # else: + # matmul_k.output[0] = mat_k_out_tmp + + elif k_nodes_no_bias is not None: + _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias + k_nodes = k_nodes_no_bias + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + present_k = transpose_k_1.output[0] + elif k_nodes_no_bias_with_past_self_attn is not None: + _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn + k_nodes = k_nodes_no_bias_with_past_self_attn + past_k = concat_k.input[0] + present_k = concat_k.output[0] + elif ( + k_nodes_no_bias_with_past_cross_attn is not None + and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names + ): + k_nodes = k_nodes_no_bias_with_past_cross_attn + past_k = k_nodes[-1].input[0] + present_k = k_nodes[-1].output[0] + if present_k not in graph_output_names: + identity_node_k = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) + ) + present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" + else: + return + past_k = past_k if past_k in graph_input_names else "" + present_k = present_k if present_k in graph_output_names else "" + + if k_nodes in (k_nodes_with_bias, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + # Create empty Add node for attention graph + bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] + empty_bias_name = "empty_bias" + empty_tensor = self.model.get_initializer(empty_bias_name) + if empty_tensor is None: + empty_tensor = helper.make_tensor(empty_bias_name, TensorProto.FLOAT, [bias_dim], [0.0] * bias_dim) + self.model.add_initializer(empty_tensor, self.this_graph_name) + + add_name = self.model.create_node_name("Add") + add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) + + """ + if not past_k and not self.check_runtime_shape_path( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ): + return + """ + + three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals() + one_root_input = ( + not three_root_inputs + and matmul_k.input[0] == root_input + and matmul_q.input[0] == root_input + and matmul_v.input[0] == root_input + ) + two_root_inputs = ( + not three_root_inputs + and matmul_q.input[0] == root_input + and matmul_k.input[0] == matmul_v.input[0] + and matmul_k.input[0] != matmul_q.input[0] + ) + + # There are 5 types of attention: + # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1 + # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2 + # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value + # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 + # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 + encoder_attention = one_root_input and qk_nodes == qk_nodes_1 + decoder_attention = one_root_input and qk_nodes == qk_nodes_2 + decoder_attention_with_past = one_root_input and qk_nodes == qk_nodes_3 and past_k and past_v + decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 + decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 + + # For decoder_attention, the attention mask needs to be included in the attention node + mask_index = None + if decoder_attention: + mask_nodes_bart = self.model.match_parent_path( + add_qk, + ["Where"], + [1], + ) + mask_nodes_whisper = self.model.match_parent_path( + add_qk, + ["Slice", "Slice", "Unsqueeze", "Gather"], + [1, 0, 2, 0], + ) + if mask_nodes_whisper is not None: + pass + # mask_index = mask_nodes_whisper[0].output[-1] + elif mask_nodes_bart is not None: + mask_index = mask_nodes_bart[0].output[-1] + + left_pad_mask_index = None + if decoder_attention: + left_pad_mask_index = add_left_pad_mask.input[1] + + if ( + encoder_attention + or decoder_attention + or decoder_attention_with_past + or decoder_cross_attention + or decoder_cross_attention_with_past + ): + attention_last_node = reshape_qkv_1 + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1) + + if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: + logger.debug("fuse_attention: failed to detect num_heads or hidden_size") + return + + new_node = None + if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + # Note: Decoder attention with past key and past value is fused as multihead attention + # rather than attention because multihead attention supports separate past key and past + # value whereas attention supports concatenated past key and past value. + new_node = ( + self.create_multihead_attention_node( + matmul_q, + matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, + matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, + add_q, + add_k if decoder_cross_attention or decoder_attention_with_past else None, + add_v if decoder_cross_attention or decoder_attention_with_past else None, + num_heads, + hidden_size, + attention_last_node.output[0], + past_k=past_k if decoder_attention_with_past else "", + past_v=past_v if decoder_attention_with_past else "", + present_k=present_k, + present_v=present_v, + packed_qkv=decoder_attention_with_past, + ) + if self.use_multi_head_attention + else None + ) + else: + # Temporarily set multihead attention flag to false + use_multi_head_attention_ground_truth = self.use_multi_head_attention + self.use_multi_head_attention = False + new_node = self.create_attention_node( + None, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + num_heads, + hidden_size, + root_input, + attention_last_node.output[0], + add_qk_str=left_pad_mask_index if decoder_attention else None, + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + ) + self.use_multi_head_attention = use_multi_head_attention_ground_truth + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + + # When using multihead attention, keep MatMul nodes in original graph + if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + if q_nodes[-1].op_type == "MatMul": + q_nodes.pop() + if k_nodes[-1].op_type == "MatMul": + k_nodes.pop() + if v_nodes[-1].op_type == "MatMul": + v_nodes.pop() + if self.disable_multi_head_attention_bias and ( + decoder_cross_attention or decoder_cross_attention_with_past + ): + if q_nodes[-1].op_type == "Add": + q_nodes.pop() + if k_nodes[-1].op_type == "Add": + k_nodes.pop() + if v_nodes[-1].op_type == "Add": + v_nodes.pop() + + self.nodes_to_remove.extend(q_nodes) + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 4c43e4487bfb1..76d603cb0e56b 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -59,6 +59,7 @@ def __init__(self, model_type): if model_type == "clip": self.enable_embed_layer_norm = False + self.model_impl = "hf" # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index bb697fe1e1506..36e6e2a845e75 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -390,6 +390,7 @@ def export_onnx_models( auto_mixed_precision=not disable_auto_mixed_precision, use_gpu=use_gpu, provider=provider, + model_impl=model_impl, ) onnx_path = output_path diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index a74666b7af297..0d51623bd176f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -49,6 +49,10 @@ def chain_model(args): "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "", + "", + "left_pad_mask" if args.use_forced_decoder_ids else "", + "position_ids" if args.use_forced_decoder_ids else "", ] beam_outputs = ["sequences"] @@ -58,13 +62,15 @@ def chain_model(args): beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 + # assert len(beam_inputs) == 1 + """ beam_inputs.extend( [ "cross_qk_layer_head" if args.collect_cross_qk else "", "extra_decoding_ids" if args.extra_decoding_ids else "", ] ) + """ if args.collect_cross_qk: while len(beam_outputs) < 3: beam_outputs.extend([""]) @@ -169,6 +175,23 @@ def chain_model(args): ) graph_inputs.append(decoder_input_ids) + left_pad_mask = helper.make_tensor_value_info( + "left_pad_mask", + TensorProto.FLOAT, + [ + "batch_size", + 1, + "initial_sequence_length", + "initial_sequence_length", + ], + ) + graph_inputs.append(left_pad_mask) + + position_ids = helper.make_tensor_value_info( + "position_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + graph_inputs.append(position_ids) + if args.use_logits_processor: logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) graph_inputs.append(logits_processor) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 351173f525727..11dce01c24e49 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -51,12 +51,16 @@ def forward( self, encoder_input_ids: torch.Tensor, decoder_input_ids: torch.Tensor = None, + left_pad_mask: torch.Tensor = None, + position_ids: torch.Tensor = None, ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) - decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + decinit_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, encoder_hidden_states, left_pad_mask=left_pad_mask, position_ids=position_ids + ) return decinit_out, encoder_hidden_states, present else: decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) @@ -66,9 +70,11 @@ def forward( class WhisperEncoderDecoderInitInputs: - def __init__(self, encoder_input_ids, decoder_input_ids=None): + def __init__(self, encoder_input_ids, decoder_input_ids=None, left_pad_mask=None, position_ids=None): self.encoder_input_ids: torch.LongTensor = encoder_input_ids self.decoder_input_ids: torch.LongTensor = decoder_input_ids + self.left_pad_mask: torch.LongTensor = left_pad_mask + self.position_ids: torch.LongTensor = position_ids @staticmethod def create_dummy( @@ -89,13 +95,17 @@ def create_dummy( if use_decoder_input_ids: dtype = torch.int32 if use_int32_inputs else torch.int64 decoder_input_ids = torch.ones((batch_size, 2), dtype=dtype, device=device) * config.decoder_start_token_id + left_pad_mask = torch.zeros((batch_size, 1, 2, 2), dtype=dtype, device=device) + position_ids = torch.zeros((batch_size, 2), dtype=dtype, device=device) - return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids) + return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids, left_pad_mask, position_ids) def to_list(self) -> List: input_list = [self.encoder_input_ids] if self.decoder_input_ids is not None: input_list.append(self.decoder_input_ids) + input_list.append(self.left_pad_mask) + input_list.append(self.position_ids) return input_list @@ -133,7 +143,9 @@ def export_onnx( # TODO : Investigate whether copy of model if needed cloned_model = copy.deepcopy(model).to(device) - out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) + out = cloned_model( + inputs.encoder_input_ids, inputs.decoder_input_ids, inputs.left_pad_mask, inputs.position_ids + ) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) @@ -178,6 +190,18 @@ def export_onnx( 1: "decode_sequence_length", } + input_names.append("left_pad_mask") + dynamic_axes["left_pad_mask"] = { + 0: "batch_size", + 2: "decode_sequence_length", + 3: "decode_sequence_length", + } + input_names.append("position_ids") + dynamic_axes["position_ids"] = { + 0: "batch_size", + 1: "decode_sequence_length", + } + for name in present_names: if "cross" in name: dynamic_axes[name] = { @@ -243,6 +267,8 @@ def onnxruntime_inference(ort_session, inputs: WhisperEncoderDecoderInitInputs): } if inputs.decoder_input_ids is not None: ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()) + ort_inputs["left_pad_mask"] = numpy.ascontiguousarray(inputs.left_pad_mask.cpu().numpy()) + ort_inputs["position_ids"] = numpy.ascontiguousarray(inputs.position_ids.cpu().numpy()) ort_outputs = ort_session.run(None, ort_inputs) return ort_outputs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e2dc79ca247ce..79d717c17b67c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -12,6 +12,9 @@ import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -22,9 +25,6 @@ from onnxruntime import InferenceSession sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff -from onnx_model import OnnxModel -from optimizer import optimize_model logger = logging.getLogger(__name__) @@ -287,12 +287,14 @@ def optimize_onnx( auto_mixed_precision: bool = True, use_gpu: bool = False, provider: str = "cpu", + model_impl: str = "hf", ): """Optimize ONNX model with an option to convert it to use mixed precision.""" from fusion_options import FusionOptions optimization_options = FusionOptions("bart") + optimization_options.model_impl = model_impl optimization_options.use_multi_head_attention = True optimization_options.disable_multi_head_attention_bias = provider == "rocm" @@ -341,8 +343,6 @@ def verify_onnx( logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") os.system(install_cmd) - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features @@ -374,6 +374,10 @@ def verify_onnx( "tensor(uint8)": np.uint8, } + # Generate prompts + # prompt_text = "" + # prompt_ids = processor.get_prompt_ids(prompt_text) + use_extra_decoding_ids = "extra_decoding_ids" in ort_names for name, dtype in zip(ort_names, ort_dtypes): if name == "input_features": @@ -395,6 +399,10 @@ def verify_onnx( inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "left_pad_mask": + inputs[name] = np.zeros((batch_size, 1, 4, 4), dtype=ort_to_np[dtype]) + elif name == "position_ids": + inputs[name] = np.repeat(np.array([[0, 1, 2, 3]], dtype=ort_to_np[dtype]), batch_size, 0) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] @@ -411,6 +419,7 @@ def verify_onnx( " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." ) pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) + ort_expected_transcription = ( " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." ) @@ -421,5 +430,4 @@ def verify_onnx( ) if parity: max_diff = 0 - return max_diff diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index 941f61cf7cc29..13cdffaa836c3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -30,6 +30,8 @@ def forward( tokens, audio_features, past=None, + left_pad_mask=None, + position_ids=None, ): # Create a kv_cache for past_values past_kv_cache = dict() @@ -47,7 +49,9 @@ def forward( if not self.kv_cache: self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() - logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) + logits = self.whisper_decoder( + tokens, audio_features, kv_cache=past_kv_cache, left_pad_mask=left_pad_mask, position_ids=position_ids + ) # Add concat node for past values if past is not None: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 61a786d7af60b..de0a418ae4daa 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -7,6 +7,7 @@ from fusion_attention import AttentionMask from fusion_bart_attention import FusionBartAttention +from fusion_bart_attention_openai import FusionBartAttentionOpenai from fusion_options import FusionOptions from fusion_reshape import FusionReshape from onnx import numpy_helper @@ -124,7 +125,12 @@ class BartOnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) - self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + if model_impl == "openai": + self.attention_fusion = FusionBartAttentionOpenai( + self, self.hidden_size, self.num_heads, self.attention_mask + ) + else: + self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.bart_reshape_fusion_preprocess = FusionBartReshape(self) def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 431e64509e3cc..2c77aeb784ec7 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -9,6 +9,7 @@ from convert_to_packing_mode import PackingMode from fusion_attention import AttentionMask, FusionAttention from fusion_bart_attention import FusionBartAttention +from fusion_bart_attention_openai import FusionBartAttentionOpenai from fusion_biasgelu import FusionBiasGelu from fusion_embedlayer import FusionEmbedLayerNormalization from fusion_fastgelu import FusionFastGelu @@ -349,7 +350,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) - if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): + if ( + options.use_multi_head_attention + and not isinstance(self.attention_fusion, FusionBartAttention) + and not isinstance(self.attention_fusion, FusionBartAttentionOpenai) + ): self.attention_fusion = FusionAttention( self, self.hidden_size, diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ce0be6b3449ed..584452bdf004a 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -226,7 +226,10 @@ def optimize_by_fusion( if optimization_options is None: optimization_options = FusionOptions(model_type) - optimizer = optimizer_class(model, num_heads, hidden_size) + if optimization_options.model_impl == "openai": + optimizer = optimizer_class(model, num_heads, hidden_size, model_impl=optimization_options.model_impl) + else: + optimizer = optimizer_class(model, num_heads, hidden_size) optimizer.optimize(optimization_options)