Skip to content

Commit

Permalink
Correct datatype for position_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Dec 19, 2023
1 parent 09fb943 commit 1c2d388
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
}
}

left_pad_mask = gsl::span<int32_t>();
left_pad_mask = gsl::span<float>();
if (this->model_type == IGenerationParameters::kModelTypeWhisper && left_pad_mask_input_id > 0) {
const Tensor* left_pad_mask_tensor = context->Input<Tensor>(left_pad_mask_input_id);
if (left_pad_mask_tensor != nullptr) {
Expand All @@ -76,7 +76,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
"left_pad_mask_tensor first dim not same as batch_size. Got ",
left_pad_mask_tensor_dims[0], ", expecting ", batch_size);
if (left_pad_mask_tensor->Shape().Size() > 0) {
left_pad_mask = gsl::span<const int32_t>(left_pad_mask_tensor->Data<int32_t>(), (size_t)left_pad_mask_tensor->Shape().Size());
left_pad_mask = gsl::span<const float>(left_pad_mask_tensor->Data<float>(), (size_t)left_pad_mask_tensor->Shape().Size());

Check warning on line 79 in onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc:79: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 79 in onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<size_t>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc:79: Using C-style cast. Use static_cast<size_t>(...) instead [readability/casting] [4]
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -910,9 +910,9 @@ Status CreateWhisperEncoderInputs(
left_pad_mask);

const TensorShape& position_ids_shape = original_position_ids->Shape();
Tensor::InitOrtValue(DataTypeImpl::GetType<T>(),
Tensor::InitOrtValue(element_type,
position_ids_shape,
const_cast<Tensor*>(original_position_ids)->MutableData<T>(),
const_cast<Tensor*>(original_position_ids)->MutableData<int32_t>(),
allocator->Info(),
position_ids);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ struct IGenerationParameters {
// Parameters for whisper model
bool decoder_output_cross_qk = false;
gsl::span<const int32_t> extra_decoding_ids;
gsl::span<const int32_t> left_pad_mask;
gsl::span<const float> left_pad_mask;
gsl::span<const int32_t> position_ids;
int32_t no_speech_token = -1;
void* no_speech_probs = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
qk_nodes_2 = self.model.match_parent_path(
matmul_qkv, ["Softmax", "Add", "Add", "MatMul"], [0, 0, 0, 0]
)
qk_nodes_3 = self.model.match_parent_path(
matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]
)
if qk_nodes_1 is not None:
_, matmul_qk = qk_nodes_1
qk_nodes = qk_nodes_1
elif qk_nodes_2 is not None:
_, add_left_pad_mask, add_qk, matmul_qk = qk_nodes_2
qk_nodes = qk_nodes_2
elif qk_nodes_3 is not None:
_, add_qk, matmul_qk = qk_nodes_3
qk_nodes = qk_nodes_3
else:
return

Expand Down Expand Up @@ -362,7 +368,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1
encoder_attention = one_root_input and qk_nodes == qk_nodes_1
decoder_attention = one_root_input and qk_nodes == qk_nodes_2
decoder_attention_with_past = decoder_attention and past_k and past_v
decoder_attention_with_past = one_root_input and qk_nodes == qk_nodes_3 and past_k and past_v
decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1
decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def chain_model(args):
graph_inputs.append(decoder_input_ids)

left_pad_mask = helper.make_tensor_value_info(
"left_pad_mask", TensorProto.INT32, ["batch_size", 1, "initial_sequence_length", "initial_sequence_length",]
"left_pad_mask", TensorProto.FLOAT, ["batch_size", 1, "initial_sequence_length", "initial_sequence_length",]
)
graph_inputs.append(left_pad_mask)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def verify_onnx(
elif name == "left_pad_mask":
inputs[name] = np.zeros((batch_size, 1, 4, 4), dtype=ort_to_np[dtype])
elif name == "position_ids":
inputs[name] = np.zeros((batch_size, 4), dtype=ort_to_np[dtype])
inputs[name] = np.repeat(np.array([[0, 1, 2, 3]], dtype=ort_to_np[dtype]), batch_size, 0)
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0][0]
Expand Down

0 comments on commit 1c2d388

Please sign in to comment.