diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index a0b8e81f666c..6916d63a56e1 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -519,11 +519,14 @@ def apply_hf_chat_template( def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: List[ChatCompletionMessageParam], - chat_template: Optional[str], + chat_template: Optional[str] = None, **kwargs: Any, ) -> List[int]: + if chat_template is not None: + logger.warning( + "'chat_template' cannot be overridden for mistral tokenizer.") + return tokenizer.apply_chat_template( messages=messages, - chat_template=chat_template, **kwargs, ) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 17e318cb5e04..ea1910ed20ec 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -45,26 +45,25 @@ class MistralTokenizer: def __init__(self, tokenizer: PublicMistralTokenizer) -> None: self.mistral = tokenizer self.instruct = tokenizer.instruct_tokenizer - self.tokenizer = tokenizer.instruct_tokenizer.tokenizer - self.vocab_size = len(self.tokenizer.vocab()) - - assert isinstance(self.tokenizer, - (Tekkenizer, SentencePieceTokenizer)), type( - self.tokenizer) - - if (is_tekken := isinstance(self.tokenizer, Tekkenizer)): + tokenizer_ = tokenizer.instruct_tokenizer.tokenizer + if isinstance(tokenizer_, Tekkenizer): # Make sure special tokens will not raise - self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE - - self._is_tekken = is_tekken + tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE + + self._vocab = { + token: idx + for idx, token in enumerate(tokenizer_.vocab()) + } + elif isinstance(tokenizer_, SentencePieceTokenizer): + self._vocab = { + token: idx + for idx, token in enumerate(tokenizer_.vocab()) + } + else: + raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") - # the following attributes are set to fit VLLM's design - self.is_fast = True - self.chat_template = True - self.all_special_ids: List[Any] = [] - self.all_special_tokens: List[Any] = [] - self.all_special_tokens_extended: List[Any] = [] + self.tokenizer = tokenizer_ @classmethod def from_pretrained(cls, @@ -102,6 +101,38 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str, revision=revision) return tokenizer_file + # the following attributes are set to fit VLLM's design + @property + def all_special_tokens_extended(self) -> List[str]: + return [] + + @property + def all_special_tokens(self) -> List[str]: + return [] + + @property + def all_special_ids(self) -> List[int]: + return [] + + @property + def bos_token_id(self) -> int: + return self.tokenizer.bos_id + + @property + def eos_token_id(self) -> int: + return self.tokenizer.eos_id + + @property + def is_fast(self) -> bool: + return True + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def __len__(self) -> int: + return self.vocab_size + def __call__( self, prompt: str, @@ -117,9 +148,12 @@ def __call__( return Encoding(input_ids=input_ids) - def get_added_vocab(self) -> List[str]: + def get_vocab(self) -> Dict[str, int]: + return self._vocab + + def get_added_vocab(self) -> Dict[str, int]: # Mistral tokenizers have no added vocabulary - return [] + return {} def encode(self, prompt: str) -> List[int]: # `encode` should only be used for prompt completion @@ -141,7 +175,7 @@ def apply_chat_template(self, return encoded.tokens def convert_tokens_to_string(self, tokens: List[str]) -> str: - if self._is_tekken: + if isinstance(self.tokenizer, Tekkenizer): return "".join(tokens) else: return self.tokenizer.decode(tokens) # type: ignore[arg-type] @@ -151,14 +185,11 @@ def decode(self, ids: Union[List[int], int]) -> str: ids = [ids] return self.tokenizer.decode(ids) - @property - def eos_token_id(self): - return self.tokenizer.eos_id - def convert_ids_to_tokens( - self, - ids: List[int], - skip_special_tokens: Optional[bool] = True) -> List[str]: + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: # TODO(Patrick) - potentially allow special tokens to not be skipped assert ( skip_special_tokens @@ -170,6 +201,3 @@ def convert_ids_to_tokens( tokens = [self.tokenizer.id_to_piece(id) for id in ids] return tokens - - def __len__(self): - return self.vocab_size