Skip to content

Commit

Permalink
[GPU] Support GeLU Tanh for Phi-2 (#27213)
Browse files Browse the repository at this point in the history
### Details:
 - Previously GeLU Tanh was supported only for x * (0.5 * (1 + tanh))
 - Support pattern with (x * 0.5) * (1 + tanh)) too.

### Tickets:
 - 155576
  • Loading branch information
yeonbok authored Oct 28, 2024
1 parent 443078c commit 7a80fe8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "openvino/op/parameter.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/tanh.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

Expand Down Expand Up @@ -280,9 +281,16 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
auto add_1 = ov::pass::pattern::wrap_type<ov::op::v1::Add>({tanh, add_1_constant});

auto mul_2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto mul_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_constant});

auto mul_3 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2});
// x * (0.5 * (1 + tanh))
auto mul_2_1 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_constant});
auto mul_3_1 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2_1});

// (x * 0.5) * (1 + tanh)
auto mul_2_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_2_constant});
auto mul_3_2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add_1, mul_2_2});

auto mul_3 = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul_3_1, mul_3_2});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
Expand All @@ -298,7 +306,6 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(mul_2_constant).get_node_shared_ptr());
auto add_1_constant_value =
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(add_1_constant).get_node_shared_ptr());

if (!pow_constant_value || !add_1_constant_value || !mul_0_constant_value || !mul_1_constant_value ||
!mul_2_constant_value) {
return false;
Expand All @@ -318,18 +325,17 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output, op::GeluApproximationMode::TANH);

gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(
{
pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(mul_0).get_node_shared_ptr(),
pattern_to_output.at(mul_1).get_node_shared_ptr(),
pattern_to_output.at(mul_2).get_node_shared_ptr(),
pattern_to_output.at(mul_3).get_node_shared_ptr(),
pattern_to_output.at(tanh).get_node_shared_ptr(),
pattern_to_output.at(add_0).get_node_shared_ptr(),
pattern_to_output.at(add_1).get_node_shared_ptr(),
},
gelu);

std::vector<std::shared_ptr<ov::Node>> pattern_nodes =
{pow, mul_0, mul_1, tanh, add_0, add_1, mul_2_1, mul_2_2, mul_3_1, mul_3_2};
std::vector<std::shared_ptr<ov::Node>> cp_rt_info_nodes;
for (const auto& pattern_node : pattern_nodes) {
if (pattern_to_output.count(pattern_node)) {
cp_rt_info_nodes.push_back(pattern_to_output.at(pattern_node).get_node_shared_ptr());
}
}
ov::copy_runtime_info(cp_rt_info_nodes, gelu);

ov::replace_node(m.get_match_root(), gelu);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,44 @@ TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value) {
}
}

TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value_2) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
auto pow_constant =
std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f + 1.0e-8f});
auto pow = std::make_shared<ov::op::v1::Power>(input, pow_constant);
auto mul_0_constant =
std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<ov::op::v1::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<ov::op::v1::Add>(input, mul_0);

auto mul_1_constant =
std::make_shared<ov::op::v0::Constant>(element::f32,
Shape{1},
std::vector<float>{static_cast<float>(std::sqrt(2.0 / M_PI))});
auto mul_1 = std::make_shared<ov::op::v1::Multiply>(add_0, mul_1_constant);

auto tanh = std::make_shared<ov::op::v0::Tanh>(mul_1);

auto add_1_constant = std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<ov::op::v1::Add>(tanh, add_1_constant);

auto mul_2_constant = std::make_shared<ov::op::v0::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<ov::op::v1::Multiply>(input, mul_2_constant);

auto mul_3 = std::make_shared<ov::op::v1::Multiply>(add_1, mul_2);

model = std::make_shared<Model>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<ov::pass::GeluFusionWithTanh>();
}

{
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<ov::op::v7::Gelu>(data, op::GeluApproximationMode::TANH);
model_ref = std::make_shared<Model>(NodeVector{gelu}, ParameterVector{data});
}
}

TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_pow_value) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
Expand Down

0 comments on commit 7a80fe8

Please sign in to comment.