-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Group Query Attention support with OV base OPs #28163
base: master
Are you sure you want to change the base?
Changes from all commits
32dc9a3
def6bd1
cc7aa0d
e2897af
0956337
3e24958
da868f3
76c18fa
b8005ea
f8a363a
733076a
31d5431
e4838e1
157103b
aeea6bc
d0fd748
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (C) 2018-2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/group_query_attention.hpp" | ||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API GroupQueryAttentionDecomposition; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_MATCHER_PASS_RTTI("GroupQueryAttentionDecomposition"); | ||
GroupQueryAttentionDecomposition(); | ||
|
||
private: | ||
ov::OutputVector decompose(std::shared_ptr<ov::op::GroupQueryAttention> node); | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
// Copyright (C) 2018-2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/op_conversions/group_query_attention_decomposition.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/gather.hpp" | ||
#include "openvino/op/greater.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/range.hpp" | ||
#include "openvino/op/reshape.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/op/select.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
#include "openvino/op/slice.hpp" | ||
#include "openvino/op/split.hpp" | ||
#include "openvino/op/subtract.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/unsqueeze.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
|
||
namespace ov { | ||
namespace detail { | ||
namespace { | ||
std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<op::v3::ShapeOf>& shape, const std::vector<int>& dims); | ||
std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node>& node, const std::vector<int>& dims); | ||
ov::OutputVector make_split(const ov::Output<ov::Node>& value, int64_t num_splits, int64_t axis); | ||
std::shared_ptr<ov::Node> rotaryEmbedding(ov::Output<ov::Node> input, | ||
ov::Output<ov::Node> past_seqlen, | ||
std::shared_ptr<ov::Node> seqlen_k, | ||
std::shared_ptr<ov::Node> cos_cache, | ||
std::shared_ptr<ov::Node> sin_cache, | ||
std::shared_ptr<ov::Node> dim_head_size, | ||
bool interleaved); | ||
} // namespace | ||
} // namespace detail | ||
} // namespace ov | ||
|
||
ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { | ||
MATCHER_SCOPE(GroupQeuryAttentionDecomposition); | ||
auto pattern_node = ov::pass::pattern::wrap_type<ov::op::GroupQueryAttention>(); | ||
|
||
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { | ||
auto& pattern_to_output = m.get_pattern_value_map(); | ||
auto node = | ||
ov::as_type_ptr<ov::op::GroupQueryAttention>(pattern_to_output.at(pattern_node).get_node_shared_ptr()); | ||
|
||
if (node == nullptr || transformation_callback(node)) { | ||
return false; | ||
} | ||
|
||
auto new_output_node = decompose(node); | ||
ov::replace_node(node, new_output_node); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(pattern_node, matcher_name); | ||
register_matcher(m, callback); | ||
} | ||
|
||
ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( | ||
std::shared_ptr<ov::op::GroupQueryAttention> node) { | ||
using namespace ov::op; | ||
|
||
const auto num_heads = node->get_num_heads(); | ||
const auto kv_num_heads = node->get_kv_num_heads(); | ||
const auto scale = node->get_scale(); | ||
const auto do_rotary = node->get_do_rotary(); | ||
const auto rotary_interleaved = node->get_rotary_interleaved(); | ||
// TODO: add softcap support | ||
|
||
auto Q = node->input_value(0); | ||
auto K = node->input_value(1); | ||
auto V = node->input_value(2); | ||
auto past_key = node->input_value(3); | ||
auto past_value = node->input_value(4); | ||
auto seqlens_k = node->input_value(5); | ||
auto cos_cache = node->input_value(6); | ||
auto sin_cache = node->input_value(7); | ||
|
||
// The length of all tokens (past + current) is `seqlens_k` + 1 | ||
// current = Q.shape[2], past = `seqlens_k` + 1 - current | ||
|
||
const auto T = Q.get_element_type(); | ||
const auto q_shape = std::make_shared<v3::ShapeOf>(Q); | ||
const auto current_sequence_length = detail::get_dimensions(q_shape, {2}); | ||
auto head_size_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {node->get_head_size()}); | ||
|
||
auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); | ||
auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); | ||
auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); | ||
auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); | ||
auto seqlens_elemi64 = std::make_shared<v0::Convert>(seqlens_k, ov::element::i64); | ||
auto real_seqlens = std::make_shared<v1::Add>(seqlens_elemi64, one); | ||
|
||
// Only consider batch is 1 | ||
auto seqlens_1d = std::make_shared<v1::Reshape>(real_seqlens, one, false); | ||
auto past_sequence_length = std::make_shared<v1::Subtract>(seqlens_1d, current_sequence_length); | ||
if (do_rotary) { | ||
Q = detail::rotaryEmbedding(Q, | ||
past_sequence_length, | ||
seqlens_1d, | ||
cos_cache.get_node_shared_ptr(), | ||
sin_cache.get_node_shared_ptr(), | ||
head_size_node, | ||
rotary_interleaved); | ||
K = detail::rotaryEmbedding(K, | ||
past_sequence_length, | ||
seqlens_1d, | ||
cos_cache.get_node_shared_ptr(), | ||
sin_cache.get_node_shared_ptr(), | ||
head_size_node, | ||
rotary_interleaved); | ||
} | ||
|
||
auto construct_kv_cache = [&](const ov::Output<ov::Node>& past, const ov::Output<ov::Node>& current) { | ||
auto past_datas = std::make_shared<v8::Slice>(past, zero, past_sequence_length, one, two); | ||
auto curr_datas = std::make_shared<v8::Slice>(current, zero, current_sequence_length, one, two); | ||
return std::make_shared<v0::Concat>(ov::NodeVector{past_datas, curr_datas}, 2); | ||
}; | ||
K = construct_kv_cache(past_key, K); | ||
V = construct_kv_cache(past_value, V); | ||
auto present_k = K; | ||
auto present_v = V; | ||
|
||
const size_t kv_num_heads_factor = num_heads / kv_num_heads; | ||
if (kv_num_heads_factor > 1) { | ||
const auto kv_shape = std::make_shared<v3::ShapeOf>(K); | ||
const auto kv_shape_prev_2 = detail::get_dimensions(kv_shape, {0, 1}); | ||
const auto kv_shape_last_2 = detail::get_dimensions(kv_shape, {2, 3}); | ||
auto new_kv_shape = std::make_shared<v0::Concat>(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); | ||
K = std::make_shared<v1::Reshape>(K, new_kv_shape, false); | ||
V = std::make_shared<v1::Reshape>(V, new_kv_shape, false); | ||
K = std::make_shared<v0::Concat>(ov::OutputVector(kv_num_heads_factor, K), 2); | ||
V = std::make_shared<v0::Concat>(ov::OutputVector(kv_num_heads_factor, V), 2); | ||
auto q_shape = std::make_shared<v3::ShapeOf>(Q); | ||
const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); | ||
auto extended_kv_shape = std::make_shared<v0::Concat>(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); | ||
K = std::make_shared<v1::Reshape>(K, extended_kv_shape, false); | ||
V = std::make_shared<v1::Reshape>(V, extended_kv_shape, false); | ||
} | ||
|
||
// need to apply low-triangle mask to attention score. | ||
// two steps, construct the total_sequence x total_sequence triangle, then slice the current length | ||
auto seqlens_1d_scalar = std::make_shared<v1::Reshape>(seqlens_1d, one_without_shape, false); | ||
std::shared_ptr<ov::Node> mask_per_line_node = | ||
std::make_shared<v4::Range>(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), | ||
seqlens_1d_scalar, | ||
one_without_shape, | ||
ov::element::i64); | ||
auto hori_range = std::make_shared<v0::Unsqueeze>(mask_per_line_node, zero); | ||
auto vert_range = std::make_shared<v0::Unsqueeze>(mask_per_line_node, one); | ||
auto triu = std::make_shared<v1::Greater>(hori_range, vert_range); | ||
auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0}); | ||
// cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp | ||
std::shared_ptr<ov::Node> minus_inf = nullptr; | ||
if (T == ov::element::f32) | ||
minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits<float>::infinity()}); | ||
else if (T == ov::element::f16) | ||
minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits<ov::float16>::lowest()}); | ||
auto atten_mask = std::make_shared<v1::Select>(triu, minus_inf, typed_zero); | ||
auto atten_mask_sliced = std::make_shared<v8::Slice>(atten_mask, past_sequence_length, seqlens_1d, one, zero); | ||
|
||
std::shared_ptr<ov::Node> qga_output; | ||
if (scale != 0.0f) { | ||
auto scale_node = v0::Constant::create(T, Shape{}, {scale}); | ||
qga_output = std::make_shared<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, scale_node, false); | ||
} else { | ||
qga_output = std::make_shared<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, false); | ||
} | ||
|
||
// transpose the result from (batch_size, num_heads, sequence_length, head_size) | ||
// to (batch_size, sequence_length, num_heads * head_size) | ||
auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); | ||
auto qga_output_transposed = std::make_shared<v1::Transpose>(qga_output, perm); | ||
auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); | ||
auto output = std::make_shared<v1::Reshape>(qga_output_transposed, dim_merge_shape, true)->output(0); | ||
|
||
return {output, present_k, present_v}; | ||
} | ||
|
||
namespace ov { | ||
namespace detail { | ||
namespace { | ||
// make split functions is a copy-past from ONNX FE. TODO: move it to one place | ||
ov::OutputVector make_split(const ov::Output<ov::Node>& value, int64_t num_splits, int64_t axis) { | ||
using namespace ov::op; | ||
const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); | ||
const auto split = std::make_shared<v1::Split>(value, axis_node, num_splits); | ||
|
||
return split->outputs(); | ||
} | ||
|
||
std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::op::v3::ShapeOf>& shape, | ||
const std::vector<int>& dims) { | ||
using namespace ov::op; | ||
const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); | ||
const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); | ||
return std::make_shared<v8::Gather>(shape, dims_const, zero); | ||
} | ||
|
||
std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node>& node, const std::vector<int>& dims) { | ||
return get_dimensions(std::make_shared<ov::op::v3::ShapeOf>(node), dims); | ||
} | ||
|
||
std::shared_ptr<ov::Node> rotaryEmbedding(ov::Output<ov::Node> input, | ||
ov::Output<ov::Node> past_seqlen, | ||
std::shared_ptr<ov::Node> seqlen_k, | ||
std::shared_ptr<ov::Node> cos_cache, | ||
std::shared_ptr<ov::Node> sin_cache, | ||
std::shared_ptr<ov::Node> dim_head_size, | ||
bool interleaved) { | ||
using namespace ov::op; | ||
auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); | ||
auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); | ||
|
||
auto slice_cache_dim_shape = seqlen_k; | ||
|
||
auto cos = std::make_shared<v8::Slice>(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); | ||
auto sin = std::make_shared<v8::Slice>(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); | ||
|
||
if (interleaved) { | ||
auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); | ||
|
||
auto cache_shape = std::make_shared<v3::ShapeOf>(cos_cache); | ||
auto cache_last_dim = get_dimensions(cos_cache, {-1}); | ||
|
||
auto input_shape = std::make_shared<v3::ShapeOf>(input); | ||
|
||
auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); | ||
std::shared_ptr<ov::Node> half_last_dim = cache_last_dim; | ||
|
||
auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); | ||
auto split_input_shape = std::make_shared<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, two}, 0); | ||
auto reshaped_input = std::make_shared<v1::Reshape>(input, split_input_shape, false); | ||
|
||
auto in_split = make_split(reshaped_input, 2, -1); | ||
split_input_shape = std::make_shared<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim}, 0); | ||
auto in_split_0 = std::make_shared<v1::Reshape>(in_split[0], split_input_shape, false); | ||
auto in_split_1 = std::make_shared<v1::Reshape>(in_split[1], split_input_shape, false); | ||
|
||
auto res_0 = std::make_shared<v1::Subtract>(std::make_shared<v1::Multiply>(in_split_0, cos), | ||
std::make_shared<v1::Multiply>(in_split_1, sin)); | ||
auto res_1 = std::make_shared<v1::Add>(std::make_shared<v1::Multiply>(in_split_0, sin), | ||
std::make_shared<v1::Multiply>(in_split_1, cos)); | ||
|
||
split_input_shape = std::make_shared<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, one}, 0); | ||
auto res_0_5d = std::make_shared<v1::Reshape>(res_0, split_input_shape, false); | ||
auto res_1_5d = std::make_shared<v1::Reshape>(res_1, split_input_shape, false); | ||
|
||
auto concat_ret = std::make_shared<v0::Concat>(ov::NodeVector{res_0_5d, res_1_5d}, -1); | ||
return std::make_shared<v1::Reshape>(concat_ret, input_shape, false); | ||
} else { | ||
auto in_split = make_split(input, 2, -1); | ||
auto res_0 = std::make_shared<v1::Subtract>(std::make_shared<v1::Multiply>(in_split[0], cos), | ||
std::make_shared<v1::Multiply>(in_split[1], sin)); | ||
auto res_1 = std::make_shared<v1::Add>(std::make_shared<v1::Multiply>(in_split[0], sin), | ||
std::make_shared<v1::Multiply>(in_split[1], cos)); | ||
|
||
return std::make_shared<v0::Concat>(ov::NodeVector{res_0, res_1}, -1); | ||
} | ||
} | ||
} // namespace | ||
} // namespace detail | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Copyright (C) 2018-2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
|
||
// This is an experimental operation that is implemented in the plugins. | ||
class OPENVINO_API GroupQueryAttention : public Op { | ||
Comment on lines
+11
to
+12
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any plugin able to support this GroupQueryAttention class right now or the decomposition to ScaleDotProductAttention is always needed and applied? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK no plugin has GQA kernels so the decomposition is always needed @sgbihu |
||
public: | ||
OPENVINO_OP("GroupQueryAttention"); | ||
|
||
GroupQueryAttention() = default; | ||
GroupQueryAttention(const ov::OutputVector& args, | ||
int64_t num_heads, | ||
int64_t kv_num_heads, | ||
float scale, | ||
bool do_rotary, | ||
bool rotary_interleaved); | ||
void validate_and_infer_types() override; | ||
bool visit_attributes(AttributeVisitor& visitor) override; | ||
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
int64_t get_num_heads() const { | ||
return m_num_heads; | ||
} | ||
int64_t get_kv_num_heads() const { | ||
return m_kv_num_heads; | ||
} | ||
float get_scale() const { | ||
return m_scale; | ||
} | ||
bool get_do_rotary() const { | ||
return m_do_rotary; | ||
} | ||
bool get_rotary_interleaved() const { | ||
return m_rotary_interleaved; | ||
} | ||
int64_t get_head_size() const { | ||
return m_head_size; | ||
} | ||
|
||
private: | ||
int64_t m_num_heads; | ||
int64_t m_kv_num_heads; | ||
float m_scale = 0; | ||
bool m_do_rotary = false; | ||
bool m_rotary_interleaved = false; | ||
int64_t m_head_size; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace ov |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current approach within the transformations is to add every node to the NodeRegistry like:
openvino/src/common/transformations/src/transformations/op_conversions/scaled_dot_product_attention_decomposition.cpp
Lines 63 to 66 in 9980e86
It is used to copy runtime info before replacement:
openvino/src/common/transformations/src/transformations/op_conversions/scaled_dot_product_attention_decomposition.cpp
Lines 148 to 151 in 9980e86
cc: @itikhono
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I make this change in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've already approved, so I don't force this change in this PR (but it should be applied it as a follow up).
@itikhono Do you consider it as a blocker?