diff --git a/tests/entrypoints/openai/test_disable_mp.py b/tests/entrypoints/openai/test_disable_mp.py new file mode 100644 index 000000000000..12c805413311 --- /dev/null +++ b/tests/entrypoints/openai/test_disable_mp.py @@ -0,0 +1,715 @@ +""" +Repeat of tests in test_completion.py with the non-mp backend. +""" + +# imports for guided decoding tests +import json +import re +import shutil +from tempfile import TemporaryDirectory +from typing import List + +import jsonschema +import openai # use the official client for correctness check +import pytest +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoTokenizer + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically these adapters use a different base model, +# but we're not testing generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" +# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also +# need to change to match the prompt adapter +PA_NUM_VIRTUAL_TOKENS = 8 + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def zephyr_pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.fixture(scope="module") +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + # pa config + "--enable-prompt-adapter", + "--prompt-adapters", + f"zephyr-pa={zephyr_pa_files}", + f"zephyr-pa2={zephyr_pa_files}", + "--max-prompt-adapters", + "2", + "--max-prompt-adapter-token", + "128", + "--disable-frontend-multiprocessing" + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name,num_virtual_tokens", + [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], +) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, + num_virtual_tokens: int): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, + prompt_tokens=6 + num_virtual_tokens, + total_tokens=11 + num_virtual_tokens) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + + +@pytest.mark.asyncio +async def test_added_lora_tokens(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model="zephyr-lora2", + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should appear in tokenized prompt + assert completion.choices[0].text.startswith("vllm1vllm2vllm3") + + +@pytest.mark.asyncio +async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should not appear in tokenized prompt + assert "vllm" not in completion.choices[0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora and 1 pa hereafter + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary + # for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +async def test_logits_bias(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 5 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + token_id = 1000 + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token_id): 100}, + seed=42, + ) + assert len(completion.choices[0].text) >= 5 + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), + add_special_tokens=False)["input_ids"] + assert all([ + response == expected + for response, expected in zip(response_tokens, expected_tokens) + ]) + + # Test ban + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + ) + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + first_response = completion.choices[0].text + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token): -100 + for token in response_tokens}, + ) + assert first_response != completion.choices[0].text + + +@pytest.mark.asyncio +async def test_allowed_token_ids(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 1 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + allowed_ids = [21555, 21557, 21558] + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + extra_body=dict(allowed_token_ids=allowed_ids), + logprobs=1, + ) + response_tokens = completion.choices[0].logprobs.tokens + assert len(response_tokens) == 1 + assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}", + n=3, + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_json=sample_json_schema, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_regex): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {sample_regex}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict(guided_regex=sample_regex, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + assert re.fullmatch(sample_regex, + completion.choices[i].text) is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_guided_choice): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="The best language for type-safe systems programming is ", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict(guided_choice=sample_guided_choice, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 2 + for i in range(2): + assert completion.choices[i].text in sample_guided_choice + + +@pytest.mark.asyncio +async def test_guided_grammar(client: openai.AsyncOpenAI, + sample_sql_statements): + + completion = await client.completions.create( + model=MODEL_NAME, + prompt=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_grammar=sample_sql_statements)) + + content = completion.choices[0].text + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(content) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") + + assert content.strip() == ground_truth + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema, sample_regex): + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) + + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex", + extra_body=dict(guided_regex=sample_regex, + guided_json=sample_json_schema)) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d3f9a0ab00f1..c39caca25cc7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,8 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -928,6 +929,14 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_parallel_config.remote( # type: ignore + ) + else: + return self.engine.get_parallel_config() + async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the vLLM engine.""" if self.engine_use_ray: @@ -936,6 +945,22 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_scheduler_config.remote( # type: ignore + ) + else: + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_lora_config.remote( # type: ignore + ) + else: + return self.engine.get_lora_config() + async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1efe2206abe8..3747f93b16cd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -38,9 +38,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (AnyTokenizer, - BaseTokenizerGroup, - get_tokenizer_group) +from vllm.transformers_utils.tokenizer_group import ( + AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -485,19 +484,12 @@ def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - init_kwargs.update(tokenizer_init_kwargs) - - return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, - **init_kwargs) + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -759,10 +751,22 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + def get_decoding_config(self) -> DecodingConfig: """Gets the decoding configuration.""" return self.decoding_config + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return sum(scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py new file mode 100644 index 000000000000..fc94ef6662e0 --- /dev/null +++ b/vllm/engine/protocol.py @@ -0,0 +1,84 @@ +from typing import (AsyncIterator, List, Mapping, Optional, Protocol, + runtime_checkable) + +from transformers import PreTrainedTokenizer + +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput + + +@runtime_checkable +class AsyncEngineClient(Protocol): + """Protocol class for Clients to AsyncLLMEngine""" + + @property + def is_running(self) -> bool: + ... + + @property + def is_stopped(self) -> bool: + ... + + @property + def errored(self) -> bool: + ... + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + """Generates outputs for a request""" + + async def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model.""" + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> PreTrainedTokenizer: + """Get the appropriate Tokenizer for the request""" + + async def is_tracing_enabled(self) -> bool: + pass + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + pass + + async def check_health(self) -> None: + """Raise if unhealthy""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0fe4dd245b5e..e330ee81f7e4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,7 +5,8 @@ import signal from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Optional, Set +from multiprocessing import Process +from typing import AsyncIterator, Set import fastapi import uvicorn @@ -17,8 +18,10 @@ from starlette.routing import Mount import vllm.envs as envs +from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -31,6 +34,8 @@ EmbeddingRequest, ErrorResponse, TokenizeRequest, TokenizeResponse) +from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient +from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -39,12 +44,12 @@ OpenAIServingTokenization) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, get_open_port from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds -engine: AsyncLLMEngine +async_engine_client: AsyncEngineClient engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion @@ -56,13 +61,22 @@ _running_tasks: Set[asyncio.Task] = set() +def model_is_embedding(model_name: str) -> bool: + return ModelConfig(model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16").embedding_mode + + @asynccontextmanager async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await engine.do_log_stats() + await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) @@ -72,6 +86,52 @@ async def _force_log(): yield +@asynccontextmanager +async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: + # Context manager to handle async_engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + global engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) + + # Backend itself still global for the silly lil' health handler + global async_engine_client + + # If manually triggered or embedding model, use AsyncLLMEngine in process. + # TODO: support embedding model via RPC. + if (model_is_embedding(args.model) + or args.disable_frontend_multiprocessing): + async_engine_client = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + yield async_engine_client + return + + # Otherwise, use the multiprocessing AsyncLLMEngine. + else: + # Start RPCServer in separate process (holds the AsyncLLMEngine). + port = get_open_port(envs.VLLM_RPC_PORT) + rpc_server_process = Process(target=run_rpc_server, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + port)) + rpc_server_process.start() + + # Build RPCClient, which conforms to AsyncEngineClient Protocol. + async_engine_client = AsyncEngineRPCClient(port) + await async_engine_client.setup() + + try: + yield async_engine_client + finally: + # Ensure rpc server process was terminated + rpc_server_process.terminate() + + # Close all open connections to the backend + async_engine_client.close() + + # Wait for server process to join + rpc_server_process.join() + + router = APIRouter() @@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI): @router.get("/health") async def health() -> Response: """Health check.""" - await openai_serving_chat.engine.check_health() + await async_engine_client.check_health() return Response(status_code=200) @@ -215,8 +275,8 @@ async def authentication(request: Request, call_next): async def build_server( + async_engine_client: AsyncEngineClient, args, - llm_engine: Optional[AsyncLLMEngine] = None, **uvicorn_kwargs, ) -> uvicorn.Server: app = build_app(args) @@ -226,14 +286,7 @@ async def build_server( else: served_model_names = [args.model] - global engine, engine_args - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = (llm_engine - if llm_engine is not None else AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) - - model_config = await engine.get_model_config() + model_config = await async_engine_client.get_model_config() if args.disable_log_requests: request_logger = None @@ -246,7 +299,7 @@ async def build_server( global openai_serving_tokenization openai_serving_chat = OpenAIServingChat( - engine, + async_engine_client, model_config, served_model_names, args.response_role, @@ -257,7 +310,7 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_completion = OpenAIServingCompletion( - engine, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -266,13 +319,13 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_embedding = OpenAIServingEmbedding( - engine, + async_engine_client, model_config, served_model_names, request_logger=request_logger, ) openai_serving_tokenization = OpenAIServingTokenization( - engine, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -304,32 +357,39 @@ async def build_server( return uvicorn.Server(config) -async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: +async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - server = await build_server( - args, - llm_engine, - **uvicorn_kwargs, - ) + shutdown_task = None + async with build_async_engine_client(args) as async_engine_client: + + server = await build_server( + async_engine_client, + args, + **uvicorn_kwargs, + ) + + loop = asyncio.get_running_loop() - loop = asyncio.get_running_loop() + server_task = loop.create_task(server.serve()) - server_task = loop.create_task(server.serve()) + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) + try: + await server_task + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + shutdown_task = server.shutdown() - try: - await server_task - except asyncio.CancelledError: - print("Gracefully stopping http server") - await server.shutdown() + if shutdown_task: + # NB: Await server shutdown only after the backend context is exited + await shutdown_task if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a4192937980f..1facedac72ca 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -131,9 +131,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--return-tokens-as-token-ids", action="store_true", - help="When --max-logprobs is specified, represents single tokens as" - "strings of the form 'token_id:{token_id}' so that tokens that" + help="When --max-logprobs is specified, represents single tokens as " + "strings of the form 'token_id:{token_id}' so that tokens that " "are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index f8e04e7f18e0..84871fc83ef5 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -1,4 +1,4 @@ -from functools import lru_cache +from functools import lru_cache, partial from typing import Dict, FrozenSet, Iterable, List, Optional, Union import torch @@ -40,6 +40,14 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) +def logit_bias_logits_processor(logit_bias: Dict[str, + float], token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + def get_logits_processors( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], allowed_token_ids: Optional[List[int]], @@ -64,13 +72,8 @@ def get_logits_processors( raise ValueError("token_id in logit_bias contains " "out-of-vocab token id") - def logit_bias_logits_processor(token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in clamped_logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors.append(logit_bias_logits_processor) + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) if allowed_token_ids is not None: logits_processors.append( diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py new file mode 100644 index 000000000000..8a7b12201cab --- /dev/null +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_RPC_SUCCESS_STR = "SUCCESS" +VLLM_RPC_HEALTHY_STR = "HEALTHY" + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCUtilityRequest(Enum): + IS_SERVER_READY = 1 + GET_MODEL_CONFIG = 2 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + CHECK_HEALTH = 8 + IS_TRACING_ENABLED = 9 + + +RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, + RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py new file mode 100644 index 000000000000..45bf88b5bf57 --- /dev/null +++ b/vllm/entrypoints/openai/rpc/client.py @@ -0,0 +1,248 @@ +from contextlib import contextmanager +from typing import Any, AsyncIterator, Mapping, Optional + +import cloudpickle +import zmq +import zmq.asyncio + +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, + VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + + +class AsyncEngineRPCClient: + + def __init__(self, port: int): + self.context = zmq.asyncio.Context() + self.path = f"tcp://localhost:{port}" + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self.wait_for_server() + + # Get the configs. + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() + self.tracing_flag = await self._is_tracing_enabled_rpc() + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), + ) + + def close(self): + """Destroy the ZeroMQ Context.""" + self.context.destroy() + + @contextmanager + def socket(self): + # Ensure client sockets are always closed after use + + # Connect to RPC socket for Request-Reply pattern, + # Note that we use DEALER to enable asynchronous communication + # to enable streaming. + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.path) + yield socket + finally: + socket.close() + + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + with self.socket() as socket: + + # Ping RPCServer with a request. + await socket.send(cloudpickle.dumps(request)) + + # Await the data from the Server. + data = cloudpickle.loads(await socket.recv()) + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + else: + raise ValueError(error_message) + + return data + + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, + error_message: str): + """Send one-way RPC request to trigger an action.""" + with self.socket() as socket: + # Ping RPC Server with request. + await socket.send(cloudpickle.dumps(request)) + + # Await acknowledgement from RPCServer. + response = cloudpickle.loads(await socket.recv()) + + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + raise ValueError(error_message) + + return response + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def wait_for_server(self): + """Wait for the RPCServer to start up.""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_READY, + error_message="Unable to start RPC Server.") + + async def _get_model_config_rpc(self) -> ModelConfig: + """Get the ModelConfig object from the RPC Server""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_decoding_config_rpc(self) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") + + async def _get_parallel_config_rpc(self) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ParallelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server") + + async def _get_lora_config_rpc(self): + """Get LoRAConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server") + + async def _is_tracing_enabled_rpc(self) -> ParallelConfig: + """Get is_tracing_enabled flag from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.IS_TRACING_ENABLED, + expected_type=bool, + error_message="Could not get is_tracing_enabled flag from RPC " + "Server") + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") + + async def do_log_stats(self): + """Send a DO_LOG_STATS signal to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + with self.socket() as socket: + + # Send RPCGenerateRequest to the RPCServer. + await socket.send_multipart([ + cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + ]) + + # Stream back the results from the RPC Server. + while True: + message = await socket.recv() + request_output = cloudpickle.loads(message) + + if isinstance(request_output, Exception): + raise request_output + + if request_output.finished: + break + yield request_output + + yield request_output + + async def check_health(self) -> None: + """Raise if unhealthy""" + + with self.socket() as socket: + + # Ping RPCServer with CHECK_HEALTH request. + await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) + ) + + # Await the reply from the server. + # TODO: do we need an internal timeout here? + # Or do we expect the external probe to timeout and let this chill? + health_message = cloudpickle.loads(await socket.recv()) + + if isinstance(health_message, Exception): + raise health_message + + if health_message != VLLM_RPC_HEALTHY_STR: + raise ValueError("Expected healthy response from backend but got " + "f{health_message}") + + async def encode(self, *args, + **kwargs) -> AsyncIterator[EmbeddingRequestOutput]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py new file mode 100644 index 000000000000..7a72a6f732c9 --- /dev/null +++ b/vllm/entrypoints/openai/rpc/server.py @@ -0,0 +1,216 @@ +import asyncio +import signal +from typing import Any, Coroutine + +import cloudpickle +import zmq +import zmq.asyncio +from typing_extensions import Never + +from vllm import AsyncEngineArgs, AsyncLLMEngine +from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext + +logger = init_logger(__name__) + + +class AsyncEngineRPCServer: + + def __init__(self, async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int): + # Initialize engine first. + self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, + usage_context) + + # Initialize context. + self.context = zmq.asyncio.Context() + + # Init socket for readiness state. + self.socket = self.context.socket(zmq.constants.ROUTER) + self.socket.bind(f"tcp://localhost:{port}") + + def cleanup(self): + """Cleanup all resources.""" + self.socket.close() + self.context.destroy() + + async def get_model_config(self, identity): + """Send the ModelConfig""" + model_config = await self.engine.get_model_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(model_config)]) + + async def get_decoding_config(self, identity): + """Send the DecodingConfig""" + decoding_config = await self.engine.get_decoding_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(decoding_config)]) + + async def get_lora_config(self, identity): + lora_config = await self.engine.get_lora_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(lora_config)]) + + async def get_scheduler_config(self, identity): + """Send the SchedulerConfig""" + parallel_config = await self.engine.get_scheduler_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def get_parallel_config(self, identity): + """Send the ParallelConfig""" + parallel_config = await self.engine.get_parallel_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def is_tracing_enabled(self, identity): + """Send the is_tracing_enabled flag""" + tracing_flag = await self.engine.is_tracing_enabled() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(tracing_flag)]) + + async def do_log_stats(self, identity): + """Log stats and confirm success.""" + await self.engine.do_log_stats() + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def is_server_ready(self, identity): + """Notify the client that we are ready.""" + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def abort(self, identity, request: RPCAbortRequest): + """Abort request and notify the client of success.""" + # Abort the request in the llm engine. + await self.engine.abort(request.request_id) + + # Send confirmation to the client. + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def generate(self, identity, generate_request: RPCGenerateRequest): + try: + results_generator = self.engine.generate( + generate_request.inputs, + sampling_params=generate_request.sampling_params, + request_id=generate_request.request_id, + lora_request=generate_request.lora_request, + trace_headers=generate_request.trace_headers, + prompt_adapter_request=generate_request.prompt_adapter_request) + + async for request_output in results_generator: + await self.socket.send_multipart( + [identity, cloudpickle.dumps(request_output)]) + + except Exception as e: + ### Notify client of all failures + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + async def check_health(self, identity): + try: + await self.engine.check_health() + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + def _make_handler_coro(self, identity, + message) -> Coroutine[Any, Any, Never]: + """Route the zmq message to the handler coroutine.""" + + request = cloudpickle.loads(message) + + if isinstance(request, RPCGenerateRequest): + return self.generate(identity, request) + + elif isinstance(request, RPCAbortRequest): + return self.abort(identity, request) + + elif isinstance(request, RPCUtilityRequest): + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + return self.get_model_config(identity) + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.get_parallel_config(identity) + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.get_scheduler_config(identity) + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.get_lora_config(identity) + elif request == RPCUtilityRequest.DO_LOG_STATS: + return self.do_log_stats(identity) + elif request == RPCUtilityRequest.IS_SERVER_READY: + return self.is_server_ready(identity) + elif request == RPCUtilityRequest.CHECK_HEALTH: + return self.check_health(identity) + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + return self.is_tracing_enabled(identity) + else: + raise ValueError(f"Unknown RPCUtilityRequest type: {request}") + + else: + raise ValueError(f"Unknown RPCRequest type: {request}") + + async def run_server_loop(self): + """Inner RPC Server Loop""" + + running_tasks = set() + while True: + # Wait for a request. + identity, message = await self.socket.recv_multipart() + + # Process the request async. + task = asyncio.create_task( + self._make_handler_coro(identity, message)) + + # We need to keep around a strong reference to the task, + # to avoid the task disappearing mid-execution as running tasks + # can be GC'ed. Below is a common "fire-and-forget" tasks + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) + + +async def run_server(server: AsyncEngineRPCServer): + # Put the server task into the asyncio loop. + loop = asyncio.get_running_loop() + server_task = loop.create_task(server.run_server_loop()) + + # Interruption handling. + def signal_handler() -> None: + # Kill the server on interrupt / terminate + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + logger.info("vLLM ZMQ RPC Server was interrupted.") + finally: + # Clean up all resources. + server.cleanup() + + +def run_rpc_server(async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int): + server = AsyncEngineRPCServer(async_engine_args, usage_context, port) + asyncio.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c832cf2a24b5..ebb1d57fbb9a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -50,7 +50,7 @@ def __init__( chat_template: Optional[str], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -89,7 +89,8 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -161,7 +162,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = await self.engine.is_tracing_enabled() + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -169,7 +171,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.engine.generate( + result_generator = self.async_engine_client.generate( engine_inputs, sampling_params, request_id, @@ -441,7 +443,7 @@ async def chat_completion_full_generator( async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(request_id) + await self.async_engine_client.abort(request_id) return self.create_error_response("Client disconnected") final_res = res assert final_res is not None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7765c5903f34..edc83d83fbba 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -42,7 +42,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -51,7 +51,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -91,7 +91,8 @@ async def create_completion(self, request: CompletionRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -119,7 +120,8 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.engine.is_tracing_enabled() + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -127,7 +129,7 @@ async def create_completion(self, request: CompletionRequest, raw_request.headers): log_tracing_disabled_warning() - generator = self.engine.generate( + generator = self.async_engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, @@ -168,7 +170,7 @@ async def create_completion(self, request: CompletionRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res @@ -230,7 +232,8 @@ async def completion_stream_generator( # Abort the request if the client disconnects. if await raw_request.is_disconnected(): - await self.engine.abort(f"{request_id}-{prompt_idx}") + await self.async_engine_client.abort( + f"{request_id}-{prompt_idx}") raise StopAsyncIteration() for output in res.outputs: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index bccc90894e79..e61c82f9a8a6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -6,7 +6,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -99,7 +99,8 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) pooling_params = request.to_pooling_params() @@ -124,7 +125,7 @@ async def create_embedding(self, request: EmbeddingRequest, "Prompt adapter is not supported " "for embedding models") - generator = self.engine.encode( + generator = self.async_engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, @@ -146,7 +147,7 @@ async def create_embedding(self, request: EmbeddingRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8c7929a12e9a..df4932d8fe18 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -61,7 +61,7 @@ class OpenAIServing: def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -72,7 +72,7 @@ def __init__( ): super().__init__() - self.engine = engine + self.async_engine_client = async_engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -155,7 +155,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.engine.get_decoding_config() + decoding_config = await self.async_engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 94e1b03ed403..c4350881a27a 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,9 +1,9 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine # yapf conflicts with isort for this block # yapf: disable +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -24,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -32,7 +32,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -57,7 +57,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -113,7 +113,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index 9bcb26f8e5a6..01461512343d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: VLLM_HOST_IP: str = "" VLLM_PORT: Optional[int] = None + VLLM_RPC_PORT: int = 5570 VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_INSTANCE_ID: Optional[str] = None @@ -139,6 +140,11 @@ def get_default_config_root(): lambda: int(os.getenv('VLLM_PORT', '0')) if 'VLLM_PORT' in os.environ else None, + # used when the frontend api server is running in multi-processing mode, + # to communicate with the backend engine process over ZMQ. + 'VLLM_RPC_PORT': + lambda: int(os.getenv('VLLM_PORT', '5570')), + # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers "VLLM_USE_MODELSCOPE": diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 1c8f6cccb3e9..554dcc0ed43e 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,6 +21,8 @@ from typing import Callable, DefaultDict, Dict, List, Union import torch +from lark import Lark +from outlines import grammars from outlines.caching import cache from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.json_schema import build_regex_from_schema @@ -44,6 +46,23 @@ def __call__(self, input_ids: List[int], last_seq_id = hash(tuple(input_ids[:-1])) self._fsm_state[seq_id] = self._guide.get_next_state( state=self._fsm_state[last_seq_id], token_id=last_token) + else: + # Note: this is a hack. + # Lark pickling does not work properly (silent failure), + # which breaks the RPC (which uses python pickleing). + # We need to find a better solution. + # On the first time this is called, we simply re-create + # the Lark object. + if isinstance(self._guide, CFGGuide): + self._guide.parser = Lark( + self._guide.cfg_string, + parser="lalr", + lexer="contextual", + propagate_positions=False, + maybe_placeholders=False, + regex=True, + import_paths=[grammars.GRAMMAR_PATH], + ) instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) diff --git a/vllm/tracing.py b/vllm/tracing.py index dc8377f2396f..7ac38e6a0f66 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -60,7 +60,7 @@ def get_span_exporter(endpoint): OTLPSpanExporter) elif protocol == "http/protobuf": from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) + OTLPSpanExporter) # type: ignore else: raise ValueError( f"Unsupported OTLP protocol '{protocol}' is configured") diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 7a0436dd1fb1..eeab19899b02 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,6 +1,7 @@ from typing import Optional, Type -from vllm.config import TokenizerPoolConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + TokenizerPoolConfig) from vllm.executor.ray_utils import ray from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup @@ -13,6 +14,22 @@ RayTokenizerGroupPool = None # type: ignore +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: tokenizer_cls: Type[BaseTokenizerGroup] diff --git a/vllm/utils.py b/vllm/utils.py index c4c17bfbefc6..51bd72977a22 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -290,6 +290,10 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: return _async_wrapper +class ProducerFinished: + pass + + def merge_async_iterators( *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: """Merge multiple asynchronous iterators into a single iterator. @@ -298,9 +302,10 @@ def merge_async_iterators( When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. """ - queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, + Exception]] = asyncio.Queue() - finished = [False] * len(iterators) + producers = len(iterators) async def producer(i: int, iterator: AsyncIterator[T]): try: @@ -308,7 +313,8 @@ async def producer(i: int, iterator: AsyncIterator[T]): await queue.put((i, item)) except Exception as e: await queue.put(e) - finished[i] = True + # Signal to the consumer that we've finished + await queue.put(ProducerFinished()) _tasks = [ asyncio.create_task(producer(i, iterator)) @@ -316,9 +322,17 @@ async def producer(i: int, iterator: AsyncIterator[T]): ] async def consumer(): + remaining = producers try: - while not all(finished) or not queue.empty(): + while remaining or not queue.empty(): + # we think there is a race condition here item = await queue.get() + + if isinstance(item, ProducerFinished): + # Signal that a producer finished- not a real item + remaining -= 1 + continue + if isinstance(item, Exception): raise item yield item @@ -374,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" -def get_open_port() -> int: - port = envs.VLLM_PORT +def get_open_port(port: Optional[int] = None) -> int: + if port is None: + # Default behavior here is to return a port for multi-gpu communication + port = envs.VLLM_PORT if port is not None: while True: try: