From e323b923fd668367846eaa805c4b928bc09d6635 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 14 Feb 2024 14:35:36 -0800 Subject: [PATCH 01/14] Asynchronous tokenization --- tests/async_engine/test_api_server.py | 16 ++- vllm/config.py | 8 ++ vllm/engine/arg_utils.py | 40 ++++++- vllm/engine/llm_engine.py | 36 +++++- vllm/transformers_utils/tokenizer.py | 165 +++++++++++++++++++++++++- 5 files changed, 244 insertions(+), 21 deletions(-) diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index ed9017c1e3e9..326682ab18c2 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,23 +25,21 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(): +def api_server(num_tokenizer_actors: int): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() uvicorn_process = subprocess.Popen([ - sys.executable, - "-u", - str(script_path), - "--model", - "facebook/opt-125m", - "--host", - "127.0.0.1", + sys.executable, "-u", + str(script_path), "--model", "facebook/opt-125m", "--host", + "127.0.0.1", "--num-tokenizer-actors", + str(num_tokenizer_actors) ]) yield uvicorn_process.terminate() -def test_api_server(api_server): +@pytest.mark.parametrize("num_tokenizer_actors", [0, 2]) +def test_api_server(api_server, num_tokenizer_actors: int): """ Run the API server and test it. diff --git a/vllm/config.py b/vllm/config.py index 27c61d4d5043..6c9293d57370 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -363,6 +363,10 @@ class ParallelConfig: parallel and large models. disable_custom_all_reduce: Disable the custom all-reduce kernel and fall back to NCCL. + num_tokenizer_actors: Number of tokenizer actors to use for + asynchronous tokenization with Ray. If 0, will use + synchronous tokenization. + tokenizer_actor_options: Options for tokenizer Ray Actors. """ def __init__( @@ -372,12 +376,16 @@ def __init__( worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, + num_tokenizer_actors: int = 0, + tokenizer_actor_options: Optional[dict] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce + self.num_tokenizer_actors = num_tokenizer_actors + self.tokenizer_actor_options = tokenizer_actor_options self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d5e63e25d6e8..6fc72c586b95 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,4 +1,5 @@ import argparse +import json import dataclasses from dataclasses import dataclass from typing import Optional, Tuple @@ -7,6 +8,19 @@ ParallelConfig, SchedulerConfig, LoRAConfig) +class _StoreJsonAction(argparse._StoreAction): + + def __call__(self, parser, namespace, values, option_string=None): + json_values = [] + for value in values: + try: + json_values.append(json.loads(value)) + except json.JSONDecodeError as e: + raise argparse.ArgumentTypeError( + f'Invalid JSON string: {value}') from e + setattr(namespace, self.dest, json_values) + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" @@ -37,6 +51,8 @@ class EngineArgs: enforce_eager: bool = False max_context_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False + num_tokenizer_actors: int = 0 + tokenizer_actor_options: Optional[dict] = None enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -224,6 +240,20 @@ def add_cli_args( action='store_true', default=EngineArgs.disable_custom_all_reduce, help='See ParallelConfig') + parser.add_argument('--num-tokenizer-actors', + type=int, + default=EngineArgs.num_tokenizer_actors, + help='Number of tokenizer actors to use for ' + 'asynchronous tokenization with Ray. If 0, will ' + 'use synchronous tokenization.') + parser.add_argument('--tokenizer-actor-options', + type=str, + default=EngineArgs.tokenizer_actor_options, + action=_StoreJsonAction, + help='Options for tokenizer Ray actors. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary. Ignored if ' + 'num_tokenizer_actors is 0.') # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -290,11 +320,11 @@ def create_engine_configs( self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) - parallel_config = ParallelConfig(self.pipeline_parallel_size, - self.tensor_parallel_size, - self.worker_use_ray, - self.max_parallel_loading_workers, - self.disable_custom_all_reduce) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, self.tensor_parallel_size, + self.worker_use_ray, self.max_parallel_loading_workers, + self.disable_custom_all_reduce, self.num_tokenizer_actors, + self.tokenizer_actor_options) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 86f092520930..989000733bcd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -19,11 +19,13 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - TokenizerGroup) + BaseTokenizerGroup, + TokenizerGroup, + RayTokenizerGroupPool) from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method if ray: - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -119,6 +121,9 @@ def __init__( else: self._init_workers() + # Make sure the tokenizer actors are alive + self.tokenizer.ping() + # Profile the memory usage and initialize the cache. self._init_cache() @@ -172,8 +177,31 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: TokenizerGroup = TokenizerGroup( - self.model_config.tokenizer, **init_kwargs) + + if self.parallel_config.num_tokenizer_actors > 0: + if not RayTokenizerGroupPool: + raise ImportError( + "RayTokenizerGroupPool is not available. Please install " + "the ray package to use the tokenizer actors or " + "set `num_tokenizer_actors` to 0.") + ray_actor_options = (self.parallel_config.tokenizer_actor_options + or { + "num_cpus": 0 + }) + ray_actor_options.setdefault( + "scheduling_strategy", + NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=True)) + + init_kwargs[ + "num_actors"] = self.parallel_config.num_tokenizer_actors + init_kwargs["ray_actor_options"] = ray_actor_options + self.tokenizer: BaseTokenizerGroup = RayTokenizerGroupPool( + self.model_config.tokenizer, **init_kwargs) + else: + self.tokenizer: TokenizerGroup = TokenizerGroup( + self.model_config.tokenizer, **init_kwargs) def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 6edc225cdfc8..7d79aa43352c 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,3 +1,6 @@ +import asyncio +import os +from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, @@ -6,11 +9,46 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.utils import make_async, LRUCache +from vllm.engine.ray_utils import ray from vllm.transformers_utils.tokenizers import * logger = init_logger(__name__) +def _get_cached_tokenizer( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Get tokenizer with cached properties. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = ( + tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + + class CachedTokenizer(tokenizer.__class__): + + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + def get_tokenizer( tokenizer_name: str, *args, @@ -64,7 +102,7 @@ def get_tokenizer( logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead.") - return tokenizer + return _get_cached_tokenizer(tokenizer) def get_lora_tokenizer(lora_request: LoRARequest, *args, @@ -88,8 +126,7 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) -class TokenizerGroup: - """A group of tokenizers that can be used for LoRA adapters.""" +class BaseTokenizerGroup(ABC): def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int], **tokenizer_config): @@ -103,6 +140,40 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, else: self.lora_tokenizers = None + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + return self.max_input_length + + def ping(self): + return True + + @abstractmethod + def encode(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + pass + + async def encode_async(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + return self.encode(prompt=prompt, + request_id=request_id, + lora_request=lora_request) + + @abstractmethod + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + ... + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + return self.get_lora_tokenizer(lora_request) + + +class TokenizerGroup(BaseTokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + def encode(self, prompt: str, request_id: Optional[str] = None, @@ -145,6 +216,94 @@ async def get_lora_tokenizer_async( return self.lora_tokenizers.get(lora_request.lora_int_id) +if ray: + RayTokenizerGroup = ray.remote(TokenizerGroup) + + class RayTokenizerGroupPool(BaseTokenizerGroup): + """A pool of TokenizerGroups for async tokenization.""" + + def __init__( # pylint: disable=super-init-not-called + self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], num_actors: int, + ray_actor_options: dict, **tokenizer_config): + self.tokenizer = TokenizerGroup(tokenizer_id, enable_lora, + max_num_seqs, max_input_length, + **tokenizer_config) + self.max_input_length = max_input_length + + # Carry over the env vars to the actors. + # This is necessary for API keys and such. + ray_actor_options.setdefault("runtime_env", {}) + env_vars = os.environ.copy() + ray_actor_options["runtime_env"].setdefault("env_vars", {}) + env_vars.update(ray_actor_options["runtime_env"]["env_vars"]) + ray_actor_options["runtime_env"]["env_vars"] = env_vars + + ray_tokenizer_cls = RayTokenizerGroup.options(**ray_actor_options) + self.tokenizer_actors = [ + ray_tokenizer_cls.remote(tokenizer_id, enable_lora, + max_num_seqs, max_input_length, + **tokenizer_config) + for _ in range(num_actors) + ] + self._idle_actors = None + + def ping(self): + return ray.get( + [actor.ping.remote() for actor in self.tokenizer_actors]) + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + if self._idle_actors is None: + self._idle_actors = asyncio.Queue() + for actor in self.tokenizer_actors: + self._idle_actors.put_nowait(actor) + if self._idle_actors.empty(): + raise RuntimeError("No idle actors available.") + actor = self._idle_actors.get_nowait() + try: + ret = ray.get( + actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request)) + finally: + self._idle_actors.put_nowait(actor) + return ret + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + if self._idle_actors is None: + self._idle_actors = asyncio.Queue() + for actor in self.tokenizer_actors: + self._idle_actors.put_nowait(actor) + actor = await self._idle_actors.get() + try: + ret = await actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + finally: + self._idle_actors.put_nowait(actor) + return ret + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + return self.tokenizer.get_lora_tokenizer(lora_request) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + +else: + RayTokenizerGroupPool = None + + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], From a29f4e6245f466b2e2c8e10e56dc48463933c188 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 14:58:37 -0700 Subject: [PATCH 02/14] WIP --- tests/async_engine/test_api_server.py | 10 +- vllm/config.py | 32 ++- vllm/engine/arg_utils.py | 40 ++-- vllm/engine/llm_engine.py | 10 +- vllm/transformers_utils/tokenizer.py | 186 +----------------- .../tokenizer_group/__init__.py | 25 +++ .../tokenizer_group/base_tokenizer_group.py | 59 ++++++ .../tokenizer_group/ray_tokenizer_group.py | 139 +++++++++++++ .../tokenizer_group/tokenizer_group.py | 56 ++++++ 9 files changed, 342 insertions(+), 215 deletions(-) create mode 100644 vllm/transformers_utils/tokenizer_group/__init__.py create mode 100644 vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py create mode 100644 vllm/transformers_utils/tokenizer_group/tokenizer_group.py diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 326682ab18c2..248bfbc8ab5c 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,21 +25,21 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(num_tokenizer_actors: int): +def api_server(tokenizer_pool_size: int): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() uvicorn_process = subprocess.Popen([ sys.executable, "-u", str(script_path), "--model", "facebook/opt-125m", "--host", - "127.0.0.1", "--num-tokenizer-actors", - str(num_tokenizer_actors) + "127.0.0.1", "--tokenizer-pool-size", + str(tokenizer_pool_size) ]) yield uvicorn_process.terminate() -@pytest.mark.parametrize("num_tokenizer_actors", [0, 2]) -def test_api_server(api_server, num_tokenizer_actors: int): +@pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) +def test_api_server(api_server, tokenizer_pool_size: int): """ Run the API server and test it. diff --git a/vllm/config.py b/vllm/config.py index b88ba220faf1..c287f47eb580 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -379,6 +379,26 @@ def verify_with_parallel_config( logger.warning("Possibly too large swap space. " + msg) +@dataclass +class TokenizerPoolConfig: + """Configuration for the tokenizer pool. + + Args: + pool_size: Number of tokenizer workers in the pool. + pool_type: Type of the pool. + extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. + """ + pool_size: int + pool_type: str + extra_config: dict + + def __post_init__(self): + if self.pool_type not in ("ray"): + raise ValueError(f"Unknown pool type: {self.pool_type}") + + class ParallelConfig: """Configuration for the distributed execution. @@ -393,10 +413,8 @@ class ParallelConfig: parallel and large models. disable_custom_all_reduce: Disable the custom all-reduce kernel and fall back to NCCL. - num_tokenizer_actors: Number of tokenizer actors to use for - asynchronous tokenization with Ray. If 0, will use - synchronous tokenization. - tokenizer_actor_options: Options for tokenizer Ray Actors. + tokenizer_pool_config: Config for the tokenizer pool. + If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. """ @@ -408,8 +426,7 @@ def __init__( worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, - num_tokenizer_actors: int = 0, - tokenizer_actor_options: Optional[dict] = None, + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, ) -> None: @@ -426,8 +443,7 @@ def __init__( self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce - self.num_tokenizer_actors = num_tokenizer_actors - self.tokenizer_actor_options = tokenizer_actor_options + self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 930077771829..ecb0a6d41529 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,7 +5,8 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) + ParallelConfig, SchedulerConfig, LoRAConfig, + TokenizerPoolConfig) class _StoreJsonAction(argparse._StoreAction): @@ -54,8 +55,9 @@ class EngineArgs: enforce_eager: bool = False max_context_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False - num_tokenizer_actors: int = 0 - tokenizer_actor_options: Optional[dict] = None + tokenizer_pool_size: int = 0 + tokenizer_pool_type: str = "ray" + tokenizer_pool_config: Optional[dict] = None enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -265,20 +267,26 @@ def add_cli_args( action='store_true', default=EngineArgs.disable_custom_all_reduce, help='See ParallelConfig') - parser.add_argument('--num-tokenizer-actors', + parser.add_argument('--tokenizer-pool-size', type=int, - default=EngineArgs.num_tokenizer_actors, - help='Number of tokenizer actors to use for ' - 'asynchronous tokenization with Ray. If 0, will ' + default=EngineArgs.tokenizer_pool_size, + help='Size of tokenizer pool to use for ' + 'asynchronous tokenization. If 0, will ' 'use synchronous tokenization.') - parser.add_argument('--tokenizer-actor-options', + parser.add_argument('--tokenizer-pool-type', type=str, - default=EngineArgs.tokenizer_actor_options, + default=EngineArgs.tokenizer_pool_type, + help='Type of tokenizer pool to use for ' + 'asynchronous tokenization. Ignored ' + 'if tokenizer_pool_size is 0.') + parser.add_argument('--tokenizer-pool-config', + type=str, + default=EngineArgs.tokenizer_pool_config, action=_StoreJsonAction, - help='Options for tokenizer Ray actors. ' + help='Config for tokenizer pool. ' 'This should be a JSON string that will be ' 'parsed into a dictionary. Ignored if ' - 'num_tokenizer_actors is 0.') + 'tokenizer_pool_size is 0.') # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -343,11 +351,17 @@ def create_engine_configs( self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) + if self.tokenizer_pool_size: + tokenizer_pool_config = TokenizerPoolConfig( + self.tokenizer_pool_size, self.tokenizer_pool_type, + self.tokenizer_pool_config or {}) + else: + tokenizer_pool_config = None parallel_config = ParallelConfig( self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, self.max_parallel_loading_workers, - self.disable_custom_all_reduce, self.num_tokenizer_actors, - self.tokenizer_actor_options, self.ray_workers_use_nsight) + self.disable_custom_all_reduce, tokenizer_pool_config, + self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 849be0d4658e..a5a22f2f52b5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -17,8 +17,9 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - TokenizerGroup) +from vllm.transformers_utils.tokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, + get_tokenizer_group) from vllm.utils import Counter logger = init_logger(__name__) @@ -155,6 +156,7 @@ def get_tokenizer_for_seq(self, def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( + tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None, @@ -163,8 +165,8 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: TokenizerGroup = TokenizerGroup( - self.model_config.tokenizer, **init_kwargs) + self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( + self.parallel_config.tokenizer_pool_config, **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 16c734568670..746713e24323 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,6 +1,3 @@ -import asyncio -import os -from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, @@ -8,8 +5,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import make_async, LRUCache -from vllm.engine.ray_utils import ray +from vllm.utils import make_async from vllm.transformers_utils.tokenizers import * logger = init_logger(__name__) @@ -126,186 +122,6 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) -class BaseTokenizerGroup(ABC): - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], **tokenizer_config): - self.tokenizer_id = tokenizer_id - self.tokenizer_config = tokenizer_config - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - if enable_lora: - self.lora_tokenizers = LRUCache(capacity=max_num_seqs) - else: - self.lora_tokenizers = None - - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - return self.max_input_length - - def ping(self): - return True - - @abstractmethod - def encode(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: - pass - - async def encode_async(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: - return self.encode(prompt=prompt, - request_id=request_id, - lora_request=lora_request) - - @abstractmethod - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - ... - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - return self.get_lora_tokenizer(lora_request) - - -class TokenizerGroup(BaseTokenizerGroup): - """A group of tokenizers that can be used for LoRA adapters.""" - - def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) - - async def encode_async( - self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (get_lora_tokenizer( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (await get_lora_tokenizer_async( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) - - -if ray: - RayTokenizerGroup = ray.remote(TokenizerGroup) - - class RayTokenizerGroupPool(BaseTokenizerGroup): - """A pool of TokenizerGroups for async tokenization.""" - - def __init__( # pylint: disable=super-init-not-called - self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], num_actors: int, - ray_actor_options: dict, **tokenizer_config): - self.tokenizer = TokenizerGroup(tokenizer_id, enable_lora, - max_num_seqs, max_input_length, - **tokenizer_config) - self.max_input_length = max_input_length - - # Carry over the env vars to the actors. - # This is necessary for API keys and such. - ray_actor_options.setdefault("runtime_env", {}) - env_vars = os.environ.copy() - ray_actor_options["runtime_env"].setdefault("env_vars", {}) - env_vars.update(ray_actor_options["runtime_env"]["env_vars"]) - ray_actor_options["runtime_env"]["env_vars"] = env_vars - - ray_tokenizer_cls = RayTokenizerGroup.options(**ray_actor_options) - self.tokenizer_actors = [ - ray_tokenizer_cls.remote(tokenizer_id, enable_lora, - max_num_seqs, max_input_length, - **tokenizer_config) - for _ in range(num_actors) - ] - self._idle_actors = None - - def ping(self): - return ray.get( - [actor.ping.remote() for actor in self.tokenizer_actors]) - - def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - if self._idle_actors is None: - self._idle_actors = asyncio.Queue() - for actor in self.tokenizer_actors: - self._idle_actors.put_nowait(actor) - if self._idle_actors.empty(): - raise RuntimeError("No idle actors available.") - actor = self._idle_actors.get_nowait() - try: - ret = ray.get( - actor.encode.remote(request_id=request_id, - prompt=prompt, - lora_request=lora_request)) - finally: - self._idle_actors.put_nowait(actor) - return ret - - async def encode_async( - self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: - if self._idle_actors is None: - self._idle_actors = asyncio.Queue() - for actor in self.tokenizer_actors: - self._idle_actors.put_nowait(actor) - actor = await self._idle_actors.get() - try: - ret = await actor.encode.remote(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - finally: - self._idle_actors.put_nowait(actor) - return ret - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(lora_request) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - -else: - RayTokenizerGroupPool = None - - def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py new file mode 100644 index 000000000000..3fb98bfa8d5a --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -0,0 +1,25 @@ +from typing import Optional +from vllm.config import TokenizerPoolConfig +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup +from vllm.engine.ray_utils import ray + +if ray: + from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import RayTokenizerGroupPool +else: + RayTokenizerGroupPool = None + +def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: + if tokenizer_pool_config is None: + return TokenizerGroup(**init_kwargs) + else: + if tokenizer_pool_config.pool_type == "ray": + if RayTokenizerGroupPool is None: + raise ImportError( + "RayTokenizerGroupPool is not available. Please install " + "the ray package to use the Ray tokenizer group pool.") + return RayTokenizerGroupPool.from_config(tokenizer_pool_config, **init_kwargs) + else: + raise ValueError(f"Unknown pool type: {tokenizer_pool_config.pool_type}") + +__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py new file mode 100644 index 000000000000..468b00819a3a --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -0,0 +1,59 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.lora.request import LoRARequest +from vllm.utils import LRUCache +from vllm.transformers_utils.tokenizer import get_tokenizer + +class BaseTokenizerGroup(ABC): + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self.max_input_length + + def ping(self): + """Check if the tokenizer group is alive.""" + return True + + @abstractmethod + def encode(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + async def encode_async(self, prompt: str, request_id: Optional[str], + lora_request: Optional[LoRARequest]) -> List[int]: + """Encode a prompt using the tokenizer group.""" + return self.encode(prompt=prompt, + request_id=request_id, + lora_request=lora_request) + + @abstractmethod + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + ... + """Get a tokenizer for a LoRA request.""" + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + """Get a tokenizer for a LoRA request.""" + return self.get_lora_tokenizer(lora_request) + diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py new file mode 100644 index 000000000000..f1205bd634aa --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -0,0 +1,139 @@ +import asyncio +import os +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.config import TokenizerPoolConfig +from vllm.lora.request import LoRARequest +from vllm.engine.ray_utils import ray +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +RayTokenizerGroup = ray.remote(TokenizerGroup) + +class RayTokenizerGroupPool(BaseTokenizerGroup): + """A Ray-based pool of TokenizerGroups for async tokenization.""" + + @classmethod + def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, **init_kwargs) -> "RayTokenizerGroupPool": + ray_actor_options = (tokenizer_pool_config.extra_config + or { + "num_cpus": 0 + }) + ray_actor_options.setdefault( + "scheduling_strategy", + NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=True)) + + # Carry over the env vars to the actors. + # This is necessary for API keys and such. + ray_actor_options.setdefault("runtime_env", {}) + env_vars = os.environ.copy() + ray_actor_options["runtime_env"].setdefault("env_vars", {}) + env_vars.update(ray_actor_options["runtime_env"]["env_vars"]) + ray_actor_options["runtime_env"]["env_vars"] = env_vars + + init_kwargs[ + "num_actors"] = tokenizer_pool_config.pool_size + init_kwargs["ray_actor_options"] = ray_actor_options + + return RayTokenizerGroupPool(**init_kwargs) + + def __init__( # pylint: disable=super-init-not-called + self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], num_actors: int, + ray_actor_options: dict, **tokenizer_config): + # Store a local copy of the TokenizerGroup for quick access + # to underlying HF tokenizers. + self.tokenizer = TokenizerGroup(tokenizer_id, enable_lora, + max_num_seqs, max_input_length, + **tokenizer_config) + self.max_input_length = max_input_length + + ray_tokenizer_cls = RayTokenizerGroup.options(**ray_actor_options) + self.tokenizer_actors = [ + ray_tokenizer_cls.remote(tokenizer_id, enable_lora, + max_num_seqs, max_input_length, + **tokenizer_config) + for _ in range(num_actors) + ] + self._idle_actors: Optional[asyncio.Queue] = None + + def ping(self): + return ray.get( + [actor.ping.remote() for actor in self.tokenizer_actors]) + + def _maybe_init_queue(self): + if self._idle_actors is None: + self._idle_actors = asyncio.Queue() + for actor in self.tokenizer_actors: + self._idle_actors.put_nowait(actor) + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + The actor is then put back in the queue for future use. + This is blocking. + """ + self._maybe_init_queue() + + if self._idle_actors.empty(): + raise RuntimeError("No idle actors available.") + actor = self._idle_actors.get_nowait() + try: + ret = ray.get( + actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request)) + finally: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + return ret + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + If there are no idle actors, we wait until one becomes + available. + The actor is then put back in the queue for future use. + This is non-blocking. + """ + self._maybe_init_queue() + + actor = await self._idle_actors.get() + try: + ret = await actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + finally: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + return ret + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + return self.tokenizer.get_lora_tokenizer(lora_request) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + return await self.tokenizer.get_lora_tokenizer_async(lora_request) \ No newline at end of file diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py new file mode 100644 index 000000000000..6f0a0f7acfc1 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -0,0 +1,56 @@ +from typing import List, Optional + +from transformers import PreTrainedTokenizer + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import get_lora_tokenizer, get_lora_tokenizer_async +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup + + + +class TokenizerGroup(BaseTokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + From 5cfa7fc70fd5eb37cad7ac0dde5656d2c4a38ebc Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 14:59:32 -0700 Subject: [PATCH 03/14] Update vllm/engine/llm_engine.py --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a5a22f2f52b5..c23584f630f5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -103,7 +103,7 @@ def __init__( parallel_config, scheduler_config, device_config, lora_config) - # Make sure the tokenizer actors are alive + # Ping the tokenizer to ensure liveness if it runs in a different process. self.tokenizer.ping() # Create the scheduler. From ba38a0b5b598147e7546f190b119af40ece44a48 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 15:42:47 -0700 Subject: [PATCH 04/14] WIP --- .buildkite/test-pipeline.yaml | 2 +- tests/conftest.py | 11 ++ tests/lora/test_tokenizer.py | 69 ------------ tests/lora/test_tokenizer_group.py | 53 +++++++++ tests/tokenization/__init__.py | 0 tests/tokenization/test_cached_tokenizer.py | 17 +++ .../test_detokenize.py | 0 tests/tokenization/test_tokenizer_group.py | 102 ++++++++++++++++++ vllm/engine/llm_engine.py | 3 +- .../tokenizer_group/__init__.py | 20 ++-- .../tokenizer_group/base_tokenizer_group.py | 2 +- .../tokenizer_group/ray_tokenizer_group.py | 50 +++++---- .../tokenizer_group/tokenizer_group.py | 8 +- 13 files changed, 234 insertions(+), 103 deletions(-) delete mode 100644 tests/lora/test_tokenizer.py create mode 100644 tests/lora/test_tokenizer_group.py create mode 100644 tests/tokenization/__init__.py create mode 100644 tests/tokenization/test_cached_tokenizer.py rename tests/{engine => tokenization}/test_detokenize.py (100%) create mode 100644 tests/tokenization/test_tokenizer_group.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 42a1eacb6de5..5a82726a6cbe 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -28,7 +28,7 @@ steps: num_gpus: 2 # only support 1 or 2 for now. - label: Engine Test - command: pytest -v -s engine test_sequence.py + command: pytest -v -s engine tokenization test_sequence.py - label: Entrypoints Test command: pytest -v -s entrypoints diff --git a/tests/conftest.py b/tests/conftest.py index 6eb8159837d5..c06b271e6c7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.config import TokenizerPoolConfig _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -258,3 +259,13 @@ def generate_beam_search( @pytest.fixture def vllm_runner(): return VllmRunner + + +def get_tokenizer_pool_config(tokenizer_group_type): + if tokenizer_group_type is None: + return None + if tokenizer_group_type == "ray": + return TokenizerPoolConfig(pool_size=1, + pool_type="ray", + extra_config={}) + raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py deleted file mode 100644 index 6c4c91fce812..000000000000 --- a/tests/lora/test_tokenizer.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer - - -@pytest.mark.asyncio -async def test_transformers_tokenizer(): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - assert reference_tokenizer.encode("prompt") == tokenizer.encode( - request_id="request_id", prompt="prompt", lora_request=None) - assert reference_tokenizer.encode( - "prompt") == await tokenizer.encode_async(request_id="request_id", - prompt="prompt", - lora_request=None) - assert isinstance(tokenizer.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - None) == await tokenizer.get_lora_tokenizer_async(None) - - -@pytest.mark.asyncio -async def test_transformers_tokenizer_lora(sql_lora_files): - reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=True, - max_num_seqs=1, - max_input_length=None, - ) - lora_request = LoRARequest("1", 1, sql_lora_files) - assert reference_tokenizer.encode("prompt") == tokenizer.encode( - request_id="request_id", prompt="prompt", lora_request=lora_request) - assert reference_tokenizer.encode( - "prompt") == await tokenizer.encode_async(request_id="request_id", - prompt="prompt", - lora_request=lora_request) - assert isinstance(tokenizer.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - None) == await tokenizer.get_lora_tokenizer_async(None) - - assert isinstance(tokenizer.get_lora_tokenizer(lora_request), - PreTrainedTokenizerBase) - assert tokenizer.get_lora_tokenizer( - lora_request) != tokenizer.get_lora_tokenizer(None) - assert tokenizer.get_lora_tokenizer( - lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) - - -def test_get_lora_tokenizer(sql_lora_files, tmpdir): - lora_request = None - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer - - lora_request = LoRARequest("1", 1, sql_lora_files) - tokenizer = get_lora_tokenizer(lora_request) - assert tokenizer.get_added_vocab() - - lora_request = LoRARequest("1", 1, str(tmpdir)) - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py new file mode 100644 index 000000000000..5fec3f179925 --- /dev/null +++ b/tests/lora/test_tokenizer_group.py @@ -0,0 +1,53 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer import get_lora_tokenizer +from ..conftest import get_tokenizer_pool_config + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) +async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer_group.encode_async( + request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer_group.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + None) == await tokenizer_group.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + lora_request) != tokenizer_group.get_lora_tokenizer(None) + assert tokenizer_group.get_lora_tokenizer( + lora_request) == await tokenizer_group.get_lora_tokenizer_async( + lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/tokenization/__init__.py b/tests/tokenization/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py new file mode 100644 index 000000000000..a2ddff51ce1b --- /dev/null +++ b/tests/tokenization/test_cached_tokenizer.py @@ -0,0 +1,17 @@ +from copy import deepcopy +from vllm.transformers_utils.tokenizer import _get_cached_tokenizer +from transformers import AutoTokenizer + + +def test_cached_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + cached_tokenizer = _get_cached_tokenizer(deepcopy(reference_tokenizer)) + + assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( + "prompt") + assert set(reference_tokenizer.all_special_ids) == set( + cached_tokenizer.all_special_ids) + assert set(reference_tokenizer.all_special_tokens) == set( + cached_tokenizer.all_special_tokens) + assert set(reference_tokenizer.all_special_tokens_extended) == set( + cached_tokenizer.all_special_tokens_extended) diff --git a/tests/engine/test_detokenize.py b/tests/tokenization/test_detokenize.py similarity index 100% rename from tests/engine/test_detokenize.py rename to tests/tokenization/test_detokenize.py diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py new file mode 100644 index 000000000000..ea4a2181c2fe --- /dev/null +++ b/tests/tokenization/test_tokenizer_group.py @@ -0,0 +1,102 @@ +import os +import pytest +import asyncio +import ray +from unittest.mock import patch + +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( + RayTokenizerGroupPool) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) +from ..conftest import get_tokenizer_pool_config + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) +async def test_tokenizer_group(tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer_group.encode_async( + request_id="request_id", prompt="prompt", lora_request=None) + assert isinstance(tokenizer_group.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer_group.get_lora_tokenizer( + None) == await tokenizer_group.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) +async def test_tokenizer_group_pool(tokenizer_group_type): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer_group_pool = get_tokenizer_group( + get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + # Send multiple requests to the tokenizer group pool + # (more than the pool size) + # and check that all requests are processed correctly. + num_requests = tokenizer_group_pool.pool_size * 5 + requests = [ + tokenizer_group_pool.encode_async(request_id=str(i), + prompt=f"prompt {i}", + lora_request=None) + for i in range(num_requests) + ] + results = await asyncio.gather(*requests) + expected_results = [ + reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests) + ] + assert results == expected_results + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) +async def test_tokenizer_group_ray_pool_env_var_propagation( + tokenizer_group_type): + """Test that env vars from caller process are propagated to + tokenizer Ray actors.""" + env_var = "MY_ENV_VAR" + + @ray.remote + class EnvVarCheckerRayTokenizerGroup(TokenizerGroup): + + def ping(self): + assert os.environ.get(env_var) == "1" + return super().ping() + + class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): + _ray_tokenizer_group_cls = EnvVarCheckerRayTokenizerGroup + + tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) + tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None) + with pytest.raises(AssertionError): + tokenizer_pool.ping() + + with patch.dict(os.environ, {env_var: "1"}): + tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) + tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( + tokenizer_pool_config, + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None) + tokenizer_pool.ping() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c23584f630f5..c691c913c13b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -103,7 +103,8 @@ def __init__( parallel_config, scheduler_config, device_config, lora_config) - # Ping the tokenizer to ensure liveness if it runs in a different process. + # Ping the tokenizer to ensure liveness if it runs in a + # different process. self.tokenizer.ping() # Create the scheduler. diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 3fb98bfa8d5a..48a8d75a5649 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,15 +1,20 @@ from typing import Optional from vllm.config import TokenizerPoolConfig -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup -from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) from vllm.engine.ray_utils import ray if ray: - from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import RayTokenizerGroupPool + from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( + RayTokenizerGroupPool) else: RayTokenizerGroupPool = None -def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: + +def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> BaseTokenizerGroup: if tokenizer_pool_config is None: return TokenizerGroup(**init_kwargs) else: @@ -18,8 +23,11 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], ** raise ImportError( "RayTokenizerGroupPool is not available. Please install " "the ray package to use the Ray tokenizer group pool.") - return RayTokenizerGroupPool.from_config(tokenizer_pool_config, **init_kwargs) + return RayTokenizerGroupPool.from_config(tokenizer_pool_config, + **init_kwargs) else: - raise ValueError(f"Unknown pool type: {tokenizer_pool_config.pool_type}") + raise ValueError( + f"Unknown pool type: {tokenizer_pool_config.pool_type}") + __all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 468b00819a3a..62c600af52af 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -7,6 +7,7 @@ from vllm.utils import LRUCache from vllm.transformers_utils.tokenizer import get_tokenizer + class BaseTokenizerGroup(ABC): def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, @@ -56,4 +57,3 @@ async def get_lora_tokenizer_async( lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" return self.get_lora_tokenizer(lora_request) - diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index f1205bd634aa..37969e6f54eb 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -7,26 +7,30 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest from vllm.engine.ray_utils import ray -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup -from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( + TokenizerGroup) from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy RayTokenizerGroup = ray.remote(TokenizerGroup) + class RayTokenizerGroupPool(BaseTokenizerGroup): """A Ray-based pool of TokenizerGroups for async tokenization.""" + _ray_tokenizer_group_cls = RayTokenizerGroup + @classmethod - def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, **init_kwargs) -> "RayTokenizerGroupPool": - ray_actor_options = (tokenizer_pool_config.extra_config - or { - "num_cpus": 0 - }) + def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, + **init_kwargs) -> "RayTokenizerGroupPool": + ray_actor_options = (tokenizer_pool_config.extra_config or { + "num_cpus": 0 + }) ray_actor_options.setdefault( "scheduling_strategy", NodeAffinitySchedulingStrategy( - node_id=ray.get_runtime_context().get_node_id(), - soft=True)) + node_id=ray.get_runtime_context().get_node_id(), soft=True)) # Carry over the env vars to the actors. # This is necessary for API keys and such. @@ -36,11 +40,10 @@ def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, **init_kwargs) env_vars.update(ray_actor_options["runtime_env"]["env_vars"]) ray_actor_options["runtime_env"]["env_vars"] = env_vars - init_kwargs[ - "num_actors"] = tokenizer_pool_config.pool_size + init_kwargs["num_actors"] = tokenizer_pool_config.pool_size init_kwargs["ray_actor_options"] = ray_actor_options - - return RayTokenizerGroupPool(**init_kwargs) + + return cls(**init_kwargs) def __init__( # pylint: disable=super-init-not-called self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, @@ -53,15 +56,20 @@ def __init__( # pylint: disable=super-init-not-called **tokenizer_config) self.max_input_length = max_input_length - ray_tokenizer_cls = RayTokenizerGroup.options(**ray_actor_options) + ray_tokenizer_group_cls = self._ray_tokenizer_group_cls.options( + **ray_actor_options) self.tokenizer_actors = [ - ray_tokenizer_cls.remote(tokenizer_id, enable_lora, - max_num_seqs, max_input_length, - **tokenizer_config) + ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora, + max_num_seqs, max_input_length, + **tokenizer_config) for _ in range(num_actors) ] self._idle_actors: Optional[asyncio.Queue] = None + @property + def pool_size(self) -> int: + return len(self.tokenizer_actors) + def ping(self): return ray.get( [actor.ping.remote() for actor in self.tokenizer_actors]) @@ -73,9 +81,9 @@ def _maybe_init_queue(self): self._idle_actors.put_nowait(actor) def encode(self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group. We pick an idle actor and use it to encode the prompt. @@ -136,4 +144,4 @@ def get_lora_tokenizer( async def get_lora_tokenizer_async( self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - return await self.tokenizer.get_lora_tokenizer_async(lora_request) \ No newline at end of file + return await self.tokenizer.get_lora_tokenizer_async(lora_request) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 6f0a0f7acfc1..75449c6e2a0f 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -3,9 +3,10 @@ from transformers import PreTrainedTokenizer from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import get_lora_tokenizer, get_lora_tokenizer_async -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup - +from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, + get_lora_tokenizer_async) +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) class TokenizerGroup(BaseTokenizerGroup): @@ -53,4 +54,3 @@ async def get_lora_tokenizer_async( return tokenizer else: return self.lora_tokenizers.get(lora_request.lora_int_id) - From 368e28b32d352b41e360949e67401abf4628e5f0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 15:47:48 -0700 Subject: [PATCH 05/14] Improve test --- tests/tokenization/test_cached_tokenizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index a2ddff51ce1b..2a4859a169aa 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -5,6 +5,8 @@ def test_cached_tokenizer(): reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + reference_tokenizer.add_special_tokens({"cls_token": ""}) + reference_tokenizer.add_special_tokens({"additional_special_tokens": [""]}) cached_tokenizer = _get_cached_tokenizer(deepcopy(reference_tokenizer)) assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( From 9d233732fabddae71b3ede8baabe8614e5a43753 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 16:14:32 -0700 Subject: [PATCH 06/14] WIP --- tests/tokenization/test_cached_tokenizer.py | 3 ++- vllm/config.py | 2 ++ vllm/engine/arg_utils.py | 21 ++++++--------------- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index 2a4859a169aa..6d164bfa92b9 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -6,7 +6,8 @@ def test_cached_tokenizer(): reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") reference_tokenizer.add_special_tokens({"cls_token": ""}) - reference_tokenizer.add_special_tokens({"additional_special_tokens": [""]}) + reference_tokenizer.add_special_tokens( + {"additional_special_tokens": [""]}) cached_tokenizer = _get_cached_tokenizer(deepcopy(reference_tokenizer)) assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( diff --git a/vllm/config.py b/vllm/config.py index c287f47eb580..214132f20564 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -397,6 +397,8 @@ class TokenizerPoolConfig: def __post_init__(self): if self.pool_type not in ("ray"): raise ValueError(f"Unknown pool type: {self.pool_type}") + if not isinstance(self.extra_config, dict): + raise ValueError("extra_config must be a dictionary.") class ParallelConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ecb0a6d41529..45effa46543f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,19 +9,6 @@ TokenizerPoolConfig) -class _StoreJsonAction(argparse._StoreAction): - - def __call__(self, parser, namespace, values, option_string=None): - json_values = [] - for value in values: - try: - json_values.append(json.loads(value)) - except json.JSONDecodeError as e: - raise argparse.ArgumentTypeError( - f'Invalid JSON string: {value}') from e - setattr(namespace, self.dest, json_values) - - @dataclass class EngineArgs: """Arguments for vLLM engine.""" @@ -282,7 +269,6 @@ def add_cli_args( parser.add_argument('--tokenizer-pool-config', type=str, default=EngineArgs.tokenizer_pool_config, - action=_StoreJsonAction, help='Config for tokenizer pool. ' 'This should be a JSON string that will be ' 'parsed into a dictionary. Ignored if ' @@ -352,9 +338,14 @@ def create_engine_configs( self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) if self.tokenizer_pool_size: + if isinstance(self.tokenizer_pool_config, str): + tokenizer_pool_config_parsed = json.loads( + self.tokenizer_pool_config) + else: + tokenizer_pool_config_parsed = self.tokenizer_pool_config or {} tokenizer_pool_config = TokenizerPoolConfig( self.tokenizer_pool_size, self.tokenizer_pool_type, - self.tokenizer_pool_config or {}) + tokenizer_pool_config_parsed) else: tokenizer_pool_config = None parallel_config = ParallelConfig( From 100074f9694b561ca0979a6afc75a902281c1a9a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 16:15:42 -0700 Subject: [PATCH 07/14] Fix --- vllm/engine/arg_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 45effa46543f..f9415766cc21 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,7 +44,7 @@ class EngineArgs: disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 tokenizer_pool_type: str = "ray" - tokenizer_pool_config: Optional[dict] = None + tokenizer_pool_extra_config: Optional[dict] = None enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -266,10 +266,10 @@ def add_cli_args( help='Type of tokenizer pool to use for ' 'asynchronous tokenization. Ignored ' 'if tokenizer_pool_size is 0.') - parser.add_argument('--tokenizer-pool-config', + parser.add_argument('--tokenizer-pool-extra-config', type=str, - default=EngineArgs.tokenizer_pool_config, - help='Config for tokenizer pool. ' + default=EngineArgs.tokenizer_pool_extra_config, + help='Extra config for tokenizer pool. ' 'This should be a JSON string that will be ' 'parsed into a dictionary. Ignored if ' 'tokenizer_pool_size is 0.') @@ -338,14 +338,15 @@ def create_engine_configs( self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) if self.tokenizer_pool_size: - if isinstance(self.tokenizer_pool_config, str): - tokenizer_pool_config_parsed = json.loads( - self.tokenizer_pool_config) + if isinstance(self.tokenizer_pool_extra_config, str): + tokenizer_pool_extra_config_parsed = json.loads( + self.tokenizer_pool_extra_config) else: - tokenizer_pool_config_parsed = self.tokenizer_pool_config or {} + tokenizer_pool_extra_config_parsed = ( + self.tokenizer_pool_extra_config or {}) tokenizer_pool_config = TokenizerPoolConfig( self.tokenizer_pool_size, self.tokenizer_pool_type, - tokenizer_pool_config_parsed) + tokenizer_pool_extra_config_parsed) else: tokenizer_pool_config = None parallel_config = ParallelConfig( From e303c5f399b9993e259cddb1a59dec1d5f0c6ec7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 17:12:08 -0700 Subject: [PATCH 08/14] Feedback --- .../tokenizer_group/__init__.py | 19 +++++++++--------- .../tokenizer_group/ray_tokenizer_group.py | 20 +++---------------- 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 48a8d75a5649..4288dab24ceb 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -17,17 +17,16 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: if tokenizer_pool_config is None: return TokenizerGroup(**init_kwargs) + if tokenizer_pool_config.pool_type == "ray": + if RayTokenizerGroupPool is None: + raise ImportError( + "RayTokenizerGroupPool is not available. Please install " + "the ray package to use the Ray tokenizer group pool.") + return RayTokenizerGroupPool.from_config(tokenizer_pool_config, + **init_kwargs) else: - if tokenizer_pool_config.pool_type == "ray": - if RayTokenizerGroupPool is None: - raise ImportError( - "RayTokenizerGroupPool is not available. Please install " - "the ray package to use the Ray tokenizer group pool.") - return RayTokenizerGroupPool.from_config(tokenizer_pool_config, - **init_kwargs) - else: - raise ValueError( - f"Unknown pool type: {tokenizer_pool_config.pool_type}") + raise ValueError( + f"Unknown pool type: {tokenizer_pool_config.pool_type}") __all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 37969e6f54eb..31d55696a80a 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -7,8 +7,6 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest from vllm.engine.ray_utils import ray -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( TokenizerGroup) from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -16,7 +14,7 @@ RayTokenizerGroup = ray.remote(TokenizerGroup) -class RayTokenizerGroupPool(BaseTokenizerGroup): +class RayTokenizerGroupPool(TokenizerGroup): """A Ray-based pool of TokenizerGroups for async tokenization.""" _ray_tokenizer_group_cls = RayTokenizerGroup @@ -51,10 +49,8 @@ def __init__( # pylint: disable=super-init-not-called ray_actor_options: dict, **tokenizer_config): # Store a local copy of the TokenizerGroup for quick access # to underlying HF tokenizers. - self.tokenizer = TokenizerGroup(tokenizer_id, enable_lora, - max_num_seqs, max_input_length, - **tokenizer_config) - self.max_input_length = max_input_length + super().__init__(tokenizer_id, enable_lora, max_num_seqs, + max_input_length, **tokenizer_config) ray_tokenizer_group_cls = self._ray_tokenizer_group_cls.options( **ray_actor_options) @@ -135,13 +131,3 @@ async def encode_async( # is raised. self._idle_actors.put_nowait(actor) return ret - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(lora_request) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - return await self.tokenizer.get_lora_tokenizer_async(lora_request) From dd0162d92eccc97045c97cbd97278eaa2a39817d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Mar 2024 17:13:32 -0700 Subject: [PATCH 09/14] Lint --- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 31d55696a80a..f030dda26252 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -2,8 +2,6 @@ import os from typing import List, Optional -from transformers import PreTrainedTokenizer - from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest from vllm.engine.ray_utils import ray From 5602685806935a1f1751f5aad9e646115625b562 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Mar 2024 10:30:36 -0700 Subject: [PATCH 10/14] Apply feedback from code review --- tests/tokenization/test_tokenizer_group.py | 6 +- vllm/config.py | 33 +++++++- vllm/engine/arg_utils.py | 21 ++---- .../tokenizer_group/__init__.py | 2 +- .../tokenizer_group/base_tokenizer_group.py | 34 +++------ .../tokenizer_group/ray_tokenizer_group.py | 75 ++++++++++++++----- .../tokenizer_group/tokenizer_group.py | 24 ++++++ 7 files changed, 133 insertions(+), 62 deletions(-) diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py index ea4a2181c2fe..d0788ee87563 100644 --- a/tests/tokenization/test_tokenizer_group.py +++ b/tests/tokenization/test_tokenizer_group.py @@ -1,7 +1,6 @@ import os import pytest import asyncio -import ray from unittest.mock import patch from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -71,15 +70,14 @@ async def test_tokenizer_group_ray_pool_env_var_propagation( tokenizer Ray actors.""" env_var = "MY_ENV_VAR" - @ray.remote - class EnvVarCheckerRayTokenizerGroup(TokenizerGroup): + class EnvVarCheckerTokenizerGroup(TokenizerGroup): def ping(self): assert os.environ.get(env_var) == "1" return super().ping() class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): - _ray_tokenizer_group_cls = EnvVarCheckerRayTokenizerGroup + _worker_cls = EnvVarCheckerTokenizerGroup tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( diff --git a/vllm/config.py b/vllm/config.py index 214132f20564..a4f35b2a7401 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,6 +3,7 @@ import os from packaging.version import Version +import json import torch from transformers import PretrainedConfig @@ -395,11 +396,41 @@ class TokenizerPoolConfig: extra_config: dict def __post_init__(self): - if self.pool_type not in ("ray"): + if self.pool_type not in ("ray", ): raise ValueError(f"Unknown pool type: {self.pool_type}") if not isinstance(self.extra_config, dict): raise ValueError("extra_config must be a dictionary.") + @classmethod + def create_config( + cls, tokenizer_pool_size: int, tokenizer_pool_type: str, + tokenizer_pool_extra_config: Optional[Union[str, dict]] + ) -> Optional["TokenizerPoolConfig"]: + """Create a TokenizerPoolConfig from the given parameters. + + If tokenizer_pool_size is 0, return None. + + Args: + tokenizer_pool_size: Number of tokenizer workers in the pool. + tokenizer_pool_type: Type of the pool. + tokenizer_pool_extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. This can be a JSON string (will be parsed). + """ + if tokenizer_pool_size: + if isinstance(tokenizer_pool_extra_config, str): + tokenizer_pool_extra_config_parsed = json.loads( + tokenizer_pool_extra_config) + else: + tokenizer_pool_extra_config_parsed = ( + tokenizer_pool_extra_config or {}) + tokenizer_pool_config = cls(tokenizer_pool_size, + tokenizer_pool_type, + tokenizer_pool_extra_config_parsed) + else: + tokenizer_pool_config = None + return tokenizer_pool_config + class ParallelConfig: """Configuration for the distributed execution. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f9415766cc21..3e146d2e6c0c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,5 +1,4 @@ import argparse -import json import dataclasses from dataclasses import dataclass from typing import Optional, Tuple @@ -337,23 +336,15 @@ def create_engine_configs( self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) - if self.tokenizer_pool_size: - if isinstance(self.tokenizer_pool_extra_config, str): - tokenizer_pool_extra_config_parsed = json.loads( - self.tokenizer_pool_extra_config) - else: - tokenizer_pool_extra_config_parsed = ( - self.tokenizer_pool_extra_config or {}) - tokenizer_pool_config = TokenizerPoolConfig( - self.tokenizer_pool_size, self.tokenizer_pool_type, - tokenizer_pool_extra_config_parsed) - else: - tokenizer_pool_config = None parallel_config = ParallelConfig( self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, self.max_parallel_loading_workers, - self.disable_custom_all_reduce, tokenizer_pool_config, - self.ray_workers_use_nsight) + self.disable_custom_all_reduce, + TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 4288dab24ceb..adc8d9b90ddb 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -23,7 +23,7 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], "RayTokenizerGroupPool is not available. Please install " "the ray package to use the Ray tokenizer group pool.") return RayTokenizerGroupPool.from_config(tokenizer_pool_config, - **init_kwargs) + **init_kwargs) else: raise ValueError( f"Unknown pool type: {tokenizer_pool_config.pool_type}") diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 62c600af52af..f89e66bb6410 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -4,33 +4,22 @@ from transformers import PreTrainedTokenizer from vllm.lora.request import LoRARequest -from vllm.utils import LRUCache -from vllm.transformers_utils.tokenizer import get_tokenizer class BaseTokenizerGroup(ABC): + """A group of tokenizers that can be used for LoRA adapters.""" - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], **tokenizer_config): - self.tokenizer_id = tokenizer_id - self.tokenizer_config = tokenizer_config - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - if enable_lora: - self.lora_tokenizers = LRUCache(capacity=max_num_seqs) - else: - self.lora_tokenizers = None + @abstractmethod + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + pass + @abstractmethod def get_max_input_len(self, lora_request: Optional[LoRARequest] = None ) -> Optional[int]: """Get the maximum input length for the LoRA request.""" - return self.max_input_length - - def ping(self): - """Check if the tokenizer group is alive.""" - return True + pass @abstractmethod def encode(self, prompt: str, request_id: Optional[str], @@ -38,12 +27,11 @@ def encode(self, prompt: str, request_id: Optional[str], """Encode a prompt using the tokenizer group.""" pass + @abstractmethod async def encode_async(self, prompt: str, request_id: Optional[str], lora_request: Optional[LoRARequest]) -> List[int]: """Encode a prompt using the tokenizer group.""" - return self.encode(prompt=prompt, - request_id=request_id, - lora_request=lora_request) + pass @abstractmethod def get_lora_tokenizer( @@ -51,9 +39,11 @@ def get_lora_tokenizer( lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": ... """Get a tokenizer for a LoRA request.""" + pass + @abstractmethod async def get_lora_tokenizer_async( self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" - return self.get_lora_tokenizer(lora_request) + pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index f030dda26252..8030aec0226b 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -2,20 +2,37 @@ import os from typing import List, Optional +from transformers import PreTrainedTokenizer + from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest from vllm.engine.ray_utils import ray +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( TokenizerGroup) from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -RayTokenizerGroup = ray.remote(TokenizerGroup) +def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> dict: + """Copy over all current process environment variables to the runtime_env. + + The variables in runtime_env will take precedence over the current process + environment variables. + + runtime_env will be modified in place.""" + env_vars = os.environ.copy() + runtime_env.setdefault("env_vars", {}) + env_vars.update(runtime_env["env_vars"]) + runtime_env["env_vars"] = env_vars + return runtime_env -class RayTokenizerGroupPool(TokenizerGroup): + +class RayTokenizerGroupPool(BaseTokenizerGroup): """A Ray-based pool of TokenizerGroups for async tokenization.""" - _ray_tokenizer_group_cls = RayTokenizerGroup + # Class to use for workers making up the pool. + _worker_cls = TokenizerGroup @classmethod def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, @@ -31,27 +48,28 @@ def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, # Carry over the env vars to the actors. # This is necessary for API keys and such. ray_actor_options.setdefault("runtime_env", {}) - env_vars = os.environ.copy() - ray_actor_options["runtime_env"].setdefault("env_vars", {}) - env_vars.update(ray_actor_options["runtime_env"]["env_vars"]) - ray_actor_options["runtime_env"]["env_vars"] = env_vars + ray_actor_options["runtime_env"] = _carry_over_env_vars_to_runtime_env( + ray_actor_options["runtime_env"]) init_kwargs["num_actors"] = tokenizer_pool_config.pool_size init_kwargs["ray_actor_options"] = ray_actor_options return cls(**init_kwargs) - def __init__( # pylint: disable=super-init-not-called - self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], num_actors: int, - ray_actor_options: dict, **tokenizer_config): + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], num_actors: int, + ray_actor_options: dict, **tokenizer_config): # Store a local copy of the TokenizerGroup for quick access # to underlying HF tokenizers. - super().__init__(tokenizer_id, enable_lora, max_num_seqs, - max_input_length, **tokenizer_config) - - ray_tokenizer_group_cls = self._ray_tokenizer_group_cls.options( - **ray_actor_options) + self._local_tokenizer_group = self._worker_cls( + tokenizer_id=tokenizer_id, + enable_lora=enable_lora, + max_num_seqs=max_num_seqs, + max_input_length=max_input_length, + ) + + ray_tokenizer_group_cls = ray.remote( + self._worker_cls).options(**ray_actor_options) self.tokenizer_actors = [ ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora, max_num_seqs, max_input_length, @@ -68,7 +86,7 @@ def ping(self): return ray.get( [actor.ping.remote() for actor in self.tokenizer_actors]) - def _maybe_init_queue(self): + def _ensure_queue_initialized(self): if self._idle_actors is None: self._idle_actors = asyncio.Queue() for actor in self.tokenizer_actors: @@ -84,7 +102,7 @@ def encode(self, The actor is then put back in the queue for future use. This is blocking. """ - self._maybe_init_queue() + self._ensure_queue_initialized() if self._idle_actors.empty(): raise RuntimeError("No idle actors available.") @@ -115,7 +133,7 @@ async def encode_async( The actor is then put back in the queue for future use. This is non-blocking. """ - self._maybe_init_queue() + self._ensure_queue_initialized() actor = await self._idle_actors.get() try: @@ -129,3 +147,22 @@ async def encode_async( # is raised. self._idle_actors.put_nowait(actor) return ret + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self._local_tokenizer_group.get_max_input_len(lora_request) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + return self._local_tokenizer_group.get_lora_tokenizer(lora_request) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + return await self._local_tokenizer_group.get_lora_tokenizer_async( + lora_request) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 75449c6e2a0f..3af1334cb5ed 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -7,11 +7,35 @@ get_lora_tokenizer_async) from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) +from vllm.utils import LRUCache +from vllm.transformers_utils.tokenizer import get_tokenizer class TokenizerGroup(BaseTokenizerGroup): """A group of tokenizers that can be used for LoRA adapters.""" + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + return True + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self.max_input_length + def encode(self, prompt: str, request_id: Optional[str] = None, From 4cd7769435a13cce5c9757f6c682ea18a1b3ee1f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Mar 2024 10:31:33 -0700 Subject: [PATCH 11/14] Nit --- tests/tokenization/test_cached_tokenizer.py | 4 ++-- vllm/transformers_utils/tokenizer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index 6d164bfa92b9..181e80032512 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -1,5 +1,5 @@ from copy import deepcopy -from vllm.transformers_utils.tokenizer import _get_cached_tokenizer +from vllm.transformers_utils.tokenizer import get_cached_tokenizer from transformers import AutoTokenizer @@ -8,7 +8,7 @@ def test_cached_tokenizer(): reference_tokenizer.add_special_tokens({"cls_token": ""}) reference_tokenizer.add_special_tokens( {"additional_special_tokens": [""]}) - cached_tokenizer = _get_cached_tokenizer(deepcopy(reference_tokenizer)) + cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( "prompt") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 746713e24323..d98df08a85af 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -11,7 +11,7 @@ logger = init_logger(__name__) -def _get_cached_tokenizer( +def get_cached_tokenizer( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Get tokenizer with cached properties. @@ -98,7 +98,7 @@ def get_tokenizer( logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead.") - return _get_cached_tokenizer(tokenizer) + return get_cached_tokenizer(tokenizer) def get_lora_tokenizer(lora_request: LoRARequest, *args, From 84cade19afbdf8d84be292dd2a64953ee0b50c29 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Mar 2024 10:32:44 -0700 Subject: [PATCH 12/14] Update vllm/transformers_utils/tokenizer.py --- vllm/transformers_utils/tokenizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index d98df08a85af..f7a1a19a89bc 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -16,6 +16,8 @@ def get_cached_tokenizer( ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Get tokenizer with cached properties. + This will patch the tokenizer object in place. + By default, transformers will recompute multiple tokenizer properties each time they are called, leading to a significant slowdown. This function caches these properties for faster access.""" From fc0b04c15d134f8bd3611004fb764d6cf364d2f8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Mar 2024 11:48:45 -0700 Subject: [PATCH 13/14] Tweak --- .../tokenizer_group/ray_tokenizer_group.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 8030aec0226b..34980a0410a7 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -14,7 +14,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> dict: +def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: """Copy over all current process environment variables to the runtime_env. The variables in runtime_env will take precedence over the current process @@ -25,7 +25,6 @@ def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> dict: runtime_env.setdefault("env_vars", {}) env_vars.update(runtime_env["env_vars"]) runtime_env["env_vars"] = env_vars - return runtime_env class RayTokenizerGroupPool(BaseTokenizerGroup): @@ -48,8 +47,7 @@ def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, # Carry over the env vars to the actors. # This is necessary for API keys and such. ray_actor_options.setdefault("runtime_env", {}) - ray_actor_options["runtime_env"] = _carry_over_env_vars_to_runtime_env( - ray_actor_options["runtime_env"]) + _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) init_kwargs["num_actors"] = tokenizer_pool_config.pool_size init_kwargs["ray_actor_options"] = ray_actor_options From e8241f774b534aa77eed8bd52c3bdc4be16ffba8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Mar 2024 11:49:52 -0700 Subject: [PATCH 14/14] Nits --- .../tokenizer_group/base_tokenizer_group.py | 1 - .../tokenizer_group/ray_tokenizer_group.py | 26 +++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index f89e66bb6410..99518a606fab 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -37,7 +37,6 @@ async def encode_async(self, prompt: str, request_id: Optional[str], def get_lora_tokenizer( self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": - ... """Get a tokenizer for a LoRA request.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 34980a0410a7..e048ec05bece 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -14,19 +14,6 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: - """Copy over all current process environment variables to the runtime_env. - - The variables in runtime_env will take precedence over the current process - environment variables. - - runtime_env will be modified in place.""" - env_vars = os.environ.copy() - runtime_env.setdefault("env_vars", {}) - env_vars.update(runtime_env["env_vars"]) - runtime_env["env_vars"] = env_vars - - class RayTokenizerGroupPool(BaseTokenizerGroup): """A Ray-based pool of TokenizerGroups for async tokenization.""" @@ -164,3 +151,16 @@ async def get_lora_tokenizer_async( ) -> "PreTrainedTokenizer": return await self._local_tokenizer_group.get_lora_tokenizer_async( lora_request) + + +def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: + """Copy over all current process environment variables to the runtime_env. + + The variables in runtime_env will take precedence over the current process + environment variables. + + runtime_env will be modified in place.""" + env_vars = os.environ.copy() + runtime_env.setdefault("env_vars", {}) + env_vars.update(runtime_env["env_vars"]) + runtime_env["env_vars"] = env_vars