Skip to content

Commit

Permalink
Remove redundant type check
Browse files Browse the repository at this point in the history
  • Loading branch information
wine99 committed Feb 7, 2025
1 parent f11cb98 commit 250d228
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
out
# build/artifact dirs
_*
[Bb]uild*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::op::v3::Shape
const std::vector<int>& dims);
ov::OutputVector make_split(const ov::Output<ov::Node>& value, int64_t num_splits, int64_t axis);
ov::OutputVector make_split(const ov::Output<ov::Node>& value, const std::vector<int64_t>& split_lengths, int64_t axis);
std::shared_ptr<ov::Node> create_minus_inf(const ov::element::Type& T);

ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() {
MATCHER_SCOPE(GroupQeuryAttentionDecomposition);
Expand Down Expand Up @@ -191,7 +190,12 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
auto vert_range = std::make_shared<v0::Unsqueeze>(mask_per_line_node, one); // 12x1 or 13x1
auto triu = std::make_shared<v1::Greater>(hori_range, vert_range); // 12x12 or 13x13
auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0});
auto minus_inf = create_minus_inf(T);
// cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp
std::shared_ptr<ov::Node> minus_inf = nullptr;
if (T == ov::element::f32)
minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits<float>::infinity()});
else if (T == ov::element::f16)
minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits<ov::float16>::lowest()});
auto atten_mask = std::make_shared<v1::Select>(triu, minus_inf, typed_zero); // 12x12 or 13x13
auto atten_mask_sliced = std::make_shared<v8::Slice>(atten_mask,
past_sequence_length,
Expand Down Expand Up @@ -309,14 +313,3 @@ ov::OutputVector make_split(const ov::Output<ov::Node>& value, int64_t num_split

return split->outputs();
}

std::shared_ptr<ov::Node> create_minus_inf(const ov::element::Type& T) {
// cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp
if (T == ov::element::f32) {
return ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits<float>::infinity()});
} else if (T == ov::element::f16) {
return ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits<ov::float16>::lowest()});
} else {
OPENVINO_THROW("GroupQueryAttention only supports f32 and f16");
}
}

0 comments on commit 250d228

Please sign in to comment.