From e5c6c4b2ad60a1acd2d6d381e9900677891c2edc Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 9 Oct 2024 22:59:57 +0800 Subject: [PATCH] [Bugfix] Access `get_vocab` instead of `vocab` in tool parsers (#9188) --- .../openai/tool_parsers/abstract_tool_parser.py | 7 +++++++ vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py | 7 +++---- .../entrypoints/openai/tool_parsers/mistral_tool_parser.py | 3 +-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 7e55532bc7297..5ce31bd4d941b 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,6 +1,7 @@ import importlib import importlib.util import os +from functools import cached_property from typing import Callable, Dict, List, Optional, Sequence, Type, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -29,6 +30,12 @@ def __init__(self, tokenizer: AnyTokenizer): self.model_tokenizer = tokenizer + @cached_property + def vocab(self) -> Dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + def adjust_request( self, request: ChatCompletionRequest) -> ChatCompletionRequest: """ diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 6c5bcc7dd59b1..bcbcda3fa528a 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -50,10 +50,9 @@ def __init__(self, tokenizer: AnyTokenizer): raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction.") - self.tool_call_start_token_id: int = self.model_tokenizer.vocab.get( - self.tool_call_start_token, None) - self.tool_call_end_token_id: int = self.model_tokenizer.vocab.get( - self.tool_call_end_token, None) + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if not self.tool_call_start_token_id or not self.tool_call_end_token_id: raise RuntimeError( "Hermes 2 Pro Tool parser could not locate tool call start/end " diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 9580fa115c6b3..c6dc0688e38f9 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -61,8 +61,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" - self.bot_token_id = self.model_tokenizer.get_vocab().get( - self.bot_token, None) + self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) if not self.bot_token_id: raise RuntimeError(