Skip to content

Commit

Permalink
[TRANSFORMATIONS] Extend PositionIDsReplacer pattern (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#24890)

Extend PositionIDsReplacer pattern to support more models:
 - facebook/opt-350m

Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>

### Tickets:
 - CVS-143065

---------

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
  • Loading branch information
CuriousPanCake and itikhono authored Jun 8, 2024
1 parent f6e6f2a commit 6e67110
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

Expand All @@ -26,7 +28,9 @@ ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output<Node>& position_
auto convert = pattern::wrap_type<v0::Convert>({add_offset});
auto position_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), convert, pattern::any_input()});

auto add = pattern::wrap_type<v1::Add>({input_embed, position_embed});
auto mul = pattern::optional<v0::MatMul>({input_embed, pattern::any_input()});

auto add = pattern::wrap_type<v1::Add>({mul, position_embed});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand Down

0 comments on commit 6e67110

Please sign in to comment.