Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Export of Openai Whisper [Batched decoding ver] #18815

2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
ORT_RETURN_IF_ERROR(whisper_encoder_subgraph_->Setup(session_state, subgraph_session_state));
encoder_feeds_fetches_manager_ = whisper_encoder_subgraph_->GetFeedsFetchesManager();

ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs != 2,
ORT_RETURN_IF(whisper_encoder_subgraph_->num_subgraph_inputs < 2,
"Encoder subgraph shall have 2 inputs (encoder_input_ids, decoder_input_ids)");
} else if (attribute_name == "decoder") {
ORT_ENFORCE(whisper_decoder_subgraph_ == nullptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
const OrtValue* encoder_input_ids_value = this->context_.GetInputOrtValue(0);
const Tensor& encoder_input_ids = encoder_input_ids_value->Get<Tensor>();

const OrtValue* left_pad_mask_value = this->context_.GetInputOrtValue(15);
const Tensor& left_pad_mask = left_pad_mask_value->Get<Tensor>();

const OrtValue* position_ids_value = this->context_.GetInputOrtValue(16);
const Tensor& position_ids = position_ids_value->Get<Tensor>();

BeamSearchCpuState cpu_state{*parameters,
this->cpu_allocator_,
this->IsCuda(),
Expand All @@ -166,6 +172,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds(
encoder_input_ids,
initial_decoder_input_ids_value,
left_pad_mask,
position_ids,
parameters->decoder_start_token_id,
this->implicit_inputs_,
encoder_feeds,
Expand Down
36 changes: 36 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,40 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
}
}

left_pad_mask = gsl::span<float>();
if (this->model_type == IGenerationParameters::kModelTypeWhisper && left_pad_mask_input_id > 0) {
const Tensor* left_pad_mask_tensor = context->Input<Tensor>(left_pad_mask_input_id);
if (left_pad_mask_tensor != nullptr) {
const auto& left_pad_mask_tensor_dims = left_pad_mask_tensor->Shape().GetDims();
ORT_ENFORCE(left_pad_mask_tensor_dims.size() == 4,
"left_pad_mask_tensor shall have 4 dimensions. Got ",
left_pad_mask_tensor_dims.size());
ORT_ENFORCE(left_pad_mask_tensor_dims[0] == batch_size,
"left_pad_mask_tensor first dim not same as batch_size. Got ",
left_pad_mask_tensor_dims[0], ", expecting ", batch_size);
if (left_pad_mask_tensor->Shape().Size() > 0) {
left_pad_mask = gsl::span<const float>(left_pad_mask_tensor->Data<float>(), (size_t)left_pad_mask_tensor->Shape().Size());
}
}
}

position_ids = gsl::span<int32_t>();
if (this->model_type == IGenerationParameters::kModelTypeWhisper && position_ids_input_id > 0) {
const Tensor* position_ids_tensor = context->Input<Tensor>(position_ids_input_id);
if (position_ids_tensor != nullptr) {
const auto& position_ids_tensor_dims = position_ids_tensor->Shape().GetDims();
ORT_ENFORCE(position_ids_tensor_dims.size() == 2,
"position_ids_tensor shall have 2 dimensions. Got ",
position_ids_tensor_dims.size());
ORT_ENFORCE(position_ids_tensor_dims[0] == batch_size,
"position_ids_tensor first dim not same as batch_size. Got ",
position_ids_tensor_dims[0], ", expecting ", batch_size);
if (position_ids_tensor->Shape().Size() > 0) {
position_ids = gsl::span<const int32_t>(position_ids_tensor->Data<int32_t>(), (size_t)position_ids_tensor->Shape().Size());
}
}
}

if (this->model_type == IGenerationParameters::kModelTypeGpt) {
sequence_length = static_cast<int>(dims[1]);
} else if (this->model_type == IGenerationParameters::kModelTypeWhisper) {
Expand Down Expand Up @@ -156,6 +190,8 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info)
no_speech_token = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token", -1LL));
cross_qk_layer_head_input_id = 12;
extra_decoding_ids_input_id = 13;
left_pad_mask_input_id = 15;
position_ids_input_id = 16;
cross_qk_output_id = 3;
no_speech_probs_output_id = 4;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,10 +877,14 @@ template <typename T>
Status CreateWhisperEncoderInputs(
const Tensor* original_encoder_input_features,
const OrtValue* original_decoder_input_ids_value,
const Tensor* original_left_pad_mask,
const Tensor* original_position_ids,
int start_token_id,
AllocatorPtr allocator,
OrtValue& encoder_input_features,
OrtValue& decoder_input_ids) {
OrtValue& decoder_input_ids,
OrtValue& left_pad_mask,
OrtValue& position_ids) {
const TensorShape& input_features_shape = original_encoder_input_features->Shape();
ORT_ENFORCE(input_features_shape.NumDimensions() == 3);
const int64_t& batch_size = input_features_shape[0];
Expand All @@ -898,6 +902,20 @@ Status CreateWhisperEncoderInputs(
allocator->Info(),
encoder_input_features);

const TensorShape& left_pad_mask_shape = original_left_pad_mask->Shape();
Tensor::InitOrtValue(DataTypeImpl::GetType<T>(),
left_pad_mask_shape,
const_cast<Tensor*>(original_left_pad_mask)->MutableData<T>(),
allocator->Info(),
left_pad_mask);

const TensorShape& position_ids_shape = original_position_ids->Shape();
Tensor::InitOrtValue(element_type,
position_ids_shape,
const_cast<Tensor*>(original_position_ids)->MutableData<int32_t>(),
allocator->Info(),
position_ids);

// decoder_input_ids is optional.
if (original_decoder_input_ids_value == nullptr) {
// Filled decoder_input_ids with start token ID
Expand Down Expand Up @@ -1071,18 +1089,26 @@ template Status ExpandBuffer<MLFloat16>(
template Status CreateWhisperEncoderInputs<float>(
const Tensor* original_encoder_input_features,
const OrtValue* original_decoder_input_ids_value,
const Tensor* original_left_pad_mask,
const Tensor* original_position_ids,
int start_token_id,
AllocatorPtr allocator,
OrtValue& encoder_input_features,
OrtValue& decoder_input_ids);
OrtValue& decoder_input_ids,
OrtValue& left_pad_mask,
OrtValue& position_ids);

template Status CreateWhisperEncoderInputs<MLFloat16>(
const Tensor* original_encoder_input_features,
const OrtValue* original_decoder_input_ids_value,
const Tensor* original_left_pad_mask,
const Tensor* original_position_ids,
int start_token_id,
AllocatorPtr allocator,
OrtValue& encoder_input_features,
OrtValue& decoder_input_ids);
OrtValue& decoder_input_ids,
OrtValue& left_pad_mask,
OrtValue& position_ids);

Status UpdateDecoderCrossQK(
[[maybe_unused]] int iteration_number,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,14 @@ using UpdateDecoderFeedsFunc = std::function<Status(
using CreateWhisperEncoderInputsFunc = std::function<Status(
const Tensor* original_encoder_input_features,
const OrtValue* original_decoder_input_ids_value,
const Tensor* original_left_pad_mask,
const Tensor* original_position_ids,
int start_token_id,
AllocatorPtr allocator,
OrtValue& encoder_input_ids,
OrtValue& decoder_input_ids)>;
OrtValue& decoder_input_ids,
OrtValue& left_pad_mask,
OrtValue& position_ids)>;

template <typename T>
using ExpandBufferFunc = std::function<Status(
Expand Down Expand Up @@ -376,10 +380,14 @@ template <typename T>
Status CreateWhisperEncoderInputs(
const Tensor* original_encoder_input_features,
const OrtValue* original_decoder_input_ids_value,
const Tensor* original_left_pad_mask,
const Tensor* original_position_ids,
int start_token_id,
AllocatorPtr allocator,
OrtValue& encoder_input_ids,
OrtValue& decoder_input_ids);
OrtValue& decoder_input_ids,
OrtValue& left_pad_mask,
OrtValue& position_ids);

// ---------------------------------------------------------------
// Utility Functions
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,15 @@ struct IGenerationParameters {
// Parameters for whisper model
bool decoder_output_cross_qk = false;
gsl::span<const int32_t> extra_decoding_ids;
gsl::span<const float> left_pad_mask;
gsl::span<const int32_t> position_ids;
int32_t no_speech_token = -1;
void* no_speech_probs = nullptr;

int cross_qk_layer_head_input_id = -1;
int extra_decoding_ids_input_id = -1;
int left_pad_mask_input_id = -1;
int position_ids_input_id = -1;
int cross_qk_output_id = -1;
int no_speech_probs_output_id = -1;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace transformers {

Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
ORT_RETURN_IF(num_subgraph_inputs != 2, "expect 2 inputs, got:", num_subgraph_inputs);
ORT_RETURN_IF(num_subgraph_inputs < 2, "expect 2 inputs, got:", num_subgraph_inputs);

ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
ORT_RETURN_IF((static_cast<int>(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0,
Expand Down Expand Up @@ -95,6 +95,8 @@ Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr
Status WhisperEncoderSubgraph::CreateInitialFeeds(
const Tensor& original_encoder_input_ids,
const OrtValue* original_decoder_input_ids_value,
const Tensor& original_left_pad_mask,
const Tensor& original_position_ids,
int start_token_id,
const std::vector<const OrtValue*>& implicit_inputs,
std::vector<OrtValue>& feeds,
Expand All @@ -117,20 +119,26 @@ Status WhisperEncoderSubgraph::CreateInitialFeeds(
ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr");

OrtValue encoder_input_ids;
OrtValue left_pad_mask;
OrtValue position_ids;
ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids,
original_decoder_input_ids_value,
&original_left_pad_mask,
&original_position_ids,
start_token_id,
cpu_allocator,
encoder_input_ids,
decoder_input_ids));
decoder_input_ids,
left_pad_mask,
position_ids));

const IExecutionProvider* provider = GetProvider();
AllocatorPtr default_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeDefault));
AllocatorPtr pinned_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeCPU));
const OrtMemoryInfo& location = default_allocator->Info();
ORT_RETURN_IF_ERROR(add_to_feeds_func(
ort_stream,
{encoder_input_ids, decoder_input_ids},
{encoder_input_ids, decoder_input_ids, left_pad_mask, position_ids},
feeds,
buffer,
default_allocator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class WhisperEncoderSubgraph : public T5EncoderSubgraph {
Status CreateInitialFeeds(
const Tensor& encoder_input_ids,
const OrtValue* original_decoder_input_ids_value,
const Tensor& left_pad_mask,
const Tensor& position_ids,
int start_token_id,
const std::vector<const OrtValue*>& implicit_inputs,
std::vector<OrtValue>& feeds,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 15) // 'left_pad_mask' needs to be on CPU
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update new input descriptions

.InputMemoryType(OrtMemTypeCPUInput, 16) // 'position_ids' needs to be on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1,
"are treated as stop of the extra_decoding_ids for corresponding batch.",
"I", OpSchema::Optional)
.Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional)
.Input(15, "left_pad_mask", "The mask is added to qk node in the qkv attention function. Shape is (batch_size, initial_sequence_length)", "T", OpSchema::Optional)
.Input(16, "position_ids", "Used to select indices of positional embeddings. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
Expand Down
Loading