Skip to content

Commit

Permalink
[Frontend] Added support for HF's new continue_final_message parame…
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljannai21 authored and liuyanyi committed Oct 6, 2024
1 parent 593a52b commit 8b228f5
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 31 deletions.
30 changes: 23 additions & 7 deletions tests/entrypoints/openai/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,28 @@

# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""")
What is the capital of"""),
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of"""),
]

TEST_MESSAGES = [
Expand All @@ -42,6 +50,10 @@
'content': 'What is the capital of'
},
]
ASSISTANT_MESSAGE_TO_CONTINUE = {
'role': 'assistant',
'content': 'The capital of'
}


def test_load_chat_template():
Expand Down Expand Up @@ -73,26 +85,30 @@ def test_no_load_chat_template_literallike():


@pytest.mark.parametrize(
"model,template,add_generation_prompt,expected_output",
"model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATON_OUTPUT)
def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
continue_final_message, expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
template_content = load_chat_template(chat_template=template)

# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES,
add_generation_prompt=add_generation_prompt)
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
if continue_final_message else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)

# Call the function and get the result
result = apply_hf_chat_template(
tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
)

# Test assertion
Expand Down
56 changes: 34 additions & 22 deletions tests/entrypoints/openai/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,28 +104,40 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
"role": "user",
"content": "Can I ask a question? vllm1"
}]

prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
conversation=conversation,
tokenize=False)
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
json={
"add_generation_prompt":
add_generation,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name
})
response.raise_for_status()

assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
for continue_final in [False, True]:
if add_generation and continue_final:
continue
if continue_final:
conversation.append({
"role": "assistant",
"content": "Sure,"
})

prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
continue_final_message=continue_final,
conversation=conversation,
tokenize=False)
tokens = tokenizer.encode(prompt,
add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
json={
"add_generation_prompt":
add_generation,
"continue_final_message":
continue_final,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name
})
response.raise_for_status()

assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}


@pytest.mark.asyncio
Expand Down
8 changes: 8 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,14 @@ def apply_mistral_chat_template(
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")

return tokenizer.apply_chat_template(
messages=messages,
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def chat(
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None,
) -> List[RequestOutput]:
"""
Expand Down Expand Up @@ -528,6 +529,9 @@ def chat(
If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
Returns:
A list of ``RequestOutput`` objects containing the generated
Expand Down Expand Up @@ -559,6 +563,7 @@ def chat(
messages=msgs,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
else:
Expand All @@ -567,6 +572,7 @@ def chat(
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)

Expand Down
28 changes: 28 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
Expand Down Expand Up @@ -431,6 +440,15 @@ def check_tool_usage(cls, data):
" of the specified `tools`")
return data

@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data


class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
Expand Down Expand Up @@ -862,8 +880,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
messages: List[ChatCompletionMessageParam]

add_generation_prompt: bool = Field(default=True)
continue_final_message: bool = Field(default=False)
add_special_tokens: bool = Field(default=False)

@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data


TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]

Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def create_chat_completion(
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
Expand All @@ -150,6 +151,7 @@ async def create_chat_completion(
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
Expand Down Expand Up @@ -371,7 +373,7 @@ async def chat_completion_stream_generator(

# Send response to echo the input portion of the
# last message
if request.echo:
if request.echo or request.continue_final_message:
last_msg_content: str = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
Expand Down Expand Up @@ -731,7 +733,7 @@ async def chat_completion_full_generator(
stop_reason=output.stop_reason)
choices.append(choice_data)

if request.echo:
if request.echo or request.continue_final_message:
last_msg_content = ""
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ async def create_tokenize(
messages=request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
)
else:
prompt = request.prompt
Expand Down

0 comments on commit 8b228f5

Please sign in to comment.