Skip to content

Commit

Permalink
[Model] Add Mistral Tokenization to improve robustness and chat encod…
Browse files Browse the repository at this point in the history
…ing (#7739)
  • Loading branch information
patrickvonplaten authored Aug 27, 2024
1 parent 9606c71 commit 6fc4e6e
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 60 deletions.
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
4 changes: 3 additions & 1 deletion tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
Expand Down
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class ModelConfig:
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
dtype: Data type for model weights and activations. The "auto" option
Expand Down Expand Up @@ -246,10 +247,10 @@ def _init_multimodal_config(

def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
if tokenizer_mode not in ["auto", "slow", "mistral"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode

def _verify_embedding_mode(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
choices=['auto', 'slow', 'mistral'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
Expand Down
4 changes: 1 addition & 3 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def apply_chat_template(
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
) -> Union[str, List[int]]:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
Expand All @@ -280,6 +280,4 @@ def apply_chat_template(
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)

return prompt
12 changes: 9 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,15 +390,21 @@ def chat(
conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)

prompts = apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt)

inputs: PromptInputs
if isinstance(prompt, list) and isinstance(prompt[0], int):
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
inputs = TextPrompt(prompt=prompt)

return self.generate(
prompts,
sampling_params,
inputs,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
Expand Down
26 changes: 18 additions & 8 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
PromptAdapterPath,
TextTokensPrompt)
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
Expand Down Expand Up @@ -130,13 +131,22 @@ async def create_chat_completion(
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))

prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)

assert prompt_inputs is not None

sampling_params = request.to_sampling_params(
tokenizer,
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens(
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset


Expand Down
94 changes: 58 additions & 36 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from pathlib import Path
from typing import Optional, Union

Expand All @@ -9,12 +10,14 @@
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
MistralTokenizer)
from vllm.utils import make_async

logger = init_logger(__name__)

AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]


def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
Expand Down Expand Up @@ -99,45 +102,64 @@ def get_tokenizer(
kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
except AttributeError as e:
if "BaichuanTokenizer" in str(e):
# This is for the error "'BaichuanTokenizer' object has no
# attribute 'sp_model'".
tokenizer = BaichuanTokenizer.from_pretrained(
# if tokenizer is from official mistral org
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn(
'It is strongly recommended to run mistral models with '
'`--tokenizer_mode "mistral"` to ensure correct '
'encoding and decoding.',
FutureWarning,
stacklevel=2)

if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
else:
raise e
**kwargs,
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported,
# suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)):
err_msg = ("Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
except AttributeError as e:
if "BaichuanTokenizer" in str(e):
# This is for the error "'BaichuanTokenizer' object has no
# attribute 'sp_model'".
tokenizer = BaichuanTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
raise e

if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
tokenizer = get_cached_tokenizer(tokenizer)

if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return get_cached_tokenizer(tokenizer)
return tokenizer


def get_lora_tokenizer(lora_request: LoRARequest, *args,
Expand Down
5 changes: 2 additions & 3 deletions vllm/transformers_utils/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer

__all__ = [
"BaichuanTokenizer",
]
__all__ = ["BaichuanTokenizer", "MistralTokenizer"]
Loading

0 comments on commit 6fc4e6e

Please sign in to comment.