From 8b228f5c4f100307b33d1a4a90271a5a7e8c3b8b Mon Sep 17 00:00:00 2001 From: danieljannai21 <100521221+danieljannai21@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:59:47 +0300 Subject: [PATCH] [Frontend] Added support for HF's new `continue_final_message` parameter (#8942) --- .../entrypoints/openai/test_chat_template.py | 30 +++++++--- tests/entrypoints/openai/test_tokenization.py | 56 +++++++++++-------- vllm/entrypoints/chat_utils.py | 8 +++ vllm/entrypoints/llm.py | 6 ++ vllm/entrypoints/openai/protocol.py | 28 ++++++++++ vllm/entrypoints/openai/serving_chat.py | 6 +- .../openai/serving_tokenization.py | 2 + 7 files changed, 105 insertions(+), 31 deletions(-) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index b98ab2e30d78d..e1e1dcff7475d 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -12,7 +12,7 @@ # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATON_OUTPUT = [ - ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user + ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> @@ -20,12 +20,20 @@ What is the capital of<|im_end|> <|im_start|>assistant """), - ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user + ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user -What is the capital of""") +What is the capital of"""), + ("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user +Hello<|im_end|> +<|im_start|>assistant +Hi there!<|im_end|> +<|im_start|>user +What is the capital of<|im_end|> +<|im_start|>assistant +The capital of"""), ] TEST_MESSAGES = [ @@ -42,6 +50,10 @@ 'content': 'What is the capital of' }, ] +ASSISTANT_MESSAGE_TO_CONTINUE = { + 'role': 'assistant', + 'content': 'The capital of' +} def test_load_chat_template(): @@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike(): @pytest.mark.parametrize( - "model,template,add_generation_prompt,expected_output", + "model,template,add_generation_prompt,continue_final_message,expected_output", MODEL_TEMPLATE_GENERATON_OUTPUT) def test_get_gen_prompt(model, template, add_generation_prompt, - expected_output): + continue_final_message, expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) template_content = load_chat_template(chat_template=template) @@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( model=model, - messages=TEST_MESSAGES, - add_generation_prompt=add_generation_prompt) + messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] + if continue_final_message else TEST_MESSAGES, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) # Call the function and get the result result = apply_hf_chat_template( @@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, add_generation_prompt=mock_request.add_generation_prompt, + continue_final_message=mock_request.continue_final_message, ) # Test assertion diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 316ca11b8e95a..859a676a9c777 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -104,28 +104,40 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str, "role": "user", "content": "Can I ask a question? vllm1" }] - - prompt = tokenizer.apply_chat_template( - add_generation_prompt=add_generation, - conversation=conversation, - tokenize=False) - tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - - response = requests.post(base_url + "/tokenize", - json={ - "add_generation_prompt": - add_generation, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name - }) - response.raise_for_status() - - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192 - } + for continue_final in [False, True]: + if add_generation and continue_final: + continue + if continue_final: + conversation.append({ + "role": "assistant", + "content": "Sure," + }) + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + continue_final_message=continue_final, + conversation=conversation, + tokenize=False) + tokens = tokenizer.encode(prompt, + add_special_tokens=add_special) + + response = requests.post(base_url + "/tokenize", + json={ + "add_generation_prompt": + add_generation, + "continue_final_message": + continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name + }) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } @pytest.mark.asyncio diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4a575ae8f8537..130f3ba49f3e1 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -542,6 +542,14 @@ def apply_mistral_chat_template( if chat_template is not None: logger.warning( "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") return tokenizer.apply_chat_template( messages=messages, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5a10e72e5c165..bd009ae915c93 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -501,6 +501,7 @@ def chat( lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = True, + continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, ) -> List[RequestOutput]: """ @@ -528,6 +529,9 @@ def chat( If not provided, the model's default chat template will be used. add_generation_prompt: If True, adds a generation template to each message. + continue_final_message: If True, continues the final message in + the conversation instead of starting a new one. Cannot be `True` + if `add_generation_prompt` is also `True`. Returns: A list of ``RequestOutput`` objects containing the generated @@ -559,6 +563,7 @@ def chat( messages=msgs, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, tools=tools, ) else: @@ -567,6 +572,7 @@ def chat( conversation=conversation, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, tools=tools, ) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 646aa4537999e..f716e4a0458bf 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -211,6 +211,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "This is a parameter used by chat template in tokenizer config of the " "model."), ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) add_special_tokens: bool = Field( default=False, description=( @@ -431,6 +440,15 @@ def check_tool_usage(cls, data): " of the specified `tools`") return data + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -862,8 +880,18 @@ class TokenizeChatRequest(OpenAIBaseModel): messages: List[ChatCompletionMessageParam] add_generation_prompt: bool = Field(default=True) + continue_final_message: bool = Field(default=False) add_special_tokens: bool = Field(default=False) + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b78d8f7c6f088..c245fd9b97cfa 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -140,6 +140,7 @@ async def create_chat_completion( messages=request.messages, chat_template=request.chat_template or self.chat_template, add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, tools=tool_dicts, documents=request.documents, **(request.chat_template_kwargs or {}), @@ -150,6 +151,7 @@ async def create_chat_completion( conversation=conversation, chat_template=request.chat_template or self.chat_template, add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, tools=tool_dicts, documents=request.documents, **(request.chat_template_kwargs or {}), @@ -371,7 +373,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message - if request.echo: + if request.echo or request.continue_final_message: last_msg_content: str = "" if conversation and "content" in conversation[ -1] and conversation[-1].get("role") == role: @@ -731,7 +733,7 @@ async def chat_completion_full_generator( stop_reason=output.stop_reason) choices.append(choice_data) - if request.echo: + if request.echo or request.continue_final_message: last_msg_content = "" if conversation and "content" in conversation[-1] and conversation[ -1].get("role") == role: diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 6d9a1ae088079..a269c94c7ec0d 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -87,6 +87,7 @@ async def create_tokenize( messages=request.messages, chat_template=self.chat_template, add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, ) else: prompt = apply_hf_chat_template( @@ -94,6 +95,7 @@ async def create_tokenize( conversation=conversation, chat_template=self.chat_template, add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, ) else: prompt = request.prompt