Skip to content

Commit

Permalink
[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-…
Browse files Browse the repository at this point in the history
…chat model (vllm-project#8405)
  • Loading branch information
sydnash authored Oct 4, 2024
1 parent 2838d6b commit 3dbb215
Show file tree
Hide file tree
Showing 13 changed files with 533 additions and 46 deletions.
3 changes: 2 additions & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
74 changes: 72 additions & 2 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
To enable this feature, you should set the following flags:
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
deems appropriate.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers
will continue to be added in the future.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
Expand Down Expand Up @@ -218,4 +219,73 @@ it works better with vLLM.

Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`

#### Internlm Models
Supported models:
* `internlm/internlm2_5-7b-chat` (confirmed)
* Additional internlm2.5 function-calling models are compatible as well

Known issues:
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.

Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`


### How to write a tool parser plugin

A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.

Here is a summary of a plugin file:

```python

# import the required packages

# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["example"])
class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
return request

# implement the tool call parse for stream call
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
return delta

# implement the tool parse for non-stream call
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)


```

Then you can use this plugin in the command line like this.
```
--enable-auto-tool-choice \
--tool-parser-plugin <absolute path of the plugin file>
--tool-call-parser example \
--chat-template <your chat template> \
```

60 changes: 60 additions & 0 deletions examples/tool_chat_template_internlm2_tool.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}

{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}

{{- bos_token }}
{%- if system_message is defined %}
{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }}
{%- endif %}

{%- if tools is not none %}
{{- "<|im_start|>system name=<|plugin|>\n[" }}
{%- for tool in tools %}
{{- tool.function|tojson }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- endif %}

{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message.tool_calls is defined and message.tool_calls is not none %}
{%- set content = message["content"] if message["content"] else "" %}
{{- "<|im_start|>assistant\n" + content }}
{%- for tool_call in message.tool_calls %}
{%- set function=tool_call.function %}
{{- "<|action_start|><|plugin|>\n" }}
{{- '{"name": "' + function.name + '", '}}
{{- '"arguments": ' + function.arguments|tojson + '}' }}
{{- "<|action_end|>" }}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- elif message["role"] == "assistant" %}
{{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }}
{%- else %}
{{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}

{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
14 changes: 13 additions & 1 deletion tests/tool_use/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
"internlm": {
"model":
"internlm/internlm2_5-7b-chat",
"arguments": [
"--tool-call-parser", "internlm", "--chat-template",
str(VLLM_PATH /
"examples/tool_chat_template_internlm2_tool.jinja"),
"--trust_remote_code"
],
"supports_parallel":
False,
}
}

Expand All @@ -109,7 +121,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
"type":
"string",
"description":
"the two-letter abbreviation for the state "
"must the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
},
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
Expand Down Expand Up @@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")

temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
temp_socket.bind(("", args.port))

Expand Down
14 changes: 13 additions & 1 deletion vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser


Expand Down Expand Up @@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")

valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes", "llama3_json"],
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")

parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")

parser = AsyncEngineArgs.add_cli_args(parser)

parser.add_argument('--max-log-len',
Expand Down
38 changes: 20 additions & 18 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
Llama3JsonToolParser,
MistralToolParser,
ToolParser)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
Expand Down Expand Up @@ -82,15 +79,13 @@ def __init__(self,

self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
if tool_parser == "mistral":
self.tool_parser = MistralToolParser
elif tool_parser == "hermes":
self.tool_parser = Hermes2ProToolParser
elif tool_parser == "llama3_json":
self.tool_parser = Llama3JsonToolParser
else:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
tool_parser)
except Exception as e:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e

async def create_chat_completion(
self,
Expand Down Expand Up @@ -187,6 +182,10 @@ async def create_chat_completion(
raw_request.state.request_metadata = request_metadata

try:
if self.enable_auto_tools and self.tool_parser:
request = self.tool_parser(tokenizer).adjust_request(
request=request)

if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
Expand Down Expand Up @@ -282,11 +281,11 @@ async def chat_completion_stream_generator(
num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices

num_prompt_tokens = 0

tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
tool_parsers: List[Optional[ToolParser]] = [
self.tool_parser(tokenizer) if self.tool_parser else None
] * num_choices

if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
Expand Down Expand Up @@ -324,7 +323,7 @@ async def chat_completion_stream_generator(
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):

tool_parser = tool_parsers[i]
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
Expand Down Expand Up @@ -399,6 +398,7 @@ async def chat_completion_stream_generator(

for output in res.outputs:
i = output.index
tool_parser = tool_parsers[i]

if finish_reason_sent[i]:
continue
Expand Down Expand Up @@ -446,7 +446,8 @@ async def chat_completion_stream_generator(
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))
delta_token_ids=output.token_ids,
request=request))

# update the previous values for the next iteration
previous_texts[i] = current_text
Expand Down Expand Up @@ -685,7 +686,8 @@ async def chat_completion_full_generator(
and self.tool_parser:

tool_parser = self.tool_parser(tokenizer)
tool_call_info = tool_parser.extract_tool_calls(output.text)
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .abstract_tool_parser import ToolParser
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser

__all__ = [
"ToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Llama3JsonToolParser"
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
]
Loading

0 comments on commit 3dbb215

Please sign in to comment.