Skip to content

Commit

Permalink
[Core] Support thread-based async tokenizer pools
Browse files Browse the repository at this point in the history
vllm-project#2879 added support for using ray to offload tokenization from the asyncio event loop.

This PR extends that to support using a thread pool instead of ray, and makes that the default, with the default pool size determined based on the number of available CPU cores and the tensor parallel size.

The main thing to note is that separate tokenizer instances are used per thread. This is because officially the HF tokenizers are not thread-safe. In practice I think they are unless you're making use of padding/truncation, which we aren't currently but may want to soon (see for example vllm-project#3144).

Also includes some type hint additions to related parts of the code.

This replaces the original PR vllm-project#3206 from before vllm-project#2879 was reworked and merged.
  • Loading branch information
njhill committed Mar 19, 2024
1 parent b37cdce commit ebf6967
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 69 deletions.
50 changes: 30 additions & 20 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Union, ClassVar
from typing import TYPE_CHECKING, Optional, Union, ClassVar, Literal
from dataclasses import dataclass
import os
from packaging.version import Version
Expand Down Expand Up @@ -393,7 +393,7 @@ def verify_with_parallel_config(
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
Expand All @@ -402,44 +402,54 @@ class TokenizerPoolConfig:
pool type.
"""
pool_size: int
pool_type: str
pool_type: Literal["ray", "thread"]
extra_config: dict

def __post_init__(self):
if self.pool_type not in ("ray", ):
if self.pool_type not in ("ray", "thread"):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")

@classmethod
def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
tokenizer_pool_extra_config: Optional[Union[str, dict]]
cls,
tokenizer_pool_size: Optional[int],
tokenizer_pool_type: Literal["ray", "thread"],
tokenizer_pool_extra_config: Optional[Union[str, dict]],
tensor_parallel_size: int,
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
tensor_parallel_size: Used in default pool size calculation.
"""
if tokenizer_pool_size:
if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
if tokenizer_pool_size == 0:
return None

if tokenizer_pool_size is None:
# Default based on CPU count
tokenizer_pool_size = min(
16,
os.cpu_count() - tensor_parallel_size - 1)
tokenizer_pool_size = max(1, tokenizer_pool_size)

if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_config = None
return tokenizer_pool_config
tokenizer_pool_extra_config_parsed = (tokenizer_pool_extra_config
or {})

return cls(tokenizer_pool_size, tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)


class ParallelConfig:
Expand Down
11 changes: 7 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Literal

from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig,
Expand Down Expand Up @@ -41,8 +41,8 @@ class EngineArgs:
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
tokenizer_pool_size: Optional[int] = None
tokenizer_pool_type: Literal["ray", "thread"] = "thread"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False
max_loras: int = 1
Expand Down Expand Up @@ -257,11 +257,13 @@ def add_cli_args(
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'asynchronous tokenization. Default chosen '
'based on available CPU cores. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
choices=['thread', 'ray'],
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
Expand Down Expand Up @@ -344,6 +346,7 @@ def create_engine_configs(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
self.tensor_parallel_size,
), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
Expand Down Expand Up @@ -366,6 +367,12 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

async def get_tokenizer_group(self) -> BaseTokenizerGroup:
if self.engine_use_ray:
return await self.engine.get_tokenizer_group.remote()
else:
return self.engine.get_tokenizer_group()

async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def __reduce__(self):
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")

def get_tokenizer_group(self) -> BaseTokenizerGroup:
return self.tokenizer

def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer()

Expand Down
9 changes: 6 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ async def create_chat_completion(

request_id = f"cmpl-{random_uuid()}"
try:
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
token_ids = await self._validate_prompt_and_tokenize(
request,
request_id=request_id,
lora_request=lora_request,
prompt=prompt)
sampling_params = request.to_sampling_params()
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, await self.engine.get_tokenizer()))
Expand Down
15 changes: 8 additions & 7 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,18 @@ async def create_completion(self, request: CompletionRequest,
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)

for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sub_request_id = f"{request_id}-{i}"
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
input_ids = await self._validate_prompt_and_tokenize(
request,
request_id=sub_request_id,
lora_request=lora_request,
**{prompt_arg: prompt})

generators.append(
self.engine.generate(prompt,
sampling_params,
f"{request_id}-{i}",
sub_request_id,
prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e:
Expand Down
15 changes: 11 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,26 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")

def _validate_prompt_and_tokenize(
async def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")

input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
if prompt_ids is None:
tokenizer = await self.engine.get_tokenizer_group()
input_ids = await tokenizer.encode_async(prompt, request_id,
lora_request)
else:
input_ids = prompt_ids

token_num = len(input_ids)

if request.max_tokens is None:
Expand Down
6 changes: 3 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import os
import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)

import safetensors.torch
import torch
Expand Down Expand Up @@ -535,14 +535,14 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
replacement_loras)


class LoRALRUCache(LRUCache):
class LoRALRUCache(LRUCache[LoRAModel]):

def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn

def _on_remove(self, key: Hashable, value: Any):
def _on_remove(self, key: Hashable, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
Expand Down
9 changes: 6 additions & 3 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from vllm.config import TokenizerPoolConfig
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.thread_tokenizer_group import (
ThreadPoolTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from vllm.engine.ray_utils import ray
Expand All @@ -24,9 +26,10 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
"the ray package to use the Ray tokenizer group pool.")
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
**init_kwargs)
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
if tokenizer_pool_config.pool_type == "thread":
return ThreadPoolTokenizerGroup(
max_workers=tokenizer_pool_config.pool_size, **init_kwargs)
raise ValueError(f"Unknown pool type: {tokenizer_pool_config.pool_type}")


__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
19 changes: 13 additions & 6 deletions vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,34 @@ def get_max_input_len(self,
pass

@abstractmethod
def encode(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass

@abstractmethod
async def encode_async(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass

@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass

@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass
37 changes: 37 additions & 0 deletions vllm/transformers_utils/tokenizer_group/thread_tokenizer_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import threading
from concurrent.futures import ThreadPoolExecutor

from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from vllm.utils import make_async

logger = init_logger(__name__)


class ThreadPoolTokenizerGroup(TokenizerGroup):
"""A threadpool of TokenizerGroups for async tokenization."""

def __init__(self, *args, max_workers: int, **tokenizer_config):
super().__init__(*args, **tokenizer_config)
self.local = threading.local()

def init_tokenizer():
logger.info(
f"Starting tokenizer thread {threading.current_thread().name}")
self.local.tokenizer = TokenizerGroup(*args, **tokenizer_config)

self.executor = ThreadPoolExecutor(
max_workers=max_workers,
thread_name_prefix='tokenizer_thread',
initializer=init_tokenizer,
)

self.encode_async = make_async(self._encode_local, self.executor)

def _encode_local(self, *args, **kwargs):
return self.local.tokenizer.encode(*args, **kwargs)

def encode(self, *args, **kwargs):
return self.executor.submit(self._encode_local, *args,
**kwargs).result()
12 changes: 4 additions & 8 deletions vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None

def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
Expand Down Expand Up @@ -62,8 +60,7 @@ def get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers.get(lora_request.lora_int_id)

async def get_lora_tokenizer_async(
self,
Expand All @@ -76,5 +73,4 @@ async def get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers.get(lora_request.lora_int_id)
Loading

0 comments on commit ebf6967

Please sign in to comment.