Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make initialization of tokenizer and detokenizer optional #3748

Merged
merged 47 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0ca0b3d
[optional-tokenizer] make tokenzier optional in initialization
EricDingNVD Mar 28, 2024
a50f0c7
[tokenizer] make detokenization optional
EricDingNVD Mar 30, 2024
e144c2e
[tokenizer] fix parameter description
EricDingNVD Mar 30, 2024
5fb16f2
[tokenizer] fix initialize engine args
EricDingNVD Mar 31, 2024
904edcc
[tokenizer] fix format
EricDingNVD Mar 31, 2024
cfc2660
[tokenization] fix arg parser field
EricDingNVD Mar 31, 2024
013a36a
[tokenizer] fix the order of initializing tokenizer and de-tokenizer
EricDingNVD Apr 1, 2024
fb3eefd
[tokenizer] Never disable tok in LLM initialization
EricDingNVD Apr 1, 2024
07cc2e5
[tokenizer] Add flag value to log info to help debug
EricDingNVD Apr 2, 2024
5b30825
[tokenizer] fix type
EricDingNVD Apr 2, 2024
676256f
[tokenizer] fix yapf errors
EricDingNVD Apr 2, 2024
f7cd883
[tokenizer] fix formatting
EricDingNVD Apr 4, 2024
8dfb59b
Merge branch 'vllm-project:main' into optional-tokenizer
GeauxEric Apr 7, 2024
0ea8446
[optional-tokenizer] make tokenzier optional in initialization
EricDingNVD Mar 28, 2024
a7be734
[tokenizer] make detokenization optional
EricDingNVD Mar 30, 2024
1e613f6
[tokenizer] fix parameter description
EricDingNVD Mar 30, 2024
3b94adb
[tokenizer] fix initialize engine args
EricDingNVD Mar 31, 2024
eab1dd7
[tokenizer] fix format
EricDingNVD Mar 31, 2024
58ccf64
[tokenization] fix arg parser field
EricDingNVD Mar 31, 2024
a0d1405
[tokenizer] fix the order of initializing tokenizer and de-tokenizer
EricDingNVD Apr 1, 2024
0af6b47
[tokenizer] Never disable tok in LLM initialization
EricDingNVD Apr 1, 2024
a497ed9
[tokenizer] Add flag value to log info to help debug
EricDingNVD Apr 2, 2024
af078a8
[tokenizer] fix type
EricDingNVD Apr 2, 2024
ad2c920
[tokenizer] fix yapf errors
EricDingNVD Apr 2, 2024
78d4091
[tokenizer] fix formatting
EricDingNVD Apr 4, 2024
4f67490
[tokenizer] fix EngineArgs
EricDingNVD Apr 7, 2024
a093589
Merge branch 'optional-tokenizer' of github.com:GeauxEric/vllm into o…
EricDingNVD Apr 7, 2024
59fc5eb
[tokenizer] fix init LLM
EricDingNVD Apr 7, 2024
ac7a3d4
[tokenizer] rename the flag
EricDingNVD Apr 10, 2024
400224d
[tokenizer] rename the flag
EricDingNVD Apr 10, 2024
6941628
[tokenizer] add integration test
EricDingNVD Apr 11, 2024
7e851ac
Merge branch 'optional-tokenizer' of github.com:GeauxEric/vllm into o…
EricDingNVD Apr 11, 2024
ad4da7c
[tokenizer] add integration test
EricDingNVD Apr 11, 2024
3f60c11
Merge branch 'optional-tokenizer' of github.com:GeauxEric/vllm into o…
EricDingNVD Apr 11, 2024
4b3c5e3
[tokenizer] test generate based on prompt token ids
EricDingNVD Apr 11, 2024
87c695b
[tokenizer] more tests
EricDingNVD Apr 11, 2024
aa5ec54
[tokenizer] consider finialize sequence
EricDingNVD Apr 12, 2024
4943cbd
Merge branch 'main' of github.com:GeauxEric/vllm into optional-tokenizer
EricDingNVD Apr 13, 2024
c208c77
Merge branch 'main' into optional-tokenizer
EricDingNVD Apr 17, 2024
68f77b1
Merge branch 'main' into optional-tokenizer
EricDingNVD Apr 17, 2024
c0951f3
[tokenizer] fix integration test
EricDingNVD Apr 17, 2024
5f8b5fd
Merge branch 'main' of github.com:GeauxEric/vllm into optional-tokenizer
EricDingNVD Apr 17, 2024
47dce6e
[tokenizer] merge with main
EricDingNVD Apr 17, 2024
50a7fad
[tokenizer] log warning if eos_token_id is None
EricDingNVD Apr 18, 2024
7c25549
Merge branch 'main' of github.com:GeauxEric/vllm into optional-tokenizer
EricDingNVD Apr 18, 2024
5f2b8ed
[tokenizer] work around mypy errors
EricDingNVD Apr 18, 2024
e584673
Merge remote-tracking branch 'upstream/main' into optional-tokenizer
ywang96 Apr 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/engine/test_skip_tokenizer_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(model=model, skip_tokenizer_init=True)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class ModelConfig:
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
"""

def __init__(
Expand All @@ -85,6 +87,7 @@ def __init__(
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
skip_tokenizer_init: bool = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -99,14 +102,16 @@ def __init__(
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self._verify_tokenizer_mode()
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()

Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
Expand Down Expand Up @@ -93,6 +94,10 @@ def add_cli_args(
type=str,
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer')
parser.add_argument(
'--revision',
type=str,
Expand Down Expand Up @@ -453,7 +458,7 @@ def create_engine_config(self, ) -> EngineConfig:
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs)
self.max_logprobs, self.skip_tokenizer_init)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
29 changes: 21 additions & 8 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
f"model={model_config.model!r}, "
f"speculative_config={speculative_config!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
Expand Down Expand Up @@ -132,8 +133,14 @@ def __init__(
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats

self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None

self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
Expand Down Expand Up @@ -187,9 +194,10 @@ def __init__(
parallel_config.disable_custom_all_reduce,
})

# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()

# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
Expand Down Expand Up @@ -296,7 +304,7 @@ def _init_tokenizer(self, **tokenizer_init_kwargs):
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)

def _verify_args(self) -> None:
Expand Down Expand Up @@ -393,8 +401,13 @@ def add_request(
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
Comment on lines +404 to +407
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should add a warning here - WDYT?

This will also affect components that use eos_token_id. For example,

# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it's worth having an warning here about eos_token_id being None.

else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)

Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
if prompt_logprobs is not None and \
seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
Expand Down Expand Up @@ -105,7 +106,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize:
if seq_group.sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
else:
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class LLM:
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer. Expect valid prompt_token_ids and None for prompt
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
Expand Down Expand Up @@ -76,6 +79,7 @@ def __init__(
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
Expand All @@ -96,6 +100,7 @@ def __init__(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
Expand Down Expand Up @@ -160,6 +165,10 @@ def generate(
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
Expand Down
Loading