Skip to content

Commit

Permalink
Asynchronous tokenization (vllm-project#2879)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Mar 15, 2024
1 parent 8fa7357 commit fb96c1e
Show file tree
Hide file tree
Showing 17 changed files with 658 additions and 153 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.

- label: Engine Test
command: pytest -v -s engine test_sequence.py
command: pytest -v -s engine tokenization test_sequence.py

- label: Entrypoints Test
command: pytest -v -s entrypoints
Expand Down
16 changes: 7 additions & 9 deletions tests/async_engine/test_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,21 @@ def _query_server_long(prompt: str) -> dict:


@pytest.fixture
def api_server():
def api_server(tokenizer_pool_size: int):
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
uvicorn_process = subprocess.Popen([
sys.executable,
"-u",
str(script_path),
"--model",
"facebook/opt-125m",
"--host",
"127.0.0.1",
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m", "--host",
"127.0.0.1", "--tokenizer-pool-size",
str(tokenizer_pool_size)
])
yield
uvicorn_process.terminate()


def test_api_server(api_server):
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
def test_api_server(api_server, tokenizer_pool_size: int):
"""
Run the API server and test it.
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.config import TokenizerPoolConfig

_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
Expand Down Expand Up @@ -258,3 +259,13 @@ def generate_beam_search(
@pytest.fixture
def vllm_runner():
return VllmRunner


def get_tokenizer_pool_config(tokenizer_group_type):
if tokenizer_group_type is None:
return None
if tokenizer_group_type == "ray":
return TokenizerPoolConfig(pool_size=1,
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
69 changes: 0 additions & 69 deletions tests/lora/test_tokenizer.py

This file was deleted.

53 changes: 53 additions & 0 deletions tests/lora/test_tokenizer_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
from ..conftest import get_tokenizer_pool_config


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id",
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)

assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
lora_request) != tokenizer_group.get_lora_tokenizer(None)
assert tokenizer_group.get_lora_tokenizer(
lora_request) == await tokenizer_group.get_lora_tokenizer_async(
lora_request)


def test_get_lora_tokenizer(sql_lora_files, tmpdir):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer

lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()

lora_request = LoRARequest("1", 1, str(tmpdir))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
Empty file added tests/tokenization/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions tests/tokenization/test_cached_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from copy import deepcopy
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from transformers import AutoTokenizer


def test_cached_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))

assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
File renamed without changes.
100 changes: 100 additions & 0 deletions tests/tokenization/test_tokenizer_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import pytest
import asyncio
from unittest.mock import patch

from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from ..conftest import get_tokenizer_pool_config


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id", prompt="prompt", lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_pool(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group_pool = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
# Send multiple requests to the tokenizer group pool
# (more than the pool size)
# and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5
requests = [
tokenizer_group_pool.encode_async(request_id=str(i),
prompt=f"prompt {i}",
lora_request=None)
for i in range(num_requests)
]
results = await asyncio.gather(*requests)
expected_results = [
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
]
assert results == expected_results


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_env_var_propagation(
tokenizer_group_type):
"""Test that env vars from caller process are propagated to
tokenizer Ray actors."""
env_var = "MY_ENV_VAR"

class EnvVarCheckerTokenizerGroup(TokenizerGroup):

def ping(self):
assert os.environ.get(env_var) == "1"
return super().ping()

class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = EnvVarCheckerTokenizerGroup

tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
with pytest.raises(AssertionError):
tokenizer_pool.ping()

with patch.dict(os.environ, {env_var: "1"}):
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
tokenizer_pool.ping()
57 changes: 57 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from packaging.version import Version

import json
import torch
from transformers import PretrainedConfig

Expand Down Expand Up @@ -389,6 +390,58 @@ def verify_with_parallel_config(
logger.warning("Possibly too large swap space. " + msg)


@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size: int
pool_type: str
extra_config: dict

def __post_init__(self):
if self.pool_type not in ("ray", ):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")

@classmethod
def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
"""
if tokenizer_pool_size:
if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
else:
tokenizer_pool_config = None
return tokenizer_pool_config


class ParallelConfig:
"""Configuration for the distributed execution.
Expand All @@ -403,6 +456,8 @@ class ParallelConfig:
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
tokenizer_pool_config: Config for the tokenizer pool.
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
"""
Expand All @@ -414,6 +469,7 @@ def __init__(
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
) -> None:
Expand All @@ -430,6 +486,7 @@ def __init__(
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group

Expand Down
Loading

0 comments on commit fb96c1e

Please sign in to comment.