From 83caf35e082b2657dce5f71ff965a13653a763b0 Mon Sep 17 00:00:00 2001 From: Guillaume Calmettes Date: Thu, 3 Oct 2024 10:44:52 +0200 Subject: [PATCH] [BugFix] Enforce Mistral ToolCall id constraint when using the Mistral tool call parser (#9020) --- tests/tool_use/test_parallel_tool_calls.py | 4 ++-- tests/tool_use/test_tool_calls.py | 4 ++-- .../tool_parsers/mistral_tool_parser.py | 20 +++++++++++++++++-- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index ed7ac8afe1b4..cff3c8a556ca 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, assert tool_call.type == "function" assert tool_call.function is not None assert isinstance(tool_call.id, str) - assert len(tool_call.id) > 16 + assert len(tool_call.id) >= 9 # make sure the weather tool was called correctly assert tool_call.function.name == WEATHER_TOOL["function"]["name"] @@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, if tool_call.id: tool_call_id_count += 1 assert (isinstance(tool_call.id, str) - and (len(tool_call.id) > 16)) + and (len(tool_call.id) >= 9)) # if parts of the function start being streamed if tool_call.function: diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index c3abe9e1f506..9e6d715f44fc 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert tool_calls[0].type == 'function' assert tool_calls[0].function is not None assert isinstance(tool_calls[0].id, str) - assert len(tool_calls[0].id) > 16 + assert len(tool_calls[0].id) >= 9 # make sure the weather tool was called (classic example) with arguments assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] @@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert finish_reason_count == 1 assert role_name == 'assistant' - assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9) # validate the name and arguments assert function_name == WEATHER_TOOL["function"]["name"] diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 4b0e1c91df97..b61ad40a697e 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -1,9 +1,12 @@ import json import re +from random import choices +from string import ascii_letters, digits from typing import Dict, List, Sequence, Union import partial_json_parser from partial_json_parser.core.options import Allow +from pydantic import Field from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -19,6 +22,19 @@ logger = init_logger(__name__) +ALPHANUMERIC = ascii_letters + digits + + +class MistralToolCall(ToolCall): + id: str = Field( + default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + class MistralToolParser(ToolParser): """ @@ -71,8 +87,8 @@ def extract_tool_calls(self, # load the JSON, and then use it to build the Function and # Tool Call function_call_arr = json.loads(raw_tool_call) - tool_calls: List[ToolCall] = [ - ToolCall( + tool_calls: List[MistralToolCall] = [ + MistralToolCall( type="function", function=FunctionCall( name=raw_function_call["name"],