Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Feb 10, 2025
1 parent cb3a513 commit e002b73
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ EncodedResults StatefulLLMPipeline::generate(

void StatefulLLMPipeline::start_chat(const std::string& system_message) {
finish_chat();
is_chat_conversation = true;

if (system_message.empty())
return;
Expand Down
13 changes: 8 additions & 5 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,16 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
VLMPerfMetrics perf_metrics;
auto& raw_counters = perf_metrics.raw_metrics;
auto& raw_vlm_counters = perf_metrics.vlm_raw_metrics;

if (!m_is_chat_conversation) {
m_language.reset_state();
m_language.get_tensor("attention_mask").set_shape({1, 0});
}

// If stop_token_ids were not provided, take value from default m_generation_config
if (generation_config.stop_token_ids.empty())
generation_config.stop_token_ids = m_generation_config.stop_token_ids;

// If eos_token_id was not provided, take value from default m_generation_config
if (generation_config.eos_token_id == -1)
generation_config.set_eos_token_id(m_generation_config.eos_token_id);
Expand Down Expand Up @@ -224,12 +231,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size));

std::string decoded_results = decoded.texts.at(0);
if (m_is_chat_conversation) {
if (m_is_chat_conversation)
m_inputs_embedder->update_chat_history(decoded_results);
} else {
m_language.reset_state();
m_language.get_tensor("attention_mask").set_shape({1, 0});
}

auto generate_end_time = std::chrono::steady_clock::now();
decoded.perf_metrics = encoded_result.perf_metrics;
Expand Down
4 changes: 2 additions & 2 deletions tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def test_chat_scenario(model_descr, intpus):
hf_generation_config = convert_to_hf(opt_model.generation_config, ov_generation_config)

ov_pipe.start_chat(system_massage)
chat_history_hf.append({"role", "system", "content", system_massage})
chat_history_ov.append({"role", "system", "content", system_massage})
chat_history_hf.append({"role": "system", "content": system_massage})
chat_history_ov.append({"role": "system", "content": system_massage})
for prompt in questions:
chat_history_hf.append({'role': 'user', 'content': prompt})
chat_history_ov.append({'role': 'user', 'content': prompt})
Expand Down

0 comments on commit e002b73

Please sign in to comment.