From fb96c1e98c05ffa35dd48416f68e88edb2f9eb34 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 15 Mar 2024 16:37:01 -0700 Subject: [PATCH] Asynchronous tokenization (#2879) --- .buildkite/test-pipeline.yaml | 2 +- tests/async_engine/test_api_server.py | 16 +- tests/conftest.py | 11 ++ tests/lora/test_tokenizer.py | 69 -------- tests/lora/test_tokenizer_group.py | 53 ++++++ tests/tokenization/__init__.py | 0 tests/tokenization/test_cached_tokenizer.py | 20 +++ .../test_detokenize.py | 0 tests/tokenization/test_tokenizer_group.py | 100 +++++++++++ vllm/config.py | 57 ++++++ vllm/engine/arg_utils.py | 43 ++++- vllm/engine/llm_engine.py | 15 +- vllm/transformers_utils/tokenizer.py | 99 ++++------- .../tokenizer_group/__init__.py | 32 ++++ .../tokenizer_group/base_tokenizer_group.py | 48 +++++ .../tokenizer_group/ray_tokenizer_group.py | 166 ++++++++++++++++++ .../tokenizer_group/tokenizer_group.py | 80 +++++++++ 17 files changed, 658 insertions(+), 153 deletions(-) delete mode 100644 tests/lora/test_tokenizer.py create mode 100644 tests/lora/test_tokenizer_group.py create mode 100644 tests/tokenization/__init__.py create mode 100644 tests/tokenization/test_cached_tokenizer.py rename tests/{engine => tokenization}/test_detokenize.py (100%) create mode 100644 tests/tokenization/test_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/__init__.py create mode 100644 vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/tokenizer_group.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6a130f6fadcc..8badc16d0cb7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -28,7 +28,7 @@ steps: num_gpus: 2 # only support 1 or 2 for now. - label: Engine Test - command: pytest -v -s engine test_sequence.py + command: pytest -v -s engine tokenization test_sequence.py - label: Entrypoints Test command: pytest -v -s entrypoints diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index ed9017c1e3e9..248bfbc8ab5c 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,23 +25,21 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(): +def api_server(tokenizer_pool_size: int): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() uvicorn_process = subprocess.Popen([ - sys.executable, - "-u", - str(script_path), - "--model", - "facebook/opt-125m", - "--host", - "127.0.0.1", + sys.executable, "-u", + str(script_path), "--model", "facebook/opt-125m", "--host", + "127.0.0.1", "--tokenizer-pool-size", + str(tokenizer_pool_size) ]) yield uvicorn_process.terminate() -def test_api_server(api_server): +@pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) +def test_api_server(api_server, tokenizer_pool_size: int): """ Run the API server and test it. diff --git a/tests/conftest.py b/tests/conftest.py index 6eb8159837d5..c06b271e6c7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.config import TokenizerPoolConfig _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -258,3 +259,13 @@ def generate_beam_search( @pytest.fixture def vllm_runner(): return VllmRunner + + +def get_tokenizer_pool_config(tokenizer_group_type): + if tokenizer_group_type is None: + return None + if tokenizer_group_type == "ray": + return TokenizerPoolConfig(pool_size=1, + pool_type="ray", + extra_config={}) + raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py deleted file mode 100644 index 6c4c91fce812..000000000000 --- a/tests/lora/test_tokenizer.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer - - -@pytest.mark.asyncio -async def test_transformers_tokenizer(): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - assert reference_tokenizer.encode("prompt") == tokenizer.encode( - request_id="request_id", prompt="prompt", lora_request=None) - assert reference_tokenizer.encode( - "prompt") == await tokenizer.encode_async(request_id="request_id", - prompt="prompt", - lora_request=None) - assert isinstance(tokenizer.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - None) == await tokenizer.get_lora_tokenizer_async(None) - - -@pytest.mark.asyncio -async def test_transformers_tokenizer_lora(sql_lora_files): - reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=True, - max_num_seqs=1, - max_input_length=None, - ) - lora_request = LoRARequest("1", 1, sql_lora_files) - assert reference_tokenizer.encode("prompt") == tokenizer.encode( - request_id="request_id", prompt="prompt", lora_request=lora_request) - assert reference_tokenizer.encode( - "prompt") == await tokenizer.encode_async(request_id="request_id", - prompt="prompt", - lora_request=lora_request) - assert isinstance(tokenizer.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - None) == await tokenizer.get_lora_tokenizer_async(None) - - assert isinstance(tokenizer.get_lora_tokenizer(lora_request), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - lora_request) != tokenizer.get_lora_tokenizer(None) - assert tokenizer.get_lora_tokenizer( - lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) - - -def test_get_lora_tokenizer(sql_lora_files, tmpdir): - lora_request = None - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer - - lora_request = LoRARequest("1", 1, sql_lora_files) - tokenizer = get_lora_tokenizer(lora_request) - assert tokenizer.get_added_vocab() - - lora_request = LoRARequest("1", 1, str(tmpdir)) - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py new file mode 100644 index 000000000000..5fec3f179925 --- /dev/null +++ b/tests/lora/test_tokenizer_group.py @@ -0,0 +1,53 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer import get_lora_tokenizer +from ..conftest import get_tokenizer_pool_config + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) +async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer_group.encode_async( + request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer_group.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + None) == await tokenizer_group.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + lora_request) != tokenizer_group.get_lora_tokenizer(None) + assert tokenizer_group.get_lora_tokenizer( + lora_request) == await tokenizer_group.get_lora_tokenizer_async( + lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/tokenization/__init__.py b/tests/tokenization/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py new file mode 100644 index 000000000000..181e80032512 --- /dev/null +++ b/tests/tokenization/test_cached_tokenizer.py @@ -0,0 +1,20 @@ +from copy import deepcopy +from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from transformers import AutoTokenizer + + +def test_cached_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + reference_tokenizer.add_special_tokens({"cls_token": ""}) + reference_tokenizer.add_special_tokens( + {"additional_special_tokens": [""]}) + cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) + + assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( + "prompt") + assert set(reference_tokenizer.all_special_ids) == set( + cached_tokenizer.all_special_ids) + assert set(reference_tokenizer.all_special_tokens) == set( + cached_tokenizer.all_special_tokens) + assert set(reference_tokenizer.all_special_tokens_extended) == set( + cached_tokenizer.all_special_tokens_extended) diff --git a/tests/engine/test_detokenize.py b/tests/tokenization/test_detokenize.py similarity index 100% rename from tests/engine/test_detokenize.py rename to tests/tokenization/test_detokenize.py diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py new file mode 100644 index 000000000000..d0788ee87563 --- /dev/null +++ b/tests/tokenization/test_tokenizer_group.py @@ -0,0 +1,100 @@ +import os +import pytest +import asyncio +from unittest.mock import patch + +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( + RayTokenizerGroupPool) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) +from ..conftest import get_tokenizer_pool_config + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) +async def test_tokenizer_group(tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer_group.encode_async( + request_id="request_id", prompt="prompt", lora_request=None) + assert isinstance(tokenizer_group.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + None) == await tokenizer_group.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) +async def test_tokenizer_group_pool(tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer_group_pool = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + # Send multiple requests to the tokenizer group pool + # (more than the pool size) + # and check that all requests are processed correctly. + num_requests = tokenizer_group_pool.pool_size * 5 + requests = [ + tokenizer_group_pool.encode_async(request_id=str(i), + prompt=f"prompt {i}", + lora_request=None) + for i in range(num_requests) + ] + results = await asyncio.gather(*requests) + expected_results = [ + reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests) + ] + assert results == expected_results + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) +async def test_tokenizer_group_ray_pool_env_var_propagation( + tokenizer_group_type): + """Test that env vars from caller process are propagated to + tokenizer Ray actors.""" + env_var = "MY_ENV_VAR" + + class EnvVarCheckerTokenizerGroup(TokenizerGroup): + + def ping(self): + assert os.environ.get(env_var) == "1" + return super().ping() + + class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): + _worker_cls = EnvVarCheckerTokenizerGroup + + tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) + tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None) + with pytest.raises(AssertionError): + tokenizer_pool.ping() + + with patch.dict(os.environ, {env_var: "1"}): + tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) + tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None) + tokenizer_pool.ping() diff --git a/vllm/config.py b/vllm/config.py index de687395a000..f792e8909524 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,6 +3,7 @@ import os from packaging.version import Version +import json import torch from transformers import PretrainedConfig @@ -389,6 +390,58 @@ def verify_with_parallel_config( logger.warning("Possibly too large swap space. " + msg) +@dataclass +class TokenizerPoolConfig: + """Configuration for the tokenizer pool. + + Args: + pool_size: Number of tokenizer workers in the pool. + pool_type: Type of the pool. + extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. + """ + pool_size: int + pool_type: str + extra_config: dict + + def __post_init__(self): + if self.pool_type not in ("ray", ): + raise ValueError(f"Unknown pool type: {self.pool_type}") + if not isinstance(self.extra_config, dict): + raise ValueError("extra_config must be a dictionary.") + + @classmethod + def create_config( + cls, tokenizer_pool_size: int, tokenizer_pool_type: str, + tokenizer_pool_extra_config: Optional[Union[str, dict]] + ) -> Optional["TokenizerPoolConfig"]: + """Create a TokenizerPoolConfig from the given parameters. + + If tokenizer_pool_size is 0, return None. + + Args: + tokenizer_pool_size: Number of tokenizer workers in the pool. + tokenizer_pool_type: Type of the pool. + tokenizer_pool_extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. This can be a JSON string (will be parsed). + """ + if tokenizer_pool_size: + if isinstance(tokenizer_pool_extra_config, str): + tokenizer_pool_extra_config_parsed = json.loads( + tokenizer_pool_extra_config) + else: + tokenizer_pool_extra_config_parsed = ( + tokenizer_pool_extra_config or {}) + tokenizer_pool_config = cls(tokenizer_pool_size, + tokenizer_pool_type, + tokenizer_pool_extra_config_parsed) + else: + tokenizer_pool_config = None + return tokenizer_pool_config + + class ParallelConfig: """Configuration for the distributed execution. @@ -403,6 +456,8 @@ class ParallelConfig: parallel and large models. disable_custom_all_reduce: Disable the custom all-reduce kernel and fall back to NCCL. + tokenizer_pool_config: Config for the tokenizer pool. + If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. """ @@ -414,6 +469,7 @@ def __init__( worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, ) -> None: @@ -430,6 +486,7 @@ def __init__( self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce + self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c3dccdd5bb50..3e146d2e6c0c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,8 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) + ParallelConfig, SchedulerConfig, LoRAConfig, + TokenizerPoolConfig) @dataclass @@ -40,6 +41,9 @@ class EngineArgs: enforce_eager: bool = False max_context_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False + tokenizer_pool_size: int = 0 + tokenizer_pool_type: str = "ray" + tokenizer_pool_extra_config: Optional[dict] = None enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -249,6 +253,25 @@ def add_cli_args( action='store_true', default=EngineArgs.disable_custom_all_reduce, help='See ParallelConfig') + parser.add_argument('--tokenizer-pool-size', + type=int, + default=EngineArgs.tokenizer_pool_size, + help='Size of tokenizer pool to use for ' + 'asynchronous tokenization. If 0, will ' + 'use synchronous tokenization.') + parser.add_argument('--tokenizer-pool-type', + type=str, + default=EngineArgs.tokenizer_pool_type, + help='Type of tokenizer pool to use for ' + 'asynchronous tokenization. Ignored ' + 'if tokenizer_pool_size is 0.') + parser.add_argument('--tokenizer-pool-extra-config', + type=str, + default=EngineArgs.tokenizer_pool_extra_config, + help='Extra config for tokenizer pool. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary. Ignored if ' + 'tokenizer_pool_size is 0.') # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -312,14 +335,16 @@ def create_engine_configs( cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window(), - self.enable_prefix_caching) - parallel_config = ParallelConfig(self.pipeline_parallel_size, - self.tensor_parallel_size, - self.worker_use_ray, - self.max_parallel_loading_workers, - self.disable_custom_all_reduce, - self.ray_workers_use_nsight) + model_config.get_sliding_window()) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, self.tensor_parallel_size, + self.worker_use_ray, self.max_parallel_loading_workers, + self.disable_custom_all_reduce, + TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 691c9e83d59c..71798ab7d17c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -17,8 +17,9 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - TokenizerGroup) +from vllm.transformers_utils.tokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, + get_tokenizer_group) from vllm.utils import Counter logger = init_logger(__name__) @@ -102,6 +103,10 @@ def __init__( parallel_config, scheduler_config, device_config, lora_config) + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. @@ -152,6 +157,7 @@ def get_tokenizer_for_seq(self, def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( + tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None, @@ -159,8 +165,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: TokenizerGroup = TokenizerGroup( - self.model_config.tokenizer, **init_kwargs) + + self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( + self.parallel_config.tokenizer_pool_config, **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2600ea2642da..f7a1a19a89bc 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,12 +5,48 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import make_async, LRUCache +from vllm.utils import make_async from vllm.transformers_utils.tokenizers import * logger = init_logger(__name__) +def get_cached_tokenizer( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = ( + tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + + class CachedTokenizer(tokenizer.__class__): + + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + def get_tokenizer( tokenizer_name: str, *args, @@ -64,7 +100,7 @@ def get_tokenizer( logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead.") - return tokenizer + return get_cached_tokenizer(tokenizer) def get_lora_tokenizer(lora_request: LoRARequest, *args, @@ -88,65 +124,6 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) -class TokenizerGroup: - """A group of tokenizers that can be used for LoRA adapters.""" - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], **tokenizer_config): - self.tokenizer_id = tokenizer_id - self.tokenizer_config = tokenizer_config - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - if enable_lora: - self.lora_tokenizers = LRUCache(capacity=max_num_seqs) - else: - self.lora_tokenizers = None - - def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) - - async def encode_async( - self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (get_lora_tokenizer( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (await get_lora_tokenizer_async( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) - - def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py new file mode 100644 index 000000000000..adc8d9b90ddb --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -0,0 +1,32 @@ +from typing import Optional +from vllm.config import TokenizerPoolConfig +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) +from vllm.engine.ray_utils import ray + +if ray: + from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( + RayTokenizerGroupPool) +else: + RayTokenizerGroupPool = None + + +def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> BaseTokenizerGroup: + if tokenizer_pool_config is None: + return TokenizerGroup(**init_kwargs) + if tokenizer_pool_config.pool_type == "ray": + if RayTokenizerGroupPool is None: + raise ImportError( + "RayTokenizerGroupPool is not available. Please install " + "the ray package to use the Ray tokenizer group pool.") + return RayTokenizerGroupPool.from_config(tokenizer_pool_config, + **init_kwargs) + else: + raise ValueError( + f"Unknown pool type: {tokenizer_pool_config.pool_type}") + + +__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py new file mode 100644 index 000000000000..99518a606fab --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.lora.request import LoRARequest + + +class BaseTokenizerGroup(ABC): + """A group of tokenizers that can be used for LoRA adapters.""" + + @abstractmethod + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + pass + + @abstractmethod + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + pass + + @abstractmethod + def encode(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + async def encode_async(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + """Get a tokenizer for a LoRA request.""" + pass + + @abstractmethod + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + """Get a tokenizer for a LoRA request.""" + pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py new file mode 100644 index 000000000000..e048ec05bece --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -0,0 +1,166 @@ +import asyncio +import os +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.config import TokenizerPoolConfig +from vllm.lora.request import LoRARequest +from vllm.engine.ray_utils import ray +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + +class RayTokenizerGroupPool(BaseTokenizerGroup): + """A Ray-based pool of TokenizerGroups for async tokenization.""" + + # Class to use for workers making up the pool. + _worker_cls = TokenizerGroup + + @classmethod + def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, + **init_kwargs) -> "RayTokenizerGroupPool": + ray_actor_options = (tokenizer_pool_config.extra_config or { + "num_cpus": 0 + }) + ray_actor_options.setdefault( + "scheduling_strategy", + NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), soft=True)) + + # Carry over the env vars to the actors. + # This is necessary for API keys and such. + ray_actor_options.setdefault("runtime_env", {}) + _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) + + init_kwargs["num_actors"] = tokenizer_pool_config.pool_size + init_kwargs["ray_actor_options"] = ray_actor_options + + return cls(**init_kwargs) + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], num_actors: int, + ray_actor_options: dict, **tokenizer_config): + # Store a local copy of the TokenizerGroup for quick access + # to underlying HF tokenizers. + self._local_tokenizer_group = self._worker_cls( + tokenizer_id=tokenizer_id, + enable_lora=enable_lora, + max_num_seqs=max_num_seqs, + max_input_length=max_input_length, + ) + + ray_tokenizer_group_cls = ray.remote( + self._worker_cls).options(**ray_actor_options) + self.tokenizer_actors = [ + ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora, + max_num_seqs, max_input_length, + **tokenizer_config) + for _ in range(num_actors) + ] + self._idle_actors: Optional[asyncio.Queue] = None + + @property + def pool_size(self) -> int: + return len(self.tokenizer_actors) + + def ping(self): + return ray.get( + [actor.ping.remote() for actor in self.tokenizer_actors]) + + def _ensure_queue_initialized(self): + if self._idle_actors is None: + self._idle_actors = asyncio.Queue() + for actor in self.tokenizer_actors: + self._idle_actors.put_nowait(actor) + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + The actor is then put back in the queue for future use. + This is blocking. + """ + self._ensure_queue_initialized() + + if self._idle_actors.empty(): + raise RuntimeError("No idle actors available.") + actor = self._idle_actors.get_nowait() + try: + ret = ray.get( + actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request)) + finally: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + return ret + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + If there are no idle actors, we wait until one becomes + available. + The actor is then put back in the queue for future use. + This is non-blocking. + """ + self._ensure_queue_initialized() + + actor = await self._idle_actors.get() + try: + ret = await actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + finally: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + return ret + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self._local_tokenizer_group.get_max_input_len(lora_request) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + return self._local_tokenizer_group.get_lora_tokenizer(lora_request) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + return await self._local_tokenizer_group.get_lora_tokenizer_async( + lora_request) + + +def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: + """Copy over all current process environment variables to the runtime_env. + + The variables in runtime_env will take precedence over the current process + environment variables. + + runtime_env will be modified in place.""" + env_vars = os.environ.copy() + runtime_env.setdefault("env_vars", {}) + env_vars.update(runtime_env["env_vars"]) + runtime_env["env_vars"] = env_vars diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py new file mode 100644 index 000000000000..3af1334cb5ed --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -0,0 +1,80 @@ +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, + get_lora_tokenizer_async) +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.utils import LRUCache +from vllm.transformers_utils.tokenizer import get_tokenizer + + +class TokenizerGroup(BaseTokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + return True + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self.max_input_length + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id)