From c726fe5f9ffb815e14276dc5671d62c8fd56b0ba Mon Sep 17 00:00:00 2001 From: Oleg Pipikin Date: Mon, 2 Sep 2024 14:00:34 +0000 Subject: [PATCH] Slice the last matmull in stateful llm pipeline --- src/cpp/src/greedy_decoding.cpp | 1 - src/cpp/src/llm_pipeline.cpp | 23 ++++++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/cpp/src/greedy_decoding.cpp b/src/cpp/src/greedy_decoding.cpp index 95a1843645..2f1ed3f89d 100644 --- a/src/cpp/src/greedy_decoding.cpp +++ b/src/cpp/src/greedy_decoding.cpp @@ -73,7 +73,6 @@ EncodedResults greedy_decoding( bool all_are_eos = std::all_of(eos_met.begin(), eos_met.end(), [](int elem) { return elem == 1; }); if (!generation_config.ignore_eos && all_are_eos) return results; - for (size_t i = 0; i < max_new_tokens - 1; ++i) { if (position_ids.has_value()) diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 66e2890671..60c9964238 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -16,6 +16,9 @@ #include "utils.hpp" #include "text_callback_streamer.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/slice.hpp" + namespace { ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& fisrt, const ov::genai::TokenizedInputs& second){ @@ -65,6 +68,22 @@ std::pair beam_search( ); class StatefulLLMPipeline final : public LLMPipelineImplBase { +private: + void slice_matmul_statefull_model(std::shared_ptr model) { + auto last_node = model->output(0).get_node()->input_value(0).get_node(); + if (auto matmul = dynamic_cast(last_node)) { + auto shape = matmul->input(0).get_partial_shape(); + if (shape.rank().get_length() == 3 && shape[1] != 1) { + auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-2}); + auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto slice = std::make_shared(matmul->input_value(0), start, stop, step, axis); + matmul->input(0).replace_source_output(slice); + } + } + } + public: ov::InferRequest m_model_runner; @@ -94,7 +113,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { { ov::Core core; core.set_property(device, plugin_config); - m_model_runner = core.compile_model(model_path / "openvino_model.xml", device).create_infer_request(); + auto model = core.read_model(model_path / "openvino_model.xml"); + slice_matmul_statefull_model(model); + m_model_runner = core.compile_model(model, device).create_infer_request(); // If eos_token_id was not provided, take value if (m_generation_config.eos_token_id == -1)