Skip to content

Commit

Permalink
[Tokenizer] Add an option to specify tokenizer (vllm-project#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jun 28, 2023
1 parent 33834ff commit 1775c69
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 31 deletions.
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
Expand All @@ -30,13 +31,15 @@ class ModelConfig:
def __init__(
self,
model: str,
tokenizer: Optional[str],
download_dir: Optional[str],
use_np_weights: bool,
use_dummy_weights: bool,
dtype: str,
seed: int,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.download_dir = download_dir
self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights
Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
Expand All @@ -27,6 +28,8 @@ class EngineArgs:
disable_log_stats: bool = False

def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

@staticmethod
Expand All @@ -37,6 +40,8 @@ def add_cli_args(
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--download-dir', type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
Expand Down Expand Up @@ -104,7 +109,7 @@ def create_engine_configs(
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(
self.model, self.download_dir, self.use_np_weights,
self.model, self.tokenizer, self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
self.swap_space)
Expand Down
6 changes: 4 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
from vllm.worker.worker import Worker

Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
logger.info(
"Initializing an LLM engine with config: "
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
Expand All @@ -75,7 +77,7 @@ def __init__(
self.log_stats = log_stats
self._verify_args()

self.tokenizer = get_tokenizer(model_config.model)
self.tokenizer = get_tokenizer(model_config.tokenizer)
self.seq_counter = Counter()

# Create the parallel GPU workers.
Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LLM:
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
Expand All @@ -38,6 +39,7 @@ class LLM:
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tensor_parallel_size: int = 1,
dtype: str = "auto",
seed: int = 0,
Expand All @@ -47,6 +49,7 @@ def __init__(
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.tokenizer_utils import get_tokenizer
from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid

TIMEOUT_KEEP_ALIVE = 5 # seconds
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,46 +1,44 @@
from typing import List, Tuple, Union

from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)

from vllm.logger import init_logger

logger = init_logger(__name__)

_MODEL_TYPES_WITH_SLOW_TOKENIZER = []
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
model_name: str,
tokenizer_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
config = AutoConfig.from_pretrained(model_name)
if "open_llama" in model_name:
kwargs["use_fast"] = False
if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True):
logger.info(
"OpenLLaMA models do not support the fast tokenizer. "
"Using the slow tokenizer instead.")
elif config.model_type == "llama" and kwargs.get("use_fast", True):
# LLaMA fast tokenizer causes protobuf errors in some environments.
# However, we found that the below LLaMA fast tokenizer works well in
# most environments.
model_name = "hf-internal-testing/llama-tokenizer"
logger.info(
f"Using the LLaMA fast tokenizer in '{model_name}' to avoid "
"potential protobuf errors.")
elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
if kwargs.get("use_fast", False) == True:
raise ValueError(
f"Cannot use the fast tokenizer for {config.model_type} due to "
"bugs in the fast tokenizer.")
logger.info(
f"Using the slow tokenizer for {config.model_type} due to bugs in "
"the fast tokenizer. This could potentially lead to performance "
"degradation.")
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.")
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args,
**kwargs)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA-based "
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.")
raise RuntimeError(err_msg) from e

if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return tokenizer


def detokenize_incrementally(
Expand Down

0 comments on commit 1775c69

Please sign in to comment.