Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error after second start_chat() for StatefulLLMPipeline #1684

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ EncodedResults StatefulLLMPipeline::generate(
OPENVINO_ASSERT(m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user",
"Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat.");

if (!is_chat_conversation) {
reset_kv_state();
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
}

auto start_time = std::chrono::steady_clock::now();
ov::Tensor input_ids;
ov::Tensor attention_mask;
Expand Down Expand Up @@ -384,7 +389,6 @@ EncodedResults StatefulLLMPipeline::generate(
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
}
} else {
reset_kv_state();
m_last_disappeared_token = std::nullopt;
}

Expand All @@ -400,16 +404,9 @@ EncodedResults StatefulLLMPipeline::generate(
}

void StatefulLLMPipeline::start_chat(const std::string& system_message) {
finish_chat();
is_chat_conversation = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history = {};
m_templated_chat_history.clear();
m_tokenized_chat_history.clear();
}

if (system_message.empty())
return;

Expand All @@ -436,8 +433,10 @@ void StatefulLLMPipeline::finish_chat() {
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size();
if (!m_tokenized_chat_history.empty() || have_state) {
reset_kv_state();
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
m_history.clear();
m_templated_chat_history.clear();
m_tokenized_chat_history.clear();
Expand Down
13 changes: 8 additions & 5 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,16 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
VLMPerfMetrics perf_metrics;
auto& raw_counters = perf_metrics.raw_metrics;
auto& raw_vlm_counters = perf_metrics.vlm_raw_metrics;

if (!m_is_chat_conversation) {
m_language.reset_state();
m_language.get_tensor("attention_mask").set_shape({1, 0});
}

// If stop_token_ids were not provided, take value from default m_generation_config
if (generation_config.stop_token_ids.empty())
generation_config.stop_token_ids = m_generation_config.stop_token_ids;

// If eos_token_id was not provided, take value from default m_generation_config
if (generation_config.eos_token_id == -1)
generation_config.set_eos_token_id(m_generation_config.eos_token_id);
Expand Down Expand Up @@ -224,12 +231,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size));

std::string decoded_results = decoded.texts.at(0);
if (m_is_chat_conversation) {
if (m_is_chat_conversation)
m_inputs_embedder->update_chat_history(decoded_results);
} else {
m_language.reset_state();
m_language.get_tensor("attention_mask").set_shape({1, 0});
}

auto generate_end_time = std::chrono::steady_clock::now();
decoded.perf_metrics = encoded_result.perf_metrics;
Expand Down
68 changes: 62 additions & 6 deletions tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def test_empty_encoded_inputs_throw():
# Chat scenario
#

generation_configs = [
dict(max_new_tokens=20),
dict(max_new_tokens=10, num_beam_groups=3, num_beams=15, num_return_sequences=1, diversity_penalty=1.0)
chat_intpus = [
(dict(max_new_tokens=20), ""),
(dict(max_new_tokens=20), "You are a helpful assistant."),
(dict(max_new_tokens=10, num_beam_groups=3, num_beams=15, num_return_sequences=1, diversity_penalty=1.0), "")
]

questions = [
Expand All @@ -121,20 +122,24 @@ def test_empty_encoded_inputs_throw():
'What was my first question?'
]

@pytest.mark.parametrize("generation_config_kwargs", generation_configs)
@pytest.mark.parametrize("intpus", chat_intpus)
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_scenario(model_descr, generation_config_kwargs: Dict):
def test_chat_scenario(model_descr, intpus):
chat_history_hf = []
chat_history_ov = []

model_id, path, tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1]))

generation_config_kwargs, system_massage = intpus

ov_generation_config = GenerationConfig(**generation_config_kwargs)
hf_generation_config = convert_to_hf(opt_model.generation_config, ov_generation_config)

ov_pipe.start_chat()
ov_pipe.start_chat(system_massage)
chat_history_hf.append({"role": "system", "content": system_massage})
chat_history_ov.append({"role": "system", "content": system_massage})
for prompt in questions:
chat_history_hf.append({'role': 'user', 'content': prompt})
chat_history_ov.append({'role': 'user', 'content': prompt})
Expand All @@ -159,6 +164,57 @@ def test_chat_scenario(model_descr, generation_config_kwargs: Dict):
assert chat_history_ov == chat_history_hf


@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_scenario_several_chats_in_series():
model_descr = get_chat_models_list()[0]
model_id, path, tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1]))

generation_config_kwargs, _ = chat_intpus[0]
ov_generation_config = GenerationConfig(**generation_config_kwargs)
hf_generation_config = convert_to_hf(opt_model.generation_config, ov_generation_config)

for i in range(2):
chat_history_hf = []
chat_history_ov = []
ov_pipe.start_chat()
for prompt in questions[:2]:
chat_history_hf.append({'role': 'user', 'content': prompt})
chat_history_ov.append({'role': 'user', 'content': prompt})

chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False)
prompt_len = tokenized['input_ids'].numel()

answer = opt_model.generate(**tokenized, generation_config=hf_generation_config).sequences[0]
answer_str = tokenizer.decode(answer[prompt_len:], skip_special_tokens=True)
chat_history_hf.append({'role': 'assistant', 'content': answer_str})

answer_ov = ov_pipe.generate(prompt, generation_config=ov_generation_config)
chat_history_ov.append({'role': 'assistant', 'content': answer_ov})

ov_pipe.finish_chat()

if chat_history_ov != chat_history_hf:
print(f'hf_output: {chat_history_hf}')
print(f'ov_output: {chat_history_ov}')

assert chat_history_ov == chat_history_hf


@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_scenario_several_start():
ov_pipe = read_model(get_chat_models_list()[0])[4]

generation_config_kwargs, _ = chat_intpus[0]
ov_generation_config = GenerationConfig(**generation_config_kwargs)

ov_pipe.start_chat()
ov_pipe.start_chat()
ov_pipe.generate(questions[0], generation_config=ov_generation_config)
ov_pipe.finish_chat()

#
# Streaming with callback
#
Expand Down
Loading