Skip to content

Commit

Permalink
[Model] Add mistral function calling format to all models loaded with…
Browse files Browse the repository at this point in the history
… "mistral" format (vllm-project#8515)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
2 people authored and garg-amit committed Oct 28, 2024
1 parent 50fdbe9 commit 5fb85e1
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 9 deletions.
138 changes: 138 additions & 0 deletions examples/offline_chat_with_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# ruff: noqa
import json
import random
import string

from vllm import LLM
from vllm.sampling_params import SamplingParams

# This script is an offline demo for function calling
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Mistral-7B-Instruct-v0.3"
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced

model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
# or "mistralai/Mistral-Large-Instruct-2407"
# or any other mistral model with function calling ability

sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
llm = LLM(model=model_name,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral")


def generate_random_id(length=9):
characters = string.ascii_letters + string.digits
random_id = ''.join(random.choice(characters) for _ in range(length))
return random_id


# simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'):
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's.")


tool_funtions = {"get_current_weather": get_current_weather}

tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]

messages = [{
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]

outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()

# append the assistant message
messages.append({
"role": "assistant",
"content": output,
})

# let's now actually parse and execute the model's output simulating an API call by using the
# above defined function
tool_calls = json.loads(output)
tool_answers = [
tool_funtions[call['name']](**call['arguments']) for call in tool_calls
]

# append the answer as a tool message and let the LLM give you an answer
messages.append({
"role": "tool",
"content": "\n\n".join(tool_answers),
"tool_call_id": generate_random_id(),
})

outputs = llm.chat(messages, sampling_params, tools=tools)

print(outputs[0].outputs[0].text.strip())
# yields
# 'The weather in Dallas, TX is 85 degrees fahrenheit. '
# 'It is partly cloudly, with highs in the 90's.'
67 changes: 67 additions & 0 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,61 @@
"""
import pytest

from vllm import SamplingParams

from ...utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
# Mistral-Nemo is to big for CI, but passes locally
# "mistralai/Mistral-Nemo-Instruct-2407"
]

SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)

# for function calling
TOOLS = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
MSGS = [{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}]
EXPECTED_FUNC_CALL = (
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
Expand Down Expand Up @@ -81,3 +129,22 @@ def test_mistral_format(
name_0="hf",
name_1="mistral",
)


@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
def test_mistral_function_calling(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral") as vllm_model:
outputs = vllm_model.model.chat(MSGS,
tools=TOOLS,
sampling_params=SAMPLING_PARAMS)

assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
6 changes: 5 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)

from tqdm import tqdm

Expand Down Expand Up @@ -357,6 +358,7 @@ def chat(
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
tools: Optional[List[Dict[str, Any]]] = None,
) -> List[RequestOutput]:
"""
Generate responses for a chat conversation.
Expand Down Expand Up @@ -401,13 +403,15 @@ def chat(
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)

inputs: PromptInputs
Expand Down
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def create_chat_completion(
]

prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
Expand Down Expand Up @@ -159,10 +160,10 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# "auto" tools requires --enable-auto-tool-choice
# and --tool-call-parser
if request.tool_choice == "auto" and not (
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
Expand Down
8 changes: 4 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,18 @@ def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."

request = ChatCompletionRequest(
messages=messages) # type: ignore[type-var]
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)

# encode-decode to get clean prompt
return encoded.tokens

def convert_tokens_to_string(self, tokens: List[str]) -> str:
if isinstance(self.tokenizer, Tekkenizer):
return "".join(tokens)
return "".join(t for t in tokens
if t not in self.tokenizer._all_special_tokens)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]

Expand Down

0 comments on commit 5fb85e1

Please sign in to comment.