Skip to content

Commit

Permalink
Apply comment
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Sep 20, 2024
1 parent 7d4cc49 commit e741a2c
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token

return {new_input_ids, new_attention_mask};
}

void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model) {
auto last_node = model->output(0).get_node()->input_value(0).get_node();
auto matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node);
OPENVINO_ASSERT(matmul, "Cannot find matmul op.");
if (matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-2});
auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
auto slice = std::make_shared<ov::op::v8::Slice>(matmul->input_value(0), start, stop, step, axis);
matmul->input(0).replace_source_output(slice);
}
}
}

namespace ov {
Expand Down Expand Up @@ -68,22 +82,6 @@ std::pair<EncodedResults, int32_t> beam_search(
);

class StatefulLLMPipeline final : public LLMPipelineImplBase {
private:
void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model) {
auto last_node = model->output(0).get_node()->input_value(0).get_node();
if (auto matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node)) {
auto shape = matmul->input(0).get_partial_shape();
if (shape.rank().get_length() == 3 && shape[1] != 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-2});
auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
auto slice = std::make_shared<ov::op::v8::Slice>(matmul->input_value(0), start, stop, step, axis);
matmul->input(0).replace_source_output(slice);
}
}
}

public:
ov::InferRequest m_model_runner;

Expand Down

0 comments on commit e741a2c

Please sign in to comment.