From 32dc9a303a5ecb5953074072925b60322b15cf34 Mon Sep 17 00:00:00 2001 From: LiangGao Date: Mon, 9 Dec 2024 15:16:47 +0800 Subject: [PATCH 01/15] Add Group Query Attention support with OV base OPs --- .../com.microsoft/group_query_attention.cpp | 365 ++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp new file mode 100644 index 00000000000000..953d5d585697c9 --- /dev/null +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -0,0 +1,365 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "core/null_node.hpp" +#include "core/operator_set.hpp" +#include "openvino/frontend/exception.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/floor.hpp" +#include "openvino/op/floor_mod.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/greater_eq.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/less_eq.hpp" +#include "openvino/op/log.hpp" +#include "openvino/op/logical_not.hpp" +#include "openvino/op/logical_or.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/negative.hpp" +#include "openvino/op/pad.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils/split.hpp" + +using namespace ov::op; +using ov::Shape; + +namespace ov { +namespace frontend { +namespace onnx { +namespace com_microsoft { +namespace detail { +namespace { + +std::shared_ptr get_present_state(const std::shared_ptr& K, + const std::shared_ptr& V, + const ov::OutputVector& op_inputs); +std::shared_ptr rotaryEmbedding(std::shared_ptr input, + std::shared_ptr past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved); +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); +} // namespace +} // namespace detail + +namespace opset_1 { +ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + const auto do_rotary = node.get_attribute_value("do_rotary"); + const auto num_heads = node.get_attribute_value("num_heads"); + const auto kv_num_heads = node.get_attribute_value("kv_num_heads"); + const auto scale = node.get_attribute_value("scale", 0.0f); + const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved"); + // TODO: add softcap support + + auto nodes = node.get_ov_inputs(); + const auto node_shape = std::make_shared(nodes[0]); + const auto batch_size = detail::get_dimensions(node_shape, {0}); + const auto current_seqlen_size = detail::get_dimensions(node_shape, {1}); + const auto hidden_size = detail::get_dimensions(node_shape, {2}); + const auto total_num_heads_node = + v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); + auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); + + // Q K V (batch_size, sequence_len, num_heads, head_size) + ov::Output oQ, oK, oV; + int index = 0; + oQ = nodes[index++]; + oK = nodes[index++]; + oV = nodes[index++]; + if (ov::op::util::is_null(oK)) { + // Handle the packed QKV + auto packed_qkv_shape = std::make_shared( + ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, + 0); + auto inputs_qkv = std::make_shared(oQ, packed_qkv_shape, false); + // split the node into 3 even parts Q, K, V with shape (batch_size, sequence_len, num_head, head_size) + auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 2); + oQ = split[0]; + oK = split[1]; + oV = split[2]; + } + + std::shared_ptr Q, K, V; + + const auto& past_key = nodes[index++].get_node_shared_ptr(); + const auto& past_value = nodes[index++].get_node_shared_ptr(); + const auto& seqlens_k = nodes[index++].get_node_shared_ptr(); + const auto& total_sequence_length = nodes[index++]; // unused, it's not always equal (seqlens_k + 1) + const auto& cos_cache = nodes[index++].get_node_shared_ptr(); + const auto& sin_cache = nodes[index++].get_node_shared_ptr(); + + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + Q = std::make_shared(oQ, perm); + K = std::make_shared(oK, perm); + V = std::make_shared(oV, perm); + + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + auto seqlens_elemi64 = std::make_shared(seqlens_k, ov::element::i64); + auto real_seqlens = std::make_shared(seqlens_elemi64, one); + + // Only consider batch is 1 + auto seqlens_1d = std::make_shared(real_seqlens, one, false); + auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); + + if (do_rotary) { + Q = detail::rotaryEmbedding(Q, + past_sequence_length, + seqlens_1d, + cos_cache, + sin_cache, + head_size_node, + rotary_interleaved); + K = detail::rotaryEmbedding(K, + past_sequence_length, + seqlens_1d, + cos_cache, + sin_cache, + head_size_node, + rotary_interleaved); + } + // present = concat(K, V) if 'past' input is unavailable + // or + // present = concat(past, K, V) + auto construct_kv_cache = [&](const std::shared_ptr& past, const std::shared_ptr& current) { + auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); + auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); + return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); + }; + + K = construct_kv_cache(past_key, K); + V = construct_kv_cache(past_value, V); + auto present_k = K; + auto present_v = V; + + std::shared_ptr alpha; + if (scale == 0.0f) { + alpha = std::make_shared(head_size_node); + } else { + alpha = v0::Constant::create(ov::element::f32, ov::Shape{}, {1.0f / scale}); + } + const size_t kv_num_heads_factor = num_heads / kv_num_heads; + if (kv_num_heads_factor > 1) { + const auto kv_shape = std::make_shared(K); + // (batch_size, num_heads, sequence_len, head_size) + const auto kv_shape_prev_2 = detail::get_dimensions(kv_shape, {0, 1}); + const auto kv_shape_last_2 = detail::get_dimensions(kv_shape, {2, 3}); + auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); + K = std::make_shared(K, new_kv_shape, false); + V = std::make_shared(V, new_kv_shape, false); + K = std::make_shared(ov::NodeVector(kv_num_heads_factor, K), 2); + V = std::make_shared(ov::NodeVector(kv_num_heads_factor, V), 2); + auto q_shape = std::make_shared(Q); + // (batch_size, num_heads, sequence_len, head_size) + const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); + auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); + K = std::make_shared(K, extended_kv_shape, false); + V = std::make_shared(V, extended_kv_shape, false); + } + // compute softmax((Q x K') / sqrt(head_size)) + std::shared_ptr softmax_input = std::make_shared(Q, K, false, true); + softmax_input = std::make_shared(softmax_input, alpha); + + // need to apply low-triangle mask to attention score. + auto past_seq_len_scalar = std::make_shared(past_sequence_length, one_without_shape, false); + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); + std::shared_ptr mask_per_line_node = + std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), + seqlens_1d_scalar, + one_without_shape, + ov::element::i64); + auto mask_shape = std::make_shared(ov::NodeVector{one, one, one, seqlens_1d}, 0); + mask_per_line_node = std::make_shared(mask_per_line_node, mask_shape, false); + auto pad_end_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, seqlens_1d}, 0); + auto paded_mask = std::make_shared(mask_per_line_node, pad_end_shape); + std::shared_ptr compare_mask = + std::make_shared(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64); + auto compare_range_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, one}, 0); + compare_mask = std::make_shared(compare_mask, compare_range_shape, false); + auto lower_triangular_mask = std::make_shared(paded_mask, compare_mask); + auto higher_triangular_mask = std::make_shared(paded_mask, compare_mask); + auto negtive_const = v0::Constant::create(ov::element::f32, ov::Shape{}, {-1e20f}); + + auto convert_mask = std::make_shared(higher_triangular_mask, ov::element::f32); + auto input_offset_data = std::make_shared(convert_mask, negtive_const); + + convert_mask = std::make_shared(lower_triangular_mask, ov::element::f32); + auto softmax_input_masked = std::make_shared(softmax_input, convert_mask); + std::shared_ptr softmax_input_added = std::make_shared(softmax_input_masked, input_offset_data); + // softmax((Q x K' + mask) / sqrt(head_size)) + const auto softmax = std::make_shared(softmax_input_added, 3); + + // softmax((Q x K' + mask) / sqrt(head_size)) x V + std::shared_ptr output = std::make_shared(softmax, V); + + // transpose the result from (batch_size, num_heads, sequence_length, head_size) + // to (batch_size, sequence_length, num_heads, head_size) + output = std::make_shared(output, perm); + auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); + // reshape the result from (batch_size, sequence_length, num_heads, head_size) + // to (batch_size, sequence_length, num_heads * head_size) + output = std::make_shared(output, dim_merge_shape, true); + + return {output, present_k, present_v}; +} + +ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); + +} // namespace opset_1 + +namespace detail { +namespace { + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} + +std::shared_ptr rotaryEmbedding(std::shared_ptr input, + std::shared_ptr past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved) { + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cache_shape, {-1}); + auto cache_1st_dim = get_dimensions(cache_shape, {0}); + + // TODO: check the shape + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + // auto dim_head_size = get_dimensions(input_shape, {3}); + // half_last_dim is same as cos_cache + std::shared_ptr half_last_dim = cache_last_dim; + + auto real_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, dim_head_size}, 0); + auto slice_cache_dim_shape = seqlen_k; + + // auto end_lens = std::make_shared(half_last_dim, one); + // auto masks = std::make_shared(one, + // zero, + // end_lens, + // op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}), + // ov::op::PadMode::CONSTANT); + auto masks = std::make_shared(one, half_last_dim); + + if (interleaved) { + auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); + auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); + auto reshaped_input = std::make_shared(input, split_input_shape, false); + auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_one); + auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_one); + + auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -1); + + auto mask_shape = std::make_shared(ov::NodeVector{half_last_dim, one}, 0); + auto reshaped_mask = std::make_shared(masks, mask_shape, false); + auto negtive_mask = std::make_shared(reshaped_mask); + auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -1); + auto real_mask = std::make_shared(concat_mask, dim_head_size, false); + auto mask_f32 = std::make_shared(real_mask, ov::element::f32); + + auto real_input0 = std::make_shared(reshaped_input, input_shape, false); + auto real_input1 = std::make_shared(second_input, input_shape, false); + + auto new_cache_shape = std::make_shared(ov::NodeVector{cache_shape, two}, 0); + auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_shape, one}, 0); + auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); + auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); + auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); + auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); + auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); + auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); + auto sliced_cos_input = + std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto sliced_sin_input = + std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto add_input0 = std::make_shared(real_input0, sliced_cos_input); + auto add_input1 = std::make_shared(real_input1, sliced_sin_input); + auto multi_input1 = std::make_shared(add_input1, mask_f32); + auto result = std::make_shared(add_input0, multi_input1); + return result; + } else { + auto negtive_two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-2}); + auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, two, half_last_dim}, 0); + auto reshaped_input = std::make_shared(input, split_input_shape, false); + auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_two); + auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_two); + + auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -2); + + auto mask_shape = std::make_shared(ov::NodeVector{one, half_last_dim}, 0); + auto reshaped_mask = std::make_shared(masks, mask_shape, false); + auto negtive_mask = std::make_shared(reshaped_mask); + auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -2); + auto real_mask = std::make_shared(concat_mask, dim_head_size, false); + auto mask_f32 = std::make_shared(real_mask, ov::element::f32); + + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 1, 2, 4, 3}); + auto input0 = reshaped_input; // std::make_shared(reshaped_input, perm); + auto input1 = second_input; // std::make_shared(second_input, perm); + auto real_input0 = std::make_shared(input0, input_shape, false); + auto real_input1 = std::make_shared(input1, input_shape, false); + + auto new_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, two, cache_last_dim}, 0); + auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, one, cache_last_dim}, 0); + auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); + auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); + auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); + auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); + auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); + auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); + // TODO: change zero to sequence_K + auto sliced_cos_input = + std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto sliced_sin_input = + std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto add_input0 = std::make_shared(real_input0, sliced_cos_input); + auto add_input1 = std::make_shared(real_input1, sliced_sin_input); + auto multi_input1 = std::make_shared(add_input1, mask_f32); + auto result = std::make_shared(add_input0, multi_input1); + return result; + } +} +} // namespace +} // namespace detail +} // namespace com_microsoft +} // namespace onnx +} // namespace frontend +} // namespace ov From def6bd1ca8577fcc5ad5c09dfcba10b391a2446f Mon Sep 17 00:00:00 2001 From: LiangGao Date: Tue, 31 Dec 2024 15:41:52 +0800 Subject: [PATCH 02/15] Update the transformation code --- .../com.microsoft/group_query_attention.cpp | 138 +++++++----------- 1 file changed, 54 insertions(+), 84 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 953d5d585697c9..b4e85812afffcf 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -38,6 +38,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/op/convert_like.hpp" #include "utils/split.hpp" using namespace ov::op; @@ -53,8 +54,8 @@ namespace { std::shared_ptr get_present_state(const std::shared_ptr& K, const std::shared_ptr& V, const ov::OutputVector& op_inputs); -std::shared_ptr rotaryEmbedding(std::shared_ptr input, - std::shared_ptr past_seqlen, +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, std::shared_ptr seqlen_k, std::shared_ptr cos_cache, std::shared_ptr sin_cache, @@ -82,40 +83,41 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + // Q K V (batch_size, sequence_len, num_heads, head_size) - ov::Output oQ, oK, oV; + ov::Output Q, K, V; int index = 0; - oQ = nodes[index++]; - oK = nodes[index++]; - oV = nodes[index++]; - if (ov::op::util::is_null(oK)) { + Q = nodes[index++]; + K = nodes[index++]; + V = nodes[index++]; + if (ov::op::util::is_null(K)) { // Handle the packed QKV auto packed_qkv_shape = std::make_shared( ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, 0); - auto inputs_qkv = std::make_shared(oQ, packed_qkv_shape, false); - // split the node into 3 even parts Q, K, V with shape (batch_size, sequence_len, num_head, head_size) - auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 2); - oQ = split[0]; - oK = split[1]; - oV = split[2]; + auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); + // (batch_size, sequence_len, num_head, head_size) + inputs_qkv = std::make_shared(inputs_qkv, perm); + // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) + auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); + Q = split[0]; + K = split[1]; + V = split[2]; + } else { + Q = std::make_shared(Q, perm); + K = std::make_shared(K, perm); + V = std::make_shared(V, perm); } - std::shared_ptr Q, K, V; - - const auto& past_key = nodes[index++].get_node_shared_ptr(); - const auto& past_value = nodes[index++].get_node_shared_ptr(); + const auto& past_key = nodes[index++]; + const auto& past_value = nodes[index++]; const auto& seqlens_k = nodes[index++].get_node_shared_ptr(); const auto& total_sequence_length = nodes[index++]; // unused, it's not always equal (seqlens_k + 1) const auto& cos_cache = nodes[index++].get_node_shared_ptr(); const auto& sin_cache = nodes[index++].get_node_shared_ptr(); - // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) - auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); - Q = std::make_shared(oQ, perm); - K = std::make_shared(oK, perm); - V = std::make_shared(oV, perm); - auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); @@ -146,7 +148,7 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { // present = concat(K, V) if 'past' input is unavailable // or // present = concat(past, K, V) - auto construct_kv_cache = [&](const std::shared_ptr& past, const std::shared_ptr& current) { + auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); @@ -172,8 +174,8 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); K = std::make_shared(K, new_kv_shape, false); V = std::make_shared(V, new_kv_shape, false); - K = std::make_shared(ov::NodeVector(kv_num_heads_factor, K), 2); - V = std::make_shared(ov::NodeVector(kv_num_heads_factor, V), 2); + K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); + V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); auto q_shape = std::make_shared(Q); // (batch_size, num_heads, sequence_len, head_size) const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); @@ -193,24 +195,22 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { seqlens_1d_scalar, one_without_shape, ov::element::i64); - auto mask_shape = std::make_shared(ov::NodeVector{one, one, one, seqlens_1d}, 0); - mask_per_line_node = std::make_shared(mask_per_line_node, mask_shape, false); - auto pad_end_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, seqlens_1d}, 0); - auto paded_mask = std::make_shared(mask_per_line_node, pad_end_shape); - std::shared_ptr compare_mask = + mask_per_line_node = std::make_shared(mask_per_line_node, zero); + auto minus_inf = v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits::infinity()}); + auto mask_shape = std::make_shared(ov::NodeVector{current_seqlen_size, seqlens_1d}, 0); + auto compare_mask = std::make_shared(mask_per_line_node, mask_shape); + + std::shared_ptr vertical_range = std::make_shared(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64); - auto compare_range_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, one}, 0); - compare_mask = std::make_shared(compare_mask, compare_range_shape, false); - auto lower_triangular_mask = std::make_shared(paded_mask, compare_mask); - auto higher_triangular_mask = std::make_shared(paded_mask, compare_mask); - auto negtive_const = v0::Constant::create(ov::element::f32, ov::Shape{}, {-1e20f}); - - auto convert_mask = std::make_shared(higher_triangular_mask, ov::element::f32); - auto input_offset_data = std::make_shared(convert_mask, negtive_const); - - convert_mask = std::make_shared(lower_triangular_mask, ov::element::f32); - auto softmax_input_masked = std::make_shared(softmax_input, convert_mask); - std::shared_ptr softmax_input_added = std::make_shared(softmax_input_masked, input_offset_data); + vertical_range = std::make_shared(vertical_range, one); + + auto triu = std::make_shared(compare_mask, vertical_range); + auto typed_zero = std::make_shared(zero, softmax_input); + auto typed_minus_inf = std::make_shared(minus_inf, softmax_input); + auto minus_inf_mask = std::make_shared(typed_minus_inf, mask_shape); + auto atten_mask = std::make_shared(triu, minus_inf_mask, typed_zero); + + std::shared_ptr softmax_input_added = std::make_shared(softmax_input, atten_mask); // softmax((Q x K' + mask) / sqrt(head_size)) const auto softmax = std::make_shared(softmax_input_added, 3); @@ -245,8 +245,8 @@ std::shared_ptr get_dimensions(const std::shared_ptr& node, return get_dimensions(std::make_shared(node), dims); } -std::shared_ptr rotaryEmbedding(std::shared_ptr input, - std::shared_ptr past_seqlen, +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, std::shared_ptr seqlen_k, std::shared_ptr cos_cache, std::shared_ptr sin_cache, @@ -316,45 +316,15 @@ std::shared_ptr rotaryEmbedding(std::shared_ptr input, auto result = std::make_shared(add_input0, multi_input1); return result; } else { - auto negtive_two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-2}); - auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, two, half_last_dim}, 0); - auto reshaped_input = std::make_shared(input, split_input_shape, false); - auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_two); - auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_two); - - auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -2); - - auto mask_shape = std::make_shared(ov::NodeVector{one, half_last_dim}, 0); - auto reshaped_mask = std::make_shared(masks, mask_shape, false); - auto negtive_mask = std::make_shared(reshaped_mask); - auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -2); - auto real_mask = std::make_shared(concat_mask, dim_head_size, false); - auto mask_f32 = std::make_shared(real_mask, ov::element::f32); - - auto perm = v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 1, 2, 4, 3}); - auto input0 = reshaped_input; // std::make_shared(reshaped_input, perm); - auto input1 = second_input; // std::make_shared(second_input, perm); - auto real_input0 = std::make_shared(input0, input_shape, false); - auto real_input1 = std::make_shared(input1, input_shape, false); - - auto new_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, two, cache_last_dim}, 0); - auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, one, cache_last_dim}, 0); - auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); - auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); - auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); - auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); - auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); - auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); - // TODO: change zero to sequence_K - auto sliced_cos_input = - std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto sliced_sin_input = - std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto add_input0 = std::make_shared(real_input0, sliced_cos_input); - auto add_input1 = std::make_shared(real_input1, sliced_sin_input); - auto multi_input1 = std::make_shared(add_input1, mask_f32); - auto result = std::make_shared(add_input0, multi_input1); - return result; + auto cos = + std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = + std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto in_split = ov::op::util::make_split(input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), std::make_shared(in_split[1], cos)); + + return std::make_shared(ov::NodeVector{res_0, res_1}, -1); } } } // namespace From cc7aa0dbb424229b2cf4dcf68e61f2862f540ef3 Mon Sep 17 00:00:00 2001 From: LiangGao Date: Wed, 8 Jan 2025 15:49:03 +0800 Subject: [PATCH 03/15] Use scaled_dot_product_attention to improve the perfomance --- .../com.microsoft/group_query_attention.cpp | 143 ++++++------------ 1 file changed, 50 insertions(+), 93 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index b4e85812afffcf..cc9dde3d256e0d 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -38,6 +38,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" #include "openvino/op/convert_like.hpp" #include "utils/split.hpp" @@ -159,12 +160,6 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { auto present_k = K; auto present_v = V; - std::shared_ptr alpha; - if (scale == 0.0f) { - alpha = std::make_shared(head_size_node); - } else { - alpha = v0::Constant::create(ov::element::f32, ov::Shape{}, {1.0f / scale}); - } const size_t kv_num_heads_factor = num_heads / kv_num_heads; if (kv_num_heads_factor > 1) { const auto kv_shape = std::make_shared(K); @@ -183,47 +178,43 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { K = std::make_shared(K, extended_kv_shape, false); V = std::make_shared(V, extended_kv_shape, false); } - // compute softmax((Q x K') / sqrt(head_size)) - std::shared_ptr softmax_input = std::make_shared(Q, K, false, true); - softmax_input = std::make_shared(softmax_input, alpha); // need to apply low-triangle mask to attention score. - auto past_seq_len_scalar = std::make_shared(past_sequence_length, one_without_shape, false); - auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); + // two steps, construct the total_sequence x total_sequence triangle, then slice the current length + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 std::shared_ptr mask_per_line_node = std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), seqlens_1d_scalar, one_without_shape, - ov::element::i64); - mask_per_line_node = std::make_shared(mask_per_line_node, zero); - auto minus_inf = v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits::infinity()}); - auto mask_shape = std::make_shared(ov::NodeVector{current_seqlen_size, seqlens_1d}, 0); - auto compare_mask = std::make_shared(mask_per_line_node, mask_shape); - - std::shared_ptr vertical_range = - std::make_shared(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64); - vertical_range = std::make_shared(vertical_range, one); - - auto triu = std::make_shared(compare_mask, vertical_range); - auto typed_zero = std::make_shared(zero, softmax_input); - auto typed_minus_inf = std::make_shared(minus_inf, softmax_input); - auto minus_inf_mask = std::make_shared(typed_minus_inf, mask_shape); - auto atten_mask = std::make_shared(triu, minus_inf_mask, typed_zero); - - std::shared_ptr softmax_input_added = std::make_shared(softmax_input, atten_mask); - // softmax((Q x K' + mask) / sqrt(head_size)) - const auto softmax = std::make_shared(softmax_input_added, 3); - - // softmax((Q x K' + mask) / sqrt(head_size)) x V - std::shared_ptr output = std::make_shared(softmax, V); + ov::element::i64); // [0,1,2,...,] + auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 + auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 + auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 + auto typed_zero = v0::Constant::create(ov::element::f32, ov::Shape{}, {0}); + auto minus_inf = v0::Constant::create(ov::element::f32, ov::Shape{}, {-std::numeric_limits::infinity()}); + auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 + auto atten_mask_sliced = std::make_shared(atten_mask, + past_sequence_length, + seqlens_1d, + one, + zero); // slice to current query seqlen, 12x12 or 1x13 + + // compute softmax((Q x K') / sqrt(head_size)) x V + std::shared_ptr qga_output; + if (scale != 0.0f) { + auto scale_node = v0::Constant::create(ov::element::f32, Shape{}, {scale}); + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); + } else { + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); + } // transpose the result from (batch_size, num_heads, sequence_length, head_size) // to (batch_size, sequence_length, num_heads, head_size) - output = std::make_shared(output, perm); + auto qga_output_transposed = std::make_shared(qga_output, perm); auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); // reshape the result from (batch_size, sequence_length, num_heads, head_size) // to (batch_size, sequence_length, num_heads * head_size) - output = std::make_shared(output, dim_merge_shape, true); + auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true); return {output, present_k, present_v}; } @@ -254,75 +245,41 @@ std::shared_ptr rotaryEmbedding(ov::Output input, bool interleaved) { auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto cache_shape = std::make_shared(cos_cache); - auto cache_last_dim = get_dimensions(cache_shape, {-1}); - auto cache_1st_dim = get_dimensions(cache_shape, {0}); + auto slice_cache_dim_shape = seqlen_k; - // TODO: check the shape - auto input_shape = std::make_shared(input); + auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); - // auto dim_head_size = get_dimensions(input_shape, {3}); - // half_last_dim is same as cos_cache - std::shared_ptr half_last_dim = cache_last_dim; + if (interleaved) { + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto real_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, dim_head_size}, 0); - auto slice_cache_dim_shape = seqlen_k; + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cos_cache, {-1}); - // auto end_lens = std::make_shared(half_last_dim, one); - // auto masks = std::make_shared(one, - // zero, - // end_lens, - // op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}), - // ov::op::PadMode::CONSTANT); - auto masks = std::make_shared(one, half_last_dim); + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + std::shared_ptr half_last_dim = cache_last_dim; - if (interleaved) { auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); auto reshaped_input = std::make_shared(input, split_input_shape, false); - auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_one); - auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_one); - - auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -1); - - auto mask_shape = std::make_shared(ov::NodeVector{half_last_dim, one}, 0); - auto reshaped_mask = std::make_shared(masks, mask_shape, false); - auto negtive_mask = std::make_shared(reshaped_mask); - auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -1); - auto real_mask = std::make_shared(concat_mask, dim_head_size, false); - auto mask_f32 = std::make_shared(real_mask, ov::element::f32); - - auto real_input0 = std::make_shared(reshaped_input, input_shape, false); - auto real_input1 = std::make_shared(second_input, input_shape, false); - - auto new_cache_shape = std::make_shared(ov::NodeVector{cache_shape, two}, 0); - auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_shape, one}, 0); - auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); - auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); - auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); - auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); - auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); - auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); - auto sliced_cos_input = - std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto sliced_sin_input = - std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto add_input0 = std::make_shared(real_input0, sliced_cos_input); - auto add_input1 = std::make_shared(real_input1, sliced_sin_input); - auto multi_input1 = std::make_shared(add_input1, mask_f32); - auto result = std::make_shared(add_input0, multi_input1); - return result; + + auto in_split = ov::op::util::make_split(reshaped_input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); + + auto concat_ret = std::make_shared(ov::NodeVector{res_0, res_1}, -1); + return std::make_shared(concat_ret, input_shape, false); } else { - auto cos = - std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto sin = - std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); auto in_split = ov::op::util::make_split(input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), std::make_shared(in_split[1], cos)); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); return std::make_shared(ov::NodeVector{res_0, res_1}, -1); } From e2897af1f27c35f1baa2413936a823781e411785 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 16 Jan 2025 13:15:18 +0800 Subject: [PATCH 04/15] Add OV GQA op and decomposition pass Fix interleave logic in decomposition Add ONNX frontend tests --- .../group_query_attention_decomposition.hpp | 24 + .../common_optimizations.cpp | 2 + .../group_query_attention_decomposition.cpp | 322 ++++++++++++++ .../openvino/op/group_query_attention.hpp | 54 +++ src/core/include/openvino/op/null.hpp | 47 ++ src/core/include/openvino/op/ops.hpp | 2 + .../include/openvino/opsets/opset15_tbl.hpp | 2 + src/core/src/op/group_query_attention.cpp | 92 ++++ src/core/tests/opset.cpp | 2 +- .../com.microsoft/group_query_attention.cpp | 277 +----------- .../gqa_past_0_input_1_rotary.prototxt | 247 +++++++++++ ...past_0_input_1_rotary_interleaved.prototxt | 247 +++++++++++ .../gqa_past_1_input_1_rotary.prototxt | 244 +++++++++++ ...past_1_input_1_rotary_interleaved.prototxt | 244 +++++++++++ .../tests/onnx_import_com_microsoft.in.cpp | 412 ++++++++++++++++++ 15 files changed, 1957 insertions(+), 261 deletions(-) create mode 100644 src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp create mode 100644 src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp create mode 100644 src/core/include/openvino/op/group_query_attention.hpp create mode 100644 src/core/include/openvino/op/null.hpp create mode 100644 src/core/src/op/group_query_attention.cpp create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp new file mode 100644 index 00000000000000..51c21f4808c61e --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/group_query_attention.hpp" +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API GroupQueryAttentionDecomposition; + +} // namespace pass +} // namespace ov + +class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("GroupQueryAttentionDecomposition", "0"); + GroupQueryAttentionDecomposition(); + ov::OutputVector decompose(std::shared_ptr node); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 87813bae65538d..66651b3907f344 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -108,6 +108,7 @@ #include "transformations/op_conversions/eye_decomposition.hpp" #include "transformations/op_conversions/gelu7_downgrade.hpp" #include "transformations/op_conversions/group_normalization_decomposition.hpp" +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" #include "transformations/op_conversions/hsigmoid_decomposition.hpp" #include "transformations/op_conversions/hswish_decomposition.hpp" #include "transformations/op_conversions/log_softmax_decomposition.hpp" @@ -156,6 +157,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr(); + ADD_MATCHER(decomp, GroupQueryAttentionDecomposition) ADD_MATCHER(decomp, ScaledDotProductAttentionDecomposition) ADD_MATCHER(decomp, Gelu7Downgrade) ADD_MATCHER(decomp, BidirectionalSequenceDecomposition) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp new file mode 100644 index 00000000000000..1441e3e67a633d --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -0,0 +1,322 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" + +#include + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/null.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/op/variadic_split.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved); +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims); +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); +ov::OutputVector make_split(const ov::Output& value, const std::vector& split_lengths, int64_t axis); +std::shared_ptr create_minus_inf(const ov::element::Type& T); + +ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { + MATCHER_SCOPE(GroupQeuryAttentionDecomposition); + auto pattern_node = ov::pass::pattern::wrap_type(); + + matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto node = + ov::as_type_ptr(pattern_to_output.at(pattern_node).get_node_shared_ptr()); + + if (node == nullptr || transformation_callback(node)) { + return false; + } + + auto new_output_node = decompose(node); + ov::replace_node(node, new_output_node); + return true; + }; + + auto m = std::make_shared(pattern_node, matcher_name); + register_matcher(m, callback); +} + +ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( + std::shared_ptr node) { + using namespace ov::op; + + const auto num_heads = node->get_num_heads(); + const auto kv_num_heads = node->get_kv_num_heads(); + const auto scale = node->get_scale(); + const auto do_rotary = node->get_do_rotary(); + const auto rotary_interleaved = node->get_rotary_interleaved(); + // TODO: add softcap support + + auto Q = node->input_value(0); + auto K = node->input_value(1); + auto V = node->input_value(2); + auto past_key = node->input_value(3); + auto past_value = node->input_value(4); + auto seqlens_k = node->input_value(5); + auto total_sequence_length = node->input_value(6); // unused, it's not always equal (seqlens_k + 1) + auto cos_cache = node->input_value(7); + auto sin_cache = node->input_value(8); + + // The length of all tokens (past + current) is `seqlens_k` + 1, current = Q.shape[2], past = `seqlens_k` + 1 - current + + const auto T = Q.get_element_type(); + const auto node_shape = std::make_shared(Q); + const auto batch_size = get_dimensions(node_shape, {0}); + const auto current_seqlen_size = get_dimensions(node_shape, {1}); + const auto hidden_size = get_dimensions(node_shape, {2}); + const auto total_num_heads_node = + v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); + auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); // should be equal to the last dim of past_key + + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + + if (v15::Null::is_null(K)) { + // Handle the packed QKV + auto packed_qkv_shape = std::make_shared( + ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, + 0); + auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); + // (batch_size, sequence_len, num_head, head_size) + inputs_qkv = std::make_shared(inputs_qkv, perm); + // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) + auto split = make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); + Q = split[0]; + K = split[1]; + V = split[2]; + } else { + Q = std::make_shared(Q, perm); + K = std::make_shared(K, perm); + V = std::make_shared(V, perm); + } + + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + auto seqlens_elemi64 = std::make_shared(seqlens_k, ov::element::i64); + auto real_seqlens = std::make_shared(seqlens_elemi64, one); + + // Only consider batch is 1 + auto seqlens_1d = std::make_shared(real_seqlens, one, false); + auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); + + if (do_rotary) { + Q = rotaryEmbedding(Q, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + K = rotaryEmbedding(K, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + } + // present = concat(K, V) if 'past' input is unavailable + // or + // present = concat(past, K, V) + auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { + auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); + auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); + return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); + }; + + K = construct_kv_cache(past_key, K); + V = construct_kv_cache(past_value, V); + auto present_k = K; + auto present_v = V; + + const size_t kv_num_heads_factor = num_heads / kv_num_heads; + if (kv_num_heads_factor > 1) { + const auto kv_shape = std::make_shared(K); + // (batch_size, num_heads, sequence_len, head_size) + const auto kv_shape_prev_2 = get_dimensions(kv_shape, {0, 1}); + const auto kv_shape_last_2 = get_dimensions(kv_shape, {2, 3}); + auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); + K = std::make_shared(K, new_kv_shape, false); + V = std::make_shared(V, new_kv_shape, false); + K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); + V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); + auto q_shape = std::make_shared(Q); + // (batch_size, num_heads, sequence_len, head_size) + const auto q_shape_prev_2 = get_dimensions(q_shape, {0, 1}); + auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); + K = std::make_shared(K, extended_kv_shape, false); + V = std::make_shared(V, extended_kv_shape, false); + } + + // need to apply low-triangle mask to attention score. + // two steps, construct the total_sequence x total_sequence triangle, then slice the current length + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 + std::shared_ptr mask_per_line_node = + std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), + seqlens_1d_scalar, + one_without_shape, + ov::element::i64); // [0,1,2,...,] + auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 + auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 + auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 + auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0}); + auto minus_inf = create_minus_inf(T); + auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 + auto atten_mask_sliced = std::make_shared(atten_mask, + past_sequence_length, + seqlens_1d, + one, + zero); // slice to current query seqlen, 12x12 or 1x13 + + // compute softmax((Q x K') / sqrt(head_size)) x V + std::shared_ptr qga_output; + if (scale != 0.0f) { + auto scale_node = v0::Constant::create(T, Shape{}, {scale}); + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); + } else { + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); + } + + // transpose the result from (batch_size, num_heads, sequence_length, head_size) + // to (batch_size, sequence_length, num_heads, head_size) + auto qga_output_transposed = std::make_shared(qga_output, perm); + auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); + // reshape the result from (batch_size, sequence_length, num_heads, head_size) + // to (batch_size, sequence_length, num_heads * head_size) + auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true)->output(0); + + return {output, present_k, present_v}; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims) { + using namespace ov::op; + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} + +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved) { + using namespace ov::op; + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + + auto slice_cache_dim_shape = seqlen_k; + + auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); + + if (interleaved) { + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cos_cache, {-1}); + + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + std::shared_ptr half_last_dim = cache_last_dim; + + auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); + auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); + auto reshaped_input = std::make_shared(input, split_input_shape, false); + + auto in_split = make_split(reshaped_input, 2, -1); + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim}, 0); + auto in_split_0 = std::make_shared(in_split[0], split_input_shape, false); + auto in_split_1 = std::make_shared(in_split[1], split_input_shape, false); + + auto res_0 = std::make_shared(std::make_shared(in_split_0, cos), + std::make_shared(in_split_1, sin)); + auto res_1 = std::make_shared(std::make_shared(in_split_0, sin), + std::make_shared(in_split_1, cos)); + + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, one}, 0); + auto res_0_5d = std::make_shared(res_0, split_input_shape, false); + auto res_1_5d = std::make_shared(res_1, split_input_shape, false); + + auto concat_ret = std::make_shared(ov::NodeVector{res_0_5d, res_1_5d}, -1); + return std::make_shared(concat_ret, input_shape, false); + } else { + auto in_split = make_split(input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); + + return std::make_shared(ov::NodeVector{res_0, res_1}, -1); + } +} + +// make split functions is a copy-past from ONNX FE. TODO: move it to one place +ov::OutputVector make_split(const ov::Output& value, + const std::vector& split_lengths, + int64_t axis) { + using namespace ov::op; + const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); + const auto split_lengths_node = + v0::Constant::create(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths); + const auto variadic_split = std::make_shared(value, axis_node, split_lengths_node); + + return variadic_split->outputs(); +} + +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis) { + using namespace ov::op; + const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); + const auto split = std::make_shared(value, axis_node, num_splits); + + return split->outputs(); +} + +std::shared_ptr create_minus_inf(const ov::element::Type& T) { + // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp + if (T == ov::element::f32) { + return ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); + } else if (T == ov::element::f16) { + return ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); + } else { + OPENVINO_THROW("GroupQueryAttention only supports f32 and f16"); + } +} diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp new file mode 100644 index 00000000000000..25337e0faf9be5 --- /dev/null +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v15 { + +// This is an experimental operation that is implemented in the plugins. +class OPENVINO_API GroupQueryAttention : public Op { +public: + OPENVINO_OP("GroupQueryAttention", "opset15", op::Op); + + GroupQueryAttention() = default; + GroupQueryAttention(const ov::OutputVector& args, + unsigned int num_heads, + unsigned int kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved); + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + unsigned int get_num_heads() const { + return m_num_heads; + } + unsigned int get_kv_num_heads() const { + return m_kv_num_heads; + } + float get_scale() const { + return m_scale; + } + bool get_do_rotary() const { + return m_do_rotary; + } + bool get_rotary_interleaved() const { + return m_rotary_interleaved; + } + +private: + unsigned int m_num_heads; + unsigned int m_kv_num_heads; + float m_scale = 0; + bool m_do_rotary = false; + bool m_rotary_interleaved = false; +}; + +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/null.hpp b/src/core/include/openvino/op/null.hpp new file mode 100644 index 00000000000000..346ff3d77c532b --- /dev/null +++ b/src/core/include/openvino/op/null.hpp @@ -0,0 +1,47 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v15 { + +/// \brief Represents a missing optional input or output of an ONNX node +/// +/// Some ONNX operators have inputs or outputs that are marked as optional, +/// which means that a referring node MAY forgo providing values for such inputs +/// or computing these outputs. +/// An empty string is used in place of a name of such input or output. +/// +/// More: +/// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs +class OPENVINO_API Null : public Op { +public: + OPENVINO_OP("Null", "opset15", op::Op); + Null() { + set_output_size(1); + } + + static bool is_null(const ov::Node* node) { + return ov::as_type(node) != nullptr; + } + + static bool is_null(const std::shared_ptr& node) { + return is_null(node.get()); + } + + static bool is_null(const Output& output) { + return is_null(output.get_node()); + } + + virtual std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { + return std::make_shared(); + } +}; +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index adeb9c25611960..098f2872e4a586 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -168,6 +168,8 @@ #include "openvino/op/roll.hpp" #include "openvino/op/round.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/null.hpp" +#include "openvino/op/group_query_attention.hpp" #include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/scatter_nd_update.hpp" #include "openvino/op/scatter_update.hpp" diff --git a/src/core/include/openvino/opsets/opset15_tbl.hpp b/src/core/include/openvino/opsets/opset15_tbl.hpp index a9e8d2a8dcc840..8dd7c0d7b830cc 100644 --- a/src/core/include/openvino/opsets/opset15_tbl.hpp +++ b/src/core/include/openvino/opsets/opset15_tbl.hpp @@ -234,3 +234,5 @@ _OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15) _OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15) _OPENVINO_OP_REG(SliceScatter, ov::op::v15) _OPENVINO_OP_REG(SearchSorted, ov::op::v15) +_OPENVINO_OP_REG(GroupQueryAttention, ov::op::v15) +_OPENVINO_OP_REG(Null, ov::op::v15) diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp new file mode 100644 index 00000000000000..475110e66bf5e3 --- /dev/null +++ b/src/core/src/op/group_query_attention.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/group_query_attention.hpp" + +#include "itt.hpp" +#include "openvino/op/null.hpp" + +using namespace std; +namespace ov { +namespace op { +namespace v15 { + +GroupQueryAttention::GroupQueryAttention(const OutputVector& args, + unsigned int num_heads, + unsigned int kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved) + : Op(args), + m_num_heads(num_heads), + m_kv_num_heads(kv_num_heads), + m_scale(scale), + m_do_rotary(do_rotary), + m_rotary_interleaved(rotary_interleaved) { + constructor_validate_and_infer_types(); +} + +int64_t get_head_size(const PartialShape& input_shape, int num_heads, int kv_num_heads) { + return input_shape[2].get_length() / (num_heads + kv_num_heads * 2); +} + +std::vector get_qkv_sizes(const PartialShape& input_shape, int num_heads, int kv_num_heads) { + int64_t per_head_size = get_head_size(input_shape, num_heads, kv_num_heads); + const std::vector qkv_sizes = {num_heads * per_head_size, + kv_num_heads * per_head_size, + kv_num_heads * per_head_size}; + return qkv_sizes; +} + +void GroupQueryAttention::validate_and_infer_types() { + OV_OP_SCOPE(v15_GroupQueryAttention_validate_and_infer_types); + PartialShape input_shape = get_input_partial_shape(0); + Dimension batch_size = input_shape[0]; + Dimension sequence_len = input_shape[1]; + Dimension head_size; + if (Null::is_null(input_value(1)) && Null::is_null(input_value(2))) { + head_size = get_head_size(input_shape, m_num_heads, m_kv_num_heads); + } else { + head_size = input_shape[2].get_length() / m_num_heads; + } + Dimension output_kv_len; + PartialShape kv_past_shape = get_input_partial_shape(3); + // FIXME: https://github.com/openvinotoolkit/openvino/pull/27648 + if (kv_past_shape[2].is_static()) { + output_kv_len = kv_past_shape[2] + sequence_len; + } else { + output_kv_len = ov::Dimension(); + } + auto element_type = get_input_element_type(0); + NODE_VALIDATION_CHECK(this, + element_type == element::f32 || element_type == element::f16, + "GroupQueryAttention only suuports f32 and f16"); + set_output_type(0, element_type, PartialShape{batch_size, sequence_len, head_size * m_num_heads}); + set_output_type(1, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); + set_output_type(2, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); +} + +bool GroupQueryAttention::visit_attributes(AttributeVisitor& visitor) { + OV_OP_SCOPE(v15_GroupQueryAttention_visit_attributes); + visitor.on_attribute("do_rotary", m_do_rotary); + visitor.on_attribute("kv_num_heads", m_kv_num_heads); + visitor.on_attribute("num_heads", m_num_heads); + visitor.on_attribute("rotary_interleaved", m_rotary_interleaved); + visitor.on_attribute("scale", m_scale); + return true; +} + +std::shared_ptr GroupQueryAttention::clone_with_new_inputs(const ov::OutputVector& new_args) const { + OV_OP_SCOPE(v15_GroupQueryAttention_clone_with_new_inputs); + return std::make_shared(new_args, + m_num_heads, + m_kv_num_heads, + m_scale, + m_do_rotary, + m_rotary_interleaved); +} + +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index 3006e04a02f960..49f24a5dfd41de 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -76,7 +76,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset12, 178}, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, - OpsetTestParams{ov::get_opset15, 199}, + OpsetTestParams{ov::get_opset15, 201}, OpsetTestParams{ov::get_opset16, 6}), OpsetTestNameGenerator{}); diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index cc9dde3d256e0d..692d2529caf7e5 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -2,45 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/group_query_attention.hpp" +#include "openvino/op/null.hpp" + #include "core/null_node.hpp" #include "core/operator_set.hpp" #include "openvino/frontend/exception.hpp" -#include "openvino/op/add.hpp" -#include "openvino/op/broadcast.hpp" -#include "openvino/op/concat.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/convert.hpp" -#include "openvino/op/divide.hpp" -#include "openvino/op/equal.hpp" -#include "openvino/op/floor.hpp" -#include "openvino/op/floor_mod.hpp" -#include "openvino/op/gather.hpp" -#include "openvino/op/greater.hpp" -#include "openvino/op/greater_eq.hpp" -#include "openvino/op/less.hpp" -#include "openvino/op/less_eq.hpp" -#include "openvino/op/log.hpp" -#include "openvino/op/logical_not.hpp" -#include "openvino/op/logical_or.hpp" -#include "openvino/op/matmul.hpp" -#include "openvino/op/maximum.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/negative.hpp" -#include "openvino/op/pad.hpp" -#include "openvino/op/range.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/select.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/slice.hpp" -#include "openvino/op/softmax.hpp" -#include "openvino/op/sqrt.hpp" -#include "openvino/op/squeeze.hpp" -#include "openvino/op/subtract.hpp" -#include "openvino/op/transpose.hpp" -#include "openvino/op/unsqueeze.hpp" -#include "openvino/op/scaled_dot_product_attention.hpp" -#include "openvino/op/convert_like.hpp" -#include "utils/split.hpp" using namespace ov::op; using ov::Shape; @@ -49,243 +16,33 @@ namespace ov { namespace frontend { namespace onnx { namespace com_microsoft { -namespace detail { -namespace { - -std::shared_ptr get_present_state(const std::shared_ptr& K, - const std::shared_ptr& V, - const ov::OutputVector& op_inputs); -std::shared_ptr rotaryEmbedding(ov::Output input, - ov::Output past_seqlen, - std::shared_ptr seqlen_k, - std::shared_ptr cos_cache, - std::shared_ptr sin_cache, - std::shared_ptr dim_head_size, - bool interleaved); -std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); -} // namespace -} // namespace detail namespace opset_1 { ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { - const auto do_rotary = node.get_attribute_value("do_rotary"); + const auto onnx_op_inputs = node.get_ov_inputs(); const auto num_heads = node.get_attribute_value("num_heads"); const auto kv_num_heads = node.get_attribute_value("kv_num_heads"); const auto scale = node.get_attribute_value("scale", 0.0f); - const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved"); - // TODO: add softcap support - - auto nodes = node.get_ov_inputs(); - const auto node_shape = std::make_shared(nodes[0]); - const auto batch_size = detail::get_dimensions(node_shape, {0}); - const auto current_seqlen_size = detail::get_dimensions(node_shape, {1}); - const auto hidden_size = detail::get_dimensions(node_shape, {2}); - const auto total_num_heads_node = - v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); - auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); - - // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) - auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); - - // Q K V (batch_size, sequence_len, num_heads, head_size) - ov::Output Q, K, V; - int index = 0; - Q = nodes[index++]; - K = nodes[index++]; - V = nodes[index++]; - if (ov::op::util::is_null(K)) { - // Handle the packed QKV - auto packed_qkv_shape = std::make_shared( - ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, - 0); - auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); - // (batch_size, sequence_len, num_head, head_size) - inputs_qkv = std::make_shared(inputs_qkv, perm); - // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) - auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); - Q = split[0]; - K = split[1]; - V = split[2]; - } else { - Q = std::make_shared(Q, perm); - K = std::make_shared(K, perm); - V = std::make_shared(V, perm); - } - - const auto& past_key = nodes[index++]; - const auto& past_value = nodes[index++]; - const auto& seqlens_k = nodes[index++].get_node_shared_ptr(); - const auto& total_sequence_length = nodes[index++]; // unused, it's not always equal (seqlens_k + 1) - const auto& cos_cache = nodes[index++].get_node_shared_ptr(); - const auto& sin_cache = nodes[index++].get_node_shared_ptr(); - - auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); - auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); - auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto seqlens_elemi64 = std::make_shared(seqlens_k, ov::element::i64); - auto real_seqlens = std::make_shared(seqlens_elemi64, one); - - // Only consider batch is 1 - auto seqlens_1d = std::make_shared(real_seqlens, one, false); - auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); + const auto do_rotary = node.get_attribute_value("do_rotary", 0); + const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved", 0.0f); - if (do_rotary) { - Q = detail::rotaryEmbedding(Q, - past_sequence_length, - seqlens_1d, - cos_cache, - sin_cache, - head_size_node, - rotary_interleaved); - K = detail::rotaryEmbedding(K, - past_sequence_length, - seqlens_1d, - cos_cache, - sin_cache, - head_size_node, - rotary_interleaved); + OutputVector ov_op_inputs; + ov_op_inputs.reserve(onnx_op_inputs.size()); + for (const auto& input : onnx_op_inputs) { + ov_op_inputs.push_back(ov::op::util::is_null(input) ? std::make_shared() : input); } - // present = concat(K, V) if 'past' input is unavailable - // or - // present = concat(past, K, V) - auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { - auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); - auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); - return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); - }; - - K = construct_kv_cache(past_key, K); - V = construct_kv_cache(past_value, V); - auto present_k = K; - auto present_v = V; - - const size_t kv_num_heads_factor = num_heads / kv_num_heads; - if (kv_num_heads_factor > 1) { - const auto kv_shape = std::make_shared(K); - // (batch_size, num_heads, sequence_len, head_size) - const auto kv_shape_prev_2 = detail::get_dimensions(kv_shape, {0, 1}); - const auto kv_shape_last_2 = detail::get_dimensions(kv_shape, {2, 3}); - auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); - K = std::make_shared(K, new_kv_shape, false); - V = std::make_shared(V, new_kv_shape, false); - K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); - V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); - auto q_shape = std::make_shared(Q); - // (batch_size, num_heads, sequence_len, head_size) - const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); - auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); - K = std::make_shared(K, extended_kv_shape, false); - V = std::make_shared(V, extended_kv_shape, false); - } - - // need to apply low-triangle mask to attention score. - // two steps, construct the total_sequence x total_sequence triangle, then slice the current length - auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 - std::shared_ptr mask_per_line_node = - std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), - seqlens_1d_scalar, - one_without_shape, - ov::element::i64); // [0,1,2,...,] - auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 - auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 - auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 - auto typed_zero = v0::Constant::create(ov::element::f32, ov::Shape{}, {0}); - auto minus_inf = v0::Constant::create(ov::element::f32, ov::Shape{}, {-std::numeric_limits::infinity()}); - auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 - auto atten_mask_sliced = std::make_shared(atten_mask, - past_sequence_length, - seqlens_1d, - one, - zero); // slice to current query seqlen, 12x12 or 1x13 - - // compute softmax((Q x K') / sqrt(head_size)) x V - std::shared_ptr qga_output; - if (scale != 0.0f) { - auto scale_node = v0::Constant::create(ov::element::f32, Shape{}, {scale}); - qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); - } else { - qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); - } - - // transpose the result from (batch_size, num_heads, sequence_length, head_size) - // to (batch_size, sequence_length, num_heads, head_size) - auto qga_output_transposed = std::make_shared(qga_output, perm); - auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); - // reshape the result from (batch_size, sequence_length, num_heads, head_size) - // to (batch_size, sequence_length, num_heads * head_size) - auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true); - - return {output, present_k, present_v}; + return std::make_shared(ov_op_inputs, + num_heads, + kv_num_heads, + scale, + do_rotary, + rotary_interleaved) + ->outputs(); } ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); } // namespace opset_1 - -namespace detail { -namespace { - -std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { - static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); - const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); - return std::make_shared(shape, dims_const, zero); -} - -std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { - return get_dimensions(std::make_shared(node), dims); -} - -std::shared_ptr rotaryEmbedding(ov::Output input, - ov::Output past_seqlen, - std::shared_ptr seqlen_k, - std::shared_ptr cos_cache, - std::shared_ptr sin_cache, - std::shared_ptr dim_head_size, - bool interleaved) { - auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); - auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - - auto slice_cache_dim_shape = seqlen_k; - - auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); - - if (interleaved) { - auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - - auto cache_shape = std::make_shared(cos_cache); - auto cache_last_dim = get_dimensions(cos_cache, {-1}); - - auto input_shape = std::make_shared(input); - - auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); - std::shared_ptr half_last_dim = cache_last_dim; - - auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); - auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); - auto reshaped_input = std::make_shared(input, split_input_shape, false); - - auto in_split = ov::op::util::make_split(reshaped_input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), - std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), - std::make_shared(in_split[1], cos)); - - auto concat_ret = std::make_shared(ov::NodeVector{res_0, res_1}, -1); - return std::make_shared(concat_ret, input_shape, false); - } else { - auto in_split = ov::op::util::make_split(input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), - std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), - std::make_shared(in_split[1], cos)); - - return std::make_shared(ov::NodeVector{res_0, res_1}, -1); - } -} -} // namespace -} // namespace detail } // namespace com_microsoft } // namespace onnx } // namespace frontend diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt new file mode 100644 index 00000000000000..a7dacf0dc94ebc --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..d1400ad344e717 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt new file mode 100644 index 00000000000000..f5ec39c9b0bd8e --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..b61cf39552efc7 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 170476aae05dd3..a7a91cbbf4c72d 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1740,3 +1740,415 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_bias_add) { test_case.run(); } + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {1.2561098, + 1.0199738, + -0.05948371, + -0.16574995, + 2.5059946, + -1.738188, + -0.03158256, + -0.35975295, + 1.0918287, + -0.90313876, + -0.4790303, + 0.67029977, + -0.87039495, + 0.7783688, + -0.81333745, + 0.89886224}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {2.118801, + -0.2640816, + -0.5926066, + -0.19455537, + 0.9903903, + 2.954185, + -0.35343042, + -0.07457897, + -0.25603274, + -0.03627284, + 0.56591415, + 0.02181074, + -0.1586003, + 0.96567893, + -0.8591481, + 0.85514885}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.53934956, 0.6341806, 1.0099611, 1.1771176, -1.270278, 2.3782496, -0.511299, 0.30222273, + -1.5435482, -0.5423737, -0.27520883, -0.4914196, -0.96554786, 1.1191509, -0.05004983, 0.85533774, + -0.49356747, 0.19581467, 0.8553029, 1.0041412, -0.9513843, 2.088453, -0.5698854, 0.25103146, + -1.4120293, -0.5311372, 0.03857604, -0.47871974, -0.8099488, 1.1256707, 0.08898184, 0.93131447}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.6519198, 1.1400802, 0.45031136, 0.5877534, -0.65952265, -1.8121169, 0.04630837, 0.5568472, + 0.20271924, 0.7458131, -0.17379119, 0.3623912, 2.5696063, -0.58594, -0.8126341, -0.7919839}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.33396345, -1.332403, 0.31613833, 0.40111685, 0.16033238, 1.0781744, -0.7741276, 0.07257013, + -0.9535321, -0.491965, 1.1324831, -0.43444604, -0.2675047, 1.1483997, 0.57366973, 1.1961825, + -0.24709277, -2.164195, 0.02267693, 0.07289726, 0.7654276, 0.5282906, -0.8852943, -0.02456442, + -0.7039771, -0.47064403, 1.7278847, -0.41034833, 0.02774171, 1.1607709, 0.83748007, 1.3403473}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.2216992, 1.7511603, 0.03145146, -0.62293506, -2.625969, 1.6767058, -0.17887366, 0.313817, + 0.1717277, -0.19334024, 0.4056727, 0.39516917, -0.25018305, 0.9460988, 1.0327814, -0.6345757}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} From 09563379f1aabe3f5f57abf80958c7d6fbc2bd52 Mon Sep 17 00:00:00 2001 From: Zijun Yu Date: Fri, 7 Feb 2025 09:51:51 +0800 Subject: [PATCH 05/15] Apply suggestions from code review Co-authored-by: Tomasz Jankowski --- .../op_conversions/group_query_attention_decomposition.hpp | 2 +- src/core/include/openvino/op/group_query_attention.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp index 51c21f4808c61e..3cad7ab229b110 100644 --- a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -18,7 +18,7 @@ class TRANSFORMATIONS_API GroupQueryAttentionDecomposition; class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("GroupQueryAttentionDecomposition", "0"); + OPENVINO_MATCHER_PASS_RTTI("GroupQueryAttentionDecomposition"); GroupQueryAttentionDecomposition(); ov::OutputVector decompose(std::shared_ptr node); }; diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp index 25337e0faf9be5..744dac6aebd82a 100644 --- a/src/core/include/openvino/op/group_query_attention.hpp +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -12,7 +12,7 @@ namespace v15 { // This is an experimental operation that is implemented in the plugins. class OPENVINO_API GroupQueryAttention : public Op { public: - OPENVINO_OP("GroupQueryAttention", "opset15", op::Op); + OPENVINO_OP("GroupQueryAttention", "opset15"); GroupQueryAttention() = default; GroupQueryAttention(const ov::OutputVector& args, From 3e24958a1a8366df49cb93074cbfc7ec74072b71 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 7 Feb 2025 11:06:42 +0800 Subject: [PATCH 06/15] Remove redundant type check --- .../group_query_attention_decomposition.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 1441e3e67a633d..5f1476e4bceae8 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -41,7 +41,6 @@ std::shared_ptr get_dimensions(const std::shared_ptr& dims); ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); ov::OutputVector make_split(const ov::Output& value, const std::vector& split_lengths, int64_t axis); -std::shared_ptr create_minus_inf(const ov::element::Type& T); ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { MATCHER_SCOPE(GroupQeuryAttentionDecomposition); @@ -191,7 +190,12 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0}); - auto minus_inf = create_minus_inf(T); + // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp + std::shared_ptr minus_inf = nullptr; + if (T == ov::element::f32) + minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); + else if (T == ov::element::f16) + minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 auto atten_mask_sliced = std::make_shared(atten_mask, past_sequence_length, @@ -309,14 +313,3 @@ ov::OutputVector make_split(const ov::Output& value, int64_t num_split return split->outputs(); } - -std::shared_ptr create_minus_inf(const ov::element::Type& T) { - // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp - if (T == ov::element::f32) { - return ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); - } else if (T == ov::element::f16) { - return ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); - } else { - OPENVINO_THROW("GroupQueryAttention only supports f32 and f16"); - } -} From da868f3ba6e4452067bd7c1b0e5a7788ad877eb5 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 7 Feb 2025 17:14:35 +0800 Subject: [PATCH 07/15] set input total_sequence_length to Null --- .../group_query_attention_decomposition.cpp | 1 - .../include/openvino/op/group_query_attention.hpp | 12 ++++++------ src/core/src/op/group_query_attention.cpp | 4 ++-- .../src/op/com.microsoft/group_query_attention.cpp | 2 ++ 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 5f1476e4bceae8..4dae56ef8b348b 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -81,7 +81,6 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto past_key = node->input_value(3); auto past_value = node->input_value(4); auto seqlens_k = node->input_value(5); - auto total_sequence_length = node->input_value(6); // unused, it's not always equal (seqlens_k + 1) auto cos_cache = node->input_value(7); auto sin_cache = node->input_value(8); diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp index 744dac6aebd82a..1efdfa53a07e3b 100644 --- a/src/core/include/openvino/op/group_query_attention.hpp +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -16,8 +16,8 @@ class OPENVINO_API GroupQueryAttention : public Op { GroupQueryAttention() = default; GroupQueryAttention(const ov::OutputVector& args, - unsigned int num_heads, - unsigned int kv_num_heads, + int64_t num_heads, + int64_t kv_num_heads, float scale, bool do_rotary, bool rotary_interleaved); @@ -25,10 +25,10 @@ class OPENVINO_API GroupQueryAttention : public Op { bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - unsigned int get_num_heads() const { + int64_t get_num_heads() const { return m_num_heads; } - unsigned int get_kv_num_heads() const { + int64_t get_kv_num_heads() const { return m_kv_num_heads; } float get_scale() const { @@ -42,8 +42,8 @@ class OPENVINO_API GroupQueryAttention : public Op { } private: - unsigned int m_num_heads; - unsigned int m_kv_num_heads; + int64_t m_num_heads; + int64_t m_kv_num_heads; float m_scale = 0; bool m_do_rotary = false; bool m_rotary_interleaved = false; diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index 475110e66bf5e3..448ee2471b581a 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -13,8 +13,8 @@ namespace op { namespace v15 { GroupQueryAttention::GroupQueryAttention(const OutputVector& args, - unsigned int num_heads, - unsigned int kv_num_heads, + int64_t num_heads, + int64_t kv_num_heads, float scale, bool do_rotary, bool rotary_interleaved) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 692d2529caf7e5..696360e7475201 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -31,6 +31,8 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { for (const auto& input : onnx_op_inputs) { ov_op_inputs.push_back(ov::op::util::is_null(input) ? std::make_shared() : input); } + // total_sequence_length is not used currently in OV GQA + ov_op_inputs[6] = std::make_shared(); return std::make_shared(ov_op_inputs, num_heads, kv_num_heads, From 76c18fa96d74ee1413121c5f5b316a2625d840ea Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 11 Feb 2025 14:50:24 +0800 Subject: [PATCH 08/15] Fix ONNX FE GQA tests --- .../gqa_past_0_input_1_rotary.prototxt | 15 +-------------- ...gqa_past_0_input_1_rotary_interleaved.prototxt | 15 +-------------- .../gqa_past_1_input_1_rotary.prototxt | 15 +-------------- ...gqa_past_1_input_1_rotary_interleaved.prototxt | 15 +-------------- .../onnx/tests/onnx_import_com_microsoft.in.cpp | 8 ++++---- 5 files changed, 8 insertions(+), 60 deletions(-) diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt index a7dacf0dc94ebc..1924d016db44b6 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt @@ -7,7 +7,7 @@ graph { input: "past_key" input: "past_value" input: "seqlens_k" - input: "total_sequence_length" + input: "" input: "cos_cache" input: "sin_cache" output: "output" @@ -129,19 +129,6 @@ graph { } } } - input { - name: "total_sequence_length" - type { - tensor_type { - elem_type: 6 - shape { - dim { - dim_value: 1 - } - } - } - } - } input { name: "cos_cache" type { diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt index d1400ad344e717..e65f2985c4a302 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt @@ -7,7 +7,7 @@ graph { input: "past_key" input: "past_value" input: "seqlens_k" - input: "total_sequence_length" + input: "" input: "cos_cache" input: "sin_cache" output: "output" @@ -129,19 +129,6 @@ graph { } } } - input { - name: "total_sequence_length" - type { - tensor_type { - elem_type: 6 - shape { - dim { - dim_value: 1 - } - } - } - } - } input { name: "cos_cache" type { diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt index f5ec39c9b0bd8e..f3f8a82c00bb03 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt @@ -7,7 +7,7 @@ graph { input: "past_key" input: "past_value" input: "seqlens_k" - input: "total_sequence_length" + input: "" input: "cos_cache" input: "sin_cache" output: "output" @@ -129,19 +129,6 @@ graph { } } } - input { - name: "total_sequence_length" - type { - tensor_type { - elem_type: 6 - shape { - dim { - dim_value: 1 - } - } - } - } - } input { name: "cos_cache" type { diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt index b61cf39552efc7..9aad910b32e0bd 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt @@ -7,7 +7,7 @@ graph { input: "past_key" input: "past_value" input: "seqlens_k" - input: "total_sequence_length" + input: "" input: "cos_cache" input: "sin_cache" output: "output" @@ -129,19 +129,6 @@ graph { } } } - input { - name: "total_sequence_length" - type { - tensor_type { - elem_type: 6 - shape { - dim { - dim_value: 1 - } - } - } - } - } input { name: "cos_cache" type { diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index a7a91cbbf4c72d..7c75f74d602f16 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1821,7 +1821,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - test_case.add_input(total_sequence_length); + // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); @@ -1910,7 +1910,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - test_case.add_input(total_sequence_length); + // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); @@ -2027,7 +2027,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - test_case.add_input(total_sequence_length); + // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); @@ -2144,7 +2144,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - test_case.add_input(total_sequence_length); + // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); From b8005ea7fa64e8ffe96b84565071184ab619d58c Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 11 Feb 2025 17:02:43 +0800 Subject: [PATCH 09/15] Add more checks from code review --- .../group_query_attention_decomposition.hpp | 2 +- src/core/src/op/group_query_attention.cpp | 25 ++++++------------- .../com.microsoft/group_query_attention.cpp | 6 ++++- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp index 3cad7ab229b110..dd6b496e0edf6c 100644 --- a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index 448ee2471b581a..1fe85296ea9821 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -27,34 +27,23 @@ GroupQueryAttention::GroupQueryAttention(const OutputVector& args, constructor_validate_and_infer_types(); } -int64_t get_head_size(const PartialShape& input_shape, int num_heads, int kv_num_heads) { - return input_shape[2].get_length() / (num_heads + kv_num_heads * 2); -} - -std::vector get_qkv_sizes(const PartialShape& input_shape, int num_heads, int kv_num_heads) { - int64_t per_head_size = get_head_size(input_shape, num_heads, kv_num_heads); - const std::vector qkv_sizes = {num_heads * per_head_size, - kv_num_heads * per_head_size, - kv_num_heads * per_head_size}; - return qkv_sizes; -} - void GroupQueryAttention::validate_and_infer_types() { OV_OP_SCOPE(v15_GroupQueryAttention_validate_and_infer_types); PartialShape input_shape = get_input_partial_shape(0); + NODE_VALIDATION_CHECK(this, input_shape[2].is_static(), "GroupQueryAttention: head size should not be dynamic"); + Dimension batch_size = input_shape[0]; Dimension sequence_len = input_shape[1]; Dimension head_size; if (Null::is_null(input_value(1)) && Null::is_null(input_value(2))) { - head_size = get_head_size(input_shape, m_num_heads, m_kv_num_heads); + head_size = input_shape[2] / (m_num_heads + m_kv_num_heads * 2); } else { - head_size = input_shape[2].get_length() / m_num_heads; + head_size = input_shape[2] / m_num_heads; } Dimension output_kv_len; - PartialShape kv_past_shape = get_input_partial_shape(3); - // FIXME: https://github.com/openvinotoolkit/openvino/pull/27648 - if (kv_past_shape[2].is_static()) { - output_kv_len = kv_past_shape[2] + sequence_len; + Dimension past_sequence_len = get_input_partial_shape(3)[2]; + if (past_sequence_len.is_static() && sequence_len.is_static()) { + output_kv_len = past_sequence_len + sequence_len; } else { output_kv_len = ov::Dimension(); } diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 696360e7475201..7b35c0d0dd656c 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -8,6 +8,7 @@ #include "core/null_node.hpp" #include "core/operator_set.hpp" #include "openvino/frontend/exception.hpp" +#include "utils/common.hpp" using namespace ov::op; using ov::Shape; @@ -19,12 +20,15 @@ namespace com_microsoft { namespace opset_1 { ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + // At least given "query" and "seqlens_k" + common::default_op_checks(node, 2); + const auto onnx_op_inputs = node.get_ov_inputs(); const auto num_heads = node.get_attribute_value("num_heads"); const auto kv_num_heads = node.get_attribute_value("kv_num_heads"); const auto scale = node.get_attribute_value("scale", 0.0f); const auto do_rotary = node.get_attribute_value("do_rotary", 0); - const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved", 0.0f); + const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved", 0); OutputVector ov_op_inputs; ov_op_inputs.reserve(onnx_op_inputs.size()); From f8a363a644eed45de58f3e147aa9246eef19f0ba Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 13 Feb 2025 09:36:16 +0800 Subject: [PATCH 10/15] Update copyrights --- .../op_conversions/group_query_attention_decomposition.cpp | 2 +- src/core/include/openvino/op/group_query_attention.hpp | 2 +- src/core/src/op/group_query_attention.cpp | 2 +- .../frontend/src/op/com.microsoft/group_query_attention.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 4dae56ef8b348b..90467b397ae3ad 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp index 1efdfa53a07e3b..261152c48d5ec5 100644 --- a/src/core/include/openvino/op/group_query_attention.hpp +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma once diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index 1fe85296ea9821..46ab1c602bf94d 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 7b35c0d0dd656c..4f2aaad3104de0 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2024 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // From 733076aa510085f83478d3e58a18b1c928f7d47c Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 13 Feb 2025 15:50:54 +0800 Subject: [PATCH 11/15] Move GQA unpack to frontend, remove Null --- .../group_query_attention_decomposition.cpp | 108 +++++++----------- .../openvino/op/group_query_attention.hpp | 4 + src/core/include/openvino/op/null.hpp | 47 -------- src/core/include/openvino/op/ops.hpp | 1 - .../include/openvino/opsets/opset15_tbl.hpp | 1 - src/core/src/op/group_query_attention.cpp | 28 ++--- src/core/tests/opset.cpp | 2 +- .../com.microsoft/group_query_attention.cpp | 88 +++++++++++++- 8 files changed, 140 insertions(+), 139 deletions(-) delete mode 100644 src/core/include/openvino/op/null.hpp diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 90467b397ae3ad..48478f1d885931 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -12,11 +12,9 @@ #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" -#include "openvino/op/divide.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/greater.hpp" #include "openvino/op/multiply.hpp" -#include "openvino/op/null.hpp" #include "openvino/op/range.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" @@ -30,6 +28,12 @@ #include "openvino/op/variadic_split.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" + +namespace ov { +namespace detail { +namespace { +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); std::shared_ptr rotaryEmbedding(ov::Output input, ov::Output past_seqlen, std::shared_ptr seqlen_k, @@ -37,10 +41,9 @@ std::shared_ptr rotaryEmbedding(ov::Output input, std::shared_ptr sin_cache, std::shared_ptr dim_head_size, bool interleaved); -std::shared_ptr get_dimensions(const std::shared_ptr& shape, - const std::vector& dims); -ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); -ov::OutputVector make_split(const ov::Output& value, const std::vector& split_lengths, int64_t axis); +} // namespace +} // namespace detail +} // namespace ov ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { MATCHER_SCOPE(GroupQeuryAttentionDecomposition); @@ -81,41 +84,15 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto past_key = node->input_value(3); auto past_value = node->input_value(4); auto seqlens_k = node->input_value(5); - auto cos_cache = node->input_value(7); - auto sin_cache = node->input_value(8); + auto cos_cache = node->input_value(6); + auto sin_cache = node->input_value(7); // The length of all tokens (past + current) is `seqlens_k` + 1, current = Q.shape[2], past = `seqlens_k` + 1 - current const auto T = Q.get_element_type(); - const auto node_shape = std::make_shared(Q); - const auto batch_size = get_dimensions(node_shape, {0}); - const auto current_seqlen_size = get_dimensions(node_shape, {1}); - const auto hidden_size = get_dimensions(node_shape, {2}); - const auto total_num_heads_node = - v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); - auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); // should be equal to the last dim of past_key - - // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) - auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); - - if (v15::Null::is_null(K)) { - // Handle the packed QKV - auto packed_qkv_shape = std::make_shared( - ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, - 0); - auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); - // (batch_size, sequence_len, num_head, head_size) - inputs_qkv = std::make_shared(inputs_qkv, perm); - // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) - auto split = make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); - Q = split[0]; - K = split[1]; - V = split[2]; - } else { - Q = std::make_shared(Q, perm); - K = std::make_shared(K, perm); - V = std::make_shared(V, perm); - } + const auto q_shape = std::make_shared(Q); + const auto current_sequence_length = detail::get_dimensions(q_shape, {2}); + auto head_size_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {node->get_head_size()}); auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); @@ -126,17 +103,16 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( // Only consider batch is 1 auto seqlens_1d = std::make_shared(real_seqlens, one, false); - auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); - + auto past_sequence_length = std::make_shared(seqlens_1d, current_sequence_length); if (do_rotary) { - Q = rotaryEmbedding(Q, + Q = detail::rotaryEmbedding(Q, past_sequence_length, seqlens_1d, cos_cache.get_node_shared_ptr(), sin_cache.get_node_shared_ptr(), head_size_node, rotary_interleaved); - K = rotaryEmbedding(K, + K = detail::rotaryEmbedding(K, past_sequence_length, seqlens_1d, cos_cache.get_node_shared_ptr(), @@ -149,7 +125,7 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( // present = concat(past, K, V) auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); - auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); + auto curr_datas = std::make_shared(current, zero, current_sequence_length, one, two); return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); }; @@ -162,8 +138,8 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( if (kv_num_heads_factor > 1) { const auto kv_shape = std::make_shared(K); // (batch_size, num_heads, sequence_len, head_size) - const auto kv_shape_prev_2 = get_dimensions(kv_shape, {0, 1}); - const auto kv_shape_last_2 = get_dimensions(kv_shape, {2, 3}); + const auto kv_shape_prev_2 = detail::get_dimensions(kv_shape, {0, 1}); + const auto kv_shape_last_2 = detail::get_dimensions(kv_shape, {2, 3}); auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); K = std::make_shared(K, new_kv_shape, false); V = std::make_shared(V, new_kv_shape, false); @@ -171,7 +147,7 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); auto q_shape = std::make_shared(Q); // (batch_size, num_heads, sequence_len, head_size) - const auto q_shape_prev_2 = get_dimensions(q_shape, {0, 1}); + const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); K = std::make_shared(K, extended_kv_shape, false); V = std::make_shared(V, extended_kv_shape, false); @@ -212,16 +188,28 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( } // transpose the result from (batch_size, num_heads, sequence_length, head_size) - // to (batch_size, sequence_length, num_heads, head_size) + // to (batch_size, sequence_length, num_heads * head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); auto qga_output_transposed = std::make_shared(qga_output, perm); auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); - // reshape the result from (batch_size, sequence_length, num_heads, head_size) - // to (batch_size, sequence_length, num_heads * head_size) auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true)->output(0); return {output, present_k, present_v}; } + +namespace ov { +namespace detail { +namespace { +// make split functions is a copy-past from ONNX FE. TODO: move it to one place +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis) { + using namespace ov::op; + const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); + const auto split = std::make_shared(value, axis_node, num_splits); + + return split->outputs(); +} + std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { using namespace ov::op; @@ -291,24 +279,6 @@ std::shared_ptr rotaryEmbedding(ov::Output input, return std::make_shared(ov::NodeVector{res_0, res_1}, -1); } } - -// make split functions is a copy-past from ONNX FE. TODO: move it to one place -ov::OutputVector make_split(const ov::Output& value, - const std::vector& split_lengths, - int64_t axis) { - using namespace ov::op; - const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); - const auto split_lengths_node = - v0::Constant::create(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths); - const auto variadic_split = std::make_shared(value, axis_node, split_lengths_node); - - return variadic_split->outputs(); -} - -ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis) { - using namespace ov::op; - const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); - const auto split = std::make_shared(value, axis_node, num_splits); - - return split->outputs(); -} +} // namespace +} // namespace detail +} // namespace ov \ No newline at end of file diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp index 261152c48d5ec5..f0a68c9800aabf 100644 --- a/src/core/include/openvino/op/group_query_attention.hpp +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -40,6 +40,9 @@ class OPENVINO_API GroupQueryAttention : public Op { bool get_rotary_interleaved() const { return m_rotary_interleaved; } + int64_t get_head_size() const { + return m_head_size; + } private: int64_t m_num_heads; @@ -47,6 +50,7 @@ class OPENVINO_API GroupQueryAttention : public Op { float m_scale = 0; bool m_do_rotary = false; bool m_rotary_interleaved = false; + int64_t m_head_size; }; } // namespace v15 diff --git a/src/core/include/openvino/op/null.hpp b/src/core/include/openvino/op/null.hpp deleted file mode 100644 index 346ff3d77c532b..00000000000000 --- a/src/core/include/openvino/op/null.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/op/op.hpp" - -namespace ov { -namespace op { -namespace v15 { - -/// \brief Represents a missing optional input or output of an ONNX node -/// -/// Some ONNX operators have inputs or outputs that are marked as optional, -/// which means that a referring node MAY forgo providing values for such inputs -/// or computing these outputs. -/// An empty string is used in place of a name of such input or output. -/// -/// More: -/// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs -class OPENVINO_API Null : public Op { -public: - OPENVINO_OP("Null", "opset15", op::Op); - Null() { - set_output_size(1); - } - - static bool is_null(const ov::Node* node) { - return ov::as_type(node) != nullptr; - } - - static bool is_null(const std::shared_ptr& node) { - return is_null(node.get()); - } - - static bool is_null(const Output& output) { - return is_null(output.get_node()); - } - - virtual std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { - return std::make_shared(); - } -}; -} // namespace v15 -} // namespace op -} // namespace ov diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index 098f2872e4a586..630044b0ef8135 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -168,7 +168,6 @@ #include "openvino/op/roll.hpp" #include "openvino/op/round.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" -#include "openvino/op/null.hpp" #include "openvino/op/group_query_attention.hpp" #include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/scatter_nd_update.hpp" diff --git a/src/core/include/openvino/opsets/opset15_tbl.hpp b/src/core/include/openvino/opsets/opset15_tbl.hpp index 8dd7c0d7b830cc..e628d7b479ccbc 100644 --- a/src/core/include/openvino/opsets/opset15_tbl.hpp +++ b/src/core/include/openvino/opsets/opset15_tbl.hpp @@ -235,4 +235,3 @@ _OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15) _OPENVINO_OP_REG(SliceScatter, ov::op::v15) _OPENVINO_OP_REG(SearchSorted, ov::op::v15) _OPENVINO_OP_REG(GroupQueryAttention, ov::op::v15) -_OPENVINO_OP_REG(Null, ov::op::v15) diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index 46ab1c602bf94d..9685aa6ae8ffd4 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -5,7 +5,6 @@ #include "openvino/op/group_query_attention.hpp" #include "itt.hpp" -#include "openvino/op/null.hpp" using namespace std; namespace ov { @@ -29,17 +28,18 @@ GroupQueryAttention::GroupQueryAttention(const OutputVector& args, void GroupQueryAttention::validate_and_infer_types() { OV_OP_SCOPE(v15_GroupQueryAttention_validate_and_infer_types); - PartialShape input_shape = get_input_partial_shape(0); - NODE_VALIDATION_CHECK(this, input_shape[2].is_static(), "GroupQueryAttention: head size should not be dynamic"); + // GQA expectes the following inputs: query, key, value, past_key, past_value, seqlens_k, cos_cache, sin_cache + // All qkv's should have the shape [batch, num_heads, seq_len, head_size] ([B, N, S, H]) + // It has three outputs: output of shape [B, S, N * H], and present_key/value of shape [B, N, S, H] + // seqlens_k is number of 1's in the attention_mask minus 1 + + PartialShape q_shape = get_input_partial_shape(0); + NODE_VALIDATION_CHECK(this, q_shape[3].is_static(), "GroupQueryAttention: head size should not be dynamic"); + m_head_size = q_shape[3].get_length(); + + Dimension batch_size = q_shape[0]; + Dimension sequence_len = q_shape[2]; - Dimension batch_size = input_shape[0]; - Dimension sequence_len = input_shape[1]; - Dimension head_size; - if (Null::is_null(input_value(1)) && Null::is_null(input_value(2))) { - head_size = input_shape[2] / (m_num_heads + m_kv_num_heads * 2); - } else { - head_size = input_shape[2] / m_num_heads; - } Dimension output_kv_len; Dimension past_sequence_len = get_input_partial_shape(3)[2]; if (past_sequence_len.is_static() && sequence_len.is_static()) { @@ -51,9 +51,9 @@ void GroupQueryAttention::validate_and_infer_types() { NODE_VALIDATION_CHECK(this, element_type == element::f32 || element_type == element::f16, "GroupQueryAttention only suuports f32 and f16"); - set_output_type(0, element_type, PartialShape{batch_size, sequence_len, head_size * m_num_heads}); - set_output_type(1, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); - set_output_type(2, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); + set_output_type(0, element_type, PartialShape{batch_size, sequence_len, m_head_size * m_num_heads}); + set_output_type(1, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, m_head_size}); + set_output_type(2, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, m_head_size}); } bool GroupQueryAttention::visit_attributes(AttributeVisitor& visitor) { diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index 49f24a5dfd41de..cf1409f629820c 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -76,7 +76,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset12, 178}, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, - OpsetTestParams{ov::get_opset15, 201}, + OpsetTestParams{ov::get_opset15, 200}, OpsetTestParams{ov::get_opset16, 6}), OpsetTestNameGenerator{}); diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 4f2aaad3104de0..99dc40ae9c9047 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -3,12 +3,19 @@ // #include "openvino/op/group_query_attention.hpp" -#include "openvino/op/null.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/transpose.hpp" #include "core/null_node.hpp" #include "core/operator_set.hpp" #include "openvino/frontend/exception.hpp" #include "utils/common.hpp" +#include "utils/split.hpp" +#include using namespace ov::op; using ov::Shape; @@ -17,6 +24,11 @@ namespace ov { namespace frontend { namespace onnx { namespace com_microsoft { +namespace detail { +namespace { +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); +} // namespace +} // namespace detail namespace opset_1 { ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { @@ -30,13 +42,62 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { const auto do_rotary = node.get_attribute_value("do_rotary", 0); const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved", 0); + // In ONNX, the format of input QKV is [B, S, N*H] and of past_kv is [B, N, S, H] + // In OV, we always use [B, N, S, H] + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + + auto Q = onnx_op_inputs[0]; + auto K = onnx_op_inputs[1]; + auto V = onnx_op_inputs[2]; + const auto q_shape_node = std::make_shared(Q); + const auto batch_size_node = detail::get_dimensions(q_shape_node, {0}); + const auto current_seqlen_size_node = detail::get_dimensions(q_shape_node, {1}); + const auto hidden_size_node = detail::get_dimensions(q_shape_node, {2}); + OutputVector ov_op_inputs; - ov_op_inputs.reserve(onnx_op_inputs.size()); - for (const auto& input : onnx_op_inputs) { - ov_op_inputs.push_back(ov::op::util::is_null(input) ? std::make_shared() : input); + if (ov::op::util::is_null(K)) { + auto total_num_heads_node = + v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); + auto head_size_node = std::make_shared(hidden_size_node, total_num_heads_node); + auto packed_qkv_shape = std::make_shared( + ov::NodeVector{batch_size_node, current_seqlen_size_node, total_num_heads_node, head_size_node}, + 0); + + auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); + inputs_qkv = std::make_shared(inputs_qkv, perm); + auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); + + std::copy(split.begin(), split.end(), std::back_inserter(ov_op_inputs)); + } else { + auto num_heads_node = v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads}); + auto head_size_node = std::make_shared(hidden_size_node, num_heads_node); + auto q_shape = std::make_shared( + ov::NodeVector{batch_size_node, current_seqlen_size_node, num_heads_node, head_size_node}, + 0); + + Q = std::make_shared(Q, q_shape, false)->output(0); + Q = std::make_shared(Q, perm); + ov_op_inputs.push_back(Q); + + auto kv_num_heads_node = v0::Constant::create(ov::element::i64, ov::Shape{1}, {kv_num_heads}); + auto kv_shape = std::make_shared( + ov::NodeVector{batch_size_node, current_seqlen_size_node, kv_num_heads_node, head_size_node}, + 0); + + K = std::make_shared(K, kv_shape, false)->output(0); + V = std::make_shared(V, kv_shape, false)->output(0); + K = std::make_shared(K, perm); + V = std::make_shared(V, perm); + ov_op_inputs.push_back(K); + ov_op_inputs.push_back(V); + } + + for (int i = 3; i < 9; ++i) { + // skip total_sequence_length + if (i == 6) + continue; + ov_op_inputs.push_back(onnx_op_inputs[i]); } - // total_sequence_length is not used currently in OV GQA - ov_op_inputs[6] = std::make_shared(); return std::make_shared(ov_op_inputs, num_heads, kv_num_heads, @@ -49,6 +110,21 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); } // namespace opset_1 + + +namespace detail { +namespace { +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} +} // namespace +} // namespace detail } // namespace com_microsoft } // namespace onnx } // namespace frontend From 31d5431efeecba8bd28ff6cfc0d39744f20be345 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 18 Feb 2025 12:19:19 +0800 Subject: [PATCH 12/15] Move GQA to dev_api, clean code --- .../group_query_attention_decomposition.hpp | 4 +- .../group_query_attention_decomposition.cpp | 63 ++++++++----------- .../openvino/op/group_query_attention.hpp | 4 +- src/core/include/openvino/op/ops.hpp | 1 - .../include/openvino/opsets/opset15_tbl.hpp | 1 - src/core/src/op/group_query_attention.cpp | 8 +-- src/core/tests/opset.cpp | 2 +- .../com.microsoft/group_query_attention.cpp | 34 +++++----- .../gqa_past_0_input_1_rotary.prototxt | 2 +- ...past_0_input_1_rotary_interleaved.prototxt | 2 +- .../tests/onnx_import_com_microsoft.in.cpp | 8 --- 11 files changed, 51 insertions(+), 78 deletions(-) rename src/core/{include => dev_api}/openvino/op/group_query_attention.hpp (94%) diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp index dd6b496e0edf6c..fe762533303bed 100644 --- a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -20,5 +20,7 @@ class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass public: OPENVINO_MATCHER_PASS_RTTI("GroupQueryAttentionDecomposition"); GroupQueryAttentionDecomposition(); - ov::OutputVector decompose(std::shared_ptr node); + +private: + ov::OutputVector decompose(std::shared_ptr node); }; diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 48478f1d885931..8c0b7566f82248 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -25,14 +25,13 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" -#include "openvino/op/variadic_split.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" - namespace ov { namespace detail { namespace { std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims); ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); std::shared_ptr rotaryEmbedding(ov::Output input, ov::Output past_seqlen, @@ -47,12 +46,12 @@ std::shared_ptr rotaryEmbedding(ov::Output input, ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { MATCHER_SCOPE(GroupQeuryAttentionDecomposition); - auto pattern_node = ov::pass::pattern::wrap_type(); + auto pattern_node = ov::pass::pattern::wrap_type(); matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { auto& pattern_to_output = m.get_pattern_value_map(); auto node = - ov::as_type_ptr(pattern_to_output.at(pattern_node).get_node_shared_ptr()); + ov::as_type_ptr(pattern_to_output.at(pattern_node).get_node_shared_ptr()); if (node == nullptr || transformation_callback(node)) { return false; @@ -68,7 +67,7 @@ ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { } ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( - std::shared_ptr node) { + std::shared_ptr node) { using namespace ov::op; const auto num_heads = node->get_num_heads(); @@ -87,7 +86,8 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto cos_cache = node->input_value(6); auto sin_cache = node->input_value(7); - // The length of all tokens (past + current) is `seqlens_k` + 1, current = Q.shape[2], past = `seqlens_k` + 1 - current + // The length of all tokens (past + current) is `seqlens_k` + 1 + // current = Q.shape[2], past = `seqlens_k` + 1 - current const auto T = Q.get_element_type(); const auto q_shape = std::make_shared(Q); @@ -106,29 +106,26 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto past_sequence_length = std::make_shared(seqlens_1d, current_sequence_length); if (do_rotary) { Q = detail::rotaryEmbedding(Q, - past_sequence_length, - seqlens_1d, - cos_cache.get_node_shared_ptr(), - sin_cache.get_node_shared_ptr(), - head_size_node, - rotary_interleaved); + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); K = detail::rotaryEmbedding(K, - past_sequence_length, - seqlens_1d, - cos_cache.get_node_shared_ptr(), - sin_cache.get_node_shared_ptr(), - head_size_node, - rotary_interleaved); + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); } - // present = concat(K, V) if 'past' input is unavailable - // or - // present = concat(past, K, V) + auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); auto curr_datas = std::make_shared(current, zero, current_sequence_length, one, two); return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); }; - K = construct_kv_cache(past_key, K); V = construct_kv_cache(past_value, V); auto present_k = K; @@ -155,15 +152,15 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( // need to apply low-triangle mask to attention score. // two steps, construct the total_sequence x total_sequence triangle, then slice the current length - auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); std::shared_ptr mask_per_line_node = std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), seqlens_1d_scalar, one_without_shape, - ov::element::i64); // [0,1,2,...,] - auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 - auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 - auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 + ov::element::i64); + auto hori_range = std::make_shared(mask_per_line_node, zero); + auto vert_range = std::make_shared(mask_per_line_node, one); + auto triu = std::make_shared(hori_range, vert_range); auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0}); // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp std::shared_ptr minus_inf = nullptr; @@ -171,14 +168,9 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); else if (T == ov::element::f16) minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); - auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 - auto atten_mask_sliced = std::make_shared(atten_mask, - past_sequence_length, - seqlens_1d, - one, - zero); // slice to current query seqlen, 12x12 or 1x13 - - // compute softmax((Q x K') / sqrt(head_size)) x V + auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); + auto atten_mask_sliced = std::make_shared(atten_mask, past_sequence_length, seqlens_1d, one, zero); + std::shared_ptr qga_output; if (scale != 0.0f) { auto scale_node = v0::Constant::create(T, Shape{}, {scale}); @@ -197,7 +189,6 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( return {output, present_k, present_v}; } - namespace ov { namespace detail { namespace { diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/dev_api/openvino/op/group_query_attention.hpp similarity index 94% rename from src/core/include/openvino/op/group_query_attention.hpp rename to src/core/dev_api/openvino/op/group_query_attention.hpp index f0a68c9800aabf..fe298a774f257e 100644 --- a/src/core/include/openvino/op/group_query_attention.hpp +++ b/src/core/dev_api/openvino/op/group_query_attention.hpp @@ -7,12 +7,11 @@ namespace ov { namespace op { -namespace v15 { // This is an experimental operation that is implemented in the plugins. class OPENVINO_API GroupQueryAttention : public Op { public: - OPENVINO_OP("GroupQueryAttention", "opset15"); + OPENVINO_OP("GroupQueryAttention"); GroupQueryAttention() = default; GroupQueryAttention(const ov::OutputVector& args, @@ -53,6 +52,5 @@ class OPENVINO_API GroupQueryAttention : public Op { int64_t m_head_size; }; -} // namespace v15 } // namespace op } // namespace ov diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index 630044b0ef8135..adeb9c25611960 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -168,7 +168,6 @@ #include "openvino/op/roll.hpp" #include "openvino/op/round.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" -#include "openvino/op/group_query_attention.hpp" #include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/scatter_nd_update.hpp" #include "openvino/op/scatter_update.hpp" diff --git a/src/core/include/openvino/opsets/opset15_tbl.hpp b/src/core/include/openvino/opsets/opset15_tbl.hpp index e628d7b479ccbc..a9e8d2a8dcc840 100644 --- a/src/core/include/openvino/opsets/opset15_tbl.hpp +++ b/src/core/include/openvino/opsets/opset15_tbl.hpp @@ -234,4 +234,3 @@ _OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15) _OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15) _OPENVINO_OP_REG(SliceScatter, ov::op::v15) _OPENVINO_OP_REG(SearchSorted, ov::op::v15) -_OPENVINO_OP_REG(GroupQueryAttention, ov::op::v15) diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index 9685aa6ae8ffd4..af3f27f299deeb 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -9,7 +9,6 @@ using namespace std; namespace ov { namespace op { -namespace v15 { GroupQueryAttention::GroupQueryAttention(const OutputVector& args, int64_t num_heads, @@ -27,7 +26,7 @@ GroupQueryAttention::GroupQueryAttention(const OutputVector& args, } void GroupQueryAttention::validate_and_infer_types() { - OV_OP_SCOPE(v15_GroupQueryAttention_validate_and_infer_types); + OV_OP_SCOPE(GroupQueryAttention_validate_and_infer_types); // GQA expectes the following inputs: query, key, value, past_key, past_value, seqlens_k, cos_cache, sin_cache // All qkv's should have the shape [batch, num_heads, seq_len, head_size] ([B, N, S, H]) // It has three outputs: output of shape [B, S, N * H], and present_key/value of shape [B, N, S, H] @@ -57,7 +56,7 @@ void GroupQueryAttention::validate_and_infer_types() { } bool GroupQueryAttention::visit_attributes(AttributeVisitor& visitor) { - OV_OP_SCOPE(v15_GroupQueryAttention_visit_attributes); + OV_OP_SCOPE(GroupQueryAttention_visit_attributes); visitor.on_attribute("do_rotary", m_do_rotary); visitor.on_attribute("kv_num_heads", m_kv_num_heads); visitor.on_attribute("num_heads", m_num_heads); @@ -67,7 +66,7 @@ bool GroupQueryAttention::visit_attributes(AttributeVisitor& visitor) { } std::shared_ptr GroupQueryAttention::clone_with_new_inputs(const ov::OutputVector& new_args) const { - OV_OP_SCOPE(v15_GroupQueryAttention_clone_with_new_inputs); + OV_OP_SCOPE(GroupQueryAttention_clone_with_new_inputs); return std::make_shared(new_args, m_num_heads, m_kv_num_heads, @@ -76,6 +75,5 @@ std::shared_ptr GroupQueryAttention::clone_with_new_inputs(const ov::O m_rotary_interleaved); } -} // namespace v15 } // namespace op } // namespace ov diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index cf1409f629820c..3006e04a02f960 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -76,7 +76,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset12, 178}, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, - OpsetTestParams{ov::get_opset15, 200}, + OpsetTestParams{ov::get_opset15, 199}, OpsetTestParams{ov::get_opset16, 6}), OpsetTestNameGenerator{}); diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 99dc40ae9c9047..37647a63663989 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -3,22 +3,21 @@ // #include "openvino/op/group_query_attention.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/gather.hpp" -#include "openvino/op/divide.hpp" -#include "openvino/op/concat.hpp" -#include "openvino/op/transpose.hpp" + +#include #include "core/null_node.hpp" #include "core/operator_set.hpp" -#include "openvino/frontend/exception.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/transpose.hpp" #include "utils/common.hpp" #include "utils/split.hpp" -#include using namespace ov::op; -using ov::Shape; namespace ov { namespace frontend { @@ -98,12 +97,12 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { continue; ov_op_inputs.push_back(onnx_op_inputs[i]); } - return std::make_shared(ov_op_inputs, - num_heads, - kv_num_heads, - scale, - do_rotary, - rotary_interleaved) + return std::make_shared(ov_op_inputs, + num_heads, + kv_num_heads, + scale, + do_rotary, + rotary_interleaved) ->outputs(); } @@ -111,7 +110,6 @@ ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_que } // namespace opset_1 - namespace detail { namespace { std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { @@ -119,10 +117,6 @@ std::shared_ptr get_dimensions(const std::shared_ptr& sha const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); return std::make_shared(shape, dims_const, zero); } - -std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { - return get_dimensions(std::make_shared(node), dims); -} } // namespace } // namespace detail } // namespace com_microsoft diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt index 1924d016db44b6..a105b7c03c17f0 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt @@ -231,4 +231,4 @@ opset_import { opset_import { domain: "com.microsoft" version: 1 -} \ No newline at end of file +} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt index e65f2985c4a302..6b5748b653f0fc 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt @@ -231,4 +231,4 @@ opset_import { opset_import { domain: "com.microsoft" version: 1 -} \ No newline at end of file +} diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 7c75f74d602f16..3fdd654e230f8b 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1755,7 +1755,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { std::vector past_key = {}; std::vector past_value = {}; std::vector seqlens_k = {0}; - std::vector total_sequence_length = {1}; std::vector cos_cache = { 0.8437, -0.7849, @@ -1821,7 +1820,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); @@ -1844,7 +1842,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) std::vector past_key = {}; std::vector past_value = {}; std::vector seqlens_k = {0}; - std::vector total_sequence_length = {1}; std::vector cos_cache = { 0.8437, -0.7849, @@ -1910,7 +1907,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); @@ -1967,7 +1963,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { 0.8058, }; std::vector seqlens_k = {1}; - std::vector total_sequence_length = {2}; std::vector cos_cache = { 0.8437, -0.7849, @@ -2027,7 +2022,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); @@ -2084,7 +2078,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) 0.8058, }; std::vector seqlens_k = {1}; - std::vector total_sequence_length = {2}; std::vector cos_cache = { 0.8437, -0.7849, @@ -2144,7 +2137,6 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) test_case.add_input(past_key); test_case.add_input(past_value); test_case.add_input(seqlens_k); - // test_case.add_input(total_sequence_length); test_case.add_input(cos_cache); test_case.add_input(sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); From e4838e1060689b8c46856d881827a2b7af24db6e Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 20 Feb 2025 10:19:04 +0800 Subject: [PATCH 13/15] Disable GQA onnx tests for interpreter backend --- .../onnx/tests/runtime/interpreter/unit_test.manifest | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest b/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest index 82fa651aaf6e94..060ef5e86aa888 100644 --- a/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest +++ b/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest @@ -72,3 +72,9 @@ INTERPRETER.onnx_softmax_crossentropy_loss_mean # Incorrect order of elements returned by the TopK implementation INTERPRETER.onnx_model_top_k_repeating INTERPRETER.onnx_model_top_k_repeating_unsorted + +# Interpreter backend doesn't implement evaluate method for OP ScaledDotProductAttention +INTERPRETER.onnx_model_gqa_past_0_input_1_rotary +INTERPRETER.onnx_model_gqa_past_0_input_1_rotary_interleaved +INTERPRETER.onnx_model_gqa_past_1_input_1_rotary +INTERPRETER.onnx_model_gqa_past_1_input_1_rotary_interleaved From 157103bc93551cd13554decba4b0dd944665b939 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 20 Feb 2025 12:08:24 +0800 Subject: [PATCH 14/15] Add trailing "f" to float-point numbers in onnx tests --- .../tests/onnx_import_com_microsoft.in.cpp | 578 +++++++++--------- 1 file changed, 289 insertions(+), 289 deletions(-) diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 3fdd654e230f8b..46e1b7ec1253be 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1745,75 +1745,75 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary.onnx"); std::vector query = { - -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, - 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, - -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, - -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, - 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, - -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, }; std::vector past_key = {}; std::vector past_value = {}; std::vector seqlens_k = {0}; std::vector cos_cache = { - 0.8437, - -0.7849, - -0.7829, - 0.4581, - -0.9870, - 0.6273, - -0.9483, - -0.9962, + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, }; std::vector sin_cache = { - 0.5368, - 0.6196, - -0.6222, - 0.8889, - 0.1605, - -0.7788, - 0.3174, - -0.0872, - }; - - std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, - -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, - -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, - -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; - - std::vector expected_present_key = {1.2561098, - 1.0199738, - -0.05948371, - -0.16574995, - 2.5059946, - -1.738188, - -0.03158256, - -0.35975295, - 1.0918287, - -0.90313876, - -0.4790303, - 0.67029977, - -0.87039495, - 0.7783688, - -0.81333745, - 0.89886224}; - - std::vector expected_present_value = {-0.2188, - -2.4351, - -0.0729, - -0.034, - 0.9625, - 0.3492, - -0.9215, - -0.0562, - -0.6227, - -0.4637, - 1.9218, - -0.4025, - 0.1239, - 1.1648, - 0.9234, - 1.3873}; + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + }; + + std::vector expected_output = {-0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; + + std::vector expected_present_key = {1.2561098f, + 1.0199738f, + -0.05948371f, + -0.16574995f, + 2.5059946f, + -1.738188f, + -0.03158256f, + -0.35975295f, + 1.0918287f, + -0.90313876f, + -0.4790303f, + 0.67029977f, + -0.87039495f, + 0.7783688f, + -0.81333745f, + 0.89886224f}; + + std::vector expected_present_value = {-0.2188f, + -2.4351f, + -0.0729f, + -0.034f, + 0.9625f, + 0.3492f, + -0.9215f, + -0.0562f, + -0.6227f, + -0.4637f, + 1.9218f, + -0.4025f, + 0.1239f, + 1.1648f, + 0.9234f, + 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); test_case.add_input(query); @@ -1832,75 +1832,75 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary_interleaved.onnx"); std::vector query = { - -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, - 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, - -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, - -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, - 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, - -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, }; std::vector past_key = {}; std::vector past_value = {}; std::vector seqlens_k = {0}; std::vector cos_cache = { - 0.8437, - -0.7849, - -0.7829, - 0.4581, - -0.9870, - 0.6273, - -0.9483, - -0.9962, + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, }; std::vector sin_cache = { - 0.5368, - 0.6196, - -0.6222, - 0.8889, - 0.1605, - -0.7788, - 0.3174, - -0.0872, - }; - - std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, - -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, - -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, - -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; - - std::vector expected_present_key = {2.118801, - -0.2640816, - -0.5926066, - -0.19455537, - 0.9903903, - 2.954185, - -0.35343042, - -0.07457897, - -0.25603274, - -0.03627284, - 0.56591415, - 0.02181074, - -0.1586003, - 0.96567893, - -0.8591481, - 0.85514885}; - - std::vector expected_present_value = {-0.2188, - -2.4351, - -0.0729, - -0.034, - 0.9625, - 0.3492, - -0.9215, - -0.0562, - -0.6227, - -0.4637, - 1.9218, - -0.4025, - 0.1239, - 1.1648, - 0.9234, - 1.3873}; + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + }; + + std::vector expected_output = {-0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; + + std::vector expected_present_key = {2.118801f, + -0.2640816f, + -0.5926066f, + -0.19455537f, + 0.9903903f, + 2.954185f, + -0.35343042f, + -0.07457897f, + -0.25603274f, + -0.03627284f, + 0.56591415f, + 0.02181074f, + -0.1586003f, + 0.96567893f, + -0.8591481f, + 0.85514885f}; + + std::vector expected_present_value = {-0.2188f, + -2.4351f, + -0.0729f, + -0.034f, + 0.9625f, + 0.3492f, + -0.9215f, + -0.0562f, + -0.6227f, + -0.4637f, + 1.9218f, + -0.4025f, + 0.1239f, + 1.1648f, + 0.9234f, + 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); test_case.add_input(query); @@ -1919,103 +1919,103 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary.onnx"); std::vector query = { - -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, - 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, - -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, - -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, - 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, - -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, }; std::vector past_key = { - -0.6136, - 0.0316, - -0.4927, - 0.2484, - 0.4397, - 0.1124, - 0.6408, - 0.4412, - -0.1023, - 0.7924, - -0.2897, - 0.0525, - 0.5229, - 2.3022, - -1.4689, - -1.5867, + -0.6136f, + 0.0316f, + -0.4927f, + 0.2484f, + 0.4397f, + 0.1124f, + 0.6408f, + 0.4412f, + -0.1023f, + 0.7924f, + -0.2897f, + 0.0525f, + 0.5229f, + 2.3022f, + -1.4689f, + -1.5867f, }; std::vector past_value = { - -0.5692, - 0.9200, - 1.1108, - 1.2899, - -1.4782, - 2.5672, - -0.4731, - 0.3356, - -1.6293, - -0.5497, - -0.4798, - -0.4997, - -1.0670, - 1.1149, - -0.1407, - 0.8058, + -0.5692f, + 0.9200f, + 1.1108f, + 1.2899f, + -1.4782f, + 2.5672f, + -0.4731f, + 0.3356f, + -1.6293f, + -0.5497f, + -0.4798f, + -0.4997f, + -1.0670f, + 1.1149f, + -0.1407f, + 0.8058f, }; std::vector seqlens_k = {1}; std::vector cos_cache = { - 0.8437, - -0.7849, - -0.7829, - 0.4581, - -0.9870, - 0.6273, - -0.9483, - -0.9962, - -0.9635, - -0.8046, - 0.4139, - 0.9863, - 0.4117, - 0.9874, - -0.9743, - 0.9494, + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, + -0.9635f, + -0.8046f, + 0.4139f, + 0.9863f, + 0.4117f, + 0.9874f, + -0.9743f, + 0.9494f, }; std::vector sin_cache = { - 0.5368, - 0.6196, - -0.6222, - 0.8889, - 0.1605, - -0.7788, - 0.3174, - -0.0872, - 0.2677, - -0.5938, - -0.9103, - -0.1650, - -0.9113, - -0.1583, - 0.2253, - 0.3140, + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + 0.2677f, + -0.5938f, + -0.9103f, + -0.1650f, + -0.9113f, + -0.1583f, + 0.2253f, + 0.3140f, }; std::vector expected_output = { - -0.53934956, 0.6341806, 1.0099611, 1.1771176, -1.270278, 2.3782496, -0.511299, 0.30222273, - -1.5435482, -0.5423737, -0.27520883, -0.4914196, -0.96554786, 1.1191509, -0.05004983, 0.85533774, - -0.49356747, 0.19581467, 0.8553029, 1.0041412, -0.9513843, 2.088453, -0.5698854, 0.25103146, - -1.4120293, -0.5311372, 0.03857604, -0.47871974, -0.8099488, 1.1256707, 0.08898184, 0.93131447}; + -0.53934956f, 0.6341806f, 1.0099611f, 1.1771176f, -1.270278f, 2.3782496f, -0.511299f, 0.30222273f, + -1.5435482f, -0.5423737f, -0.27520883f, -0.4914196f, -0.96554786f, 1.1191509f, -0.05004983f, 0.85533774f, + -0.49356747f, 0.19581467f, 0.8553029f, 1.0041412f, -0.9513843f, 2.088453f, -0.5698854f, 0.25103146f, + -1.4120293f, -0.5311372f, 0.03857604f, -0.47871974f, -0.8099488f, 1.1256707f, 0.08898184f, 0.93131447f}; std::vector expected_present_key = { - -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, - -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, - -1.6519198, 1.1400802, 0.45031136, 0.5877534, -0.65952265, -1.8121169, 0.04630837, 0.5568472, - 0.20271924, 0.7458131, -0.17379119, 0.3623912, 2.5696063, -0.58594, -0.8126341, -0.7919839}; + -0.6136f, 0.0316f, -0.4927f, 0.2484f, 0.4397f, 0.1124f, 0.6408f, 0.4412f, + -0.1023f, 0.7924f, -0.2897f, 0.0525f, 0.5229f, 2.3022f, -1.4689f, -1.5867f, + -1.6519198f, 1.1400802f, 0.45031136f, 0.5877534f, -0.65952265f, -1.8121169f, 0.04630837f, 0.5568472f, + 0.20271924f, 0.7458131f, -0.17379119f, 0.3623912f, 2.5696063f, -0.58594f, -0.8126341f, -0.7919839f}; - std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, - -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, - -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, - -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + std::vector expected_present_value = { + -0.5692f, 0.9200f, 1.1108f, 1.2899f, -1.4782f, 2.5672f, -0.4731f, 0.3356f, -1.6293f, -0.5497f, -0.4798f, + -0.4997f, -1.0670f, 1.1149f, -0.1407f, 0.8058f, -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, + -0.9215f, -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); test_case.add_input(query); @@ -2034,103 +2034,103 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary_interleaved.onnx"); std::vector query = { - -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, - 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, - -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, - -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, - 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, - -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, }; std::vector past_key = { - -0.6136, - 0.0316, - -0.4927, - 0.2484, - 0.4397, - 0.1124, - 0.6408, - 0.4412, - -0.1023, - 0.7924, - -0.2897, - 0.0525, - 0.5229, - 2.3022, - -1.4689, - -1.5867, + -0.6136f, + 0.0316f, + -0.4927f, + 0.2484f, + 0.4397f, + 0.1124f, + 0.6408f, + 0.4412f, + -0.1023f, + 0.7924f, + -0.2897f, + 0.0525f, + 0.5229f, + 2.3022f, + -1.4689f, + -1.5867f, }; std::vector past_value = { - -0.5692, - 0.9200, - 1.1108, - 1.2899, - -1.4782, - 2.5672, - -0.4731, - 0.3356, - -1.6293, - -0.5497, - -0.4798, - -0.4997, - -1.0670, - 1.1149, - -0.1407, - 0.8058, + -0.5692f, + 0.9200f, + 1.1108f, + 1.2899f, + -1.4782f, + 2.5672f, + -0.4731f, + 0.3356f, + -1.6293f, + -0.5497f, + -0.4798f, + -0.4997f, + -1.0670f, + 1.1149f, + -0.1407f, + 0.8058f, }; std::vector seqlens_k = {1}; std::vector cos_cache = { - 0.8437, - -0.7849, - -0.7829, - 0.4581, - -0.9870, - 0.6273, - -0.9483, - -0.9962, - -0.9635, - -0.8046, - 0.4139, - 0.9863, - 0.4117, - 0.9874, - -0.9743, - 0.9494, + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, + -0.9635f, + -0.8046f, + 0.4139f, + 0.9863f, + 0.4117f, + 0.9874f, + -0.9743f, + 0.9494f, }; std::vector sin_cache = { - 0.5368, - 0.6196, - -0.6222, - 0.8889, - 0.1605, - -0.7788, - 0.3174, - -0.0872, - 0.2677, - -0.5938, - -0.9103, - -0.1650, - -0.9113, - -0.1583, - 0.2253, - 0.3140, + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + 0.2677f, + -0.5938f, + -0.9103f, + -0.1650f, + -0.9113f, + -0.1583f, + 0.2253f, + 0.3140f, }; std::vector expected_output = { - -0.33396345, -1.332403, 0.31613833, 0.40111685, 0.16033238, 1.0781744, -0.7741276, 0.07257013, - -0.9535321, -0.491965, 1.1324831, -0.43444604, -0.2675047, 1.1483997, 0.57366973, 1.1961825, - -0.24709277, -2.164195, 0.02267693, 0.07289726, 0.7654276, 0.5282906, -0.8852943, -0.02456442, - -0.7039771, -0.47064403, 1.7278847, -0.41034833, 0.02774171, 1.1607709, 0.83748007, 1.3403473}; + -0.33396345f, -1.332403f, 0.31613833f, 0.40111685f, 0.16033238f, 1.0781744f, -0.7741276f, 0.07257013f, + -0.9535321f, -0.491965f, 1.1324831f, -0.43444604f, -0.2675047f, 1.1483997f, 0.57366973f, 1.1961825f, + -0.24709277f, -2.164195f, 0.02267693f, 0.07289726f, 0.7654276f, 0.5282906f, -0.8852943f, -0.02456442f, + -0.7039771f, -0.47064403f, 1.7278847f, -0.41034833f, 0.02774171f, 1.1607709f, 0.83748007f, 1.3403473f}; std::vector expected_present_key = { - -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, - -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, - -1.2216992, 1.7511603, 0.03145146, -0.62293506, -2.625969, 1.6767058, -0.17887366, 0.313817, - 0.1717277, -0.19334024, 0.4056727, 0.39516917, -0.25018305, 0.9460988, 1.0327814, -0.6345757}; - - std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, - -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, - -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, - -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + -0.6136f, 0.0316f, -0.4927f, 0.2484f, 0.4397f, 0.1124f, 0.6408f, 0.4412f, + -0.1023f, 0.7924f, -0.2897f, 0.0525f, 0.5229f, 2.3022f, -1.4689f, -1.5867f, + -1.2216992f, 1.7511603f, 0.03145146f, -0.62293506f, -2.625969f, 1.6767058f, -0.17887366f, 0.313817f, + 0.1717277f, -0.19334024f, 0.4056727f, 0.39516917f, -0.25018305f, 0.9460988f, 1.0327814f, -0.6345757f}; + + std::vector expected_present_value = { + -0.5692f, 0.9200f, 1.1108f, 1.2899f, -1.4782f, 2.5672f, -0.4731f, 0.3356f, -1.6293f, -0.5497f, -0.4798f, + -0.4997f, -1.0670f, 1.1149f, -0.1407f, 0.8058f, -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, + -0.9215f, -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); test_case.add_input(query); From aeea6bc4dd5825752f9647adcf849c3aa12cd7c9 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Sun, 23 Feb 2025 21:38:34 +0800 Subject: [PATCH 15/15] Fix exception during onnx tests exiting --- .../group_query_attention_decomposition.cpp | 4 +- .../gqa_past_1_input_1_rotary.prototxt | 231 ------------------ ...past_1_input_1_rotary_interleaved.prototxt | 231 ------------------ ..._1_rotary.prototxt => gqa_rotary.prototxt} | 109 +++------ ...ototxt => gqa_rotary_interleaved.prototxt} | 109 +++------ .../tests/onnx_import_com_microsoft.in.cpp | 56 ++--- 6 files changed, 85 insertions(+), 655 deletions(-) delete mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt delete mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt rename src/frontends/onnx/tests/models/com.microsoft/{gqa_past_0_input_1_rotary.prototxt => gqa_rotary.prototxt} (60%) rename src/frontends/onnx/tests/models/com.microsoft/{gqa_past_0_input_1_rotary_interleaved.prototxt => gqa_rotary_interleaved.prototxt} (60%) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 8c0b7566f82248..b2f3d8566843e2 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -134,7 +134,6 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( const size_t kv_num_heads_factor = num_heads / kv_num_heads; if (kv_num_heads_factor > 1) { const auto kv_shape = std::make_shared(K); - // (batch_size, num_heads, sequence_len, head_size) const auto kv_shape_prev_2 = detail::get_dimensions(kv_shape, {0, 1}); const auto kv_shape_last_2 = detail::get_dimensions(kv_shape, {2, 3}); auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); @@ -143,7 +142,6 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); auto q_shape = std::make_shared(Q); - // (batch_size, num_heads, sequence_len, head_size) const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); K = std::make_shared(K, extended_kv_shape, false); @@ -204,7 +202,7 @@ ov::OutputVector make_split(const ov::Output& value, int64_t num_split std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { using namespace ov::op; - static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); return std::make_shared(shape, dims_const, zero); } diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt deleted file mode 100644 index f3f8a82c00bb03..00000000000000 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt +++ /dev/null @@ -1,231 +0,0 @@ -ir_version: 10 -graph { - node { - input: "query" - input: "" - input: "" - input: "past_key" - input: "past_value" - input: "seqlens_k" - input: "" - input: "cos_cache" - input: "sin_cache" - output: "output" - output: "present_key" - output: "present_value" - name: "GroupQueryAttention_0" - op_type: "GroupQueryAttention" - attribute { - name: "do_rotary" - i: 1 - type: INT - } - attribute { - name: "kv_num_heads" - i: 1 - type: INT - } - attribute { - name: "local_window_size" - i: -1 - type: INT - } - attribute { - name: "num_heads" - i: 2 - type: INT - } - attribute { - name: "rotary_interleaved" - i: 0 - type: INT - } - attribute { - name: "smooth_softmax" - i: 0 - type: INT - } - attribute { - name: "softcap" - f: 0 - type: FLOAT - } - domain: "com.microsoft" - } - name: "GroupQueryAttention_Graph" - input { - name: "query" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 64 - } - } - } - } - } - input { - name: "past_key" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } - } - } - } - } - input { - name: "past_value" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } - } - } - } - } - input { - name: "seqlens_k" - type { - tensor_type { - elem_type: 6 - shape { - dim { - dim_value: 1 - } - } - } - } - } - input { - name: "cos_cache" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 2 - } - dim { - dim_value: 8 - } - } - } - } - } - input { - name: "sin_cache" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 2 - } - dim { - dim_value: 8 - } - } - } - } - } - output { - name: "output" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 32 - } - } - } - } - } - output { - name: "present_key" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 2 - } - dim { - dim_value: 16 - } - } - } - } - } - output { - name: "present_value" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 2 - } - dim { - dim_value: 16 - } - } - } - } - } -} -opset_import { - domain: "" - version: 21 -} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt deleted file mode 100644 index 9aad910b32e0bd..00000000000000 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt +++ /dev/null @@ -1,231 +0,0 @@ -ir_version: 10 -graph { - node { - input: "query" - input: "" - input: "" - input: "past_key" - input: "past_value" - input: "seqlens_k" - input: "" - input: "cos_cache" - input: "sin_cache" - output: "output" - output: "present_key" - output: "present_value" - name: "GroupQueryAttention_0" - op_type: "GroupQueryAttention" - attribute { - name: "do_rotary" - i: 1 - type: INT - } - attribute { - name: "kv_num_heads" - i: 1 - type: INT - } - attribute { - name: "local_window_size" - i: -1 - type: INT - } - attribute { - name: "num_heads" - i: 2 - type: INT - } - attribute { - name: "rotary_interleaved" - i: 1 - type: INT - } - attribute { - name: "smooth_softmax" - i: 0 - type: INT - } - attribute { - name: "softcap" - f: 0 - type: FLOAT - } - domain: "com.microsoft" - } - name: "GroupQueryAttention_Graph" - input { - name: "query" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 64 - } - } - } - } - } - input { - name: "past_key" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } - } - } - } - } - input { - name: "past_value" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } - } - } - } - } - input { - name: "seqlens_k" - type { - tensor_type { - elem_type: 6 - shape { - dim { - dim_value: 1 - } - } - } - } - } - input { - name: "cos_cache" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 2 - } - dim { - dim_value: 8 - } - } - } - } - } - input { - name: "sin_cache" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 2 - } - dim { - dim_value: 8 - } - } - } - } - } - output { - name: "output" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 32 - } - } - } - } - } - output { - name: "present_key" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 2 - } - dim { - dim_value: 16 - } - } - } - } - } - output { - name: "present_value" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 2 - } - dim { - dim_value: 16 - } - } - } - } - } -} -opset_import { - domain: "" - version: 21 -} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary.prototxt similarity index 60% rename from src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt rename to src/frontends/onnx/tests/models/com.microsoft/gqa_rotary.prototxt index a105b7c03c17f0..179e7935d7dcd6 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary.prototxt @@ -59,15 +59,9 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 64 - } + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 64 } } } } @@ -78,18 +72,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 0 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } } } } @@ -100,18 +86,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 0 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } } } } @@ -122,9 +100,8 @@ graph { tensor_type { elem_type: 6 shape { - dim { - dim_value: 1 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } } } } @@ -135,12 +112,8 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 8 - } + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } } } } @@ -151,12 +124,8 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 8 - } + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } } } } @@ -167,15 +136,9 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 32 - } + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 32 } } } } @@ -186,18 +149,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } } } } @@ -208,18 +163,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } } } } diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary_interleaved.prototxt similarity index 60% rename from src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt rename to src/frontends/onnx/tests/models/com.microsoft/gqa_rotary_interleaved.prototxt index 6b5748b653f0fc..9186bafc770c65 100644 --- a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary_interleaved.prototxt @@ -59,15 +59,9 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 64 - } + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 64 } } } } @@ -78,18 +72,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 0 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } } } } @@ -100,18 +86,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 0 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } } } } @@ -122,9 +100,8 @@ graph { tensor_type { elem_type: 6 shape { - dim { - dim_value: 1 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } } } } @@ -135,12 +112,8 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 8 - } + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } } } } @@ -151,12 +124,8 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 8 - } + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } } } } @@ -167,15 +136,9 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 32 - } + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 32 } } } } @@ -186,18 +149,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } } } } @@ -208,18 +163,10 @@ graph { tensor_type { elem_type: 1 shape { - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 1 - } - dim { - dim_value: 16 - } + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } } } } diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 46e1b7ec1253be..764e4d8ca12956 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1742,7 +1742,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_bias_add) { } OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { - const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary.onnx"); + const auto model = convert_model("com.microsoft/gqa_rotary.onnx"); std::vector query = { -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, @@ -1816,12 +1816,12 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); - test_case.add_input(query); - test_case.add_input(past_key); - test_case.add_input(past_value); - test_case.add_input(seqlens_k); - test_case.add_input(cos_cache); - test_case.add_input(sin_cache); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 0, 16}, past_key); + test_case.add_input(Shape{1, 1, 0, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{1, 8}, cos_cache); + test_case.add_input(Shape{1, 8}, sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); @@ -1829,7 +1829,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { } OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) { - const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary_interleaved.onnx"); + const auto model = convert_model("com.microsoft/gqa_rotary_interleaved.onnx"); std::vector query = { -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, @@ -1903,12 +1903,12 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); - test_case.add_input(query); - test_case.add_input(past_key); - test_case.add_input(past_value); - test_case.add_input(seqlens_k); - test_case.add_input(cos_cache); - test_case.add_input(sin_cache); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 0, 16}, past_key); + test_case.add_input(Shape{1, 1, 0, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{1, 8}, cos_cache); + test_case.add_input(Shape{1, 8}, sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); @@ -1916,7 +1916,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) } OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { - const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary.onnx"); + const auto model = convert_model("com.microsoft/gqa_rotary.onnx"); std::vector query = { -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, @@ -2018,12 +2018,12 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { -0.9215f, -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); - test_case.add_input(query); - test_case.add_input(past_key); - test_case.add_input(past_value); - test_case.add_input(seqlens_k); - test_case.add_input(cos_cache); - test_case.add_input(sin_cache); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 1, 16}, past_key); + test_case.add_input(Shape{1, 1, 1, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{2, 8}, cos_cache); + test_case.add_input(Shape{2, 8}, sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); @@ -2031,7 +2031,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { } OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) { - const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary_interleaved.onnx"); + const auto model = convert_model("com.microsoft/gqa_rotary_interleaved.onnx"); std::vector query = { -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, @@ -2133,12 +2133,12 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) -0.9215f, -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; auto test_case = ov::test::TestCase(model, s_device); - test_case.add_input(query); - test_case.add_input(past_key); - test_case.add_input(past_value); - test_case.add_input(seqlens_k); - test_case.add_input(cos_cache); - test_case.add_input(sin_cache); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 1, 16}, past_key); + test_case.add_input(Shape{1, 1, 1, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{2, 8}, cos_cache); + test_case.add_input(Shape{2, 8}, sin_cache); test_case.add_expected_output(Shape{1, 1, 32}, expected_output); test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value);