Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Feature] Add jamba tool parser #9154

Merged
merged 22 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cbd955a
first working version of jamba tool parsing
tomeras91 Sep 26, 2024
1a8c4e1
lint and format
tomeras91 Sep 26, 2024
0da420d
fix: We don't want to add content if it's an empty string
tomeras91 Oct 1, 2024
310535c
add initial tests for jamba tool parser
tomeras91 Oct 1, 2024
f5c9d09
reduce code duplication with use of parametrize
tomeras91 Oct 1, 2024
6b04e35
fix model outputs to match jamba expected output
tomeras91 Oct 1, 2024
c25cd51
add tests for jamba tool parsing with streaming
tomeras91 Oct 1, 2024
d551be0
Merge branch 'main' into add-jamba-tool-parser
tomeras91 Oct 8, 2024
d31e688
adjust JambaToolParser to changes in upstream
tomeras91 Oct 8, 2024
6a27eb3
Add adjust_request function to JambaToolParser since we need to set s…
tomeras91 Oct 8, 2024
bc16953
update comments and remove unused code
tomeras91 Oct 8, 2024
25d839d
lint & format + adjust tests to new tool parser API
tomeras91 Oct 8, 2024
16542bc
dummy for build
tomeras91 Oct 9, 2024
2a25f10
Revert "dummy for build"
tomeras91 Oct 9, 2024
a935865
Merge branch 'main' into add-jamba-tool-parser
DarkLight1337 Oct 9, 2024
3c757c5
Use #9188 and improve validation
DarkLight1337 Oct 9, 2024
0db1408
removed done TODO
tomeras91 Oct 10, 2024
e5c2878
Merge branch 'add-jamba-tool-parser' of github.com:tomeras91/vllm int…
tomeras91 Oct 10, 2024
20aeb6d
Added Jamba tool calling to docs
tomeras91 Oct 17, 2024
54efc40
Apply #9461
DarkLight1337 Oct 18, 2024
ae9a0b7
Trigger build with fix typo
DarkLight1337 Oct 18, 2024
d5fefe9
Fix missing option
DarkLight1337 Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ Known issues:

Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`

#### Jamba Models
AI21's Jamba-1.5 models are supported.
* `ai21labs/AI21-Jamba-1.5-Mini`
* `ai21labs/AI21-Jamba-1.5-Large`


Flags: `--tool-call-parser jamba`


### How to write a tool parser plugin

Expand Down
275 changes: 275 additions & 0 deletions tests/tool_use/test_jamba_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import json
from typing import Generator, List, Optional

import partial_json_parser
import pytest
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer

MODEL = "ai21labs/Jamba-tiny-dev"


@pytest.fixture(scope="module")
def jamba_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)


@pytest.fixture
def jamba_tool_parser(jamba_tokenizer):
return JambaToolParser(jamba_tokenizer)


def assert_tool_calls(actual_tool_calls: List[ToolCall],
expected_tool_calls: List[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)

for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16

assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function


def stream_delta_message_generator(
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer,
model_output: str) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output,
add_special_tokens=False)

previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]

(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=jamba_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)

current_text = previous_text + delta_text

delta_message = jamba_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message

previous_text = current_text
previous_tokens = previous_tokens + new_tokens if previous_tokens\
else new_tokens
prefix_offset = new_prefix_offset
read_offset = new_read_offset


def test_extract_tool_calls_no_tools(jamba_tool_parser):
model_output = "This is a test"
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output


@pytest.mark.parametrize(
ids=[
"single_tool",
"single_tool_with_content",
"parallel_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
None),
(
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
" Sure! let me call the tool for you."),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
],
None)
],
)
def test_extract_tool_calls(jamba_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called

assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

assert extracted_tool_calls.content == expected_content


@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool",
"single_tool_with_content",
"parallel_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
('''This is a test''', [], '''This is a test'''),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
" "),
(
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
" Sure! let me call the tool for you."),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
],
" ")
],
)
def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
model_output, expected_tool_calls,
expected_content):
other_content: str = ''
function_names: List[str] = []
function_args_strs: List[str] = []
tool_call_idx: int = -1
tool_call_ids: List[Optional[str]] = []

for delta_message in stream_delta_message_generator(
jamba_tool_parser, jamba_tokenizer, model_output):
# role should never be streamed from tool parser
assert not delta_message.role

if delta_message.content:
other_content += delta_message.content

streamed_tool_calls = delta_message.tool_calls

if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]

# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
function_args_strs.append("")
tool_call_ids.append(None)

# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id

# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
function_names.append(tool_call.function.name)

if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)

function_args_strs[
tool_call.index] += tool_call.function.arguments

assert other_content == expected_content

actual_tool_calls = [
ToolCall(id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR)))
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser

__all__ = [
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",
"JambaToolParser"
]
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, tokenizer: AnyTokenizer):
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
Expand Down
Loading