Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous tokenization #2879

Merged
merged 18 commits into from
Mar 15, 2024
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.

- label: Engine Test
command: pytest -v -s engine test_sequence.py
command: pytest -v -s engine tokenization test_sequence.py

- label: Entrypoints Test
command: pytest -v -s entrypoints
Expand Down
16 changes: 7 additions & 9 deletions tests/async_engine/test_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,21 @@ def _query_server_long(prompt: str) -> dict:


@pytest.fixture
def api_server():
def api_server(tokenizer_pool_size: int):
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
uvicorn_process = subprocess.Popen([
sys.executable,
"-u",
str(script_path),
"--model",
"facebook/opt-125m",
"--host",
"127.0.0.1",
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m", "--host",
"127.0.0.1", "--tokenizer-pool-size",
str(tokenizer_pool_size)
])
yield
uvicorn_process.terminate()


def test_api_server(api_server):
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
def test_api_server(api_server, tokenizer_pool_size: int):
"""
Run the API server and test it.

Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.config import TokenizerPoolConfig

_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
Expand Down Expand Up @@ -258,3 +259,13 @@ def generate_beam_search(
@pytest.fixture
def vllm_runner():
return VllmRunner


def get_tokenizer_pool_config(tokenizer_group_type):
if tokenizer_group_type is None:
return None
if tokenizer_group_type == "ray":
return TokenizerPoolConfig(pool_size=1,
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
69 changes: 0 additions & 69 deletions tests/lora/test_tokenizer.py

This file was deleted.

53 changes: 53 additions & 0 deletions tests/lora/test_tokenizer_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
from ..conftest import get_tokenizer_pool_config


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id",
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)

assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
lora_request) != tokenizer_group.get_lora_tokenizer(None)
assert tokenizer_group.get_lora_tokenizer(
lora_request) == await tokenizer_group.get_lora_tokenizer_async(
lora_request)


def test_get_lora_tokenizer(sql_lora_files, tmpdir):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer

lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()

lora_request = LoRARequest("1", 1, str(tmpdir))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
Empty file added tests/tokenization/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions tests/tokenization/test_cached_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from copy import deepcopy
from vllm.transformers_utils.tokenizer import _get_cached_tokenizer
from transformers import AutoTokenizer


def test_cached_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = _get_cached_tokenizer(deepcopy(reference_tokenizer))

assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
File renamed without changes.
102 changes: 102 additions & 0 deletions tests/tokenization/test_tokenizer_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import pytest
import asyncio
import ray
from unittest.mock import patch

from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from ..conftest import get_tokenizer_pool_config


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id", prompt="prompt", lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_pool(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group_pool = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
# Send multiple requests to the tokenizer group pool
# (more than the pool size)
# and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5
requests = [
tokenizer_group_pool.encode_async(request_id=str(i),
prompt=f"prompt {i}",
lora_request=None)
for i in range(num_requests)
]
results = await asyncio.gather(*requests)
expected_results = [
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
]
assert results == expected_results


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_env_var_propagation(
tokenizer_group_type):
"""Test that env vars from caller process are propagated to
tokenizer Ray actors."""
env_var = "MY_ENV_VAR"

@ray.remote
class EnvVarCheckerRayTokenizerGroup(TokenizerGroup):

def ping(self):
assert os.environ.get(env_var) == "1"
return super().ping()

class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
_ray_tokenizer_group_cls = EnvVarCheckerRayTokenizerGroup

tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
with pytest.raises(AssertionError):
tokenizer_pool.ping()

with patch.dict(os.environ, {env_var: "1"}):
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
tokenizer_pool.ping()
26 changes: 26 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,28 @@ def verify_with_parallel_config(
logger.warning("Possibly too large swap space. " + msg)


@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.

Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size: int
pool_type: str
extra_config: dict

def __post_init__(self):
if self.pool_type not in ("ray"):
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")


class ParallelConfig:
"""Configuration for the distributed execution.

Expand All @@ -393,6 +415,8 @@ class ParallelConfig:
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
tokenizer_pool_config: Config for the tokenizer pool.
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
"""
Expand All @@ -404,6 +428,7 @@ def __init__(
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
) -> None:
Expand All @@ -420,6 +445,7 @@ def __init__(
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group

Expand Down
Loading
Loading