Skip to content

Commit

Permalink
add test and fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Feb 7, 2025
1 parent 5572219 commit 77c2688
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 24 deletions.
27 changes: 9 additions & 18 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ EncodedResults StatefulLLMPipeline::generate(
OPENVINO_ASSERT(m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user",
"Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat.");

if (!is_chat_conversation) {
reset_kv_state();
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
}

auto start_time = std::chrono::steady_clock::now();
ov::Tensor input_ids;
ov::Tensor attention_mask;
Expand Down Expand Up @@ -354,7 +359,6 @@ EncodedResults StatefulLLMPipeline::generate(

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
} else {
reset_kv_state();
m_last_disappeared_token = std::nullopt;
}

Expand All @@ -373,22 +377,8 @@ EncodedResults StatefulLLMPipeline::generate(
}

void StatefulLLMPipeline::start_chat(const std::string& system_message) {
bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size();
if (have_state) {
m_model_runner.reset_state();
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
}
is_chat_conversation = true;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history = {};
m_templated_chat_history.clear();
m_tokenized_chat_history.clear();
}
finish_chat();

if (system_message.empty())
return;

Expand Down Expand Up @@ -416,7 +406,8 @@ void StatefulLLMPipeline::finish_chat() {
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size();
if (!m_tokenized_chat_history.empty() || have_state) {
reset_kv_state();
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
m_history.clear();
Expand Down
68 changes: 62 additions & 6 deletions tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def test_empty_encoded_inputs_throw():
# Chat scenario
#

generation_configs = [
dict(max_new_tokens=20),
dict(max_new_tokens=10, num_beam_groups=3, num_beams=15, num_return_sequences=1, diversity_penalty=1.0)
chat_intpus = [
(dict(max_new_tokens=20), ""),
(dict(max_new_tokens=20), "You are a helpful assistant."),
(dict(max_new_tokens=10, num_beam_groups=3, num_beams=15, num_return_sequences=1, diversity_penalty=1.0), "")
]

questions = [
Expand All @@ -121,20 +122,24 @@ def test_empty_encoded_inputs_throw():
'What was my first question?'
]

@pytest.mark.parametrize("generation_config_kwargs", generation_configs)
@pytest.mark.parametrize("intpus", chat_intpus)
@pytest.mark.parametrize("model_descr", get_chat_models_list())
@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_scenario(model_descr, generation_config_kwargs: Dict):
def test_chat_scenario(model_descr, intpus):
chat_history_hf = []
chat_history_ov = []

model_id, path, tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1]))

generation_config_kwargs, system_massage = intpus

ov_generation_config = GenerationConfig(**generation_config_kwargs)
hf_generation_config = convert_to_hf(opt_model.generation_config, ov_generation_config)

ov_pipe.start_chat()
ov_pipe.start_chat(system_massage)
chat_history_hf.append({"role", "system"}, {"content", system_massage})
chat_history_ov.append({"role", "system"}, {"content", system_massage})
for prompt in questions:
chat_history_hf.append({'role': 'user', 'content': prompt})
chat_history_ov.append({'role': 'user', 'content': prompt})
Expand All @@ -159,6 +164,57 @@ def test_chat_scenario(model_descr, generation_config_kwargs: Dict):
assert chat_history_ov == chat_history_hf


@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_scenario_several_chats_in_series():
model_descr = get_chat_models_list()[0]
model_id, path, tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1]))

generation_config_kwargs, _ = chat_intpus[0]
ov_generation_config = GenerationConfig(**generation_config_kwargs)
hf_generation_config = convert_to_hf(opt_model.generation_config, ov_generation_config)

for i in range(2):
chat_history_hf = []
chat_history_ov = []
ov_pipe.start_chat()
for prompt in questions[:2]:
chat_history_hf.append({'role': 'user', 'content': prompt})
chat_history_ov.append({'role': 'user', 'content': prompt})

chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False)
prompt_len = tokenized['input_ids'].numel()

answer = opt_model.generate(**tokenized, generation_config=hf_generation_config).sequences[0]
answer_str = tokenizer.decode(answer[prompt_len:], skip_special_tokens=True)
chat_history_hf.append({'role': 'assistant', 'content': answer_str})

answer_ov = ov_pipe.generate(prompt, generation_config=ov_generation_config)
chat_history_ov.append({'role': 'assistant', 'content': answer_ov})

ov_pipe.finish_chat()

if chat_history_ov != chat_history_hf:
print(f'hf_output: {chat_history_hf}')
print(f'ov_output: {chat_history_ov}')

assert chat_history_ov == chat_history_hf


@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_scenario_several_start():
ov_pipe = read_model(get_chat_models_list()[0])[4]

generation_config_kwargs, _ = chat_intpus[0]
ov_generation_config = GenerationConfig(**generation_config_kwargs)

ov_pipe.start_chat()
ov_pipe.start_chat()
ov_pipe.generate(questions[0], generation_config=ov_generation_config)
ov_pipe.finish_chat()

#
# Streaming with callback
#
Expand Down

0 comments on commit 77c2688

Please sign in to comment.