diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index 93cda00e5a3c3..58e842e72d83b 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -169,7 +169,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
       ORT_RETURN_IF_ERROR(whisper_encoder_subgraph_->Setup(session_state, subgraph_session_state));
       encoder_feeds_fetches_manager_ = whisper_encoder_subgraph_->GetFeedsFetchesManager();
 
-      ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs != 2,
+      ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs < 2,
                     "Encoder subgraph shall have 2 inputs (encoder_input_ids, decoder_input_ids)");
     } else if (attribute_name == "decoder") {
       ORT_ENFORCE(whisper_decoder_subgraph_ == nullptr,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
index 72e6d3930a548..e77f0a8d572f5 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
@@ -153,6 +153,12 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
   const OrtValue* encoder_input_ids_value = this->context_.GetInputOrtValue(0);
   const Tensor& encoder_input_ids = encoder_input_ids_value->Get<Tensor>();
 
+  const OrtValue* left_pad_mask_value = this->context_.GetInputOrtValue(15);
+  const Tensor& left_pad_mask = left_pad_mask_value->Get<Tensor>();
+
+  const OrtValue* position_ids_value = this->context_.GetInputOrtValue(16);
+  const Tensor& position_ids = position_ids_value->Get<Tensor>();
+
   BeamSearchCpuState cpu_state{*parameters,
                                this->cpu_allocator_,
                                this->IsCuda(),
@@ -166,6 +172,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
   ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds(
       encoder_input_ids,
       initial_decoder_input_ids_value,
+      left_pad_mask,
+      position_ids,
       parameters->decoder_start_token_id,
       this->implicit_inputs_,
       encoder_feeds,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index bb6885c3216bc..c2c99c67ed8a1 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -64,6 +64,40 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
     }
   }
 
+  left_pad_mask = gsl::span<float>();
+  if (this->model_type == IGenerationParameters::kModelTypeWhisper && left_pad_mask_input_id > 0) {
+    const Tensor* left_pad_mask_tensor = context->Input<Tensor>(left_pad_mask_input_id);
+    if (left_pad_mask_tensor != nullptr) {
+      const auto& left_pad_mask_tensor_dims = left_pad_mask_tensor->Shape().GetDims();
+      ORT_ENFORCE(left_pad_mask_tensor_dims.size() == 4,
+                  "left_pad_mask_tensor shall have 4 dimensions. Got ",
+                  left_pad_mask_tensor_dims.size());
+      ORT_ENFORCE(left_pad_mask_tensor_dims[0] == batch_size,
+                  "left_pad_mask_tensor first dim not same as batch_size. Got ",
+                  left_pad_mask_tensor_dims[0], ", expecting ", batch_size);
+      if (left_pad_mask_tensor->Shape().Size() > 0) {
+        left_pad_mask = gsl::span<const float>(left_pad_mask_tensor->Data<float>(), (size_t)left_pad_mask_tensor->Shape().Size());
+      }
+    }
+  }
+
+  position_ids = gsl::span<int32_t>();
+  if (this->model_type == IGenerationParameters::kModelTypeWhisper && position_ids_input_id > 0) {
+    const Tensor* position_ids_tensor = context->Input<Tensor>(position_ids_input_id);
+    if (position_ids_tensor != nullptr) {
+      const auto& position_ids_tensor_dims = position_ids_tensor->Shape().GetDims();
+      ORT_ENFORCE(position_ids_tensor_dims.size() == 2,
+                  "position_ids_tensor shall have 2 dimensions. Got ",
+                  position_ids_tensor_dims.size());
+      ORT_ENFORCE(position_ids_tensor_dims[0] == batch_size,
+                  "position_ids_tensor first dim not same as batch_size. Got ",
+                  position_ids_tensor_dims[0], ", expecting ", batch_size);
+      if (position_ids_tensor->Shape().Size() > 0) {
+        position_ids = gsl::span<const int32_t>(position_ids_tensor->Data<int32_t>(), (size_t)position_ids_tensor->Shape().Size());
+      }
+    }
+  }
+
   if (this->model_type == IGenerationParameters::kModelTypeGpt) {
     sequence_length = static_cast<int>(dims[1]);
   } else if (this->model_type == IGenerationParameters::kModelTypeWhisper) {
@@ -156,6 +190,8 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info)
   no_speech_token = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token", -1LL));
   cross_qk_layer_head_input_id = 12;
   extra_decoding_ids_input_id = 13;
+  left_pad_mask_input_id = 15;
+  position_ids_input_id = 16;
   cross_qk_output_id = 3;
   no_speech_probs_output_id = 4;
 }
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
index 927d3a58e5a6f..d554bb6345131 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
@@ -877,10 +877,14 @@ template <typename T>
 Status CreateWhisperEncoderInputs(
     const Tensor* original_encoder_input_features,
     const OrtValue* original_decoder_input_ids_value,
+    const Tensor* original_left_pad_mask,
+    const Tensor* original_position_ids,
     int start_token_id,
     AllocatorPtr allocator,
     OrtValue& encoder_input_features,
-    OrtValue& decoder_input_ids) {
+    OrtValue& decoder_input_ids,
+    OrtValue& left_pad_mask,
+    OrtValue& position_ids) {
   const TensorShape& input_features_shape = original_encoder_input_features->Shape();
   ORT_ENFORCE(input_features_shape.NumDimensions() == 3);
   const int64_t& batch_size = input_features_shape[0];
@@ -898,6 +902,20 @@ Status CreateWhisperEncoderInputs(
                        allocator->Info(),
                        encoder_input_features);
 
+  const TensorShape& left_pad_mask_shape = original_left_pad_mask->Shape();
+  Tensor::InitOrtValue(DataTypeImpl::GetType<T>(),
+                       left_pad_mask_shape,
+                       const_cast<Tensor*>(original_left_pad_mask)->MutableData<T>(),
+                       allocator->Info(),
+                       left_pad_mask);
+
+  const TensorShape& position_ids_shape = original_position_ids->Shape();
+  Tensor::InitOrtValue(element_type,
+                       position_ids_shape,
+                       const_cast<Tensor*>(original_position_ids)->MutableData<int32_t>(),
+                       allocator->Info(),
+                       position_ids);
+
   // decoder_input_ids is optional.
   if (original_decoder_input_ids_value == nullptr) {
     // Filled decoder_input_ids with start token ID
@@ -1071,18 +1089,26 @@ template Status ExpandBuffer<MLFloat16>(
 template Status CreateWhisperEncoderInputs<float>(
     const Tensor* original_encoder_input_features,
     const OrtValue* original_decoder_input_ids_value,
+    const Tensor* original_left_pad_mask,
+    const Tensor* original_position_ids,
     int start_token_id,
     AllocatorPtr allocator,
     OrtValue& encoder_input_features,
-    OrtValue& decoder_input_ids);
+    OrtValue& decoder_input_ids,
+    OrtValue& left_pad_mask,
+    OrtValue& position_ids);
 
 template Status CreateWhisperEncoderInputs<MLFloat16>(
     const Tensor* original_encoder_input_features,
     const OrtValue* original_decoder_input_ids_value,
+    const Tensor* original_left_pad_mask,
+    const Tensor* original_position_ids,
     int start_token_id,
     AllocatorPtr allocator,
     OrtValue& encoder_input_features,
-    OrtValue& decoder_input_ids);
+    OrtValue& decoder_input_ids,
+    OrtValue& left_pad_mask,
+    OrtValue& position_ids);
 
 Status UpdateDecoderCrossQK(
     [[maybe_unused]] int iteration_number,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
index 6dfdc6b027671..3923f624f1aea 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
@@ -190,10 +190,14 @@ using UpdateDecoderFeedsFunc = std::function<Status(
 using CreateWhisperEncoderInputsFunc = std::function<Status(
     const Tensor* original_encoder_input_features,
     const OrtValue* original_decoder_input_ids_value,
+    const Tensor* original_left_pad_mask,
+    const Tensor* original_position_ids,
     int start_token_id,
     AllocatorPtr allocator,
     OrtValue& encoder_input_ids,
-    OrtValue& decoder_input_ids)>;
+    OrtValue& decoder_input_ids,
+    OrtValue& left_pad_mask,
+    OrtValue& position_ids)>;
 
 template <typename T>
 using ExpandBufferFunc = std::function<Status(
@@ -376,10 +380,14 @@ template <typename T>
 Status CreateWhisperEncoderInputs(
     const Tensor* original_encoder_input_features,
     const OrtValue* original_decoder_input_ids_value,
+    const Tensor* original_left_pad_mask,
+    const Tensor* original_position_ids,
     int start_token_id,
     AllocatorPtr allocator,
     OrtValue& encoder_input_ids,
-    OrtValue& decoder_input_ids);
+    OrtValue& decoder_input_ids,
+    OrtValue& left_pad_mask,
+    OrtValue& position_ids);
 
 // ---------------------------------------------------------------
 // Utility Functions
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index cb62e2f7bf4da..e1b8ecacc07b9 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -183,11 +183,15 @@ struct IGenerationParameters {
   // Parameters for whisper model
   bool decoder_output_cross_qk = false;
   gsl::span<const int32_t> extra_decoding_ids;
+  gsl::span<const float> left_pad_mask;
+  gsl::span<const int32_t> position_ids;
   int32_t no_speech_token = -1;
   void* no_speech_probs = nullptr;
 
   int cross_qk_layer_head_input_id = -1;
   int extra_decoding_ids_input_id = -1;
+  int left_pad_mask_input_id = -1;
+  int position_ids_input_id = -1;
   int cross_qk_output_id = -1;
   int no_speech_probs_output_id = -1;
 };
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
index 8480edc405e53..f6bde4123e709 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
@@ -40,7 +40,7 @@ namespace transformers {
 
 Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
                                         const std::vector<const NodeArg*>& subgraph_outputs) {
-  ORT_RETURN_IF(num_subgraph_inputs != 2, "expect 2 inputs, got:", num_subgraph_inputs);
+  ORT_RETURN_IF(num_subgraph_inputs < 2, "expect 2 inputs, got:", num_subgraph_inputs);
 
   ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
   ORT_RETURN_IF((static_cast<int>(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0,
@@ -95,6 +95,8 @@ Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr
 Status WhisperEncoderSubgraph::CreateInitialFeeds(
     const Tensor& original_encoder_input_ids,
     const OrtValue* original_decoder_input_ids_value,
+    const Tensor& original_left_pad_mask,
+    const Tensor& original_position_ids,
     int start_token_id,
     const std::vector<const OrtValue*>& implicit_inputs,
     std::vector<OrtValue>& feeds,
@@ -117,12 +119,18 @@ Status WhisperEncoderSubgraph::CreateInitialFeeds(
   ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr");
 
   OrtValue encoder_input_ids;
+  OrtValue left_pad_mask;
+  OrtValue position_ids;
   ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids,
                                                  original_decoder_input_ids_value,
+                                                 &original_left_pad_mask,
+                                                 &original_position_ids,
                                                  start_token_id,
                                                  cpu_allocator,
                                                  encoder_input_ids,
-                                                 decoder_input_ids));
+                                                 decoder_input_ids,
+                                                 left_pad_mask,
+                                                 position_ids));
 
   const IExecutionProvider* provider = GetProvider();
   AllocatorPtr default_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeDefault));
@@ -130,7 +138,7 @@ Status WhisperEncoderSubgraph::CreateInitialFeeds(
   const OrtMemoryInfo& location = default_allocator->Info();
   ORT_RETURN_IF_ERROR(add_to_feeds_func(
       ort_stream,
-      {encoder_input_ids, decoder_input_ids},
+      {encoder_input_ids, decoder_input_ids, left_pad_mask, position_ids},
       feeds,
       buffer,
       default_allocator,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h
index 10e7f43b0ea83..e50e1527a1881 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h
@@ -22,6 +22,8 @@ class WhisperEncoderSubgraph : public T5EncoderSubgraph {
   Status CreateInitialFeeds(
       const Tensor& encoder_input_ids,
       const OrtValue* original_decoder_input_ids_value,
+      const Tensor& left_pad_mask,
+      const Tensor& position_ids,
       int start_token_id,
       const std::vector<const OrtValue*>& implicit_inputs,
       std::vector<OrtValue>& feeds,
diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
index 08cbb145a6f65..95e4728750702 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
@@ -50,6 +50,8 @@ ONNX_OPERATOR_KERNEL_EX(
         .InputMemoryType(OrtMemTypeCPUInput, 10)   // 'decoder_input_ids' needs to be on CPU
         .InputMemoryType(OrtMemTypeCPUInput, 11)   // 'logits_processor' needs to be on CPU
         .InputMemoryType(OrtMemTypeCPUInput, 14)   // 'temperature' needs to be on CPU
+        .InputMemoryType(OrtMemTypeCPUInput, 15)   // 'left_pad_mask' needs to be on CPU
+        .InputMemoryType(OrtMemTypeCPUInput, 16)   // 'position_ids' needs to be on CPU
         .OutputMemoryType(OrtMemTypeCPUOutput, 0)  // 'sequences' output on CPU
         .OutputMemoryType(OrtMemTypeCPUOutput, 1)  // 'sequences_scores' output on CPU
         .TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 27c968a59eb91..49a62b739697c 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1232,6 +1232,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1,
                                        "are treated as stop of the extra_decoding_ids for corresponding batch.",
                                        "I", OpSchema::Optional)
                                 .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional)
+                                .Input(15, "left_pad_mask", "The mask is added to qk node in the qkv attention function. Shape is (batch_size, initial_sequence_length)", "T", OpSchema::Optional)
+                                .Input(16, "position_ids", "Used to select indices of positional embeddings. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
                                 .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
                                 .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
                                 .Output(2, "scores",
diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py
new file mode 100644
index 0000000000000..3e4aa7add68e8
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py
@@ -0,0 +1,487 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import logging
+
+from fusion_attention import AttentionMask, FusionAttention
+from onnx import TensorProto, helper
+from onnx_model import OnnxModel
+
+logger = logging.getLogger(__name__)
+
+
+class FusionBartAttentionOpenai(FusionAttention):
+    """
+    Fuse Bart Attention subgraph into one Attention node.
+    """
+
+    def __init__(
+        self,
+        model: OnnxModel,
+        hidden_size: int,
+        num_heads: int,
+        attention_mask: AttentionMask,
+    ):
+        super().__init__(model, hidden_size, num_heads, attention_mask)
+
+    def check_runtime_shape_path(
+        self,
+        reshape_qkv_2,
+        reshape_qkv_1,
+        reshape_q_2,
+        reshape_k_2,
+        reshape_v_2,
+        root_input,
+    ):
+        concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
+        if concat_qkv_2_path is None:
+            return False
+        concat_qkv_2 = concat_qkv_2_path[0]
+
+        reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
+        reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
+        if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None:
+            return False
+
+        _, gather_1, shape_1 = reshape_qkv_2_path_1
+        _, gather_2, shape_2 = reshape_qkv_2_path_2
+
+        if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
+            return False
+
+        reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0])
+        reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0])
+        if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None:
+            return False
+        if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name:
+            return False
+
+        reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
+        reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
+        reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
+        if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None:
+            return False
+
+        mul_q = reshape_q_2_path[-1]
+        mul_k = reshape_k_2_path[-1]
+        mul_v = reshape_v_2_path[-1]
+
+        gather_1_out = gather_1.output[0]
+        if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
+            return False
+
+        return True
+
+    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
+        # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
+        qkv_nodes = self.model.match_parent_path(
+            normalize_node,
+            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
+            [1, 1, 0, 0, 0],
+        )
+        if qkv_nodes is not None:
+            (
+                add_out,
+                matmul_out,
+                reshape_qkv_1,
+                transpose_qkv,
+                matmul_qkv,
+            ) = qkv_nodes
+        else:
+            return
+
+        other_inputs = []
+        for input in normalize_node.input:
+            if input not in output_name_to_node:
+                continue
+            if input == qkv_nodes[0].output[0]:
+                continue
+            other_inputs.append(input)
+        if len(other_inputs) != 1:
+            return
+        root_input = other_inputs[0]
+
+        # Sometimes the input name to the attention MatMul nodes does not match the input name to the end
+        # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
+        # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
+        # children nodes for each of its output names.
+        """
+                                        root_input
+                    +---------------------------------------------------+
+                    |                                                   |
+                    |                                                   |
+        SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
+        """
+        skip_layernorm = output_name_to_node[root_input]
+        # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose
+        # child is the LayerNormalization node.
+        if skip_layernorm.op_type == "Add":
+            skip_layernorm = self.model.get_children(skip_layernorm)[0]
+        for output in skip_layernorm.output:
+            if not output:
+                continue
+            children = input_name_to_nodes[output]
+            children_types = [child.op_type for child in children]
+            if children_types.count("MatMul") >= 1:
+                root_input = output
+                break
+
+        graph_input_names = set([node.name for node in self.model.graph().input])
+        graph_output_names = set([node.name for node in self.model.graph().output])
+
+        v_nodes = self.model.match_parent_path(
+            matmul_qkv,
+            ["Transpose", "Reshape", "Add", "MatMul"],
+            [1, 0, 0, None],
+        )
+        v_nodes_with_past_self_attn = self.model.match_parent_path(
+            # Decoder attention with past value concatenated before MatMul
+            matmul_qkv,
+            ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
+            [1, 0, 1, 0, 0, None],
+        )
+        v_nodes_with_past_cross_attn = self.model.match_parent_path(
+            # Decoder attention with past value directly used in MatMul
+            matmul_qkv,
+            ["Transpose", "Reshape", "Reshape", "Transpose"],
+            [1, 0, 0, 0],
+        )
+        past_v, present_v = "", ""
+        reshape_v_2, add_v = None, None
+        if v_nodes is not None:
+            (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
+            # For initial pass through encoder-decoder_with_past to get starting past values (beam search)
+            # present_v = add_v.output[0]
+
+            add_v_children = self.model.get_children(add_v)
+            for child in add_v_children:
+                if child.op_type == "Reshape":
+                    # if child.output[0] in graph_output_names:
+                    # present_v = child.output[0]
+                    reshape_v_children = self.model.get_children(child)
+                    for reshape_child in reshape_v_children:
+                        if reshape_child.op_type == "Transpose":
+                            if reshape_child.output[0] in graph_output_names:
+                                present_v = reshape_child.output[0]
+                if child.op_type == "Concat":
+                    concat_v_children = self.model.get_children(child)
+                    for concat_child in concat_v_children:
+                        if concat_child.op_type == "Reshape":
+                            reshape_v_children = self.model.get_children(concat_child)
+                            for reshape_child in reshape_v_children:
+                                if reshape_child.op_type == "Transpose":
+                                    if reshape_child.output[0] in graph_output_names:
+                                        present_v = reshape_child.output[0]
+                    concat_v_parents = self.model.get_parents(child)
+                    for concat_parent in concat_v_parents:
+                        if concat_parent.op_type == "Reshape":
+                            reshape_v_parents = self.model.get_parents(concat_parent)
+                            for reshape_parent in reshape_v_parents:
+                                if reshape_parent.op_type == "Transpose":
+                                    if reshape_parent.input[0] in graph_input_names:
+                                        past_v = reshape_parent.input[0]
+
+        elif v_nodes_with_past_self_attn is not None:
+            (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn
+            v_nodes = v_nodes_with_past_self_attn
+            past_v = concat_v.input[0]
+            present_v = concat_v.output[0]
+        elif (
+            v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names
+        ):
+            v_nodes = v_nodes_with_past_cross_attn
+            past_v = v_nodes[-1].input[0]
+            present_v = v_nodes[-1].output[0]
+            if present_v not in graph_output_names:
+                identity_node_v = list(
+                    filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
+                )
+                present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
+        else:
+            logger.debug("fuse_attention: failed to match v path")
+            return
+        past_v = past_v if past_v in graph_input_names else ""
+        present_v = present_v if present_v in graph_output_names else ""
+
+        qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
+        qk_nodes_2 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Add", "MatMul"], [0, 0, 0, 0])
+        qk_nodes_3 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
+        if qk_nodes_1 is not None:
+            _, matmul_qk = qk_nodes_1
+            qk_nodes = qk_nodes_1
+        elif qk_nodes_2 is not None:
+            _, add_left_pad_mask, add_qk, matmul_qk = qk_nodes_2
+            qk_nodes = qk_nodes_2
+        elif qk_nodes_3 is not None:
+            _, add_qk, matmul_qk = qk_nodes_3
+            qk_nodes = qk_nodes_3
+        else:
+            return
+
+        q_nodes = self.model.match_parent_path(
+            matmul_qk,
+            ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
+            [0, 0, 0, 0, 1],
+        )
+        if q_nodes is not None:
+            mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes
+        else:
+            return
+
+        k_nodes_with_bias = self.model.match_parent_path(
+            matmul_qk,
+            ["Mul", "Transpose", "Reshape", "MatMul"],
+            [1, 0, 0, 0],
+        )
+        k_nodes_no_bias = self.model.match_parent_path(
+            matmul_qk,
+            ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"],
+            [1, 0, 0, 0, 0],
+        )
+        k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path(
+            # Decoder attention with past key concatenated before MatMul
+            matmul_qk,
+            ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"],
+            [1, 0, 0, 1, 0, 0],
+        )
+        k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path(
+            # Decoder attention with past key directly used in MatMul
+            matmul_qk,
+            ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
+            [1, 0, 0, 0, 0],
+        )
+        past_k, present_k = "", ""
+        reshape_k_2, reshape_k_1, matmul_k = None, None, None
+        if k_nodes_with_bias is not None:
+            mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias
+            k_nodes = k_nodes_with_bias
+            present_k = matmul_k.output[0]
+            mat_k_out_tmp = matmul_k.output[0] + "_temp"
+            # matmul_k.output[0] = matmul_k.output[0] + "_temp"
+
+            matmul_k_children = self.model.get_children(matmul_k)
+            for child in matmul_k_children:
+                if child.op_type == "Reshape":
+                    # if child.output[0] in graph_output_names:
+                    #    present_k = child.output[0]
+                    reshape_k_children = self.model.get_children(child)
+                    for reshape_child in reshape_k_children:
+                        if reshape_child.op_type == "Transpose":
+                            if reshape_child.output[0] in graph_output_names:
+                                present_k = reshape_child.output[0]
+                if child.op_type == "Concat":
+                    concat_k_children = self.model.get_children(child)
+                    for concat_child in concat_k_children:
+                        if concat_child.op_type == "Reshape":
+                            reshape_k_children = self.model.get_children(concat_child)
+                            for reshape_child in reshape_k_children:
+                                if reshape_child.op_type == "Transpose":
+                                    if reshape_child.output[0] in graph_output_names:
+                                        present_k = reshape_child.output[0]
+                    concat_k_parents = self.model.get_parents(child)
+                    for concat_parent in concat_k_parents:
+                        if concat_parent.op_type == "Reshape":
+                            reshape_k_parents = self.model.get_parents(concat_parent)
+                            for reshape_parent in reshape_k_parents:
+                                if reshape_parent.op_type == "Transpose":
+                                    if reshape_parent.input[0] in graph_input_names:
+                                        past_k = reshape_parent.input[0]
+                # else:
+                #    matmul_k.output[0] = mat_k_out_tmp
+
+        elif k_nodes_no_bias is not None:
+            _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias
+            k_nodes = k_nodes_no_bias
+            # For initial pass through encoder-decoder_with_past to get starting past values (beam search)
+            present_k = transpose_k_1.output[0]
+        elif k_nodes_no_bias_with_past_self_attn is not None:
+            _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn
+            k_nodes = k_nodes_no_bias_with_past_self_attn
+            past_k = concat_k.input[0]
+            present_k = concat_k.output[0]
+        elif (
+            k_nodes_no_bias_with_past_cross_attn is not None
+            and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names
+        ):
+            k_nodes = k_nodes_no_bias_with_past_cross_attn
+            past_k = k_nodes[-1].input[0]
+            present_k = k_nodes[-1].output[0]
+            if present_k not in graph_output_names:
+                identity_node_k = list(
+                    filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
+                )
+                present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
+        else:
+            return
+        past_k = past_k if past_k in graph_input_names else ""
+        present_k = present_k if present_k in graph_output_names else ""
+
+        if k_nodes in (k_nodes_with_bias, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn):
+            # Create empty Add node for attention graph
+            bias_dim = self.model.get_initializer(add_v.input[0]).dims[0]
+            empty_bias_name = "empty_bias"
+            empty_tensor = self.model.get_initializer(empty_bias_name)
+            if empty_tensor is None:
+                empty_tensor = helper.make_tensor(empty_bias_name, TensorProto.FLOAT, [bias_dim], [0.0] * bias_dim)
+                self.model.add_initializer(empty_tensor, self.this_graph_name)
+
+            add_name = self.model.create_node_name("Add")
+            add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name)
+
+        """
+        if not past_k and not self.check_runtime_shape_path(
+            reshape_qkv_2,
+            reshape_qkv_1,
+            reshape_q_2,
+            reshape_k_2,
+            reshape_v_2,
+            root_input,
+        ):
+            return
+        """
+
+        three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals()
+        one_root_input = (
+            not three_root_inputs
+            and matmul_k.input[0] == root_input
+            and matmul_q.input[0] == root_input
+            and matmul_v.input[0] == root_input
+        )
+        two_root_inputs = (
+            not three_root_inputs
+            and matmul_q.input[0] == root_input
+            and matmul_k.input[0] == matmul_v.input[0]
+            and matmul_k.input[0] != matmul_q.input[0]
+        )
+
+        # There are 5 types of attention:
+        # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1
+        # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2
+        # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value
+        # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1
+        # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1
+        encoder_attention = one_root_input and qk_nodes == qk_nodes_1
+        decoder_attention = one_root_input and qk_nodes == qk_nodes_2
+        decoder_attention_with_past = one_root_input and qk_nodes == qk_nodes_3 and past_k and past_v
+        decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1
+        decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1
+
+        # For decoder_attention, the attention mask needs to be included in the attention node
+        mask_index = None
+        if decoder_attention:
+            mask_nodes_bart = self.model.match_parent_path(
+                add_qk,
+                ["Where"],
+                [1],
+            )
+            mask_nodes_whisper = self.model.match_parent_path(
+                add_qk,
+                ["Slice", "Slice", "Unsqueeze", "Gather"],
+                [1, 0, 2, 0],
+            )
+            if mask_nodes_whisper is not None:
+                pass
+                # mask_index = mask_nodes_whisper[0].output[-1]
+            elif mask_nodes_bart is not None:
+                mask_index = mask_nodes_bart[0].output[-1]
+
+        left_pad_mask_index = None
+        if decoder_attention:
+            left_pad_mask_index = add_left_pad_mask.input[1]
+
+        if (
+            encoder_attention
+            or decoder_attention
+            or decoder_attention_with_past
+            or decoder_cross_attention
+            or decoder_cross_attention_with_past
+        ):
+            attention_last_node = reshape_qkv_1
+            num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1)
+
+            if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
+                logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
+                return
+
+            new_node = None
+            if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
+                # Note: Decoder attention with past key and past value is fused as multihead attention
+                # rather than attention because multihead attention supports separate past key and past
+                # value whereas attention supports concatenated past key and past value.
+                new_node = (
+                    self.create_multihead_attention_node(
+                        matmul_q,
+                        matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k,
+                        matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v,
+                        add_q,
+                        add_k if decoder_cross_attention or decoder_attention_with_past else None,
+                        add_v if decoder_cross_attention or decoder_attention_with_past else None,
+                        num_heads,
+                        hidden_size,
+                        attention_last_node.output[0],
+                        past_k=past_k if decoder_attention_with_past else "",
+                        past_v=past_v if decoder_attention_with_past else "",
+                        present_k=present_k,
+                        present_v=present_v,
+                        packed_qkv=decoder_attention_with_past,
+                    )
+                    if self.use_multi_head_attention
+                    else None
+                )
+            else:
+                # Temporarily set multihead attention flag to false
+                use_multi_head_attention_ground_truth = self.use_multi_head_attention
+                self.use_multi_head_attention = False
+                new_node = self.create_attention_node(
+                    None,
+                    matmul_q,
+                    matmul_k,
+                    matmul_v,
+                    add_q,
+                    add_k,
+                    add_v,
+                    num_heads,
+                    hidden_size,
+                    root_input,
+                    attention_last_node.output[0],
+                    add_qk_str=left_pad_mask_index if decoder_attention else None,
+                    past_k=past_k,
+                    past_v=past_v,
+                    present_k=present_k,
+                    present_v=present_v,
+                )
+                self.use_multi_head_attention = use_multi_head_attention_ground_truth
+            if new_node is None:
+                return
+
+            self.nodes_to_add.append(new_node)
+            self.node_name_to_graph_name[new_node.name] = self.this_graph_name
+
+            self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
+            self.nodes_to_remove.extend(qk_nodes)
+
+            # When using multihead attention, keep MatMul nodes in original graph
+            if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
+                if q_nodes[-1].op_type == "MatMul":
+                    q_nodes.pop()
+                if k_nodes[-1].op_type == "MatMul":
+                    k_nodes.pop()
+                if v_nodes[-1].op_type == "MatMul":
+                    v_nodes.pop()
+                if self.disable_multi_head_attention_bias and (
+                    decoder_cross_attention or decoder_cross_attention_with_past
+                ):
+                    if q_nodes[-1].op_type == "Add":
+                        q_nodes.pop()
+                    if k_nodes[-1].op_type == "Add":
+                        k_nodes.pop()
+                    if v_nodes[-1].op_type == "Add":
+                        v_nodes.pop()
+
+            self.nodes_to_remove.extend(q_nodes)
+            self.nodes_to_remove.extend(k_nodes)
+            self.nodes_to_remove.extend(v_nodes)
+
+            # Use prune graph to remove mask nodes since they are shared by all attention nodes.
+            self.prune_graph = True
diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py
index 4c43e4487bfb1..76d603cb0e56b 100644
--- a/onnxruntime/python/tools/transformers/fusion_options.py
+++ b/onnxruntime/python/tools/transformers/fusion_options.py
@@ -59,6 +59,7 @@ def __init__(self, model_type):
 
         if model_type == "clip":
             self.enable_embed_layer_norm = False
+        self.model_impl = "hf"
 
         # Set default to sequence length for BERT model to use fused attention to speed up.
         # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
index bb697fe1e1506..36e6e2a845e75 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
@@ -390,6 +390,7 @@ def export_onnx_models(
                         auto_mixed_precision=not disable_auto_mixed_precision,
                         use_gpu=use_gpu,
                         provider=provider,
+                        model_impl=model_impl,
                     )
                     onnx_path = output_path
 
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
index a74666b7af297..0d51623bd176f 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
@@ -49,6 +49,10 @@ def chain_model(args):
         "",  # attention mask
         "decoder_input_ids" if args.use_forced_decoder_ids else "",
         "logits_processor" if args.use_logits_processor else "",
+        "",
+        "",
+        "left_pad_mask" if args.use_forced_decoder_ids else "",
+        "position_ids" if args.use_forced_decoder_ids else "",
     ]
 
     beam_outputs = ["sequences"]
@@ -58,13 +62,15 @@ def chain_model(args):
         beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores")
 
     if args.use_whisper_beamsearch:
-        assert len(beam_inputs) == 12
+        # assert len(beam_inputs) == 1
+        """
         beam_inputs.extend(
             [
                 "cross_qk_layer_head" if args.collect_cross_qk else "",
                 "extra_decoding_ids" if args.extra_decoding_ids else "",
             ]
         )
+        """
         if args.collect_cross_qk:
             while len(beam_outputs) < 3:
                 beam_outputs.extend([""])
@@ -169,6 +175,23 @@ def chain_model(args):
         )
         graph_inputs.append(decoder_input_ids)
 
+        left_pad_mask = helper.make_tensor_value_info(
+            "left_pad_mask",
+            TensorProto.FLOAT,
+            [
+                "batch_size",
+                1,
+                "initial_sequence_length",
+                "initial_sequence_length",
+            ],
+        )
+        graph_inputs.append(left_pad_mask)
+
+        position_ids = helper.make_tensor_value_info(
+            "position_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
+        )
+        graph_inputs.append(position_ids)
+
     if args.use_logits_processor:
         logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
         graph_inputs.append(logits_processor)
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
index 351173f525727..11dce01c24e49 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
@@ -51,12 +51,16 @@ def forward(
         self,
         encoder_input_ids: torch.Tensor,
         decoder_input_ids: torch.Tensor = None,
+        left_pad_mask: torch.Tensor = None,
+        position_ids: torch.Tensor = None,
     ):
         encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids)
         # Decoder out: (logits, past_key_values, encoder_hidden_state)
         if self.model_impl == "openai":
             encoder_hidden_states.unsqueeze(0)
-            decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states)
+            decinit_out, present = self.whisper_decoder_openai_init(
+                decoder_input_ids, encoder_hidden_states, left_pad_mask=left_pad_mask, position_ids=position_ids
+            )
             return decinit_out, encoder_hidden_states, present
         else:
             decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states)
@@ -66,9 +70,11 @@ def forward(
 
 
 class WhisperEncoderDecoderInitInputs:
-    def __init__(self, encoder_input_ids, decoder_input_ids=None):
+    def __init__(self, encoder_input_ids, decoder_input_ids=None, left_pad_mask=None, position_ids=None):
         self.encoder_input_ids: torch.LongTensor = encoder_input_ids
         self.decoder_input_ids: torch.LongTensor = decoder_input_ids
+        self.left_pad_mask: torch.LongTensor = left_pad_mask
+        self.position_ids: torch.LongTensor = position_ids
 
     @staticmethod
     def create_dummy(
@@ -89,13 +95,17 @@ def create_dummy(
         if use_decoder_input_ids:
             dtype = torch.int32 if use_int32_inputs else torch.int64
             decoder_input_ids = torch.ones((batch_size, 2), dtype=dtype, device=device) * config.decoder_start_token_id
+        left_pad_mask = torch.zeros((batch_size, 1, 2, 2), dtype=dtype, device=device)
+        position_ids = torch.zeros((batch_size, 2), dtype=dtype, device=device)
 
-        return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids)
+        return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids, left_pad_mask, position_ids)
 
     def to_list(self) -> List:
         input_list = [self.encoder_input_ids]
         if self.decoder_input_ids is not None:
             input_list.append(self.decoder_input_ids)
+        input_list.append(self.left_pad_mask)
+        input_list.append(self.position_ids)
         return input_list
 
 
@@ -133,7 +143,9 @@ def export_onnx(
 
         # TODO : Investigate whether copy of model if needed
         cloned_model = copy.deepcopy(model).to(device)
-        out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids)
+        out = cloned_model(
+            inputs.encoder_input_ids, inputs.decoder_input_ids, inputs.left_pad_mask, inputs.position_ids
+        )
         present = out[2]
         present_names = PastKeyValuesHelper.get_input_names(present, encoder=True)
 
@@ -178,6 +190,18 @@ def export_onnx(
                 1: "decode_sequence_length",
             }
 
+        input_names.append("left_pad_mask")
+        dynamic_axes["left_pad_mask"] = {
+            0: "batch_size",
+            2: "decode_sequence_length",
+            3: "decode_sequence_length",
+        }
+        input_names.append("position_ids")
+        dynamic_axes["position_ids"] = {
+            0: "batch_size",
+            1: "decode_sequence_length",
+        }
+
         for name in present_names:
             if "cross" in name:
                 dynamic_axes[name] = {
@@ -243,6 +267,8 @@ def onnxruntime_inference(ort_session, inputs: WhisperEncoderDecoderInitInputs):
         }
         if inputs.decoder_input_ids is not None:
             ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
+        ort_inputs["left_pad_mask"] = numpy.ascontiguousarray(inputs.left_pad_mask.cpu().numpy())
+        ort_inputs["position_ids"] = numpy.ascontiguousarray(inputs.position_ids.cpu().numpy())
 
         ort_outputs = ort_session.run(None, ort_inputs)
         return ort_outputs
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
index e2dc79ca247ce..79d717c17b67c 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
@@ -12,6 +12,9 @@
 
 import numpy as np
 import torch
+from float16 import float_to_float16_max_diff
+from onnx_model import OnnxModel
+from optimizer import optimize_model
 from packaging import version
 from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
 from transformers import __version__ as transformers_version
@@ -22,9 +25,6 @@
 from onnxruntime import InferenceSession
 
 sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
-from float16 import float_to_float16_max_diff
-from onnx_model import OnnxModel
-from optimizer import optimize_model
 
 logger = logging.getLogger(__name__)
 
@@ -287,12 +287,14 @@ def optimize_onnx(
         auto_mixed_precision: bool = True,
         use_gpu: bool = False,
         provider: str = "cpu",
+        model_impl: str = "hf",
     ):
         """Optimize ONNX model with an option to convert it to use mixed precision."""
 
         from fusion_options import FusionOptions
 
         optimization_options = FusionOptions("bart")
+        optimization_options.model_impl = model_impl
         optimization_options.use_multi_head_attention = True
         optimization_options.disable_multi_head_attention_bias = provider == "rocm"
 
@@ -341,8 +343,6 @@ def verify_onnx(
             logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
             os.system(install_cmd)
 
-        from datasets import load_dataset
-
         ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
         input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
 
@@ -374,6 +374,10 @@ def verify_onnx(
             "tensor(uint8)": np.uint8,
         }
 
+        # Generate prompts
+        # prompt_text = ""
+        # prompt_ids = processor.get_prompt_ids(prompt_text)
+
         use_extra_decoding_ids = "extra_decoding_ids" in ort_names
         for name, dtype in zip(ort_names, ort_dtypes):
             if name == "input_features":
@@ -395,6 +399,10 @@ def verify_onnx(
                 inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype])
             elif name == "extra_decoding_ids":
                 inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0)
+            elif name == "left_pad_mask":
+                inputs[name] = np.zeros((batch_size, 1, 4, 4), dtype=ort_to_np[dtype])
+            elif name == "position_ids":
+                inputs[name] = np.repeat(np.array([[0, 1, 2, 3]], dtype=ort_to_np[dtype]), batch_size, 0)
             else:
                 inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
         ort_outputs = ort_session.run(None, inputs)[0][0]
@@ -411,6 +419,7 @@ def verify_onnx(
                 " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
             )
             pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)
+
             ort_expected_transcription = (
                 " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
             )
@@ -421,5 +430,4 @@ def verify_onnx(
             )
             if parity:
                 max_diff = 0
-
         return max_diff
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
index 941f61cf7cc29..13cdffaa836c3 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
@@ -30,6 +30,8 @@ def forward(
         tokens,
         audio_features,
         past=None,
+        left_pad_mask=None,
+        position_ids=None,
     ):
         # Create a kv_cache for past_values
         past_kv_cache = dict()
@@ -47,7 +49,9 @@ def forward(
         if not self.kv_cache:
             self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks()
 
-        logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache)
+        logits = self.whisper_decoder(
+            tokens, audio_features, kv_cache=past_kv_cache, left_pad_mask=left_pad_mask, position_ids=position_ids
+        )
 
         # Add concat node for past values
         if past is not None:
diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py
index 61a786d7af60b..de0a418ae4daa 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_bart.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py
@@ -7,6 +7,7 @@
 
 from fusion_attention import AttentionMask
 from fusion_bart_attention import FusionBartAttention
+from fusion_bart_attention_openai import FusionBartAttentionOpenai
 from fusion_options import FusionOptions
 from fusion_reshape import FusionReshape
 from onnx import numpy_helper
@@ -124,7 +125,12 @@ class BartOnnxModel(BertOnnxModel):
     def __init__(self, model, num_heads, hidden_size, model_impl="hf"):
         super().__init__(model, num_heads, hidden_size)
         self.attention_mask = AttentionMask(self)
-        self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
+        if model_impl == "openai":
+            self.attention_fusion = FusionBartAttentionOpenai(
+                self, self.hidden_size, self.num_heads, self.attention_mask
+            )
+        else:
+            self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
         self.bart_reshape_fusion_preprocess = FusionBartReshape(self)
 
     def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py
index 431e64509e3cc..2c77aeb784ec7 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_bert.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py
@@ -9,6 +9,7 @@
 from convert_to_packing_mode import PackingMode
 from fusion_attention import AttentionMask, FusionAttention
 from fusion_bart_attention import FusionBartAttention
+from fusion_bart_attention_openai import FusionBartAttentionOpenai
 from fusion_biasgelu import FusionBiasGelu
 from fusion_embedlayer import FusionEmbedLayerNormalization
 from fusion_fastgelu import FusionFastGelu
@@ -349,7 +350,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo
 
         if options is not None:
             self.attention_mask.set_mask_format(options.attention_mask_format)
-            if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention):
+            if (
+                options.use_multi_head_attention
+                and not isinstance(self.attention_fusion, FusionBartAttention)
+                and not isinstance(self.attention_fusion, FusionBartAttentionOpenai)
+            ):
                 self.attention_fusion = FusionAttention(
                     self,
                     self.hidden_size,
diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py
index ce0be6b3449ed..584452bdf004a 100644
--- a/onnxruntime/python/tools/transformers/optimizer.py
+++ b/onnxruntime/python/tools/transformers/optimizer.py
@@ -226,7 +226,10 @@ def optimize_by_fusion(
     if optimization_options is None:
         optimization_options = FusionOptions(model_type)
 
-    optimizer = optimizer_class(model, num_heads, hidden_size)
+    if optimization_options.model_impl == "openai":
+        optimizer = optimizer_class(model, num_heads, hidden_size, model_impl=optimization_options.model_impl)
+    else:
+        optimizer = optimizer_class(model, num_heads, hidden_size)
 
     optimizer.optimize(optimization_options)