From bed649a62aaf143b0c8cb71f2bc63bfbc5d6b9c1 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 25 Jul 2024 11:44:28 -0600 Subject: [PATCH 01/80] :alembic: add backend proto file Signed-off-by: Joe Runde --- vllm/proto/generate.proto | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 vllm/proto/generate.proto diff --git a/vllm/proto/generate.proto b/vllm/proto/generate.proto new file mode 100644 index 000000000000..258386c6535b --- /dev/null +++ b/vllm/proto/generate.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package generate.v1; + +service TextGenerationService { + rpc Generate (GenerateRequest) returns (GenerateResponse); // will need a streaming version later- unary for POC + + rpc Health (HealthRequest) returns (HealthResponse); +} + + +message GenerateRequest { + PromptInputs prompt_inputs = 1; + // SamplingParams sampling_params = 2; + string request_id = 3; +} + +message PromptInputs { + string prompt = 1; +} + +message GenerateResponse { + repeated CompletionOutput outputs = 1; // 1 per generation from a single prompt +} + +message CompletionOutput { + uint64 index = 1; + repeated uint64 token_ids = 2; +} + From 7de9d4926ed47ee52056e0759ec330951257cbc5 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 25 Jul 2024 11:57:37 -0600 Subject: [PATCH 02/80] :recycle: move proto to grpc/pb Signed-off-by: Joe Runde --- vllm/{proto => grpc/pb}/generate.proto | 1 - 1 file changed, 1 deletion(-) rename vllm/{proto => grpc/pb}/generate.proto (99%) diff --git a/vllm/proto/generate.proto b/vllm/grpc/pb/generate.proto similarity index 99% rename from vllm/proto/generate.proto rename to vllm/grpc/pb/generate.proto index 258386c6535b..8919522ac9ae 100644 --- a/vllm/proto/generate.proto +++ b/vllm/grpc/pb/generate.proto @@ -27,4 +27,3 @@ message CompletionOutput { uint64 index = 1; repeated uint64 token_ids = 2; } - From 9394a62266155d64acad8fcb95edb64a82e7207a Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 25 Jul 2024 11:59:43 -0600 Subject: [PATCH 03/80] :sparkles: add proto compilation Signed-off-by: Joe Runde --- setup.py | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 72ef26f15e40..1588280ba3bb 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,12 @@ from shutil import which from typing import Dict, List +from shlex import split +from subprocess import CalledProcessError, check_call +from textwrap import dedent +from setuptools.command.build_py import build_py +from setuptools.errors import SetupError + import torch from packaging.version import Version, parse from setuptools import Extension, find_packages, setup @@ -28,6 +34,39 @@ def load_module_from_path(module_name, path): logger = logging.getLogger(__name__) +class BuildPyAndGenerateGrpc(build_py): + """build python module using protoc to prepare generated files.""" + + proto_source = "vllm/grpc/pb/generation.proto" + + def run(self): + print(f"Invoking protoc on {self.proto_source}") + + # NOTE: imports in generated files will be broken unless some care is given in + # how --proto_path, --*_out and .proto paths are given. + # + # See https://github.com/grpc/grpc/issues/9575#issuecomment-293934506 + try: + check_call( + split( + dedent( + f""" + python -m grpc_tools.protoc \ + --proto_path=src \ + --python_out=src/ \ + --grpc_python_out=src/ \ + --mypy_out=src/ \ + {self.proto_source} + """, + ), + ) + ) + except CalledProcessError as exc: + raise SetupError(f"protoc failed, exit code {exc.returncode}") from exc + + super().run() + + def embed_commit_hash(): try: if "BUILDKITE_COMMIT" in os.environ: @@ -486,7 +525,7 @@ def _read_requirements(filename: str) -> List[str]: extras_require={ "tensorizer": ["tensorizer>=2.9.0"], }, - cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, + cmdclass={"build_ext": cmake_build_ext, "build_py": BuildPyAndGenerateGrpc} if _build_custom_ops() else {"build_py": BuildPyAndGenerateGrpc}, package_data=package_data, entry_points={ "console_scripts": [ From dd8bf96b014f954193393856d0dd027dfca00d63 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 25 Jul 2024 20:26:39 +0000 Subject: [PATCH 04/80] updated --- .gitignore | 3 ++ requirements-common.txt | 2 + setup.py | 9 ++-- vllm/entrypoints/openai/api_server.py | 12 +++-- vllm/grpc/__init__.py | 0 vllm/grpc/client.py | 51 ++++++++++++++++++ vllm/grpc/pb/__init__.py | 0 vllm/grpc/pb/generate.proto | 10 ++-- vllm/grpc/server.py | 74 +++++++++++++++++++++++++++ 9 files changed, 145 insertions(+), 16 deletions(-) create mode 100644 vllm/grpc/__init__.py create mode 100644 vllm/grpc/client.py create mode 100644 vllm/grpc/pb/__init__.py create mode 100644 vllm/grpc/server.py diff --git a/.gitignore b/.gitignore index 17184b19127c..fffd4683e3dc 100644 --- a/.gitignore +++ b/.gitignore @@ -190,3 +190,6 @@ hip_compat.h # Benchmark dataset *.json + +# Protobuf +pb2 diff --git a/requirements-common.txt b/requirements-common.txt index 3b8d473c1fe7..b3d4736f282f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -22,3 +22,5 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq +grpcio +grpcio-tools diff --git a/setup.py b/setup.py index 1588280ba3bb..cf920065a822 100644 --- a/setup.py +++ b/setup.py @@ -52,10 +52,9 @@ def run(self): dedent( f""" python -m grpc_tools.protoc \ - --proto_path=src \ - --python_out=src/ \ - --grpc_python_out=src/ \ - --mypy_out=src/ \ + --proto_path=. \ + --python_out=. \ + --grpc_python_out=. \ {self.proto_source} """, ), @@ -525,7 +524,7 @@ def _read_requirements(filename: str) -> List[str]: extras_require={ "tensorizer": ["tensorizer>=2.9.0"], }, - cmdclass={"build_ext": cmake_build_ext, "build_py": BuildPyAndGenerateGrpc} if _build_custom_ops() else {"build_py": BuildPyAndGenerateGrpc}, + cmdclass={"build_py": BuildPyAndGenerateGrpc, "build_ext": cmake_build_ext,} if _build_custom_ops() else {"build_py": BuildPyAndGenerateGrpc}, package_data=package_data, entry_points={ "console_scripts": [ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0fe4dd245b5e..41cee898b4ac 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -221,12 +221,14 @@ async def build_server( ) -> uvicorn.Server: app = build_app(args) - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] + # if args.served_model_name is not None: + # served_model_names = args.served_model_name + # else: + # served_model_names = [args.model] + + served_model_names = "meta-llama/Meta-Llama-3-8B-Instruct" - global engine, engine_args + # global engine, engine_args engine_args = AsyncEngineArgs.from_cli_args(args) engine = (llm_engine diff --git a/vllm/grpc/__init__.py b/vllm/grpc/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py new file mode 100644 index 000000000000..5a08646018d3 --- /dev/null +++ b/vllm/grpc/client.py @@ -0,0 +1,51 @@ +from vllm import AsyncLLMEngine +import grpc +from .pb import generate_pb2_grpc, generate_pb2 +from typing import AsyncIterator, Optional, Mapping + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + + + +class TextGenerationClient(AsyncLLMEngine): + def __init__(self): + channel = grpc.insecure_channel("localhost:5543") + self.stub = generate_pb2_grpc.TextGenerationServiceStub(channel) + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + + generate_stream = self.stub.Generate( + generate_pb2.GenerateRequest( + prompt_inputs=generate_pb2.PromptInputs(prompt=inputs.prompt), + request_id=request_id, + ) + ) + + async for generate_response in generate_stream: + completion_outputs = [ + CompletionOutput( + index=output.index, + text=output.text, + token_ids=output.token_ids, + cumulative_logprob=0.0, + ) for output in generate_response.outputs + ] + + yield RequestOutput( + request_id=request_id, + prompt_token_ids=[], + outputs=completion_outputs + ) \ No newline at end of file diff --git a/vllm/grpc/pb/__init__.py b/vllm/grpc/pb/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/grpc/pb/generate.proto b/vllm/grpc/pb/generate.proto index 8919522ac9ae..ea74e4909d2f 100644 --- a/vllm/grpc/pb/generate.proto +++ b/vllm/grpc/pb/generate.proto @@ -3,16 +3,13 @@ syntax = "proto3"; package generate.v1; service TextGenerationService { - rpc Generate (GenerateRequest) returns (GenerateResponse); // will need a streaming version later- unary for POC - - rpc Health (HealthRequest) returns (HealthResponse); + rpc Generate (GenerateRequest) returns (stream GenerateResponse); // will need a streaming version later- unary for POC } message GenerateRequest { PromptInputs prompt_inputs = 1; - // SamplingParams sampling_params = 2; - string request_id = 3; + string request_id = 2; } message PromptInputs { @@ -20,10 +17,11 @@ message PromptInputs { } message GenerateResponse { - repeated CompletionOutput outputs = 1; // 1 per generation from a single prompt + repeated CompletionOutput outputs = 1; } message CompletionOutput { uint64 index = 1; repeated uint64 token_ids = 2; + string text = 3; } diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py new file mode 100644 index 000000000000..669e5fe64c5e --- /dev/null +++ b/vllm/grpc/server.py @@ -0,0 +1,74 @@ +from .pb import generate_pb2_grpc, generate_pb2 +from .pb.generate_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR +from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams +from collections.abc import AsyncIterator +from grpc import aio +import asyncio +from grpc_reflection.v1alpha import reflection + +MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" +MAX_TOKENS = 200 +TEMPERATURE = 0 + + +class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): + SERVICE_NAME = _GENERATION_DESCRIPTOR.services_by_name[ + "TextGenerationService" + ].full_name + + def __init__(self): + self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(model=MODEL, enforce_eager=True)) + + async def Generate( + self, request: generate_pb2.GenerateRequest, context + ) -> AsyncIterator[generate_pb2.GenerateResponse]: + + results_generator = self.engine.generate( + request.prompt_inputs.prompt, + sampling_params=SamplingParams(max_tokens=MAX_TOKENS, + temperature=TEMPERATURE), + request_id=request.request_id) + + async for request_output in results_generator: + outputs = [ + generate_pb2.CompletionOutput( + index=output.index, + token_ids=output.token_ids) + for output in request_output.outputs + ] + yield generate_pb2.GenerateResponse(outputs=outputs) + + +async def start_grpc_server() -> aio.Server: + server = aio.server() + generation = TextGenerationService() + generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(generation, server) + + service_names = ( + generation.SERVICE_NAME, + reflection.SERVICE_NAME, + ) + + reflection.enable_server_reflection(service_names, server) + + host = "0.0.0.0" + grpc_port = 5543 + server.add_insecure_port(f"{host}:{grpc_port}") + await server.start() + return server + + +async def run_grpc_server() -> None: + server = await start_grpc_server() + + try: + while True: + await asyncio.sleep(10) + + except asyncio.CancelledError: + print("Gracefully stopping gRPC server") # noqa: T201 + await server.stop(30) # TODO configurable grace + await server.wait_for_termination() + +if __name__ == "__main__": + asyncio.run(run_grpc_server()) From 5c7fbffb5fa1256247045911db72d4952b9bde77 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 25 Jul 2024 20:35:29 +0000 Subject: [PATCH 05/80] kinda working --- examples/openai_completion_client.py | 5 +- vllm/entrypoints/openai/api_server.py | 73 ++++++++++--------- vllm/entrypoints/openai/serving_completion.py | 4 +- vllm/entrypoints/openai/serving_engine.py | 7 +- vllm/grpc/client.py | 1 + vllm/grpc/pb/generate.proto | 1 + vllm/grpc/server.py | 5 +- 7 files changed, 52 insertions(+), 44 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 58519f978d34..01e5c0d54e1e 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -19,9 +19,8 @@ model=model, prompt="A robot may not injure a human being", echo=False, - n=2, - stream=stream, - logprobs=3) + n=1, + stream=stream) print("Completion results:") if stream: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 41cee898b4ac..31c30f866ef2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -64,10 +64,10 @@ async def _force_log(): await asyncio.sleep(10) await engine.do_log_stats() - if not engine_args.disable_log_stats: - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) + # if not engine_args.disable_log_stats: + # task = asyncio.create_task(_force_log()) + # _running_tasks.add(task) + # task.add_done_callback(_running_tasks.remove) yield @@ -228,14 +228,17 @@ async def build_server( served_model_names = "meta-llama/Meta-Llama-3-8B-Instruct" + from vllm.grpc.client import TextGenerationClient + engine = TextGenerationClient() + # global engine, engine_args - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = (llm_engine - if llm_engine is not None else AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) + # engine_args = AsyncEngineArgs.from_cli_args(args) + # engine = (llm_engine + # if llm_engine is not None else AsyncLLMEngine.from_engine_args( + # engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) - model_config = await engine.get_model_config() + # model_config = await engine.get_model_config() if args.disable_log_requests: request_logger = None @@ -247,40 +250,40 @@ async def build_server( global openai_serving_embedding global openai_serving_tokenization - openai_serving_chat = OpenAIServingChat( - engine, - model_config, - served_model_names, - args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, - request_logger=request_logger, - chat_template=args.chat_template, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) + # openai_serving_chat = OpenAIServingChat( + # engine, + # model_config, + # served_model_names, + # args.response_role, + # lora_modules=args.lora_modules, + # prompt_adapters=args.prompt_adapters, + # request_logger=request_logger, + # chat_template=args.chat_template, + # return_tokens_as_token_ids=args.return_tokens_as_token_ids, + # ) openai_serving_completion = OpenAIServingCompletion( engine, - model_config, + # model_config, served_model_names, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) - openai_serving_embedding = OpenAIServingEmbedding( - engine, - model_config, - served_model_names, - request_logger=request_logger, - ) - openai_serving_tokenization = OpenAIServingTokenization( - engine, - model_config, - served_model_names, - lora_modules=args.lora_modules, - request_logger=request_logger, - chat_template=args.chat_template, - ) + # openai_serving_embedding = OpenAIServingEmbedding( + # engine, + # model_config, + # served_model_names, + # request_logger=request_logger, + # ) + # openai_serving_tokenization = OpenAIServingTokenization( + # engine, + # model_config, + # served_model_names, + # lora_modules=args.lora_modules, + # request_logger=request_logger, + # chat_template=args.chat_template, + # ) app.root_path = args.root_path logger.info("Available routes are:") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 73e420141813..d80b73b360e9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -45,7 +45,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, engine: AsyncLLMEngine, - model_config: ModelConfig, + # model_config: ModelConfig, served_model_names: List[str], *, lora_modules: Optional[List[LoRAModulePath]], @@ -54,7 +54,7 @@ def __init__( return_tokens_as_token_ids: bool = False, ): super().__init__(engine=engine, - model_config=model_config, + # model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, prompt_adapters=prompt_adapters, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 321c9ac2c1d5..ebe4b787f98b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -62,7 +62,7 @@ class OpenAIServing: def __init__( self, engine: AsyncLLMEngine, - model_config: ModelConfig, + # model_config: ModelConfig, served_model_names: List[str], *, lora_modules: Optional[List[LoRAModulePath]], @@ -73,8 +73,9 @@ def __init__( super().__init__() self.engine = engine - self.model_config = model_config - self.max_model_len = model_config.max_model_len + # self.model_config = model_config + # self.max_model_len = model_config.max_model_len + self.max_model_len = 4096 self.served_model_names = served_model_names diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 5a08646018d3..55255786b056 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -41,6 +41,7 @@ async def generate( text=output.text, token_ids=output.token_ids, cumulative_logprob=0.0, + finish_reason=output.finish_reason, ) for output in generate_response.outputs ] diff --git a/vllm/grpc/pb/generate.proto b/vllm/grpc/pb/generate.proto index ea74e4909d2f..511129e37f83 100644 --- a/vllm/grpc/pb/generate.proto +++ b/vllm/grpc/pb/generate.proto @@ -24,4 +24,5 @@ message CompletionOutput { uint64 index = 1; repeated uint64 token_ids = 2; string text = 3; + string finish_reason = 4; } diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 669e5fe64c5e..362c080a90c6 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -33,7 +33,9 @@ async def Generate( outputs = [ generate_pb2.CompletionOutput( index=output.index, - token_ids=output.token_ids) + token_ids=output.token_ids, + text=output.text, + finish_reason=output.finish_reason) for output in request_output.outputs ] yield generate_pb2.GenerateResponse(outputs=outputs) @@ -55,6 +57,7 @@ async def start_grpc_server() -> aio.Server: grpc_port = 5543 server.add_insecure_port(f"{host}:{grpc_port}") await server.start() + print("ready") return server From 952e8ef9b802f5601982d60bc05d7ac434d63faf Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 25 Jul 2024 15:34:54 -0600 Subject: [PATCH 06/80] :construction: more wip Signed-off-by: Joe Runde --- setup.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 40 +++++++++--------- vllm/grpc/client.py | 41 +++++++++++++++++-- vllm/grpc/pb/generate.proto | 3 +- vllm/grpc/server.py | 10 ++++- 5 files changed, 69 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index cf920065a822..ec1feb7b76b0 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def load_module_from_path(module_name, path): class BuildPyAndGenerateGrpc(build_py): """build python module using protoc to prepare generated files.""" - proto_source = "vllm/grpc/pb/generation.proto" + proto_source = "vllm/grpc/pb/generate.proto" def run(self): print(f"Invoking protoc on {self.proto_source}") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d80b73b360e9..95252373689d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -96,18 +96,18 @@ async def create_completion(self, request: CompletionRequest, tokenizer = await self.engine.get_tokenizer(lora_request) sampling_params = request.to_sampling_params() - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend - guided_decode_logit_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logit_processor is not None: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logit_processor) + # decoding_config = await self.engine.get_decoding_config() + # guided_decoding_backend = request.guided_decoding_backend \ + # or decoding_config.guided_decoding_backend + # guided_decode_logit_processor = ( + # await + # get_guided_decoding_logits_processor(guided_decoding_backend, + # request, tokenizer)) + # if guided_decode_logit_processor is not None: + # if sampling_params.logits_processors is None: + # sampling_params.logits_processors = [] + # sampling_params.logits_processors.append( + # guided_decode_logit_processor) prompts = list( self._tokenize_prompt_input_or_inputs( @@ -128,13 +128,13 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.engine.is_tracing_enabled() - trace_headers = None - if is_tracing_enabled: - trace_headers = extract_trace_headers(raw_request.headers) - if not is_tracing_enabled and contains_trace_headers( - raw_request.headers): - log_tracing_disabled_warning() + # is_tracing_enabled = await self.engine.is_tracing_enabled() + # trace_headers = None + # if is_tracing_enabled: + # trace_headers = extract_trace_headers(raw_request.headers) + # if not is_tracing_enabled and contains_trace_headers( + # raw_request.headers): + # log_tracing_disabled_warning() generator = self.engine.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, @@ -142,7 +142,7 @@ async def create_completion(self, request: CompletionRequest, request_id_item, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, + # trace_headers=trace_headers, ) generators.append(generator) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 55255786b056..18defdf1d3f9 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -1,7 +1,8 @@ +import asyncio from vllm import AsyncLLMEngine import grpc from .pb import generate_pb2_grpc, generate_pb2 -from typing import AsyncIterator, Optional, Mapping +from typing import AsyncIterator, List, Optional, Mapping from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest @@ -14,9 +15,37 @@ class TextGenerationClient(AsyncLLMEngine): def __init__(self): - channel = grpc.insecure_channel("localhost:5543") + channel = grpc.aio.insecure_channel("localhost:5543") self.stub = generate_pb2_grpc.TextGenerationServiceStub(channel) + self.engine_use_ray = False + self.worker_use_ray = False + self.log_requests = False + self.engine = None + + @property + def is_running(self) -> bool: + return True + + @property + def is_stopped(self) -> bool: + return False + + @property + def errored(self) -> bool: + return False + + async def run_engine_loop(self): + while True: + await asyncio.sleep(1) + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> "PreTrainedTokenizer": + # TODO: what to return :/ + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained("facebook/opt-125m") + async def generate( self, inputs: PromptInputs, @@ -27,9 +56,15 @@ async def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: + prompt: str = inputs.get('prompt', "") + prompt_token_ids: List[int] = inputs.get('prompt_token_ids', []) + generate_stream = self.stub.Generate( generate_pb2.GenerateRequest( - prompt_inputs=generate_pb2.PromptInputs(prompt=inputs.prompt), + prompt_inputs=generate_pb2.PromptInputs( + prompt=prompt, + prompt_token_ids=prompt_token_ids, + ), request_id=request_id, ) ) diff --git a/vllm/grpc/pb/generate.proto b/vllm/grpc/pb/generate.proto index 511129e37f83..282914e7a60f 100644 --- a/vllm/grpc/pb/generate.proto +++ b/vllm/grpc/pb/generate.proto @@ -3,7 +3,7 @@ syntax = "proto3"; package generate.v1; service TextGenerationService { - rpc Generate (GenerateRequest) returns (stream GenerateResponse); // will need a streaming version later- unary for POC + rpc Generate (GenerateRequest) returns (stream GenerateResponse); } @@ -14,6 +14,7 @@ message GenerateRequest { message PromptInputs { string prompt = 1; + repeated uint64 prompt_token_ids = 2; } message GenerateResponse { diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 362c080a90c6..a4e8179aa320 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -1,3 +1,4 @@ +from vllm.inputs.data import TextPrompt, TokensPrompt from .pb import generate_pb2_grpc, generate_pb2 from .pb.generate_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams @@ -6,7 +7,7 @@ import asyncio from grpc_reflection.v1alpha import reflection -MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL = "facebook/opt-125m" MAX_TOKENS = 200 TEMPERATURE = 0 @@ -22,9 +23,14 @@ def __init__(self): async def Generate( self, request: generate_pb2.GenerateRequest, context ) -> AsyncIterator[generate_pb2.GenerateResponse]: + + if len(request.prompt_inputs.prompt_token_ids) > 0: + inputs = TokensPrompt(prompt_token_ids=request.prompt_inputs.prompt_token_ids) + else: + inputs = TextPrompt(prompt=request.prompt_inputs.prompt) results_generator = self.engine.generate( - request.prompt_inputs.prompt, + inputs, sampling_params=SamplingParams(max_tokens=MAX_TOKENS, temperature=TEMPERATURE), request_id=request.request_id) From e8eac95c17bc28a403e999b2132b2fad5916d4cd Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 25 Jul 2024 22:18:09 +0000 Subject: [PATCH 07/80] fixed --- vllm/entrypoints/openai/serving_completion.py | 5 ++++- vllm/grpc/client.py | 6 +++++- vllm/grpc/server.py | 2 +- vllm/utils.py | 7 ++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 95252373689d..3080e06be86e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -152,6 +152,7 @@ async def create_completion(self, request: CompletionRequest, result_generator: AsyncIterator[Tuple[ int, RequestOutput]] = merge_async_iterators(*generators) + # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use @@ -175,6 +176,8 @@ async def create_completion(self, request: CompletionRequest, final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: + # async for res in generators[0]: + # i = 0 if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(f"{request_id}-{i}") @@ -189,7 +192,6 @@ async def create_completion(self, request: CompletionRequest, # with the inputs token IDs if final_res.prompt is None: final_res.prompt = prompts[i]["prompt"] - final_res_batch_checked = cast(List[RequestOutput], final_res_batch) @@ -236,6 +238,7 @@ async def completion_stream_generator( try: async for prompt_idx, res in result_generator: + breakpoint() # Abort the request if the client disconnects. if await raw_request.is_disconnected(): diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 18defdf1d3f9..354065970c96 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -76,6 +76,7 @@ async def generate( text=output.text, token_ids=output.token_ids, cumulative_logprob=0.0, + logprobs=None, finish_reason=output.finish_reason, ) for output in generate_response.outputs ] @@ -83,5 +84,8 @@ async def generate( yield RequestOutput( request_id=request_id, prompt_token_ids=[], - outputs=completion_outputs + outputs=completion_outputs, + finished=(completion_outputs[0].finish_reason != ""), + prompt_logprobs=None, + prompt=prompt, ) \ No newline at end of file diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index a4e8179aa320..5e1f2e303e0a 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -8,7 +8,7 @@ from grpc_reflection.v1alpha import reflection MODEL = "facebook/opt-125m" -MAX_TOKENS = 200 +MAX_TOKENS = 10 TEMPERATURE = 0 diff --git a/vllm/utils.py b/vllm/utils.py index 876c3bf90b02..8e6d5da77c12 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -298,14 +298,18 @@ def merge_async_iterators( queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() finished = [False] * len(iterators) + print(f"len(iterators) = {len(iterators)}") async def producer(i: int, iterator: AsyncIterator[T]): try: async for item in iterator: + print(f"{i}: before producer loop") await queue.put((i, item)) + print(f"{i}: after producer await") except Exception as e: await queue.put(e) finished[i] = True + print("producer finished") _tasks = [ asyncio.create_task(producer(i, iterator)) @@ -315,7 +319,8 @@ async def producer(i: int, iterator: AsyncIterator[T]): async def consumer(): try: while not all(finished) or not queue.empty(): - item = await queue.get() + # we think there is a race condition here + item = await queue.get(timeout=0.1) if isinstance(item, Exception): raise item yield item From 938a8431d896f879b8a3c227e1db90105ec759cf Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 25 Jul 2024 16:49:16 -0600 Subject: [PATCH 08/80] :bug: fixup race condition Signed-off-by: Joe Runde --- vllm/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/utils.py b/vllm/utils.py index 8e6d5da77c12..03909eea5d7a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -287,6 +287,9 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: return _async_wrapper +class ProducerFinished: + pass + def merge_async_iterators( *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: """Merge multiple asynchronous iterators into a single iterator. @@ -310,6 +313,8 @@ async def producer(i: int, iterator: AsyncIterator[T]): await queue.put(e) finished[i] = True print("producer finished") + # Signal to the consumer that we've finished + await queue.put(ProducerFinished()) _tasks = [ asyncio.create_task(producer(i, iterator)) @@ -321,6 +326,11 @@ async def consumer(): while not all(finished) or not queue.empty(): # we think there is a race condition here item = await queue.get(timeout=0.1) + + if isinstance(item, ProducerFinished): + # Signal that a producer finished- not a real item + continue + if isinstance(item, Exception): raise item yield item From 2b8d7cd7b95e790a0409090525a794e545e546f7 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 25 Jul 2024 17:33:51 -0600 Subject: [PATCH 09/80] :bug: remove timeout Signed-off-by: Joe Runde --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 03909eea5d7a..7489cc910dc9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -325,7 +325,7 @@ async def consumer(): try: while not all(finished) or not queue.empty(): # we think there is a race condition here - item = await queue.get(timeout=0.1) + item = await queue.get() if isinstance(item, ProducerFinished): # Signal that a producer finished- not a real item From ea02d399da9cd19e6b308aab982bf2c6a078b696 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 26 Jul 2024 13:02:13 +0000 Subject: [PATCH 10/80] format --- vllm/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 7489cc910dc9..9cf6aa9b91d6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -306,13 +306,10 @@ def merge_async_iterators( async def producer(i: int, iterator: AsyncIterator[T]): try: async for item in iterator: - print(f"{i}: before producer loop") await queue.put((i, item)) - print(f"{i}: after producer await") except Exception as e: await queue.put(e) finished[i] = True - print("producer finished") # Signal to the consumer that we've finished await queue.put(ProducerFinished()) From 4a2dc460f270b60ca79449631643234681c2f919 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 26 Jul 2024 15:21:03 +0000 Subject: [PATCH 11/80] streaming --- benchmarks/backend_request_func.py | 3 +++ examples/openai_completion_client.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 7 ++++--- vllm/grpc/server.py | 9 +++++---- vllm/utils.py | 1 - 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index fbab547d094f..4d9d148fada7 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -252,6 +252,7 @@ async def async_request_openai_completions( try: async with session.post(url=api_url, json=payload, headers=headers) as response: + # breakpoint() if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -281,6 +282,8 @@ async def async_request_openai_completions( most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] + # print(generated_text) + # breakpoint() output.generated_text = generated_text output.success = True diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 01e5c0d54e1e..cf932e67f9a4 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -14,7 +14,7 @@ model = models.data[0].id # Completion API -stream = False +stream = True completion = client.completions.create( model=model, prompt="A robot may not injure a human being", diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3080e06be86e..60d78edfe47b 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -238,8 +238,6 @@ async def completion_stream_generator( try: async for prompt_idx, res in result_generator: - breakpoint() - # Abort the request if the client disconnects. if await raw_request.is_disconnected(): await self.engine.abort(f"{request_id}-{prompt_idx}") @@ -289,7 +287,7 @@ async def completion_stream_generator( previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) - finish_reason = output.finish_reason + finish_reason = None if output.finish_reason == "" else output.finish_reason stop_reason = output.stop_reason chunk = CompletionStreamResponse( @@ -321,7 +319,10 @@ async def completion_stream_generator( else: chunk.usage = None + response_json = chunk.model_dump_json(exclude_unset=False) + print(response_json) + breakpoint() yield f"data: {response_json}\n\n" if (request.stream_options diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 5e1f2e303e0a..0c37e257989b 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from grpc import aio import asyncio -from grpc_reflection.v1alpha import reflection +# from grpc_reflection.v1alpha import reflection MODEL = "facebook/opt-125m" MAX_TOKENS = 10 @@ -22,7 +22,8 @@ def __init__(self): async def Generate( self, request: generate_pb2.GenerateRequest, context - ) -> AsyncIterator[generate_pb2.GenerateResponse]: + # ) -> AsyncIterator[generate_pb2.GenerateResponse]: + ) -> AsyncIterator: if len(request.prompt_inputs.prompt_token_ids) > 0: inputs = TokensPrompt(prompt_token_ids=request.prompt_inputs.prompt_token_ids) @@ -54,10 +55,10 @@ async def start_grpc_server() -> aio.Server: service_names = ( generation.SERVICE_NAME, - reflection.SERVICE_NAME, + # reflection.SERVICE_NAME, ) - reflection.enable_server_reflection(service_names, server) + # reflection.enable_server_reflection(service_names, server) host = "0.0.0.0" grpc_port = 5543 diff --git a/vllm/utils.py b/vllm/utils.py index 9cf6aa9b91d6..9537c07aabc8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -301,7 +301,6 @@ def merge_async_iterators( queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() finished = [False] * len(iterators) - print(f"len(iterators) = {len(iterators)}") async def producer(i: int, iterator: AsyncIterator[T]): try: From 30f2bc9a57a905f27a686134f67e6ecceaa22c42 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 26 Jul 2024 15:23:30 +0000 Subject: [PATCH 12/80] removed breaks --- vllm/entrypoints/openai/serving_completion.py | 4 ++-- vllm/grpc/server.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 60d78edfe47b..315c47ab9d1e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -321,8 +321,8 @@ async def completion_stream_generator( response_json = chunk.model_dump_json(exclude_unset=False) - print(response_json) - breakpoint() + # print(response_json) + # breakpoint() yield f"data: {response_json}\n\n" if (request.stream_options diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 0c37e257989b..4662c4910c14 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -7,8 +7,9 @@ import asyncio # from grpc_reflection.v1alpha import reflection -MODEL = "facebook/opt-125m" -MAX_TOKENS = 10 +# MODEL = "facebook/opt-125m" +MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" +MAX_TOKENS = 150 TEMPERATURE = 0 From c718b68d885c509fe818dfa5aa863a2e08f59036 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 26 Jul 2024 16:10:22 +0000 Subject: [PATCH 13/80] pushing current state --- benchmarks/benchmark_serving.py | 3 +++ vllm/grpc/client.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fc0dbf77f16b..c7b728d6045f 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -362,9 +362,12 @@ async def benchmark( ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("TOKENS PER REQUESTS:", + metrics.total_output // metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 354065970c96..94978049f92f 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingParams +MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" class TextGenerationClient(AsyncLLMEngine): def __init__(self): @@ -44,7 +45,7 @@ async def get_tokenizer( ) -> "PreTrainedTokenizer": # TODO: what to return :/ from transformers import AutoTokenizer - return AutoTokenizer.from_pretrained("facebook/opt-125m") + return AutoTokenizer.from_pretrained(MODEL) async def generate( self, From b3d25c6b7ba26d9772fd962a3265ace3b6b7ae15 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 26 Jul 2024 10:19:10 -0600 Subject: [PATCH 14/80] :alembic: try unix sockets Signed-off-by: Joe Runde --- vllm/grpc/client.py | 5 ++++- vllm/grpc/server.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 94978049f92f..af66ca6f3fdd 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -1,6 +1,8 @@ import asyncio from vllm import AsyncLLMEngine import grpc + +from vllm.grpc.server import UNIX_SOCKET from .pb import generate_pb2_grpc, generate_pb2 from typing import AsyncIterator, List, Optional, Mapping @@ -16,7 +18,8 @@ class TextGenerationClient(AsyncLLMEngine): def __init__(self): - channel = grpc.aio.insecure_channel("localhost:5543") + # channel = grpc.aio.insecure_channel("localhost:5543") + channel = grpc.aio.insecure_channel(UNIX_SOCKET) self.stub = generate_pb2_grpc.TextGenerationServiceStub(channel) self.engine_use_ray = False self.worker_use_ray = False diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 4662c4910c14..010119901b61 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -11,7 +11,7 @@ MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" MAX_TOKENS = 150 TEMPERATURE = 0 - +UNIX_SOCKET = "unix:///tmp/ricky-bobby" class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): SERVICE_NAME = _GENERATION_DESCRIPTOR.services_by_name[ @@ -63,7 +63,8 @@ async def start_grpc_server() -> aio.Server: host = "0.0.0.0" grpc_port = 5543 - server.add_insecure_port(f"{host}:{grpc_port}") + # server.add_insecure_port(f"{host}:{grpc_port}") + server.add_insecure_port(UNIX_SOCKET) await server.start() print("ready") return server From 2765b17bbdc727a93d7fa0d8782280ba873fcdb1 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 26 Jul 2024 10:27:38 -0600 Subject: [PATCH 15/80] :zap: no background loop Signed-off-by: Joe Runde --- vllm/grpc/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index af66ca6f3fdd..654b8007b73d 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -38,9 +38,9 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return False - async def run_engine_loop(self): - while True: - await asyncio.sleep(1) + def start_background_loop(self): + # TODO something lol + pass async def get_tokenizer( self, From b21977881c4e9fc4d2eebaef1085772e0b5c891b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 26 Jul 2024 17:02:10 +0000 Subject: [PATCH 16/80] spurious change --- benchmarks/backend_request_func.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 4d9d148fada7..fbab547d094f 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -252,7 +252,6 @@ async def async_request_openai_completions( try: async with session.post(url=api_url, json=payload, headers=headers) as response: - # breakpoint() if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -282,8 +281,6 @@ async def async_request_openai_completions( most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] - # print(generated_text) - # breakpoint() output.generated_text = generated_text output.success = True From 932ea230a8be94a698c67d8bd53181666aebf252 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 26 Jul 2024 17:03:03 +0000 Subject: [PATCH 17/80] remove spurious change --- vllm/entrypoints/openai/serving_completion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 315c47ab9d1e..987584ac340c 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -176,8 +176,6 @@ async def create_completion(self, request: CompletionRequest, final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: - # async for res in generators[0]: - # i = 0 if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(f"{request_id}-{i}") From f0291140a7910e77ddaab6164bdb43dd360685ba Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 26 Jul 2024 17:03:47 +0000 Subject: [PATCH 18/80] spurious changes --- vllm/entrypoints/openai/serving_completion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 987584ac340c..02ba5a120326 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -317,10 +317,7 @@ async def completion_stream_generator( else: chunk.usage = None - response_json = chunk.model_dump_json(exclude_unset=False) - # print(response_json) - # breakpoint() yield f"data: {response_json}\n\n" if (request.stream_options From 685475815b2ad5ae0d82cb5ba345676656dc9989 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 26 Jul 2024 17:04:51 +0000 Subject: [PATCH 19/80] spurioous change --- benchmarks/benchmark_serving.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index c7b728d6045f..8a3d55a959b2 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -367,7 +367,6 @@ async def benchmark( print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) - print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) From 3b5ff66018c836581e5b57408031c93577fecd4d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 26 Jul 2024 11:13:55 -0600 Subject: [PATCH 20/80] :bug: whoops Signed-off-by: Joe Runde --- vllm/grpc/client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 654b8007b73d..ff55d0af84ff 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -26,6 +26,9 @@ def __init__(self): self.log_requests = False self.engine = None + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(MODEL) + @property def is_running(self) -> bool: return True @@ -47,8 +50,7 @@ async def get_tokenizer( lora_request: Optional[LoRARequest] = None, ) -> "PreTrainedTokenizer": # TODO: what to return :/ - from transformers import AutoTokenizer - return AutoTokenizer.from_pretrained(MODEL) + return self.tokenizer async def generate( self, From 79247c39c790efe567362214ca79839f36f8e034 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 26 Jul 2024 12:28:26 -0600 Subject: [PATCH 21/80] :memo: log stuff Signed-off-by: Joe Runde --- vllm/grpc/client.py | 19 ++++++++++++++++++- vllm/grpc/server.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index ff55d0af84ff..a7f1e9110554 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -13,6 +13,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +import time MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -62,6 +63,9 @@ async def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: + start = time.time() + first = True + prompt: str = inputs.get('prompt', "") prompt_token_ids: List[int] = inputs.get('prompt_token_ids', []) @@ -75,7 +79,17 @@ async def generate( ) ) + ttft = 0 + tpots = [] async for generate_response in generate_stream: + if first: + ttft = time.time() - start + first = False + else: + tpot = time.time() - last + tpots.append(tpot) + last = time.time() + completion_outputs = [ CompletionOutput( index=output.index, @@ -94,4 +108,7 @@ async def generate( finished=(completion_outputs[0].finish_reason != ""), prompt_logprobs=None, prompt=prompt, - ) \ No newline at end of file + ) + + print(f"TTFT: {ttft}") + print(f"TPOT: {sum(tpots)/len(tpots)}") \ No newline at end of file diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 010119901b61..3fbc063a218e 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -5,6 +5,7 @@ from collections.abc import AsyncIterator from grpc import aio import asyncio +import time # from grpc_reflection.v1alpha import reflection # MODEL = "facebook/opt-125m" @@ -26,6 +27,12 @@ async def Generate( # ) -> AsyncIterator[generate_pb2.GenerateResponse]: ) -> AsyncIterator: + start = time.time() + first = True + ttft = 0 + tpots = [] + + if len(request.prompt_inputs.prompt_token_ids) > 0: inputs = TokensPrompt(prompt_token_ids=request.prompt_inputs.prompt_token_ids) else: @@ -38,6 +45,14 @@ async def Generate( request_id=request.request_id) async for request_output in results_generator: + if first: + ttft = time.time() - start + first = False + else: + tpot = time.time() - last + tpots.append(tpot) + last = time.time() + outputs = [ generate_pb2.CompletionOutput( index=output.index, @@ -47,6 +62,9 @@ async def Generate( for output in request_output.outputs ] yield generate_pb2.GenerateResponse(outputs=outputs) + + print(f"TTFT (backend): {ttft}") + print(f"TPOT (backend): {sum(tpots)/len(tpots)}") async def start_grpc_server() -> aio.Server: From a39ebc092cd90f39b1d968d1e8ea78f2e77320c9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 26 Jul 2024 19:22:21 +0000 Subject: [PATCH 22/80] stash --- vllm/grpc/client.py | 2 +- vllm/grpc/server.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index a7f1e9110554..281aa1e44030 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -111,4 +111,4 @@ async def generate( ) print(f"TTFT: {ttft}") - print(f"TPOT: {sum(tpots)/len(tpots)}") \ No newline at end of file + # print(f"TPOT: {sum(tpots)/len(tpots)}") \ No newline at end of file diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 3fbc063a218e..149a87f92195 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -64,7 +64,7 @@ async def Generate( yield generate_pb2.GenerateResponse(outputs=outputs) print(f"TTFT (backend): {ttft}") - print(f"TPOT (backend): {sum(tpots)/len(tpots)}") + # print(f"TPOT (backend): {sum(tpots)/len(tpots)}") async def start_grpc_server() -> aio.Server: @@ -72,11 +72,10 @@ async def start_grpc_server() -> aio.Server: generation = TextGenerationService() generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(generation, server) - service_names = ( - generation.SERVICE_NAME, - # reflection.SERVICE_NAME, - ) - + # service_names = ( + # generation.SERVICE_NAME, + # reflection.SERVICE_NAME, + # ) # reflection.enable_server_reflection(service_names, server) host = "0.0.0.0" @@ -92,8 +91,7 @@ async def run_grpc_server() -> None: server = await start_grpc_server() try: - while True: - await asyncio.sleep(10) + await server.wait_for_termination() except asyncio.CancelledError: print("Gracefully stopping gRPC server") # noqa: T201 @@ -101,4 +99,7 @@ async def run_grpc_server() -> None: await server.wait_for_termination() if __name__ == "__main__": + import uvloop + + # uvloop.run(run_grpc_server()) asyncio.run(run_grpc_server()) From ef257f1047afe68131baad9c80786f89250c5228 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 26 Jul 2024 20:31:55 +0000 Subject: [PATCH 23/80] pushing up --- vllm/entrypoints/openai/api_server.py | 1 + vllm/grpc/client.py | 9 ++++++--- vllm/grpc/server.py | 5 +++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 31c30f866ef2..a34b1d2b7df4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -303,6 +303,7 @@ async def build_server( ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, + # workers=8, **uvicorn_kwargs, ) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 281aa1e44030..39fdcd5d72fb 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -81,6 +81,7 @@ async def generate( ttft = 0 tpots = [] + text = "" async for generate_response in generate_stream: if first: ttft = time.time() - start @@ -89,11 +90,13 @@ async def generate( tpot = time.time() - last tpots.append(tpot) last = time.time() - + text += "test_" completion_outputs = [ CompletionOutput( index=output.index, - text=output.text, + # text=output.text, + # text=self.tokenizer.decode(output.token_ids), + text=text, token_ids=output.token_ids, cumulative_logprob=0.0, logprobs=None, @@ -110,5 +113,5 @@ async def generate( prompt=prompt, ) - print(f"TTFT: {ttft}") + # print(f"TTFT: {ttft}") # print(f"TPOT: {sum(tpots)/len(tpots)}") \ No newline at end of file diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 149a87f92195..cdc023091b1d 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -41,7 +41,8 @@ async def Generate( results_generator = self.engine.generate( inputs, sampling_params=SamplingParams(max_tokens=MAX_TOKENS, - temperature=TEMPERATURE), + temperature=TEMPERATURE, + detokenize=False), request_id=request.request_id) async for request_output in results_generator: @@ -63,7 +64,7 @@ async def Generate( ] yield generate_pb2.GenerateResponse(outputs=outputs) - print(f"TTFT (backend): {ttft}") + # print(f"TTFT (backend): {ttft}") # print(f"TPOT (backend): {sum(tpots)/len(tpots)}") From a6c9bc5f4f3bd30e6adaa9d837d513e107190d7b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 15:26:53 +0000 Subject: [PATCH 24/80] stash --- vllm/entrypoints/openai/api_server.py | 4 +- vllm/entrypoints/openai/serving_completion.py | 3 +- vllm/grpc/client.py | 96 ++++----- vllm/grpc/server.py | 194 +++++++++++------- 4 files changed, 158 insertions(+), 139 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a34b1d2b7df4..e16eaad5aed7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -228,8 +228,8 @@ async def build_server( served_model_names = "meta-llama/Meta-Llama-3-8B-Instruct" - from vllm.grpc.client import TextGenerationClient - engine = TextGenerationClient() + from vllm.grpc.client import RPCClient + engine = RPCClient() # global engine, engine_args diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 02ba5a120326..2f583c600b18 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -285,7 +285,8 @@ async def completion_stream_generator( previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) - finish_reason = None if output.finish_reason == "" else output.finish_reason + # finish_reason = None if output.finish_reason == "" else output.finish_reason + finish_reason = output.finish_reason stop_reason = output.stop_reason chunk = CompletionStreamResponse( diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 39fdcd5d72fb..c200cbfc0ffd 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -2,7 +2,7 @@ from vllm import AsyncLLMEngine import grpc -from vllm.grpc.server import UNIX_SOCKET +# from vllm.grpc.server import UNIX_SOCKET from .pb import generate_pb2_grpc, generate_pb2 from typing import AsyncIterator, List, Optional, Mapping @@ -12,24 +12,35 @@ from vllm.outputs import CompletionOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from transformers import AutoTokenizer +from dataclasses import dataclass import time +import zmq +import zmq.asyncio +import pickle MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" -class TextGenerationClient(AsyncLLMEngine): +@dataclass +class RCPRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + + +class RPCClient(AsyncLLMEngine): def __init__(self): - # channel = grpc.aio.insecure_channel("localhost:5543") - channel = grpc.aio.insecure_channel(UNIX_SOCKET) - self.stub = generate_pb2_grpc.TextGenerationServiceStub(channel) self.engine_use_ray = False self.worker_use_ray = False self.log_requests = False self.engine = None - - from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(MODEL) + self.context = zmq.asyncio.Context() + + @property def is_running(self) -> bool: return True @@ -42,16 +53,16 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return False - def start_background_loop(self): - # TODO something lol - pass - async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> "PreTrainedTokenizer": # TODO: what to return :/ return self.tokenizer + + def start_background_loop(self): + # TODO something lol + pass async def generate( self, @@ -62,56 +73,19 @@ async def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: - - start = time.time() - first = True - - prompt: str = inputs.get('prompt', "") - prompt_token_ids: List[int] = inputs.get('prompt_token_ids', []) - - generate_stream = self.stub.Generate( - generate_pb2.GenerateRequest( - prompt_inputs=generate_pb2.PromptInputs( - prompt=prompt, - prompt_token_ids=prompt_token_ids, - ), - request_id=request_id, - ) - ) - - ttft = 0 - tpots = [] - text = "" - async for generate_response in generate_stream: - if first: - ttft = time.time() - start - first = False - else: - tpot = time.time() - last - tpots.append(tpot) - last = time.time() - text += "test_" - completion_outputs = [ - CompletionOutput( - index=output.index, - # text=output.text, - # text=self.tokenizer.decode(output.token_ids), - text=text, - token_ids=output.token_ids, - cumulative_logprob=0.0, - logprobs=None, - finish_reason=output.finish_reason, - ) for output in generate_response.outputs - ] + socket = self.context.socket(zmq.DEALER) + socket.connect('tcp://localhost:5570') - yield RequestOutput( - request_id=request_id, - prompt_token_ids=[], - outputs=completion_outputs, - finished=(completion_outputs[0].finish_reason != ""), - prompt_logprobs=None, - prompt=prompt, + await socket.send_multipart([ + pickle.dumps( + RCPRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id + ), pickle.HIGHEST_PROTOCOL ) + ]) - # print(f"TTFT: {ttft}") - # print(f"TPOT: {sum(tpots)/len(tpots)}") \ No newline at end of file + while True: + message = await socket.recv() + yield pickle.loads(message) diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index cdc023091b1d..a5f9ffbff283 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -4,6 +4,8 @@ from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams from collections.abc import AsyncIterator from grpc import aio +import zmq +import zmq.asyncio import asyncio import time # from grpc_reflection.v1alpha import reflection @@ -12,95 +14,137 @@ MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" MAX_TOKENS = 150 TEMPERATURE = 0 -UNIX_SOCKET = "unix:///tmp/ricky-bobby" +# UNIX_SOCKET = "unix:///tmp/ricky-bobby" -class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): - SERVICE_NAME = _GENERATION_DESCRIPTOR.services_by_name[ - "TextGenerationService" - ].full_name +# class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): +# SERVICE_NAME = _GENERATION_DESCRIPTOR.services_by_name[ +# "TextGenerationService" +# ].full_name - def __init__(self): - self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(model=MODEL, enforce_eager=True)) +# def __init__(self): +# self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(model=MODEL, enforce_eager=True)) - async def Generate( - self, request: generate_pb2.GenerateRequest, context - # ) -> AsyncIterator[generate_pb2.GenerateResponse]: - ) -> AsyncIterator: +# async def Generate( +# self, request: generate_pb2.GenerateRequest, context +# # ) -> AsyncIterator[generate_pb2.GenerateResponse]: +# ) -> AsyncIterator: - start = time.time() - first = True - ttft = 0 - tpots = [] +# start = time.time() +# first = True +# ttft = 0 +# tpots = [] - if len(request.prompt_inputs.prompt_token_ids) > 0: - inputs = TokensPrompt(prompt_token_ids=request.prompt_inputs.prompt_token_ids) - else: - inputs = TextPrompt(prompt=request.prompt_inputs.prompt) - - results_generator = self.engine.generate( - inputs, - sampling_params=SamplingParams(max_tokens=MAX_TOKENS, - temperature=TEMPERATURE, - detokenize=False), - request_id=request.request_id) +# if len(request.prompt_inputs.prompt_token_ids) > 0: +# inputs = TokensPrompt(prompt_token_ids=request.prompt_inputs.prompt_token_ids) +# else: +# inputs = TextPrompt(prompt=request.prompt_inputs.prompt) + +# results_generator = self.engine.generate( +# inputs, +# sampling_params=SamplingParams(max_tokens=MAX_TOKENS, +# temperature=TEMPERATURE, +# detokenize=False), +# request_id=request.request_id) - async for request_output in results_generator: - if first: - ttft = time.time() - start - first = False - else: - tpot = time.time() - last - tpots.append(tpot) - last = time.time() - - outputs = [ - generate_pb2.CompletionOutput( - index=output.index, - token_ids=output.token_ids, - text=output.text, - finish_reason=output.finish_reason) - for output in request_output.outputs - ] - yield generate_pb2.GenerateResponse(outputs=outputs) - - # print(f"TTFT (backend): {ttft}") - # print(f"TPOT (backend): {sum(tpots)/len(tpots)}") +# async for request_output in results_generator: +# if first: +# ttft = time.time() - start +# first = False +# else: +# tpot = time.time() - last +# tpots.append(tpot) +# last = time.time() + +# outputs = [ +# generate_pb2.CompletionOutput( +# index=output.index, +# token_ids=output.token_ids, +# text=output.text, +# finish_reason=output.finish_reason) +# for output in request_output.outputs +# ] +# yield generate_pb2.GenerateResponse(outputs=outputs) + +# # print(f"TTFT (backend): {ttft}") +# # print(f"TPOT (backend): {sum(tpots)/len(tpots)}") -async def start_grpc_server() -> aio.Server: - server = aio.server() - generation = TextGenerationService() - generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(generation, server) +# async def start_grpc_server() -> aio.Server: +# server = aio.server() +# generation = TextGenerationService() +# generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(generation, server) - # service_names = ( - # generation.SERVICE_NAME, - # reflection.SERVICE_NAME, - # ) - # reflection.enable_server_reflection(service_names, server) +# # service_names = ( +# # generation.SERVICE_NAME, +# # reflection.SERVICE_NAME, +# # ) +# # reflection.enable_server_reflection(service_names, server) - host = "0.0.0.0" - grpc_port = 5543 - # server.add_insecure_port(f"{host}:{grpc_port}") - server.add_insecure_port(UNIX_SOCKET) - await server.start() - print("ready") - return server +# host = "0.0.0.0" +# grpc_port = 5543 +# # server.add_insecure_port(f"{host}:{grpc_port}") +# server.add_insecure_port(UNIX_SOCKET) +# await server.start() +# print("ready") +# return server -async def run_grpc_server() -> None: - server = await start_grpc_server() +# async def run_grpc_server() -> None: +# server = await start_grpc_server() - try: - await server.wait_for_termination() +# try: +# await server.wait_for_termination() - except asyncio.CancelledError: - print("Gracefully stopping gRPC server") # noqa: T201 - await server.stop(30) # TODO configurable grace - await server.wait_for_termination() +# except asyncio.CancelledError: +# print("Gracefully stopping gRPC server") # noqa: T201 +# await server.stop(30) # TODO configurable grace +# await server.wait_for_termination() -if __name__ == "__main__": - import uvloop +import asyncio +import pickle + +class RPCServer: + def __init__(self): + self.context = zmq.asyncio.Context() + self.socket = self.context.socket(zmq.ROUTER) + self.socket.bind('tcp://*:5570') + + self.running_tasks = set() + self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(model=MODEL, + enforce_eager=True)) - # uvloop.run(run_grpc_server()) - asyncio.run(run_grpc_server()) + async def generate(self, identity, message): + request = pickle.loads(message) + results_generator = self.engine.generate( + request.inputs, + sampling_params=request.sampling_params, + request_id=request.request_id) + + async for request_output in results_generator: + self.socket.send_multipart([ + identity, + pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) + ]) + + print("All done!") + + async def run_loop(self): + while True: + identity, message = await self.socket.recv_multipart() + + # Process the request in the background. + task = asyncio.create_task(self.generate(identity=identity, + message=message)) + + # We need to keep around a strong reference to the task, + # to avoid the task disappearing mid-execution as running tasks + # can be GC'ed. Below is a common "fire-and-forget" tasks + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + self.running_tasks.add(task) + task.add_done_callback(self.running_tasks.discard) + + +if __name__ == "__main__": + server = RPCServer() + asyncio.run(server.run_loop()) From d7490bc5e4f00722a6102592d1db940956f8fd87 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 15:34:57 +0000 Subject: [PATCH 25/80] actually working --- vllm/grpc/client.py | 9 ++++++++- vllm/grpc/server.py | 2 -- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index c200cbfc0ffd..651d5147bfcb 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -88,4 +88,11 @@ async def generate( while True: message = await socket.recv() - yield pickle.loads(message) + request_output = pickle.loads(message) + + if request_output.finished: + break + yield request_output + + socket.close() + yield request_output diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index a5f9ffbff283..aa8f261b1583 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -127,8 +127,6 @@ async def generate(self, identity, message): pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) - print("All done!") - async def run_loop(self): while True: identity, message = await self.socket.recv_multipart() From f68fd60c1a1e296dcbf07cf57b865a3d4e858309 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 15:57:06 +0000 Subject: [PATCH 26/80] cleanup --- vllm/grpc/pb/__init__.py | 0 vllm/grpc/pb/generate.proto | 29 ----------------------------- 2 files changed, 29 deletions(-) delete mode 100644 vllm/grpc/pb/__init__.py delete mode 100644 vllm/grpc/pb/generate.proto diff --git a/vllm/grpc/pb/__init__.py b/vllm/grpc/pb/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/grpc/pb/generate.proto b/vllm/grpc/pb/generate.proto deleted file mode 100644 index 282914e7a60f..000000000000 --- a/vllm/grpc/pb/generate.proto +++ /dev/null @@ -1,29 +0,0 @@ -syntax = "proto3"; - -package generate.v1; - -service TextGenerationService { - rpc Generate (GenerateRequest) returns (stream GenerateResponse); -} - - -message GenerateRequest { - PromptInputs prompt_inputs = 1; - string request_id = 2; -} - -message PromptInputs { - string prompt = 1; - repeated uint64 prompt_token_ids = 2; -} - -message GenerateResponse { - repeated CompletionOutput outputs = 1; -} - -message CompletionOutput { - uint64 index = 1; - repeated uint64 token_ids = 2; - string text = 3; - string finish_reason = 4; -} From 38b5b9ca8019dcc6f995ed014fb91d9c5babd3f6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 15:58:13 +0000 Subject: [PATCH 27/80] more cleanup --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index fffd4683e3dc..17184b19127c 100644 --- a/.gitignore +++ b/.gitignore @@ -190,6 +190,3 @@ hip_compat.h # Benchmark dataset *.json - -# Protobuf -pb2 From bc54311db636f7d089040e4747100edd6fc6c578 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 15:58:39 +0000 Subject: [PATCH 28/80] cleanup --- requirements-common.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index b3d4736f282f..3b8d473c1fe7 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -22,5 +22,3 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq -grpcio -grpcio-tools From 3cccebb724adbed0a3673d7df77b618253b40772 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 16:00:54 +0000 Subject: [PATCH 29/80] stash --- setup.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/setup.py b/setup.py index ec1feb7b76b0..816ab3f93bbd 100644 --- a/setup.py +++ b/setup.py @@ -9,12 +9,6 @@ from shutil import which from typing import Dict, List -from shlex import split -from subprocess import CalledProcessError, check_call -from textwrap import dedent -from setuptools.command.build_py import build_py -from setuptools.errors import SetupError - import torch from packaging.version import Version, parse from setuptools import Extension, find_packages, setup @@ -34,38 +28,6 @@ def load_module_from_path(module_name, path): logger = logging.getLogger(__name__) -class BuildPyAndGenerateGrpc(build_py): - """build python module using protoc to prepare generated files.""" - - proto_source = "vllm/grpc/pb/generate.proto" - - def run(self): - print(f"Invoking protoc on {self.proto_source}") - - # NOTE: imports in generated files will be broken unless some care is given in - # how --proto_path, --*_out and .proto paths are given. - # - # See https://github.com/grpc/grpc/issues/9575#issuecomment-293934506 - try: - check_call( - split( - dedent( - f""" - python -m grpc_tools.protoc \ - --proto_path=. \ - --python_out=. \ - --grpc_python_out=. \ - {self.proto_source} - """, - ), - ) - ) - except CalledProcessError as exc: - raise SetupError(f"protoc failed, exit code {exc.returncode}") from exc - - super().run() - - def embed_commit_hash(): try: if "BUILDKITE_COMMIT" in os.environ: From 4b78e299e6d28469ef99299ea8d65734949ed1c2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 16:01:41 +0000 Subject: [PATCH 30/80] more cleanup --- vllm/entrypoints/openai/api_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e16eaad5aed7..99af28d554cb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -303,7 +303,6 @@ async def build_server( ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, - # workers=8, **uvicorn_kwargs, ) From 345bfdd3dce8d1db0e359437d97ca019b591d011 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 16:02:50 +0000 Subject: [PATCH 31/80] setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 816ab3f93bbd..72ef26f15e40 100644 --- a/setup.py +++ b/setup.py @@ -486,7 +486,7 @@ def _read_requirements(filename: str) -> List[str]: extras_require={ "tensorizer": ["tensorizer>=2.9.0"], }, - cmdclass={"build_py": BuildPyAndGenerateGrpc, "build_ext": cmake_build_ext,} if _build_custom_ops() else {"build_py": BuildPyAndGenerateGrpc}, + cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, package_data=package_data, entry_points={ "console_scripts": [ From cfbb0015ebe74263c4b9dbd523f3e2447ff4a0d5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 16:04:29 +0000 Subject: [PATCH 32/80] cleanup --- vllm/entrypoints/openai/serving_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 2f583c600b18..91515e8d46d9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -152,7 +152,6 @@ async def create_completion(self, request: CompletionRequest, result_generator: AsyncIterator[Tuple[ int, RequestOutput]] = merge_async_iterators(*generators) - # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use @@ -190,6 +189,7 @@ async def create_completion(self, request: CompletionRequest, # with the inputs token IDs if final_res.prompt is None: final_res.prompt = prompts[i]["prompt"] + final_res_batch_checked = cast(List[RequestOutput], final_res_batch) From d811b42b7b26425cd18e081694f191b281b750fe Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 16:04:52 +0000 Subject: [PATCH 33/80] format --- vllm/entrypoints/openai/serving_completion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 91515e8d46d9..78c0539d6eb6 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -236,6 +236,7 @@ async def completion_stream_generator( try: async for prompt_idx, res in result_generator: + # Abort the request if the client disconnects. if await raw_request.is_disconnected(): await self.engine.abort(f"{request_id}-{prompt_idx}") From 852534ebe0c1376b6178c7fba6fec894c2f34248 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 16:06:49 +0000 Subject: [PATCH 34/80] cleaning up --- vllm/grpc/server.py | 107 +++----------------------------------------- 1 file changed, 5 insertions(+), 102 deletions(-) diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index aa8f261b1583..9064891e9755 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -1,108 +1,11 @@ -from vllm.inputs.data import TextPrompt, TokensPrompt -from .pb import generate_pb2_grpc, generate_pb2 -from .pb.generate_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR -from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams -from collections.abc import AsyncIterator -from grpc import aio +from vllm import AsyncEngineArgs, AsyncLLMEngine +import asyncio +import pickle import zmq import zmq.asyncio -import asyncio -import time -# from grpc_reflection.v1alpha import reflection -# MODEL = "facebook/opt-125m" MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" -MAX_TOKENS = 150 -TEMPERATURE = 0 -# UNIX_SOCKET = "unix:///tmp/ricky-bobby" - -# class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): -# SERVICE_NAME = _GENERATION_DESCRIPTOR.services_by_name[ -# "TextGenerationService" -# ].full_name - -# def __init__(self): -# self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(model=MODEL, enforce_eager=True)) - -# async def Generate( -# self, request: generate_pb2.GenerateRequest, context -# # ) -> AsyncIterator[generate_pb2.GenerateResponse]: -# ) -> AsyncIterator: - -# start = time.time() -# first = True -# ttft = 0 -# tpots = [] - - -# if len(request.prompt_inputs.prompt_token_ids) > 0: -# inputs = TokensPrompt(prompt_token_ids=request.prompt_inputs.prompt_token_ids) -# else: -# inputs = TextPrompt(prompt=request.prompt_inputs.prompt) - -# results_generator = self.engine.generate( -# inputs, -# sampling_params=SamplingParams(max_tokens=MAX_TOKENS, -# temperature=TEMPERATURE, -# detokenize=False), -# request_id=request.request_id) - -# async for request_output in results_generator: -# if first: -# ttft = time.time() - start -# first = False -# else: -# tpot = time.time() - last -# tpots.append(tpot) -# last = time.time() -# outputs = [ -# generate_pb2.CompletionOutput( -# index=output.index, -# token_ids=output.token_ids, -# text=output.text, -# finish_reason=output.finish_reason) -# for output in request_output.outputs -# ] -# yield generate_pb2.GenerateResponse(outputs=outputs) - -# # print(f"TTFT (backend): {ttft}") -# # print(f"TPOT (backend): {sum(tpots)/len(tpots)}") - - -# async def start_grpc_server() -> aio.Server: -# server = aio.server() -# generation = TextGenerationService() -# generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(generation, server) - -# # service_names = ( -# # generation.SERVICE_NAME, -# # reflection.SERVICE_NAME, -# # ) -# # reflection.enable_server_reflection(service_names, server) - -# host = "0.0.0.0" -# grpc_port = 5543 -# # server.add_insecure_port(f"{host}:{grpc_port}") -# server.add_insecure_port(UNIX_SOCKET) -# await server.start() -# print("ready") -# return server - - -# async def run_grpc_server() -> None: -# server = await start_grpc_server() - -# try: -# await server.wait_for_termination() - -# except asyncio.CancelledError: -# print("Gracefully stopping gRPC server") # noqa: T201 -# await server.stop(30) # TODO configurable grace -# await server.wait_for_termination() - -import asyncio -import pickle class RPCServer: def __init__(self): @@ -111,8 +14,8 @@ def __init__(self): self.socket.bind('tcp://*:5570') self.running_tasks = set() - self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(model=MODEL, - enforce_eager=True)) + self.engine = AsyncLLMEngine.from_engine_args( + AsyncEngineArgs(model=MODEL)) async def generate(self, identity, message): request = pickle.loads(message) From e42be96fc9f03b09a89d19e0a57ee936bf8cdc3a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 23:05:18 +0000 Subject: [PATCH 35/80] zlib --- vllm/grpc/client.py | 20 +++++--------------- vllm/grpc/server.py | 9 +++++---- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 651d5147bfcb..44105ec1ac37 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -1,22 +1,14 @@ -import asyncio from vllm import AsyncLLMEngine -import grpc - -# from vllm.grpc.server import UNIX_SOCKET -from .pb import generate_pb2_grpc, generate_pb2 -from typing import AsyncIterator, List, Optional, Mapping +from typing import AsyncIterator, Optional, Mapping from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput -from vllm.outputs import CompletionOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from transformers import AutoTokenizer from dataclasses import dataclass - -import time -import zmq +import zmq, zlib import zmq.asyncio import pickle @@ -35,9 +27,7 @@ def __init__(self): self.worker_use_ray = False self.log_requests = False self.engine = None - self.tokenizer = AutoTokenizer.from_pretrained(MODEL) - self.context = zmq.asyncio.Context() @@ -77,18 +67,18 @@ async def generate( socket.connect('tcp://localhost:5570') await socket.send_multipart([ - pickle.dumps( + zlib.compress(pickle.dumps( RCPRequest( inputs=inputs, sampling_params=sampling_params, request_id=request_id ), pickle.HIGHEST_PROTOCOL - ) + )) ]) while True: message = await socket.recv() - request_output = pickle.loads(message) + request_output = pickle.loads(zlib.decompress(message)) if request_output.finished: break diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 9064891e9755..8699e419ff8f 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -1,7 +1,7 @@ from vllm import AsyncEngineArgs, AsyncLLMEngine import asyncio import pickle -import zmq +import zmq, zlib import zmq.asyncio MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -15,10 +15,10 @@ def __init__(self): self.running_tasks = set() self.engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model=MODEL)) + AsyncEngineArgs(model=MODEL, enable_chunked_prefill=True)) async def generate(self, identity, message): - request = pickle.loads(message) + request = pickle.loads(zlib.decompress(message)) results_generator = self.engine.generate( request.inputs, sampling_params=request.sampling_params, @@ -27,7 +27,8 @@ async def generate(self, identity, message): async for request_output in results_generator: self.socket.send_multipart([ identity, - pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) + zlib.compress( + pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL)) ]) async def run_loop(self): From 5202a596e48180a4b94f9c62053b09957a34d341 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 23:05:29 +0000 Subject: [PATCH 36/80] Revert "zlib" This reverts commit e42be96fc9f03b09a89d19e0a57ee936bf8cdc3a. --- vllm/grpc/client.py | 20 +++++++++++++++----- vllm/grpc/server.py | 9 ++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 44105ec1ac37..651d5147bfcb 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -1,14 +1,22 @@ +import asyncio from vllm import AsyncLLMEngine -from typing import AsyncIterator, Optional, Mapping +import grpc + +# from vllm.grpc.server import UNIX_SOCKET +from .pb import generate_pb2_grpc, generate_pb2 +from typing import AsyncIterator, List, Optional, Mapping from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from transformers import AutoTokenizer from dataclasses import dataclass -import zmq, zlib + +import time +import zmq import zmq.asyncio import pickle @@ -27,7 +35,9 @@ def __init__(self): self.worker_use_ray = False self.log_requests = False self.engine = None + self.tokenizer = AutoTokenizer.from_pretrained(MODEL) + self.context = zmq.asyncio.Context() @@ -67,18 +77,18 @@ async def generate( socket.connect('tcp://localhost:5570') await socket.send_multipart([ - zlib.compress(pickle.dumps( + pickle.dumps( RCPRequest( inputs=inputs, sampling_params=sampling_params, request_id=request_id ), pickle.HIGHEST_PROTOCOL - )) + ) ]) while True: message = await socket.recv() - request_output = pickle.loads(zlib.decompress(message)) + request_output = pickle.loads(message) if request_output.finished: break diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 8699e419ff8f..9064891e9755 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -1,7 +1,7 @@ from vllm import AsyncEngineArgs, AsyncLLMEngine import asyncio import pickle -import zmq, zlib +import zmq import zmq.asyncio MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -15,10 +15,10 @@ def __init__(self): self.running_tasks = set() self.engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model=MODEL, enable_chunked_prefill=True)) + AsyncEngineArgs(model=MODEL)) async def generate(self, identity, message): - request = pickle.loads(zlib.decompress(message)) + request = pickle.loads(message) results_generator = self.engine.generate( request.inputs, sampling_params=request.sampling_params, @@ -27,8 +27,7 @@ async def generate(self, identity, message): async for request_output in results_generator: self.socket.send_multipart([ identity, - zlib.compress( - pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL)) + pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) async def run_loop(self): From 71b1bf92ad03f39bc0669181e1f4acd5104f8e84 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 28 Jul 2024 23:33:37 +0000 Subject: [PATCH 37/80] turn on chunked prefill --- vllm/grpc/client.py | 8 ++++---- vllm/grpc/server.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/grpc/client.py b/vllm/grpc/client.py index 651d5147bfcb..0952602aa113 100644 --- a/vllm/grpc/client.py +++ b/vllm/grpc/client.py @@ -1,9 +1,7 @@ -import asyncio from vllm import AsyncLLMEngine -import grpc # from vllm.grpc.server import UNIX_SOCKET -from .pb import generate_pb2_grpc, generate_pb2 +# from .pb import generate_pb2_grpc, generate_pb2 from typing import AsyncIterator, List, Optional, Mapping from vllm.inputs import PromptInputs @@ -21,6 +19,8 @@ import pickle MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" +ADDRESS = "ipc:///tmp/zmqtest" +# ADDRESS = "tcp://localhost:5570" @dataclass class RCPRequest: @@ -74,7 +74,7 @@ async def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: socket = self.context.socket(zmq.DEALER) - socket.connect('tcp://localhost:5570') + socket.connect(ADDRESS) await socket.send_multipart([ pickle.dumps( diff --git a/vllm/grpc/server.py b/vllm/grpc/server.py index 9064891e9755..4fe5053dccaf 100644 --- a/vllm/grpc/server.py +++ b/vllm/grpc/server.py @@ -4,18 +4,18 @@ import zmq import zmq.asyncio -MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" - +from .client import MODEL, ADDRESS class RPCServer: def __init__(self): self.context = zmq.asyncio.Context() self.socket = self.context.socket(zmq.ROUTER) - self.socket.bind('tcp://*:5570') + self.socket.bind(ADDRESS) self.running_tasks = set() self.engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model=MODEL)) + AsyncEngineArgs(model=MODEL, + enable_chunked_prefill=True)) async def generate(self, identity, message): request = pickle.loads(message) From a49907925859d32f9bdc687c26eff95312ececa8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 02:48:16 +0000 Subject: [PATCH 38/80] move RPC code into oai server --- vllm/{grpc => entrypoints/openai/rpc}/__init__.py | 0 vllm/{grpc => entrypoints/openai/rpc}/client.py | 8 +------- vllm/{grpc => entrypoints/openai/rpc}/server.py | 0 3 files changed, 1 insertion(+), 7 deletions(-) rename vllm/{grpc => entrypoints/openai/rpc}/__init__.py (100%) rename vllm/{grpc => entrypoints/openai/rpc}/client.py (90%) rename vllm/{grpc => entrypoints/openai/rpc}/server.py (100%) diff --git a/vllm/grpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py similarity index 100% rename from vllm/grpc/__init__.py rename to vllm/entrypoints/openai/rpc/__init__.py diff --git a/vllm/grpc/client.py b/vllm/entrypoints/openai/rpc/client.py similarity index 90% rename from vllm/grpc/client.py rename to vllm/entrypoints/openai/rpc/client.py index 0952602aa113..fc9d1728f30f 100644 --- a/vllm/grpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,26 +1,20 @@ from vllm import AsyncLLMEngine - -# from vllm.grpc.server import UNIX_SOCKET -# from .pb import generate_pb2_grpc, generate_pb2 -from typing import AsyncIterator, List, Optional, Mapping +from typing import AsyncIterator, Optional, Mapping from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput -from vllm.outputs import CompletionOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from transformers import AutoTokenizer from dataclasses import dataclass -import time import zmq import zmq.asyncio import pickle MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" ADDRESS = "ipc:///tmp/zmqtest" -# ADDRESS = "tcp://localhost:5570" @dataclass class RCPRequest: diff --git a/vllm/grpc/server.py b/vllm/entrypoints/openai/rpc/server.py similarity index 100% rename from vllm/grpc/server.py rename to vllm/entrypoints/openai/rpc/server.py From 88a1d089586280a42da3badb01c93bc8a055a397 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:08:51 +0000 Subject: [PATCH 39/80] format --- vllm/entrypoints/openai/api_server.py | 46 +++++----- vllm/entrypoints/openai/rpc/__init__.py | 25 ++++++ vllm/entrypoints/openai/rpc/client.py | 68 +++++---------- vllm/entrypoints/openai/rpc/server.py | 84 +++++++++++++++---- vllm/entrypoints/openai/serving_completion.py | 4 +- 5 files changed, 138 insertions(+), 89 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 99af28d554cb..cde40b6d08b9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,6 +4,7 @@ import re import signal from contextlib import asynccontextmanager +from multiprocessing import Process from http import HTTPStatus from typing import Optional, Set @@ -37,8 +38,9 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.rpc.client import RPCClient +from vllm.entrypoints.openai.rpc.server import run_rpc_server from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION @@ -216,29 +218,19 @@ async def authentication(request: Request, call_next): async def build_server( args, - llm_engine: Optional[AsyncLLMEngine] = None, **uvicorn_kwargs, ) -> uvicorn.Server: app = build_app(args) - # if args.served_model_name is not None: - # served_model_names = args.served_model_name - # else: - # served_model_names = [args.model] - - served_model_names = "meta-llama/Meta-Llama-3-8B-Instruct" - - from vllm.grpc.client import RPCClient - engine = RPCClient() - - # global engine, engine_args - - # engine_args = AsyncEngineArgs.from_cli_args(args) - # engine = (llm_engine - # if llm_engine is not None else AsyncLLMEngine.from_engine_args( - # engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) - - # model_config = await engine.get_model_config() + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + print("HERE") + rpc_client = RPCClient() + model_config = await rpc_client.get_model_config() + print("HERE2") if args.disable_log_requests: request_logger = None @@ -309,13 +301,17 @@ async def build_server( return uvicorn.Server(config) -async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: +async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - + + logger.info("Starting RPC Server") + rpc_server_process = Process(target=run_rpc_server, + args=(AsyncEngineArgs.from_cli_args(args),)) + rpc_server_process.start() + server = await build_server( args, - llm_engine, **uvicorn_kwargs, ) @@ -332,10 +328,12 @@ def signal_handler() -> None: try: await server_task + rpc_server_process.join() except asyncio.CancelledError: print("Gracefully stopping http server") await server.shutdown() - + rpc_server_process.join() + if __name__ == "__main__": # NOTE(simon): diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index e69de29bb2d1..f6f129bef2a3 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import Optional, Mapping +from enum import Enum + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_GENERATE_RPC_PATH = "tcp://localhost:5570" +VLLM_GET_DATA_RPC_PATH = "tcp://localhost:5571" +VLLM_IS_READY_RPC_PATH = "tcp://localhost:5572" + +@dataclass +class GenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +class GetDataRequest(Enum): + MODEL_CONFIG = 1 \ No newline at end of file diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index fc9d1728f30f..9cd36b86ee88 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,62 +1,31 @@ -from vllm import AsyncLLMEngine from typing import AsyncIterator, Optional, Mapping +from vllm.config import ModelConfig from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from transformers import AutoTokenizer -from dataclasses import dataclass +from vllm.entrypoints.openai.rpc import ( + VLLM_GENERATE_RPC_PATH, VLLM_GET_DATA_RPC_PATH, GenerateRequest, GetDataRequest) import zmq import zmq.asyncio import pickle -MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" -ADDRESS = "ipc:///tmp/zmqtest" -@dataclass -class RCPRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - - -class RPCClient(AsyncLLMEngine): +class RPCClient: def __init__(self): - self.engine_use_ray = False - self.worker_use_ray = False - self.log_requests = False - self.engine = None - - self.tokenizer = AutoTokenizer.from_pretrained(MODEL) - self.context = zmq.asyncio.Context() - - - @property - def is_running(self) -> bool: - return True + self.is_ready_socket = self.context.socket(zmq.REP) + self.get_data_socket = self.context.socket(zmq.REQ) + self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) - @property - def is_stopped(self) -> bool: - return False - @property - def errored(self) -> bool: - return False - - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> "PreTrainedTokenizer": - # TODO: what to return :/ - return self.tokenizer - - def start_background_loop(self): - # TODO something lol - pass + async def get_model_config(self) -> ModelConfig: + self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG)) + return pickle.loads(await self.get_data_socket.recv()) + async def generate( self, @@ -67,19 +36,28 @@ async def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: + + # Connect to RPC socket for Request-Reply pattern, + # Note that we use DEALER to enable asynchronous communication + # to enable streaming. socket = self.context.socket(zmq.DEALER) - socket.connect(ADDRESS) + socket.connect(VLLM_GENERATE_RPC_PATH) + # Send GenerateRequest to the RPC Server. await socket.send_multipart([ pickle.dumps( - RCPRequest( + GenerateRequest( inputs=inputs, sampling_params=sampling_params, - request_id=request_id + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request ), pickle.HIGHEST_PROTOCOL ) ]) + # Stream back the results from the RPC Server. while True: message = await socket.recv() request_output = pickle.loads(message) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 4fe5053dccaf..8f63f02ecb8a 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,21 +1,52 @@ -from vllm import AsyncEngineArgs, AsyncLLMEngine import asyncio import pickle import zmq import zmq.asyncio -from .client import MODEL, ADDRESS +from vllm import AsyncLLMEngine +from vllm.usage.usage_lib import UsageContext +from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, + VLLM_GET_DATA_RPC_PATH, + VLLM_IS_READY_RPC_PATH, + GetDataRequest) class RPCServer: - def __init__(self): + def __init__(self, async_engine_args): + # Initialize engine first. + self.engine = AsyncLLMEngine.from_engine_args( + async_engine_args, UsageContext.OPENAI_API_SERVER) + + # Initialize context. self.context = zmq.asyncio.Context() - self.socket = self.context.socket(zmq.ROUTER) - self.socket.bind(ADDRESS) + + # Init socket for readiness state. + self.is_ready_socket = self.context.socket(zmq.REP) + self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH) + + # Init socket for generation. + self.generate_socket = self.context.socket(zmq.ROUTER) + self.generate_socket.bind(VLLM_GENERATE_RPC_PATH) + + # TODO (robertgshaw2-neuralmagic): + # add socket for generation without streaming + + # Init socket for simple data requests. + self.get_data_socket = self.context.socket(zmq.REP) + self.get_data_socket.bind(VLLM_GET_DATA_RPC_PATH) + + # Setup polling so we can listen on both sockets. + self.poller = zmq.asyncio.Poller() + self.poller.register(self.generate_socket, zmq.POLLIN) + self.poller.register(self.get_data_socket, zmq.POLLIN) + + + async def get_data(self, message): + request_type = pickle.loads(message) + if request_type == GetDataRequest.MODEL_CONFIG: + return await self.engine.get_model_config() + else: + raise ValueError(f"Unknown request type: {request_type}") - self.running_tasks = set() - self.engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model=MODEL, - enable_chunked_prefill=True)) async def generate(self, identity, message): request = pickle.loads(message) @@ -29,23 +60,40 @@ async def generate(self, identity, message): identity, pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) - + async def run_loop(self): + # Notify the RPC client that we are ready to recieve requests. + await self.is_ready_socket.send_string("Ready!") + self.is_ready_socket.close() + + # Avoid GC of running tasks. + running_tasks = set() while True: - identity, message = await self.socket.recv_multipart() + try: + socks = dict(await self.poller.poll()) + except KeyboardInterrupt: + # TODO: should there be some other exception here? + break - # Process the request in the background. - task = asyncio.create_task(self.generate(identity=identity, - message=message)) + task = None + if self.generate_socket in socks: + identity, message = await self.generate_socket.recv_multipart() + task = asyncio.create_task(self.generate(identity, message)) + + elif self.get_data_socket in socks: + message = await self.get_data_socket.recv() + task = asyncio.create_task(self.get_data(message)) # We need to keep around a strong reference to the task, # to avoid the task disappearing mid-execution as running tasks # can be GC'ed. Below is a common "fire-and-forget" tasks # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - self.running_tasks.add(task) - task.add_done_callback(self.running_tasks.discard) + if task is not None: + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) + # TODO: Do I need to close the generate / get_data sockets? -if __name__ == "__main__": - server = RPCServer() +def run_rpc_server(async_engine_args): + server = RPCServer(async_engine_args=async_engine_args) asyncio.run(server.run_loop()) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 78c0539d6eb6..40ae4bbc8087 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -45,7 +45,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, engine: AsyncLLMEngine, - # model_config: ModelConfig, + model_config: ModelConfig, served_model_names: List[str], *, lora_modules: Optional[List[LoRAModulePath]], @@ -54,7 +54,7 @@ def __init__( return_tokens_as_token_ids: bool = False, ): super().__init__(engine=engine, - # model_config=model_config, + model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, prompt_adapters=prompt_adapters, From 13ce2f1c888c98cf3627fdb993496b9fa6ef6960 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:14:36 +0000 Subject: [PATCH 40/80] format --- vllm/entrypoints/openai/api_server.py | 10 +++++----- vllm/entrypoints/openai/rpc/client.py | 14 ++++++++++---- vllm/entrypoints/openai/rpc/server.py | 2 +- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cde40b6d08b9..512d729d6204 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -227,10 +227,10 @@ async def build_server( else: served_model_names = [args.model] - print("HERE") rpc_client = RPCClient() - model_config = await rpc_client.get_model_config() - print("HERE2") + rpc_client.wait_for_server() + logger.info("RPC Client connected to RPC server.") + model_config = rpc_client.get_model_config() if args.disable_log_requests: request_logger = None @@ -255,7 +255,7 @@ async def build_server( # ) openai_serving_completion = OpenAIServingCompletion( engine, - # model_config, + model_config, served_model_names, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, @@ -305,7 +305,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - logger.info("Starting RPC Server") + logger.info("Starting RPC Server.") rpc_server_process = Process(target=run_rpc_server, args=(AsyncEngineArgs.from_cli_args(args),)) rpc_server_process.start() diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 9cd36b86ee88..34f7c2226bb9 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -6,8 +6,10 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.entrypoints.openai.rpc import ( - VLLM_GENERATE_RPC_PATH, VLLM_GET_DATA_RPC_PATH, GenerateRequest, GetDataRequest) +from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, + VLLM_GET_DATA_RPC_PATH, + VLLM_IS_READY_RPC_PATH, + GenerateRequest, GetDataRequest) import zmq import zmq.asyncio @@ -17,10 +19,14 @@ class RPCClient: def __init__(self): self.context = zmq.asyncio.Context() - self.is_ready_socket = self.context.socket(zmq.REP) + self.is_ready_socket = self.context.socket(zmq.PULL) + self.is_ready_socket.connect(VLLM_GET_DATA_RPC_PATH) self.get_data_socket = self.context.socket(zmq.REQ) self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) - + + async def wait_for_server(self): + await self.is_ready_socket.recv() + async def get_model_config(self) -> ModelConfig: self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG)) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 8f63f02ecb8a..e414452319bd 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -20,7 +20,7 @@ def __init__(self, async_engine_args): self.context = zmq.asyncio.Context() # Init socket for readiness state. - self.is_ready_socket = self.context.socket(zmq.REP) + self.is_ready_socket = self.context.socket(zmq.PUSH) self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH) # Init socket for generation. From bb8ac060c5c9ff1667d426e2175a85f226ec4739 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:33:13 +0000 Subject: [PATCH 41/80] trying to flow it through --- vllm/entrypoints/openai/api_server.py | 4 +- vllm/entrypoints/openai/rpc/client.py | 14 +++-- vllm/entrypoints/openai/rpc/server.py | 17 ++++-- vllm/entrypoints/openai/serving_completion.py | 53 ++++++++++--------- vllm/entrypoints/openai/serving_engine.py | 14 +++-- 5 files changed, 58 insertions(+), 44 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 512d729d6204..0b84b7c9c154 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -228,9 +228,9 @@ async def build_server( served_model_names = [args.model] rpc_client = RPCClient() - rpc_client.wait_for_server() + await rpc_client.wait_for_server() logger.info("RPC Client connected to RPC server.") - model_config = rpc_client.get_model_config() + model_config = await rpc_client.get_model_config() if args.disable_log_requests: request_logger = None diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 34f7c2226bb9..a5cd97e83998 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -19,18 +19,26 @@ class RPCClient: def __init__(self): self.context = zmq.asyncio.Context() - self.is_ready_socket = self.context.socket(zmq.PULL) - self.is_ready_socket.connect(VLLM_GET_DATA_RPC_PATH) + + # TODO: check if opening all these is an antipattern? + + # Socket to check if the RPC server is ready. + self.is_ready_socket = self.context.socket(zmq.REP) + self.is_ready_socket.connect(VLLM_IS_READY_RPC_PATH) + + # Socket to query data (e.g. get_model_config) self.get_data_socket = self.context.socket(zmq.REQ) self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) + async def wait_for_server(self): await self.is_ready_socket.recv() async def get_model_config(self) -> ModelConfig: self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG)) - return pickle.loads(await self.get_data_socket.recv()) + model_config = await self.get_data_socket.recv() + return pickle.loads(model_config) async def generate( diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index e414452319bd..767d63a021b1 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -20,15 +20,14 @@ def __init__(self, async_engine_args): self.context = zmq.asyncio.Context() # Init socket for readiness state. - self.is_ready_socket = self.context.socket(zmq.PUSH) + self.is_ready_socket = self.context.socket(zmq.REQ) self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH) # Init socket for generation. self.generate_socket = self.context.socket(zmq.ROUTER) self.generate_socket.bind(VLLM_GENERATE_RPC_PATH) - # TODO (robertgshaw2-neuralmagic): - # add socket for generation without streaming + # TODO: add socket for generation without streaming # Init socket for simple data requests. self.get_data_socket = self.context.socket(zmq.REP) @@ -42,25 +41,32 @@ def __init__(self, async_engine_args): async def get_data(self, message): request_type = pickle.loads(message) + if request_type == GetDataRequest.MODEL_CONFIG: - return await self.engine.get_model_config() + data = await self.engine.get_model_config() else: raise ValueError(f"Unknown request type: {request_type}") + + await self.get_data_socket.send_multipart([ + pickle.dumps(data, pickle.HIGHEST_PROTOCOL) + ]) async def generate(self, identity, message): request = pickle.loads(message) + results_generator = self.engine.generate( request.inputs, sampling_params=request.sampling_params, request_id=request.request_id) async for request_output in results_generator: - self.socket.send_multipart([ + self.generate_socket.send_multipart([ identity, pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) + async def run_loop(self): # Notify the RPC client that we are ready to recieve requests. await self.is_ready_socket.send_string("Ready!") @@ -94,6 +100,7 @@ async def run_loop(self): # TODO: Do I need to close the generate / get_data sockets? + def run_rpc_server(async_engine_args): server = RPCServer(async_engine_args=async_engine_args) asyncio.run(server.run_loop()) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 40ae4bbc8087..1d3f20339b2e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -20,6 +20,7 @@ CompletionStreamResponse, UsageInfo) # yapf: enable +from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) @@ -44,7 +45,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + rpc_client: RPCClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -53,7 +54,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(rpc_client=rpc_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -93,21 +94,21 @@ async def create_completion(self, request: CompletionRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.rpc_client.get_tokenizer(lora_request) sampling_params = request.to_sampling_params() - # decoding_config = await self.engine.get_decoding_config() - # guided_decoding_backend = request.guided_decoding_backend \ - # or decoding_config.guided_decoding_backend - # guided_decode_logit_processor = ( - # await - # get_guided_decoding_logits_processor(guided_decoding_backend, - # request, tokenizer)) - # if guided_decode_logit_processor is not None: - # if sampling_params.logits_processors is None: - # sampling_params.logits_processors = [] - # sampling_params.logits_processors.append( - # guided_decode_logit_processor) + decoding_config = await self.rpc_client.get_decoding_config() + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend + guided_decode_logit_processor = ( + await + get_guided_decoding_logits_processor(guided_decoding_backend, + request, tokenizer)) + if guided_decode_logit_processor is not None: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append( + guided_decode_logit_processor) prompts = list( self._tokenize_prompt_input_or_inputs( @@ -128,21 +129,21 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - # is_tracing_enabled = await self.engine.is_tracing_enabled() - # trace_headers = None - # if is_tracing_enabled: - # trace_headers = extract_trace_headers(raw_request.headers) - # if not is_tracing_enabled and contains_trace_headers( - # raw_request.headers): - # log_tracing_disabled_warning() + is_tracing_enabled = await self.rcp_client.is_tracing_enabled() + trace_headers = None + if is_tracing_enabled: + trace_headers = extract_trace_headers(raw_request.headers) + if not is_tracing_enabled and contains_trace_headers( + raw_request.headers): + log_tracing_disabled_warning() - generator = self.engine.generate( + generator = self.rpc_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - # trace_headers=trace_headers, + trace_headers=trace_headers, ) generators.append(generator) @@ -177,7 +178,7 @@ async def create_completion(self, request: CompletionRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.rcp_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res @@ -239,7 +240,7 @@ async def completion_stream_generator( # Abort the request if the client disconnects. if await raw_request.is_disconnected(): - await self.engine.abort(f"{request_id}-{prompt_idx}") + await self.rpc_client.abort(f"{request_id}-{prompt_idx}") raise StopAsyncIteration() for output in res.outputs: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ebe4b787f98b..98ac152069f2 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -9,10 +9,10 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable +from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, @@ -61,8 +61,8 @@ class OpenAIServing: def __init__( self, - engine: AsyncLLMEngine, - # model_config: ModelConfig, + rpc_client: RPCClient, + model_config: ModelConfig, served_model_names: List[str], *, lora_modules: Optional[List[LoRAModulePath]], @@ -72,11 +72,9 @@ def __init__( ): super().__init__() - self.engine = engine - # self.model_config = model_config - # self.max_model_len = model_config.max_model_len - self.max_model_len = 4096 - + self.rcp_client = rpc_client + self.model_config = model_config + self.max_model_len = model_config.max_model_len self.served_model_names = served_model_names self.lora_requests = [] From 6ebdb3db5b7f4102022fec3070ddba5d78d4ed41 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:34:32 +0000 Subject: [PATCH 42/80] cleaning --- vllm/entrypoints/openai/serving_completion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1d3f20339b2e..e3a1dc56e2d0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -287,7 +287,6 @@ async def completion_stream_generator( previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) - # finish_reason = None if output.finish_reason == "" else output.finish_reason finish_reason = output.finish_reason stop_reason = output.stop_reason From 24c8100291e631dae36f9db5e22ef3deb7294ef5 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:34:45 +0000 Subject: [PATCH 43/80] cleaning --- vllm/entrypoints/openai/serving_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 98ac152069f2..6b223834f1fa 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -75,6 +75,7 @@ def __init__( self.rcp_client = rpc_client self.model_config = model_config self.max_model_len = model_config.max_model_len + self.served_model_names = served_model_names self.lora_requests = [] From e7070490734b5489d5b8aa3c130d8db3db9b3951 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:35:59 +0000 Subject: [PATCH 44/80] cleaning --- vllm/entrypoints/openai/rpc/client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index a5cd97e83998..2da5f5c086c1 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -30,7 +30,7 @@ def __init__(self): self.get_data_socket = self.context.socket(zmq.REQ) self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) - + async def wait_for_server(self): await self.is_ready_socket.recv() @@ -41,6 +41,12 @@ async def get_model_config(self) -> ModelConfig: return pickle.loads(model_config) + async def abort(self, request_id: str): + pass + + async def get_tokenizer(self, lora_request: LoRARequest): + pass + async def generate( self, inputs: PromptInputs, From baaf6bc1383061074237a17a3ca567d169f73c51 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:37:59 +0000 Subject: [PATCH 45/80] add stubs --- vllm/entrypoints/openai/rpc/client.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 2da5f5c086c1..d5b64c070b4e 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -30,23 +30,29 @@ def __init__(self): self.get_data_socket = self.context.socket(zmq.REQ) self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) - async def wait_for_server(self): await self.is_ready_socket.recv() - async def get_model_config(self) -> ModelConfig: self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG)) model_config = await self.get_data_socket.recv() return pickle.loads(model_config) + async def get_tokenizer(self, lora_request: LoRARequest): + # TODO: handle this via get data? + pass - async def abort(self, request_id: str): + async def get_decoding_config(self, lora_request: LoRARequest): + # TODO: handle this via get data? pass - async def get_tokenizer(self, lora_request: LoRARequest): + async def abort(self, request_id: str): + # TODO: actually handle this with a new socket. pass + async def is_tracing_enabled(self): + return False + async def generate( self, inputs: PromptInputs, From 9d19d92c96502a727228fa89eb0a24c73b1751db Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:38:14 +0000 Subject: [PATCH 46/80] format --- vllm/entrypoints/openai/api_server.py | 12 ++++++------ vllm/entrypoints/openai/rpc/__init__.py | 3 ++- vllm/entrypoints/openai/rpc/client.py | 21 ++++++++++---------- vllm/entrypoints/openai/rpc/server.py | 26 ++++++++++++------------- vllm/utils.py | 1 + 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0b84b7c9c154..3faa93a8282a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -226,7 +226,7 @@ async def build_server( served_model_names = args.served_model_name else: served_model_names = [args.model] - + rpc_client = RPCClient() await rpc_client.wait_for_server() logger.info("RPC Client connected to RPC server.") @@ -304,12 +304,12 @@ async def build_server( async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - + logger.info("Starting RPC Server.") - rpc_server_process = Process(target=run_rpc_server, - args=(AsyncEngineArgs.from_cli_args(args),)) + rpc_server_process = Process(target=run_rpc_server, + args=(AsyncEngineArgs.from_cli_args(args), )) rpc_server_process.start() - + server = await build_server( args, **uvicorn_kwargs, @@ -333,7 +333,7 @@ def signal_handler() -> None: print("Gracefully stopping http server") await server.shutdown() rpc_server_process.join() - + if __name__ == "__main__": # NOTE(simon): diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index f6f129bef2a3..3272a99d142e 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -11,6 +11,7 @@ VLLM_GET_DATA_RPC_PATH = "tcp://localhost:5571" VLLM_IS_READY_RPC_PATH = "tcp://localhost:5572" + @dataclass class GenerateRequest: inputs: PromptInputs @@ -22,4 +23,4 @@ class GenerateRequest: class GetDataRequest(Enum): - MODEL_CONFIG = 1 \ No newline at end of file + MODEL_CONFIG = 1 diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index d5b64c070b4e..67c59378569c 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -7,7 +7,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, - VLLM_GET_DATA_RPC_PATH, + VLLM_GET_DATA_RPC_PATH, VLLM_IS_READY_RPC_PATH, GenerateRequest, GetDataRequest) @@ -17,6 +17,7 @@ class RPCClient: + def __init__(self): self.context = zmq.asyncio.Context() @@ -62,7 +63,7 @@ async def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: - + # Connect to RPC socket for Request-Reply pattern, # Note that we use DEALER to enable asynchronous communication # to enable streaming. @@ -72,15 +73,13 @@ async def generate( # Send GenerateRequest to the RPC Server. await socket.send_multipart([ pickle.dumps( - GenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request - ), pickle.HIGHEST_PROTOCOL - ) + GenerateRequest(inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request), + pickle.HIGHEST_PROTOCOL) ]) # Stream back the results from the RPC Server. diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 767d63a021b1..c8eaeb8d7d27 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -10,15 +10,17 @@ VLLM_IS_READY_RPC_PATH, GetDataRequest) + class RPCServer: - def __init__(self, async_engine_args): + + def __init__(self, async_engine_args): # Initialize engine first. self.engine = AsyncLLMEngine.from_engine_args( async_engine_args, UsageContext.OPENAI_API_SERVER) # Initialize context. self.context = zmq.asyncio.Context() - + # Init socket for readiness state. self.is_ready_socket = self.context.socket(zmq.REQ) self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH) @@ -38,7 +40,6 @@ def __init__(self, async_engine_args): self.poller.register(self.generate_socket, zmq.POLLIN) self.poller.register(self.get_data_socket, zmq.POLLIN) - async def get_data(self, message): request_type = pickle.loads(message) @@ -46,26 +47,23 @@ async def get_data(self, message): data = await self.engine.get_model_config() else: raise ValueError(f"Unknown request type: {request_type}") - - await self.get_data_socket.send_multipart([ - pickle.dumps(data, pickle.HIGHEST_PROTOCOL) - ]) + await self.get_data_socket.send_multipart( + [pickle.dumps(data, pickle.HIGHEST_PROTOCOL)]) async def generate(self, identity, message): request = pickle.loads(message) results_generator = self.engine.generate( - request.inputs, + request.inputs, sampling_params=request.sampling_params, request_id=request.request_id) - + async for request_output in results_generator: self.generate_socket.send_multipart([ - identity, + identity, pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) - async def run_loop(self): # Notify the RPC client that we are ready to recieve requests. @@ -80,17 +78,17 @@ async def run_loop(self): except KeyboardInterrupt: # TODO: should there be some other exception here? break - + task = None if self.generate_socket in socks: identity, message = await self.generate_socket.recv_multipart() - task = asyncio.create_task(self.generate(identity, message)) + task = asyncio.create_task(self.generate(identity, message)) elif self.get_data_socket in socks: message = await self.get_data_socket.recv() task = asyncio.create_task(self.get_data(message)) - # We need to keep around a strong reference to the task, + # We need to keep around a strong reference to the task, # to avoid the task disappearing mid-execution as running tasks # can be GC'ed. Below is a common "fire-and-forget" tasks # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task diff --git a/vllm/utils.py b/vllm/utils.py index 9537c07aabc8..858423ec8efe 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -290,6 +290,7 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: class ProducerFinished: pass + def merge_async_iterators( *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: """Merge multiple asynchronous iterators into a single iterator. From f1be4b8d27afdd86b6d6482bda38cfb5be164a07 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:52:19 +0000 Subject: [PATCH 47/80] working with single launch... --- vllm/entrypoints/openai/api_server.py | 8 ++++--- vllm/entrypoints/openai/rpc/client.py | 23 +++++++++++-------- vllm/entrypoints/openai/serving_completion.py | 3 +-- vllm/entrypoints/openai/serving_engine.py | 2 +- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3faa93a8282a..6175b8c1fab6 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,6 +17,8 @@ from prometheus_client import make_asgi_app from starlette.routing import Mount +from transformers import AutoTokenizer + import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -227,7 +229,8 @@ async def build_server( else: served_model_names = [args.model] - rpc_client = RPCClient() + # TODO: figure out a way around passing the token + rpc_client = RPCClient(tokenizer=AutoTokenizer.from_pretrained(args.model)) await rpc_client.wait_for_server() logger.info("RPC Client connected to RPC server.") model_config = await rpc_client.get_model_config() @@ -254,7 +257,7 @@ async def build_server( # return_tokens_as_token_ids=args.return_tokens_as_token_ids, # ) openai_serving_completion = OpenAIServingCompletion( - engine, + rpc_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -305,7 +308,6 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - logger.info("Starting RPC Server.") rpc_server_process = Process(target=run_rpc_server, args=(AsyncEngineArgs.from_cli_args(args), )) rpc_server_process.start() diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 67c59378569c..3ed7d73a795f 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,6 +1,6 @@ from typing import AsyncIterator, Optional, Mapping -from vllm.config import ModelConfig +from vllm.config import ModelConfig, DecodingConfig from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -17,11 +17,14 @@ class RPCClient: - - def __init__(self): + + # TODO: check if opening all these sockets is an antipattern? + def __init__(self, tokenizer): self.context = zmq.asyncio.Context() - - # TODO: check if opening all these is an antipattern? + + # TODO: do the tokenizer properly. + self.tokenizer = tokenizer + self.decoding_config = DecodingConfig() # Socket to check if the RPC server is ready. self.is_ready_socket = self.context.socket(zmq.REP) @@ -40,12 +43,12 @@ async def get_model_config(self) -> ModelConfig: return pickle.loads(model_config) async def get_tokenizer(self, lora_request: LoRARequest): - # TODO: handle this via get data? - pass + # TODO: handle this via get data? - or avoid doing via RPC + return self.tokenizer - async def get_decoding_config(self, lora_request: LoRARequest): - # TODO: handle this via get data? - pass + async def get_decoding_config(self): + # TODO: handle this via get data? - or avoid doing via RPC + return self.decoding_config async def abort(self, request_id: str): # TODO: actually handle this with a new socket. diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e3a1dc56e2d0..afee0173778c 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,6 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -129,7 +128,7 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.rcp_client.is_tracing_enabled() + is_tracing_enabled = await self.rpc_client.is_tracing_enabled() trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 6b223834f1fa..ccf58b74a498 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -72,7 +72,7 @@ def __init__( ): super().__init__() - self.rcp_client = rpc_client + self.rpc_client = rpc_client self.model_config = model_config self.max_model_len = model_config.max_model_len From 8e417ade495c74147e1abd5e3217ca3b2e86568f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 29 Jul 2024 04:54:20 +0000 Subject: [PATCH 48/80] working end to end - with some hacks --- benchmarks/benchmark_serving.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 8a3d55a959b2..fc0dbf77f16b 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -362,8 +362,6 @@ async def benchmark( ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) - print("{:<40} {:<10}".format("TOKENS PER REQUESTS:", - metrics.total_output // metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) From 4c16c5e2fb971e8901d20d1c4a7e37a08a2c6a2e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 29 Jul 2024 12:29:49 -0600 Subject: [PATCH 49/80] :goal_net: handle shutdown and request errors Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 4 ++ vllm/entrypoints/openai/rpc/client.py | 10 +++- vllm/entrypoints/openai/rpc/server.py | 74 +++++++++++++++++++++------ 3 files changed, 70 insertions(+), 18 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6175b8c1fab6..2daa0a1d1f63 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -54,6 +54,7 @@ openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding openai_serving_tokenization: OpenAIServingTokenization +rpc_client: RPCClient logger = init_logger('vllm.entrypoints.openai.api_server') @@ -230,6 +231,7 @@ async def build_server( served_model_names = [args.model] # TODO: figure out a way around passing the token + global rpc_client rpc_client = RPCClient(tokenizer=AutoTokenizer.from_pretrained(args.model)) await rpc_client.wait_for_server() logger.info("RPC Client connected to RPC server.") @@ -334,6 +336,8 @@ def signal_handler() -> None: except asyncio.CancelledError: print("Gracefully stopping http server") await server.shutdown() + print("Cleaning up ZMQ client context") + rpc_client.close() rpc_server_process.join() diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 3ed7d73a795f..aec61655c19c 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -37,6 +37,10 @@ def __init__(self, tokenizer): async def wait_for_server(self): await self.is_ready_socket.recv() + def close(self): + """Destroy the zmq context and close all sockets""" + self.context.destroy() + async def get_model_config(self) -> ModelConfig: self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG)) model_config = await self.get_data_socket.recv() @@ -90,9 +94,13 @@ async def generate( message = await socket.recv() request_output = pickle.loads(message) + if isinstance(request_output, Exception): + socket.close() + raise request_output + if request_output.finished: break yield request_output - socket.close() yield request_output + socket.close() diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index c8eaeb8d7d27..74fb8dbe08b7 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -2,6 +2,7 @@ import pickle import zmq import zmq.asyncio +import signal from vllm import AsyncLLMEngine from vllm.usage.usage_lib import UsageContext @@ -9,6 +10,9 @@ VLLM_GET_DATA_RPC_PATH, VLLM_IS_READY_RPC_PATH, GetDataRequest) +from vllm.logger import init_logger + +logger = init_logger('vllm.entrypoints.openai.rpc.server') class RPCServer: @@ -40,6 +44,13 @@ def __init__(self, async_engine_args): self.poller.register(self.generate_socket, zmq.POLLIN) self.poller.register(self.get_data_socket, zmq.POLLIN) + def cleanup(self): + """Shuts down the zmq context and closes all sockets""" + self.context.destroy() + del self.get_data_socket + del self.generate_socket + del self.is_ready_socket + async def get_data(self, message): request_type = pickle.loads(message) @@ -52,18 +63,26 @@ async def get_data(self, message): [pickle.dumps(data, pickle.HIGHEST_PROTOCOL)]) async def generate(self, identity, message): - request = pickle.loads(message) - - results_generator = self.engine.generate( - request.inputs, - sampling_params=request.sampling_params, - request_id=request.request_id) - - async for request_output in results_generator: + try: + request = pickle.loads(message) + + results_generator = self.engine.generate( + request.inputs, + sampling_params=request.sampling_params, + request_id=request.request_id) + + async for request_output in results_generator: + self.generate_socket.send_multipart([ + identity, + pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) + ]) + except Exception as e: + ### Notify client of all failures self.generate_socket.send_multipart([ - identity, - pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) - ]) + identity, + pickle.dumps(e, pickle.HIGHEST_PROTOCOL) + ]) + async def run_loop(self): # Notify the RPC client that we are ready to recieve requests. @@ -73,11 +92,8 @@ async def run_loop(self): # Avoid GC of running tasks. running_tasks = set() while True: - try: - socks = dict(await self.poller.poll()) - except KeyboardInterrupt: - # TODO: should there be some other exception here? - break + self.poll_future = self.poller.poll() + socks = dict(await self.poll_future) task = None if self.generate_socket in socks: @@ -99,6 +115,30 @@ async def run_loop(self): # TODO: Do I need to close the generate / get_data sockets? +async def run_server(server: RPCServer): + # Run with proper interrupt handling + logger.info("Booting up vLLM zmq backend") + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.run_loop()) + def signal_handler() -> None: + # Kill the server on interrupt / terminate + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + logger.info("ZMQ Backend was interrupted") + finally: + # Clean up all the zmq resources before exiting + server.cleanup() + logger.info("vLLM ZMQ Backend shut down") + + def run_rpc_server(async_engine_args): server = RPCServer(async_engine_args=async_engine_args) - asyncio.run(server.run_loop()) + asyncio.run(run_server(server)) From 6ddd4a730ca87129b2c43423c13fe7b4cc46d7cd Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 29 Jul 2024 14:19:27 -0600 Subject: [PATCH 50/80] :art: fmt and clean up shutdown handler Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 9 ++++++--- vllm/entrypoints/openai/rpc/client.py | 4 ++-- vllm/entrypoints/openai/rpc/server.py | 8 +++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2daa0a1d1f63..126b5c5aae3e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -332,11 +332,14 @@ def signal_handler() -> None: try: await server_task - rpc_server_process.join() + # If the frontend server exited on its own, then terminate the + # backend server too + rpc_server_process.terminate() except asyncio.CancelledError: - print("Gracefully stopping http server") + logger.info("Gracefully stopping http server") await server.shutdown() - print("Cleaning up ZMQ client context") + finally: + logger.info("Cleaning up ZMQ client context") rpc_client.close() rpc_server_process.join() diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index aec61655c19c..a292c703b2bb 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -17,11 +17,11 @@ class RPCClient: - + # TODO: check if opening all these sockets is an antipattern? def __init__(self, tokenizer): self.context = zmq.asyncio.Context() - + # TODO: do the tokenizer properly. self.tokenizer = tokenizer self.decoding_config = DecodingConfig() diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 74fb8dbe08b7..225c9a33d545 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -78,11 +78,8 @@ async def generate(self, identity, message): ]) except Exception as e: ### Notify client of all failures - self.generate_socket.send_multipart([ - identity, - pickle.dumps(e, pickle.HIGHEST_PROTOCOL) - ]) - + self.generate_socket.send_multipart( + [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) async def run_loop(self): # Notify the RPC client that we are ready to recieve requests. @@ -122,6 +119,7 @@ async def run_server(server: RPCServer): loop = asyncio.get_running_loop() server_task = loop.create_task(server.run_loop()) + def signal_handler() -> None: # Kill the server on interrupt / terminate server_task.cancel() From 6d7da74521fa0a50e708c6f955e975b1fe3525ec Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 29 Jul 2024 14:20:52 -0600 Subject: [PATCH 51/80] :bug: fixup type hint for queue Signed-off-by: Joe Runde --- vllm/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 858423ec8efe..9ce909d9d79c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -299,7 +299,8 @@ def merge_async_iterators( When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. """ - queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, + Exception]] = asyncio.Queue() finished = [False] * len(iterators) From 97ea04df0535dbfb940c3f0e93e94403fd1975de Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 29 Jul 2024 14:54:35 -0600 Subject: [PATCH 52/80] :sparkles: update chat endpoint Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 22 +++++++++---------- vllm/entrypoints/openai/serving_chat.py | 16 +++++++------- vllm/entrypoints/openai/serving_completion.py | 2 +- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 126b5c5aae3e..57d6eef4a73d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -247,17 +247,17 @@ async def build_server( global openai_serving_embedding global openai_serving_tokenization - # openai_serving_chat = OpenAIServingChat( - # engine, - # model_config, - # served_model_names, - # args.response_role, - # lora_modules=args.lora_modules, - # prompt_adapters=args.prompt_adapters, - # request_logger=request_logger, - # chat_template=args.chat_template, - # return_tokens_as_token_ids=args.return_tokens_as_token_ids, - # ) + openai_serving_chat = OpenAIServingChat( + rpc_client, + model_config, + served_model_names, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=args.chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) openai_serving_completion = OpenAIServingCompletion( rpc_client, model_config, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 012f70e66110..ef6aa2c7d2d0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -8,7 +8,6 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -20,6 +19,7 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, FunctionCall, ToolCall, UsageInfo) +from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) @@ -41,7 +41,7 @@ class OpenAIServingChat(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + rpc_client: RPCClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -52,7 +52,7 @@ def __init__( chat_template: Optional[str], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(rpc_client=rpc_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -91,7 +91,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.rpc_client.get_tokenizer(lora_request) conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -135,7 +135,7 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: sampling_params = request.to_sampling_params() - decoding_config = await self.engine.get_decoding_config() + decoding_config = await self.rpc_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( @@ -168,7 +168,7 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = await self.engine.is_tracing_enabled() + is_tracing_enabled = await self.rpc_client.is_tracing_enabled() trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -176,7 +176,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.engine.generate( + result_generator = self.rpc_client.generate( engine_inputs, sampling_params, request_id, @@ -448,7 +448,7 @@ async def chat_completion_full_generator( async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(request_id) + await self.rpc_client.abort(request_id) return self.create_error_response("Client disconnected") final_res = res assert final_res is not None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index afee0173778c..6e5b1be5b869 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -177,7 +177,7 @@ async def create_completion(self, request: CompletionRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.rcp_client.abort(f"{request_id}-{i}") + await self.rpc_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res From 6d753a40f6e2e34c40fc2660259181521ff9a7f5 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 29 Jul 2024 14:59:27 -0600 Subject: [PATCH 53/80] :bug: fixup zmq constant types Signed-off-by: Joe Runde --- vllm/entrypoints/openai/rpc/client.py | 6 +++--- vllm/entrypoints/openai/rpc/server.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index a292c703b2bb..f4424a7fb407 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -27,11 +27,11 @@ def __init__(self, tokenizer): self.decoding_config = DecodingConfig() # Socket to check if the RPC server is ready. - self.is_ready_socket = self.context.socket(zmq.REP) + self.is_ready_socket = self.context.socket(zmq.constants.REP) self.is_ready_socket.connect(VLLM_IS_READY_RPC_PATH) # Socket to query data (e.g. get_model_config) - self.get_data_socket = self.context.socket(zmq.REQ) + self.get_data_socket = self.context.socket(zmq.constants.REQ) self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) async def wait_for_server(self): @@ -74,7 +74,7 @@ async def generate( # Connect to RPC socket for Request-Reply pattern, # Note that we use DEALER to enable asynchronous communication # to enable streaming. - socket = self.context.socket(zmq.DEALER) + socket = self.context.socket(zmq.constants.DEALER) socket.connect(VLLM_GENERATE_RPC_PATH) # Send GenerateRequest to the RPC Server. diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 225c9a33d545..3ef721094f4d 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -26,23 +26,23 @@ def __init__(self, async_engine_args): self.context = zmq.asyncio.Context() # Init socket for readiness state. - self.is_ready_socket = self.context.socket(zmq.REQ) + self.is_ready_socket = self.context.socket(zmq.constants.REQ) self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH) # Init socket for generation. - self.generate_socket = self.context.socket(zmq.ROUTER) + self.generate_socket = self.context.socket(zmq.constants.ROUTER) self.generate_socket.bind(VLLM_GENERATE_RPC_PATH) # TODO: add socket for generation without streaming # Init socket for simple data requests. - self.get_data_socket = self.context.socket(zmq.REP) + self.get_data_socket = self.context.socket(zmq.constants.REP) self.get_data_socket.bind(VLLM_GET_DATA_RPC_PATH) # Setup polling so we can listen on both sockets. self.poller = zmq.asyncio.Poller() - self.poller.register(self.generate_socket, zmq.POLLIN) - self.poller.register(self.get_data_socket, zmq.POLLIN) + self.poller.register(self.generate_socket, zmq.constants.POLLIN) + self.poller.register(self.get_data_socket, zmq.constants.POLLIN) def cleanup(self): """Shuts down the zmq context and closes all sockets""" From 38e308e45dd3886ddf465890c08fb391e1af2418 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 29 Jul 2024 15:28:47 -0600 Subject: [PATCH 54/80] :sparkles: hook up de/tokenize Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 17 +++++++++-------- vllm/entrypoints/openai/serving_tokenization.py | 9 +++++---- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 57d6eef4a73d..bc988f4912d3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -267,20 +267,21 @@ async def build_server( request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) + # TODO: emebddings should probably just run with a local AsyncLLMEmgine # openai_serving_embedding = OpenAIServingEmbedding( # engine, # model_config, # served_model_names, # request_logger=request_logger, # ) - # openai_serving_tokenization = OpenAIServingTokenization( - # engine, - # model_config, - # served_model_names, - # lora_modules=args.lora_modules, - # request_logger=request_logger, - # chat_template=args.chat_template, - # ) + openai_serving_tokenization = OpenAIServingTokenization( + rpc_client, + model_config, + served_model_names, + lora_modules=args.lora_modules, + request_logger=request_logger, + chat_template=args.chat_template, + ) app.root_path = args.root_path logger.info("Available routes are:") diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 94e1b03ed403..c52ba56d5edd 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -15,6 +15,7 @@ TokenizeRequest, TokenizeResponse) # yapf: enable +from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.utils import random_uuid @@ -24,7 +25,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + rpc_client: RPCClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -32,7 +33,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(engine=engine, + super().__init__(rpc_client=rpc_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -57,7 +58,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.rpc_client.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -113,7 +114,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.rpc_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, From ec19a7b6e9319ee090c7d36e14fd8d471563ae77 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 29 Jul 2024 16:15:29 -0600 Subject: [PATCH 55/80] :recycle: add VLLMBackend protocol Signed-off-by: Joe Runde --- vllm/engine/protocol.py | 79 +++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 28 +++---- vllm/entrypoints/openai/rpc/server.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 16 ++-- vllm/entrypoints/openai/serving_completion.py | 19 ++--- vllm/entrypoints/openai/serving_embedding.py | 12 +-- vllm/entrypoints/openai/serving_engine.py | 6 +- .../openai/serving_tokenization.py | 11 ++- 8 files changed, 126 insertions(+), 47 deletions(-) create mode 100644 vllm/engine/protocol.py diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py new file mode 100644 index 000000000000..66f7f02da737 --- /dev/null +++ b/vllm/engine/protocol.py @@ -0,0 +1,79 @@ +from typing import AsyncIterator, List, Mapping, Optional, Protocol + +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput + +from transformers import PreTrainedTokenizer + +class VLLMBackend(Protocol): + """Protocol class for asynchronous vllm backends""" + + @property + def is_running(self) -> bool: + pass + + @property + def is_stopped(self) -> bool: + pass + + @property + def errored(self) -> bool: + pass + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + """Generates outputs for a request""" + + async def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model.""" + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> PreTrainedTokenizer: + """Get the appropriate Tokenizer for the request""" + + async def is_tracing_enabled(self) -> bool: + pass + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + pass + \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bc988f4912d3..05139a2ff597 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -6,7 +6,7 @@ from contextlib import asynccontextmanager from multiprocessing import Process from http import HTTPStatus -from typing import Optional, Set +from typing import Set import fastapi import uvicorn @@ -19,9 +19,10 @@ from transformers import AutoTokenizer +from vllm.engine.protocol import VLLMBackend import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine +# from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -48,13 +49,12 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -engine: AsyncLLMEngine engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding openai_serving_tokenization: OpenAIServingTokenization -rpc_client: RPCClient +backend: VLLMBackend logger = init_logger('vllm.entrypoints.openai.api_server') @@ -67,7 +67,7 @@ async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await engine.do_log_stats() + await backend.do_log_stats() # if not engine_args.disable_log_stats: # task = asyncio.create_task(_force_log()) @@ -91,7 +91,7 @@ def mount_metrics(app: fastapi.FastAPI): @router.get("/health") async def health() -> Response: """Health check.""" - await openai_serving_chat.engine.check_health() + await backend.check_health() return Response(status_code=200) @@ -231,11 +231,11 @@ async def build_server( served_model_names = [args.model] # TODO: figure out a way around passing the token - global rpc_client - rpc_client = RPCClient(tokenizer=AutoTokenizer.from_pretrained(args.model)) - await rpc_client.wait_for_server() + global backend + backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained(args.model)) + await backend.wait_for_server() logger.info("RPC Client connected to RPC server.") - model_config = await rpc_client.get_model_config() + model_config = await backend.get_model_config() if args.disable_log_requests: request_logger = None @@ -248,7 +248,7 @@ async def build_server( global openai_serving_tokenization openai_serving_chat = OpenAIServingChat( - rpc_client, + backend, model_config, served_model_names, args.response_role, @@ -259,7 +259,7 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_completion = OpenAIServingCompletion( - rpc_client, + backend, model_config, served_model_names, lora_modules=args.lora_modules, @@ -275,7 +275,7 @@ async def build_server( # request_logger=request_logger, # ) openai_serving_tokenization = OpenAIServingTokenization( - rpc_client, + backend, model_config, served_model_names, lora_modules=args.lora_modules, @@ -341,7 +341,7 @@ def signal_handler() -> None: await server.shutdown() finally: logger.info("Cleaning up ZMQ client context") - rpc_client.close() + backend.close() rpc_server_process.join() diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 3ef721094f4d..a407d9217744 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -82,7 +82,7 @@ async def generate(self, identity, message): [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) async def run_loop(self): - # Notify the RPC client that we are ready to recieve requests. + # Notify the RPC client that we are ready to receive requests. await self.is_ready_socket.send_string("Ready!") self.is_ready_socket.close() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ef6aa2c7d2d0..e7d69aba40be 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -8,6 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig +from vllm.engine.protocol import VLLMBackend from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -19,7 +20,6 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, FunctionCall, ToolCall, UsageInfo) -from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) @@ -41,7 +41,7 @@ class OpenAIServingChat(OpenAIServing): def __init__( self, - rpc_client: RPCClient, + vllm_backend: VLLMBackend, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -52,7 +52,7 @@ def __init__( chat_template: Optional[str], return_tokens_as_token_ids: bool = False, ): - super().__init__(rpc_client=rpc_client, + super().__init__(vllm_backend=vllm_backend, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -91,7 +91,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.rpc_client.get_tokenizer(lora_request) + tokenizer = await self.vllm_backend.get_tokenizer(lora_request) conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -135,7 +135,7 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: sampling_params = request.to_sampling_params() - decoding_config = await self.rpc_client.get_decoding_config() + decoding_config = await self.vllm_backend.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( @@ -168,7 +168,7 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = await self.rpc_client.is_tracing_enabled() + is_tracing_enabled = await self.vllm_backend.is_tracing_enabled() trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -176,7 +176,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.rpc_client.generate( + result_generator = self.vllm_backend.generate( engine_inputs, sampling_params, request_id, @@ -448,7 +448,7 @@ async def chat_completion_full_generator( async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.rpc_client.abort(request_id) + await self.vllm_backend.abort(request_id) return self.create_error_response("Client disconnected") final_res = res assert final_res is not None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6e5b1be5b869..7ab750bd84af 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig +from vllm.engine.protocol import VLLMBackend from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -19,7 +20,6 @@ CompletionStreamResponse, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) @@ -44,7 +44,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - rpc_client: RPCClient, + vllm_backend: VLLMBackend, model_config: ModelConfig, served_model_names: List[str], *, @@ -53,7 +53,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(rpc_client=rpc_client, + super().__init__(vllm_backend=vllm_backend, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -93,10 +93,10 @@ async def create_completion(self, request: CompletionRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.rpc_client.get_tokenizer(lora_request) + tokenizer = await self.vllm_backend.get_tokenizer(lora_request) sampling_params = request.to_sampling_params() - decoding_config = await self.rpc_client.get_decoding_config() + decoding_config = await self.vllm_backend.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( @@ -128,7 +128,8 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.rpc_client.is_tracing_enabled() + is_tracing_enabled = await self.vllm_backend.is_tracing_enabled( + ) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -136,7 +137,7 @@ async def create_completion(self, request: CompletionRequest, raw_request.headers): log_tracing_disabled_warning() - generator = self.rpc_client.generate( + generator = self.vllm_backend.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, @@ -177,7 +178,7 @@ async def create_completion(self, request: CompletionRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.rpc_client.abort(f"{request_id}-{i}") + await self.vllm_backend.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res @@ -239,7 +240,7 @@ async def completion_stream_generator( # Abort the request if the client disconnects. if await raw_request.is_disconnected(): - await self.rpc_client.abort(f"{request_id}-{prompt_idx}") + await self.vllm_backend.abort(f"{request_id}-{prompt_idx}") raise StopAsyncIteration() for output in res.outputs: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index bccc90894e79..9518c42057cf 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -6,7 +6,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import VLLMBackend from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + vllm_backend: VLLMBackend, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(engine=engine, + super().__init__(vllm_backend=vllm_backend, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -99,7 +99,7 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.vllm_backend.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() @@ -124,7 +124,7 @@ async def create_embedding(self, request: EmbeddingRequest, "Prompt adapter is not supported " "for embedding models") - generator = self.engine.encode( + generator = self.vllm_backend.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, @@ -146,7 +146,7 @@ async def create_embedding(self, request: EmbeddingRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.vllm_backend.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ccf58b74a498..93a6b501de47 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -9,10 +9,10 @@ from typing_extensions import Annotated from vllm.config import ModelConfig +from vllm.engine.protocol import VLLMBackend from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable -from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, @@ -61,7 +61,7 @@ class OpenAIServing: def __init__( self, - rpc_client: RPCClient, + vllm_backend: VLLMBackend, model_config: ModelConfig, served_model_names: List[str], *, @@ -72,7 +72,7 @@ def __init__( ): super().__init__() - self.rpc_client = rpc_client + self.vllm_backend = vllm_backend self.model_config = model_config self.max_model_len = model_config.max_model_len diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c52ba56d5edd..ab6a01570802 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,9 +1,9 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine # yapf conflicts with isort for this block # yapf: disable +from vllm.engine.protocol import VLLMBackend from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -15,7 +15,6 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.rpc.client import RPCClient from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.utils import random_uuid @@ -25,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - rpc_client: RPCClient, + vllm_backend: VLLMBackend, model_config: ModelConfig, served_model_names: List[str], *, @@ -33,7 +32,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(rpc_client=rpc_client, + super().__init__(vllm_backend=vllm_backend, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -58,7 +57,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.rpc_client.get_tokenizer(lora_request) + tokenizer = await self.vllm_backend.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -114,7 +113,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.rpc_client.get_tokenizer(lora_request) + tokenizer = await self.vllm_backend.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, From 453939b48b4dc9ddbcf5d050c9bae530543109aa Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 30 Jul 2024 14:41:47 -0600 Subject: [PATCH 56/80] Frontend mp flag (#384) @robertgshaw2-neuralmagic This adds the `--disable-frontend-multiprocessing` flag and should also correctly pick up embeddings models to disable the multiprocessing here. (Also some unrelated formatting changes) The backend stuff is wrapped up in a context manager that handles the process startup and shutdown at exit as well, so that we don't have to muck around much in the existing server lifecycle code --------- Signed-off-by: Joe Runde --- vllm/engine/protocol.py | 14 ++- vllm/entrypoints/openai/api_server.py | 142 ++++++++++++++++-------- vllm/entrypoints/openai/cli_args.py | 9 +- vllm/entrypoints/openai/rpc/__init__.py | 2 +- vllm/entrypoints/openai/rpc/client.py | 20 ++-- vllm/entrypoints/openai/rpc/server.py | 5 +- 6 files changed, 122 insertions(+), 70 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 66f7f02da737..67b9e5cf5cc0 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,4 +1,7 @@ -from typing import AsyncIterator, List, Mapping, Optional, Protocol +from typing import (AsyncIterator, List, Mapping, Optional, Protocol, + runtime_checkable) + +from transformers import PreTrainedTokenizer from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs @@ -10,22 +13,22 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput -from transformers import PreTrainedTokenizer +@runtime_checkable class VLLMBackend(Protocol): """Protocol class for asynchronous vllm backends""" @property def is_running(self) -> bool: - pass + ... @property def is_stopped(self) -> bool: - pass + ... @property def errored(self) -> bool: - pass + ... async def generate( self, @@ -76,4 +79,3 @@ async def do_log_stats( model_output: Optional[List[SamplerOutput]] = None, ) -> None: pass - \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 05139a2ff597..7de90ff5ec23 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,9 +4,9 @@ import re import signal from contextlib import asynccontextmanager -from multiprocessing import Process from http import HTTPStatus -from typing import Set +from multiprocessing import Process +from typing import AsyncIterator, Set import fastapi import uvicorn @@ -16,13 +16,13 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from starlette.routing import Mount - from transformers import AutoTokenizer -from vllm.engine.protocol import VLLMBackend +from vllm.config import ModelConfig import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs -# from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import VLLMBackend +from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -35,15 +35,16 @@ EmbeddingRequest, ErrorResponse, TokenizeRequest, TokenizeResponse) +from vllm.entrypoints.openai.rpc.client import RPCClient +from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) -from vllm.entrypoints.openai.rpc.client import RPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION @@ -77,6 +78,62 @@ async def _force_log(): yield +@asynccontextmanager +async def build_backend(args) -> AsyncIterator[VLLMBackend]: + # Context manager to handle backend lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + + # Backend itself still global for the silly lil' health handler + global backend + + # First need to determine if this is an embeddings model + # (no remote backend for those) + model_config = ModelConfig(model=args.model, + tokenizer=args.tokenizer, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16") + if model_config.embedding_mode or args.disable_frontend_multiprocessing: + # local backend + logger.info("Initializing in-process AsyncLLMEmgine") + backend = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + yield backend + # No cleanup + return + + else: + # remote backend + ## First need to start the backend process + logger.info("Initializing AsyncLLMEmgine in separate process") + rpc_server_process = Process(target=run_rpc_server, + args=(engine_args, )) + rpc_server_process.start() + + ## Then build the client for the backend process + # TODO: figure out a way around passing the tokenizer + backend = RPCClient( + tokenizer=AutoTokenizer.from_pretrained(args.model)) + await backend.wait_for_server() + logger.info("RPC Client connected to RPC server.") + + try: + yield backend + finally: + ## Cleanup: + # Ensure backend process was terminated + rpc_server_process.terminate() + + # Close all open connections to the backend + logger.info("Cleaning up ZMQ client context") + backend.close() + + # Wait for server process to join + rpc_server_process.join() + + router = APIRouter() @@ -220,6 +277,7 @@ async def authentication(request: Request, call_next): async def build_server( + backend: VLLMBackend, args, **uvicorn_kwargs, ) -> uvicorn.Server: @@ -230,11 +288,6 @@ async def build_server( else: served_model_names = [args.model] - # TODO: figure out a way around passing the token - global backend - backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained(args.model)) - await backend.wait_for_server() - logger.info("RPC Client connected to RPC server.") model_config = await backend.get_model_config() if args.disable_log_requests: @@ -267,13 +320,12 @@ async def build_server( request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) - # TODO: emebddings should probably just run with a local AsyncLLMEmgine - # openai_serving_embedding = OpenAIServingEmbedding( - # engine, - # model_config, - # served_model_names, - # request_logger=request_logger, - # ) + openai_serving_embedding = OpenAIServingEmbedding( + backend, + model_config, + served_model_names, + request_logger=request_logger, + ) openai_serving_tokenization = OpenAIServingTokenization( backend, model_config, @@ -311,38 +363,30 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - rpc_server_process = Process(target=run_rpc_server, - args=(AsyncEngineArgs.from_cli_args(args), )) - rpc_server_process.start() + async with build_backend(args) as backend: - server = await build_server( - args, - **uvicorn_kwargs, - ) + server = await build_server( + backend, + args, + **uvicorn_kwargs, + ) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) - loop = asyncio.get_running_loop() - - server_task = loop.create_task(server.serve()) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - # If the frontend server exited on its own, then terminate the - # backend server too - rpc_server_process.terminate() - except asyncio.CancelledError: - logger.info("Gracefully stopping http server") - await server.shutdown() - finally: - logger.info("Cleaning up ZMQ client context") - backend.close() - rpc_server_process.join() + try: + await server_task + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + await server.shutdown() if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a4192937980f..e637e20e16f5 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -131,9 +131,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--return-tokens-as-token-ids", action="store_true", - help="When --max-logprobs is specified, represents single tokens as" - "strings of the form 'token_id:{token_id}' so that tokens that" + help="When --max-logprobs is specified, represents single tokens as " + "strings of the form 'token_id:{token_id}' so that tokens that " "are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "proecss as the model servinge engine.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 3272a99d142e..299821a701ef 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from typing import Optional, Mapping from enum import Enum +from typing import Mapping, Optional from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index f4424a7fb407..db0f15bf53df 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,19 +1,19 @@ -from typing import AsyncIterator, Optional, Mapping +import pickle +from typing import AsyncIterator, Mapping, Optional + +import zmq +import zmq.asyncio -from vllm.config import ModelConfig, DecodingConfig +from vllm.config import DecodingConfig, ModelConfig +from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, + VLLM_GET_DATA_RPC_PATH, + VLLM_IS_READY_RPC_PATH, + GenerateRequest, GetDataRequest) from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, - VLLM_GET_DATA_RPC_PATH, - VLLM_IS_READY_RPC_PATH, - GenerateRequest, GetDataRequest) - -import zmq -import zmq.asyncio -import pickle class RPCClient: diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index a407d9217744..5663036dc7b1 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,16 +1,17 @@ import asyncio import pickle +import signal + import zmq import zmq.asyncio -import signal from vllm import AsyncLLMEngine -from vllm.usage.usage_lib import UsageContext from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, VLLM_GET_DATA_RPC_PATH, VLLM_IS_READY_RPC_PATH, GetDataRequest) from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext logger = init_logger('vllm.entrypoints.openai.rpc.server') From 1f33286c238b8780257c461a73b3d99d2cf98ee3 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 31 Jul 2024 12:05:55 -0400 Subject: [PATCH 57/80] Features / Cleanup for MP Frontend (#387) SUMMARY: * refactor to use single socket * cleanup comments / logging * add `do_log_stats` * add `abort` --- vllm/entrypoints/openai/api_server.py | 20 ++- vllm/entrypoints/openai/rpc/__init__.py | 24 +++- vllm/entrypoints/openai/rpc/client.py | 112 ++++++++++----- vllm/entrypoints/openai/rpc/server.py | 184 ++++++++++++++---------- 4 files changed, 212 insertions(+), 128 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7de90ff5ec23..a3b945c00e8c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,11 +18,11 @@ from starlette.routing import Mount from transformers import AutoTokenizer -from vllm.config import ModelConfig import vllm.envs as envs +from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.protocol import VLLMBackend from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import VLLMBackend from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -70,10 +70,10 @@ async def _force_log(): await asyncio.sleep(10) await backend.do_log_stats() - # if not engine_args.disable_log_stats: - # task = asyncio.create_task(_force_log()) - # _running_tasks.add(task) - # task.add_done_callback(_running_tasks.remove) + if not engine_args.disable_log_stats: + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) yield @@ -82,6 +82,7 @@ async def _force_log(): async def build_backend(args) -> AsyncIterator[VLLMBackend]: # Context manager to handle backend lifecycle # Ensures everything is shutdown and cleaned up on error/exit + global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) # Backend itself still global for the silly lil' health handler @@ -97,7 +98,6 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: dtype="float16") if model_config.embedding_mode or args.disable_frontend_multiprocessing: # local backend - logger.info("Initializing in-process AsyncLLMEmgine") backend = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) yield backend @@ -107,9 +107,9 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: else: # remote backend ## First need to start the backend process - logger.info("Initializing AsyncLLMEmgine in separate process") rpc_server_process = Process(target=run_rpc_server, - args=(engine_args, )) + args=(engine_args, + UsageContext.OPENAI_API_SERVER)) rpc_server_process.start() ## Then build the client for the backend process @@ -117,7 +117,6 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: backend = RPCClient( tokenizer=AutoTokenizer.from_pretrained(args.model)) await backend.wait_for_server() - logger.info("RPC Client connected to RPC server.") try: yield backend @@ -127,7 +126,6 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: rpc_server_process.terminate() # Close all open connections to the backend - logger.info("Cleaning up ZMQ client context") backend.close() # Wait for server process to join diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 299821a701ef..6a403f48793f 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -1,19 +1,18 @@ from dataclasses import dataclass from enum import Enum -from typing import Mapping, Optional +from typing import Mapping, Optional, Union from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -VLLM_GENERATE_RPC_PATH = "tcp://localhost:5570" -VLLM_GET_DATA_RPC_PATH = "tcp://localhost:5571" -VLLM_IS_READY_RPC_PATH = "tcp://localhost:5572" +VLLM_RPC_PATH = "tcp://localhost:5570" +VLLM_RPC_SUCCESS_STR = "SUCCESS" @dataclass -class GenerateRequest: +class RPCGenerateRequest: inputs: PromptInputs sampling_params: SamplingParams request_id: str @@ -22,5 +21,16 @@ class GenerateRequest: prompt_adapter_request: Optional[PromptAdapterRequest] = None -class GetDataRequest(Enum): - MODEL_CONFIG = 1 +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCUtilityRequest(Enum): + IS_SERVER_READY = 1 + GET_MODEL_CONFIG = 2 + DO_LOG_STATS = 3 + + +RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, + RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index db0f15bf53df..f430e417388a 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -5,10 +5,9 @@ import zmq.asyncio from vllm.config import DecodingConfig, ModelConfig -from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, - VLLM_GET_DATA_RPC_PATH, - VLLM_IS_READY_RPC_PATH, - GenerateRequest, GetDataRequest) +from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, VLLM_RPC_PATH, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -20,31 +19,38 @@ class RPCClient: # TODO: check if opening all these sockets is an antipattern? def __init__(self, tokenizer): + # ZMQ context. self.context = zmq.asyncio.Context() # TODO: do the tokenizer properly. self.tokenizer = tokenizer self.decoding_config = DecodingConfig() - # Socket to check if the RPC server is ready. - self.is_ready_socket = self.context.socket(zmq.constants.REP) - self.is_ready_socket.connect(VLLM_IS_READY_RPC_PATH) + def close(self): + """Destroy the ZeroMQ Context.""" + self.context.destroy() - # Socket to query data (e.g. get_model_config) - self.get_data_socket = self.context.socket(zmq.constants.REQ) - self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, + error_message: str): + """Send one-way RPC request to trigger an action.""" - async def wait_for_server(self): - await self.is_ready_socket.recv() + # Connect to socket. + socket = self.context.socket(zmq.constants.DEALER) + socket.connect(VLLM_RPC_PATH) - def close(self): - """Destroy the zmq context and close all sockets""" - self.context.destroy() + # Ping RPC Server with request. + socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) - async def get_model_config(self) -> ModelConfig: - self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG)) - model_config = await self.get_data_socket.recv() - return pickle.loads(model_config) + # Await acknowledgement from RPCServer. + response = pickle.loads(await socket.recv()) + + if (not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR): + socket.close() + raise ValueError(error_message) + + socket.close() + + return response async def get_tokenizer(self, lora_request: LoRARequest): # TODO: handle this via get data? - or avoid doing via RPC @@ -54,13 +60,53 @@ async def get_decoding_config(self): # TODO: handle this via get data? - or avoid doing via RPC return self.decoding_config - async def abort(self, request_id: str): - # TODO: actually handle this with a new socket. - pass - async def is_tracing_enabled(self): + # TODO: what is this? return False + async def wait_for_server(self): + """Wait for the RPCServer to start up.""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_READY, + error_message="Unable to start RPC Server.") + + async def get_model_config(self) -> ModelConfig: + """Get the ModelConfig object from the RPC Server""" + + # Connect to socket. + socket = self.context.socket(zmq.constants.DEALER) + socket.connect(VLLM_RPC_PATH) + + # Ping RPCServer with GET_MODEL_CONFIG request. + socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) + + # Await the MODEL_CONFIG from the Server. + model_config = pickle.loads(await socket.recv()) + + if not isinstance(model_config, ModelConfig): + socket.close() + raise ValueError("Expected ModelConfig object from RPC, but " + f"got {model_config}") + + socket.close() + + return model_config + + async def abort(self, request_id: str): + """Send an RPCAbortRequest to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") + + async def do_log_stats(self): + """Send a DO_LOG_STATS signal to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") + async def generate( self, inputs: PromptInputs, @@ -70,22 +116,24 @@ async def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # Connect to RPC socket for Request-Reply pattern, # Note that we use DEALER to enable asynchronous communication # to enable streaming. socket = self.context.socket(zmq.constants.DEALER) - socket.connect(VLLM_GENERATE_RPC_PATH) + socket.connect(VLLM_RPC_PATH) - # Send GenerateRequest to the RPC Server. - await socket.send_multipart([ + # Send RPCGenerateRequest to the RPCServer. + socket.send_multipart([ pickle.dumps( - GenerateRequest(inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request), + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request), pickle.HIGHEST_PROTOCOL) ]) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 5663036dc7b1..0284e5eb91c5 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,15 +1,16 @@ import asyncio import pickle import signal +from typing import Any, Coroutine import zmq import zmq.asyncio +from typing_extensions import Never -from vllm import AsyncLLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, - VLLM_GET_DATA_RPC_PATH, - VLLM_IS_READY_RPC_PATH, - GetDataRequest) +from vllm import AsyncEngineArgs, AsyncLLMEngine +from vllm.entrypoints.openai.rpc import (VLLM_RPC_PATH, VLLM_RPC_SUCCESS_STR, + RPCAbortRequest, RPCGenerateRequest, + RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -18,109 +19,136 @@ class RPCServer: - def __init__(self, async_engine_args): + # TODO: check if opening all these sockets is an antipattern. + # Alternative, use a smaller number of sockets with conditioning on the + # data that is passed through the socket. + def __init__(self, async_engine_args: AsyncEngineArgs, + usage_context: UsageContext): # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, UsageContext.OPENAI_API_SERVER) + self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, + usage_context) # Initialize context. self.context = zmq.asyncio.Context() # Init socket for readiness state. - self.is_ready_socket = self.context.socket(zmq.constants.REQ) - self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH) - - # Init socket for generation. - self.generate_socket = self.context.socket(zmq.constants.ROUTER) - self.generate_socket.bind(VLLM_GENERATE_RPC_PATH) - - # TODO: add socket for generation without streaming - - # Init socket for simple data requests. - self.get_data_socket = self.context.socket(zmq.constants.REP) - self.get_data_socket.bind(VLLM_GET_DATA_RPC_PATH) - - # Setup polling so we can listen on both sockets. - self.poller = zmq.asyncio.Poller() - self.poller.register(self.generate_socket, zmq.constants.POLLIN) - self.poller.register(self.get_data_socket, zmq.constants.POLLIN) + self.socket = self.context.socket(zmq.constants.ROUTER) + self.socket.bind(VLLM_RPC_PATH) def cleanup(self): - """Shuts down the zmq context and closes all sockets""" + """Cleanup all resources.""" + self.socket.close() self.context.destroy() - del self.get_data_socket - del self.generate_socket - del self.is_ready_socket - - async def get_data(self, message): - request_type = pickle.loads(message) - if request_type == GetDataRequest.MODEL_CONFIG: - data = await self.engine.get_model_config() - else: - raise ValueError(f"Unknown request type: {request_type}") - - await self.get_data_socket.send_multipart( - [pickle.dumps(data, pickle.HIGHEST_PROTOCOL)]) - - async def generate(self, identity, message): + async def _send_success_message(self, identity): + """Send message to client indicating an action was successful.""" + self.socket.send_multipart([ + identity, + pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), + ]) + + async def get_model_config(self, identity): + """Send the ModelConfig """ + model_config = await self.engine.get_model_config() + + self.socket.send_multipart([ + identity, + pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL) + ]) + + async def do_log_stats(self, identity): + await self.engine.do_log_stats() + + self.socket.send_multipart([ + identity, + pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), + ]) + + async def is_server_ready(self, identity): + self.socket.send_multipart([ + identity, + pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), + ]) + + async def abort(self, identity, request: RPCAbortRequest): + # Abort the request in the llm engine. + await self.engine.abort(request.request_id) + + # Send confirmation to the client. + self.socket.send_multipart([ + identity, + pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), + ]) + + async def generate(self, identity, generate_request: RPCGenerateRequest): try: - request = pickle.loads(message) - results_generator = self.engine.generate( - request.inputs, - sampling_params=request.sampling_params, - request_id=request.request_id) + generate_request.inputs, + sampling_params=generate_request.sampling_params, + request_id=generate_request.request_id) async for request_output in results_generator: - self.generate_socket.send_multipart([ + self.socket.send_multipart([ identity, pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) + except Exception as e: ### Notify client of all failures - self.generate_socket.send_multipart( + self.socket.send_multipart( [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) - async def run_loop(self): - # Notify the RPC client that we are ready to receive requests. - await self.is_ready_socket.send_string("Ready!") - self.is_ready_socket.close() + def _make_handler_coro(self, identity, + message) -> Coroutine[Any, Any, Never]: + """Route the zmq message to the handler coroutine.""" + + request = pickle.loads(message) + + if isinstance(request, RPCGenerateRequest): + return self.generate(identity, request) + + elif isinstance(request, RPCAbortRequest): + return self.abort(identity, request) + + elif isinstance(request, RPCUtilityRequest): + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + return self.get_model_config(identity) + elif request == RPCUtilityRequest.DO_LOG_STATS: + return self.do_log_stats(identity) + elif request == RPCUtilityRequest.IS_SERVER_READY: + return self.is_server_ready(identity) + else: + raise ValueError(f"Unknown RPCUtilityRequest type: {request}") + + else: + raise ValueError(f"Unknown RPCRequest type: {request}") + + async def run_server_loop(self): + """Inner RPC Server Loop""" - # Avoid GC of running tasks. running_tasks = set() while True: - self.poll_future = self.poller.poll() - socks = dict(await self.poll_future) - - task = None - if self.generate_socket in socks: - identity, message = await self.generate_socket.recv_multipart() - task = asyncio.create_task(self.generate(identity, message)) + # Wait for a request. + identity, message = await self.socket.recv_multipart() - elif self.get_data_socket in socks: - message = await self.get_data_socket.recv() - task = asyncio.create_task(self.get_data(message)) + # Process the request async. + task = asyncio.create_task( + self._make_handler_coro(identity, message)) # We need to keep around a strong reference to the task, # to avoid the task disappearing mid-execution as running tasks # can be GC'ed. Below is a common "fire-and-forget" tasks # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - if task is not None: - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - # TODO: Do I need to close the generate / get_data sockets? + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) async def run_server(server: RPCServer): - # Run with proper interrupt handling - logger.info("Booting up vLLM zmq backend") - + # Put the server task into the asyncio loop. loop = asyncio.get_running_loop() + server_task = loop.create_task(server.run_server_loop()) - server_task = loop.create_task(server.run_loop()) - + # Interruption handling. def signal_handler() -> None: # Kill the server on interrupt / terminate server_task.cancel() @@ -131,13 +159,13 @@ def signal_handler() -> None: try: await server_task except asyncio.CancelledError: - logger.info("ZMQ Backend was interrupted") + logger.info("vLLM ZMQ RPC Server was interrupted.") finally: - # Clean up all the zmq resources before exiting + # Clean up all resources. server.cleanup() - logger.info("vLLM ZMQ Backend shut down") -def run_rpc_server(async_engine_args): - server = RPCServer(async_engine_args=async_engine_args) +def run_rpc_server(async_engine_args: AsyncEngineArgs, + usage_context: UsageContext): + server = RPCServer(async_engine_args, usage_context) asyncio.run(run_server(server)) From 53629528f95fdd4356cffd8b291501f6a797f64b Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 31 Jul 2024 11:59:00 -0600 Subject: [PATCH 58/80] Use random port for backend (#390) Picks an open port to use and boots both the client and server with it --------- Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 13 ++++++++----- vllm/entrypoints/openai/rpc/__init__.py | 1 - vllm/entrypoints/openai/rpc/client.py | 11 ++++++----- vllm/entrypoints/openai/rpc/server.py | 20 +++++++++----------- vllm/envs.py | 6 ++++++ vllm/utils.py | 6 ++++-- 6 files changed, 33 insertions(+), 24 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a3b945c00e8c..772738351cda 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -45,7 +45,7 @@ OpenAIServingTokenization) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, get_open_port from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -107,15 +107,18 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: else: # remote backend ## First need to start the backend process + port = get_open_port(envs.VLLM_RPC_PORT) rpc_server_process = Process(target=run_rpc_server, - args=(engine_args, - UsageContext.OPENAI_API_SERVER)) + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + port)) rpc_server_process.start() ## Then build the client for the backend process # TODO: figure out a way around passing the tokenizer - backend = RPCClient( - tokenizer=AutoTokenizer.from_pretrained(args.model)) + backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained( + args.model), + port=port) await backend.wait_for_server() try: diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 6a403f48793f..0f05b59cb2e9 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -7,7 +7,6 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -VLLM_RPC_PATH = "tcp://localhost:5570" VLLM_RPC_SUCCESS_STR = "SUCCESS" diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index f430e417388a..1e8a98d6418f 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -5,7 +5,7 @@ import zmq.asyncio from vllm.config import DecodingConfig, ModelConfig -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, VLLM_RPC_PATH, +from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.inputs import PromptInputs @@ -18,13 +18,14 @@ class RPCClient: # TODO: check if opening all these sockets is an antipattern? - def __init__(self, tokenizer): + def __init__(self, tokenizer, port: int): # ZMQ context. self.context = zmq.asyncio.Context() # TODO: do the tokenizer properly. self.tokenizer = tokenizer self.decoding_config = DecodingConfig() + self.path = f"tcp://localhost:{port}" def close(self): """Destroy the ZeroMQ Context.""" @@ -36,7 +37,7 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, # Connect to socket. socket = self.context.socket(zmq.constants.DEALER) - socket.connect(VLLM_RPC_PATH) + socket.connect(self.path) # Ping RPC Server with request. socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) @@ -76,7 +77,7 @@ async def get_model_config(self) -> ModelConfig: # Connect to socket. socket = self.context.socket(zmq.constants.DEALER) - socket.connect(VLLM_RPC_PATH) + socket.connect(self.path) # Ping RPCServer with GET_MODEL_CONFIG request. socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) @@ -122,7 +123,7 @@ async def generate( # Note that we use DEALER to enable asynchronous communication # to enable streaming. socket = self.context.socket(zmq.constants.DEALER) - socket.connect(VLLM_RPC_PATH) + socket.connect(self.path) # Send RPCGenerateRequest to the RPCServer. socket.send_multipart([ diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 0284e5eb91c5..6385eaa1b226 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -8,9 +8,8 @@ from typing_extensions import Never from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_RPC_PATH, VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, RPCGenerateRequest, - RPCUtilityRequest) +from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -23,7 +22,7 @@ class RPCServer: # Alternative, use a smaller number of sockets with conditioning on the # data that is passed through the socket. def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext): + usage_context: UsageContext, port: int): # Initialize engine first. self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, usage_context) @@ -33,7 +32,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs, # Init socket for readiness state. self.socket = self.context.socket(zmq.constants.ROUTER) - self.socket.bind(VLLM_RPC_PATH) + self.socket.bind(f"tcp://localhost:{port}") def cleanup(self): """Cleanup all resources.""" @@ -51,10 +50,9 @@ async def get_model_config(self, identity): """Send the ModelConfig """ model_config = await self.engine.get_model_config() - self.socket.send_multipart([ - identity, - pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL) - ]) + self.socket.send_multipart( + [identity, + pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)]) async def do_log_stats(self, identity): await self.engine.do_log_stats() @@ -166,6 +164,6 @@ def signal_handler() -> None: def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext): - server = RPCServer(async_engine_args, usage_context) + usage_context: UsageContext, port: int): + server = RPCServer(async_engine_args, usage_context, port) asyncio.run(run_server(server)) diff --git a/vllm/envs.py b/vllm/envs.py index 595992e51db8..4670efbee0b8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: VLLM_HOST_IP: str = "" VLLM_PORT: Optional[int] = None + VLLM_RPC_PORT: int = 5570 VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_INSTANCE_ID: Optional[str] = None @@ -142,6 +143,11 @@ def get_default_config_root(): lambda: int(os.getenv('VLLM_PORT', '0')) if 'VLLM_PORT' in os.environ else None, + # used when the frontend api server is running in multi-processing mode, + # to communicate with the backend engine process over ZMQ. + 'VLLM_RPC_PORT': + lambda: int(os.getenv('VLLM_PORT', '5570')), + # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers "VLLM_USE_MODELSCOPE": diff --git a/vllm/utils.py b/vllm/utils.py index 9ce909d9d79c..59ebab1eb380 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -384,8 +384,10 @@ def get_distributed_init_method(ip: str, port: int) -> str: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" -def get_open_port() -> int: - port = envs.VLLM_PORT +def get_open_port(port: Optional[int] = None) -> int: + if port is None: + # Default behavior here is to return a port for multi-gpu communication + port = envs.VLLM_PORT if port is not None: while True: try: From 7214fb89a8edaf9f158347a6f0ebbef48dec90a1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 31 Jul 2024 11:26:58 -0700 Subject: [PATCH 59/80] Await socket operations + some other minor cleanup (#391) --- vllm/entrypoints/openai/cli_args.py | 2 +- vllm/entrypoints/openai/rpc/client.py | 8 ++++---- vllm/entrypoints/openai/rpc/server.py | 17 +++++++---------- vllm/utils.py | 7 ++++--- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index e637e20e16f5..1facedac72ca 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -138,7 +138,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--disable-frontend-multiprocessing", action="store_true", help="If specified, will run the OpenAI frontend server in the same " - "proecss as the model servinge engine.") + "process as the model serving engine.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 1e8a98d6418f..ea50338c1f2e 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -40,12 +40,12 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, socket.connect(self.path) # Ping RPC Server with request. - socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) + await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) # Await acknowledgement from RPCServer. response = pickle.loads(await socket.recv()) - if (not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR): + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: socket.close() raise ValueError(error_message) @@ -80,7 +80,7 @@ async def get_model_config(self) -> ModelConfig: socket.connect(self.path) # Ping RPCServer with GET_MODEL_CONFIG request. - socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) + await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) # Await the MODEL_CONFIG from the Server. model_config = pickle.loads(await socket.recv()) @@ -126,7 +126,7 @@ async def generate( socket.connect(self.path) # Send RPCGenerateRequest to the RPCServer. - socket.send_multipart([ + await socket.send_multipart([ pickle.dumps( RPCGenerateRequest( inputs=inputs, diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 6385eaa1b226..17439d1bef96 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -18,9 +18,6 @@ class RPCServer: - # TODO: check if opening all these sockets is an antipattern. - # Alternative, use a smaller number of sockets with conditioning on the - # data that is passed through the socket. def __init__(self, async_engine_args: AsyncEngineArgs, usage_context: UsageContext, port: int): # Initialize engine first. @@ -41,7 +38,7 @@ def cleanup(self): async def _send_success_message(self, identity): """Send message to client indicating an action was successful.""" - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) @@ -50,20 +47,20 @@ async def get_model_config(self, identity): """Send the ModelConfig """ model_config = await self.engine.get_model_config() - self.socket.send_multipart( + await self.socket.send_multipart( [identity, pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)]) async def do_log_stats(self, identity): await self.engine.do_log_stats() - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) async def is_server_ready(self, identity): - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) @@ -73,7 +70,7 @@ async def abort(self, identity, request: RPCAbortRequest): await self.engine.abort(request.request_id) # Send confirmation to the client. - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) @@ -86,14 +83,14 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): request_id=generate_request.request_id) async for request_output in results_generator: - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) except Exception as e: ### Notify client of all failures - self.socket.send_multipart( + await self.socket.send_multipart( [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) def _make_handler_coro(self, identity, diff --git a/vllm/utils.py b/vllm/utils.py index 59ebab1eb380..b18c3f3e81e6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -302,7 +302,7 @@ def merge_async_iterators( queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, Exception]] = asyncio.Queue() - finished = [False] * len(iterators) + producers = len(iterators) async def producer(i: int, iterator: AsyncIterator[T]): try: @@ -310,7 +310,6 @@ async def producer(i: int, iterator: AsyncIterator[T]): await queue.put((i, item)) except Exception as e: await queue.put(e) - finished[i] = True # Signal to the consumer that we've finished await queue.put(ProducerFinished()) @@ -320,13 +319,15 @@ async def producer(i: int, iterator: AsyncIterator[T]): ] async def consumer(): + remaining = producers try: - while not all(finished) or not queue.empty(): + while remaining or not queue.empty(): # we think there is a race condition here item = await queue.get() if isinstance(item, ProducerFinished): # Signal that a producer finished- not a real item + remaining -= 1 continue if isinstance(item, Exception): From 98a7dab9b570d80f94b66af5db1b5ca2274fe6db Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 31 Jul 2024 13:00:47 -0600 Subject: [PATCH 60/80] :sparkles: health check round 2 (#392) With all the extra fun refactors Signed-off-by: Joe Runde --- vllm/engine/protocol.py | 3 +++ vllm/entrypoints/openai/rpc/__init__.py | 2 ++ vllm/entrypoints/openai/rpc/client.py | 24 ++++++++++++++++++++++++ vllm/entrypoints/openai/rpc/server.py | 16 +++++++++++++++- 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 67b9e5cf5cc0..b8f8eea44573 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -79,3 +79,6 @@ async def do_log_stats( model_output: Optional[List[SamplerOutput]] = None, ) -> None: pass + + async def check_health(self) -> None: + """Raise if unhealthy""" diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 0f05b59cb2e9..7187bcdbe77b 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -8,6 +8,7 @@ from vllm.sampling_params import SamplingParams VLLM_RPC_SUCCESS_STR = "SUCCESS" +VLLM_RPC_HEALTHY_STR = "HEALTHY" @dataclass @@ -29,6 +30,7 @@ class RPCUtilityRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 DO_LOG_STATS = 3 + CHECK_HEALTH = 4 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index ea50338c1f2e..9bcdf6c48bbc 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -6,6 +6,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, + VLLM_RPC_HEALTHY_STR, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.inputs import PromptInputs @@ -153,3 +154,26 @@ async def generate( yield request_output socket.close() + + async def check_health(self) -> None: + """Raise if unhealthy""" + + # Connect to socket. + socket = self.context.socket(zmq.constants.DEALER) + socket.connect(self.path) + + # Ping RPCServer with CHECK_HEALTH request. + await socket.send(pickle.dumps(RPCUtilityRequest.CHECK_HEALTH)) + + # Await the reply from the server. + # TODO: do we need an internal timeout here? + # Or do we expect the external probe to timeout and let this chill? + health_message = pickle.loads(await socket.recv()) + socket.close() + + if isinstance(health_message, Exception): + raise health_message + + if health_message != VLLM_RPC_HEALTHY_STR: + raise ValueError("Expected healthy response from backend but got " + "f{health_message}") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 17439d1bef96..73ae2aae06ea 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -8,7 +8,8 @@ from typing_extensions import Never from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest, +from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -93,6 +94,17 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): await self.socket.send_multipart( [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) + async def check_health(self, identity): + try: + await self.engine.check_health() + await self.socket.send_multipart([ + identity, + pickle.dumps(VLLM_RPC_HEALTHY_STR, pickle.HIGHEST_PROTOCOL) + ]) + except Exception as e: + await self.socket.send_multipart( + [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) + def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" @@ -112,6 +124,8 @@ def _make_handler_coro(self, identity, return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: return self.is_server_ready(identity) + elif request == RPCUtilityRequest.CHECK_HEALTH: + return self.check_health(identity) else: raise ValueError(f"Unknown RPCUtilityRequest type: {request}") From f5f0b45f294745298a2bea810fa4809330fd7b34 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 31 Jul 2024 18:02:22 -0400 Subject: [PATCH 61/80] Add tokenizer (#394) SUMMARY: * add endpoints to request `ModelConfig`, `SchedulerConfig`, `LoRAConfig`, `ParallelConfig` * factor out tokenizer group creation function to be a utility function * create tokenizer_group on client side --- tests/entrypoints/openai/test_completion.py | 1 + vllm/engine/async_llm_engine.py | 27 ++++- vllm/engine/llm_engine.py | 35 +++--- vllm/entrypoints/openai/api_server.py | 8 +- vllm/entrypoints/openai/rpc/__init__.py | 8 +- vllm/entrypoints/openai/rpc/client.py | 114 +++++++++++++----- vllm/entrypoints/openai/rpc/server.py | 56 +++++++-- .../tokenizer_group/__init__.py | 19 ++- 8 files changed, 207 insertions(+), 61 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index fe00640c0021..521a450f1356 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -119,6 +119,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, choice = completion.choices[0] assert len(choice.text) >= 5 assert choice.finish_reason == "length" + print(completion.usage) assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6 + num_virtual_tokens, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 16b7bc64a284..0584d8eb6f32 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,8 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -924,6 +925,14 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_parallel_config.remote( # type: ignore + ) + else: + return self.engine.get_parallel_config() + async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the vLLM engine.""" if self.engine_use_ray: @@ -932,6 +941,22 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_scheduler_config.remote( # type: ignore + ) + else: + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_lora_config.remote( # type: ignore + ) + else: + return self.engine.get_lora_config() + async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 48d530589221..627f028b99d7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,8 +40,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - get_tokenizer_group) +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -481,19 +481,12 @@ def get_tokenizer_for_seq(self, return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - init_kwargs.update(tokenizer_init_kwargs) - - return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, - **init_kwargs) + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -755,10 +748,22 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + def get_decoding_config(self) -> DecodingConfig: """Gets the decoding configuration.""" return self.decoding_config + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return sum(scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 772738351cda..104f70f1386a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -16,7 +16,6 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from starlette.routing import Mount -from transformers import AutoTokenizer import vllm.envs as envs from vllm.config import ModelConfig @@ -115,11 +114,8 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: rpc_server_process.start() ## Then build the client for the backend process - # TODO: figure out a way around passing the tokenizer - backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained( - args.model), - port=port) - await backend.wait_for_server() + backend = RPCClient(port) + await backend.setup() try: yield backend diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 7187bcdbe77b..0c055b76fe2a 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -29,8 +29,12 @@ class RPCAbortRequest: class RPCUtilityRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 - DO_LOG_STATS = 3 - CHECK_HEALTH = 4 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + CHECK_HEALTH = 8 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 9bcdf6c48bbc..f69e7c24b449 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,10 +1,11 @@ import pickle -from typing import AsyncIterator, Mapping, Optional +from typing import Any, AsyncIterator, Mapping, Optional import zmq import zmq.asyncio -from vllm.config import DecodingConfig, ModelConfig +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, VLLM_RPC_HEALTHY_STR, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, @@ -14,24 +15,64 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs class RPCClient: - # TODO: check if opening all these sockets is an antipattern? - def __init__(self, tokenizer, port: int): - # ZMQ context. + def __init__(self, port: int): self.context = zmq.asyncio.Context() - - # TODO: do the tokenizer properly. - self.tokenizer = tokenizer - self.decoding_config = DecodingConfig() self.path = f"tcp://localhost:{port}" + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self.wait_for_server() + + # Get the configs. + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), + ) + def close(self): """Destroy the ZeroMQ Context.""" self.context.destroy() + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + # Connect to socket. + socket = self.context.socket(zmq.constants.DEALER) + socket.connect(self.path) + + # Ping RPCServer with a request. + await socket.send(pickle.dumps(request)) + + # Await the data from the Server. + data = pickle.loads(await socket.recv()) + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + else: + socket.close() + raise ValueError(error_message) + + socket.close() + + return data + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, error_message: str): """Send one-way RPC request to trigger an action.""" @@ -55,13 +96,14 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, return response async def get_tokenizer(self, lora_request: LoRARequest): - # TODO: handle this via get data? - or avoid doing via RPC - return self.tokenizer + return await self.tokenizer.get_lora_tokenizer_async(lora_request) async def get_decoding_config(self): - # TODO: handle this via get data? - or avoid doing via RPC return self.decoding_config + async def get_model_config(self): + return self.model_config + async def is_tracing_enabled(self): # TODO: what is this? return False @@ -73,30 +115,48 @@ async def wait_for_server(self): request=RPCUtilityRequest.IS_SERVER_READY, error_message="Unable to start RPC Server.") - async def get_model_config(self) -> ModelConfig: + async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" - # Connect to socket. - socket = self.context.socket(zmq.constants.DEALER) - socket.connect(self.path) + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server") - # Ping RPCServer with GET_MODEL_CONFIG request. - await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) + async def _get_decoding_config_rpc(self) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" - # Await the MODEL_CONFIG from the Server. - model_config = pickle.loads(await socket.recv()) + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") - if not isinstance(model_config, ModelConfig): - socket.close() - raise ValueError("Expected ModelConfig object from RPC, but " - f"got {model_config}") + async def _get_parallel_config_rpc(self) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" - socket.close() + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server") + + async def _get_lora_config_rpc(self): + """Get LoRAConfig from the RPCServer""" - return model_config + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server") async def abort(self, request_id: str): - """Send an RPCAbortRequest to the RPC Server""" + """Send an ABORT_REQUEST signal to the RPC Server""" await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 73ae2aae06ea..ca57295c6996 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -37,22 +37,47 @@ def cleanup(self): self.socket.close() self.context.destroy() - async def _send_success_message(self, identity): - """Send message to client indicating an action was successful.""" - await self.socket.send_multipart([ - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), - ]) - async def get_model_config(self, identity): - """Send the ModelConfig """ + """Send the ModelConfig""" model_config = await self.engine.get_model_config() await self.socket.send_multipart( [identity, pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)]) + async def get_decoding_config(self, identity): + """Send the DecodingConfig""" + decoding_config = await self.engine.get_decoding_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(decoding_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_lora_config(self, identity): + lora_config = await self.engine.get_lora_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(lora_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_scheduler_config(self, identity): + """Send the SchedulerConfig""" + parallel_config = await self.engine.get_scheduler_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_parallel_config(self, identity): + """Send the ParallelConfig""" + parallel_config = await self.engine.get_parallel_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + async def do_log_stats(self, identity): + """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart([ @@ -61,12 +86,14 @@ async def do_log_stats(self, identity): ]) async def is_server_ready(self, identity): + """Notify the client that we are ready.""" await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) async def abort(self, identity, request: RPCAbortRequest): + """Abort request and notify the client of success.""" # Abort the request in the llm engine. await self.engine.abort(request.request_id) @@ -81,7 +108,10 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): results_generator = self.engine.generate( generate_request.inputs, sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id) + request_id=generate_request.request_id, + lora_request=generate_request.lora_request, + trace_headers=generate_request.trace_headers, + prompt_adapter_request=generate_request.prompt_adapter_request) async for request_output in results_generator: await self.socket.send_multipart([ @@ -120,6 +150,14 @@ def _make_handler_coro(self, identity, elif isinstance(request, RPCUtilityRequest): if request == RPCUtilityRequest.GET_MODEL_CONFIG: return self.get_model_config(identity) + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.get_parallel_config(identity) + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.get_scheduler_config(identity) + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.get_lora_config(identity) elif request == RPCUtilityRequest.DO_LOG_STATS: return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9f54f5409b18..ae17ccf056b9 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,6 +1,7 @@ from typing import Optional, Type -from vllm.config import TokenizerPoolConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + TokenizerPoolConfig) from vllm.executor.ray_utils import ray from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -14,6 +15,22 @@ RayTokenizerGroupPool = None # type: ignore +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: tokenizer_cls: Type[BaseTokenizerGroup] From 0b351c00d13961bd3c95c4faf31118d45fade7eb Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 31 Jul 2024 16:21:04 -0600 Subject: [PATCH 62/80] Socket context (#393) Ensures no sockets are leaked on the client-side Also postpones the server shutdown await so that the backend can shutdown concurrently, and all connections can be cleaned up at the same time. This prevents hangs where the frontend blocks on remaining connections but the backend has not yet initiated shutdown --------- Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 7 +- vllm/entrypoints/openai/rpc/client.py | 125 +++++++++--------- .../tokenizer_group/__init__.py | 6 +- 3 files changed, 69 insertions(+), 69 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 104f70f1386a..5c8e5c4d76f9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -360,6 +360,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + shutdown_task = None async with build_backend(args) as backend: server = await build_server( @@ -383,7 +384,11 @@ def signal_handler() -> None: await server_task except asyncio.CancelledError: logger.info("Gracefully stopping http server") - await server.shutdown() + shutdown_task = server.shutdown() + + if shutdown_task: + # NB: Await server shutdown only after the backend context is exited + await shutdown_task if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index f69e7c24b449..dd2bccac0e4a 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,4 +1,5 @@ import pickle +from contextlib import contextmanager from typing import Any, AsyncIterator, Mapping, Optional import zmq @@ -47,52 +48,55 @@ def close(self): """Destroy the ZeroMQ Context.""" self.context.destroy() + @contextmanager + def socket(self): + # Ensure client sockets are always closed after use + + # Connect to RPC socket for Request-Reply pattern, + # Note that we use DEALER to enable asynchronous communication + # to enable streaming. + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.path) + yield socket + finally: + socket.close() + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, expected_type: Any, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" - # Connect to socket. - socket = self.context.socket(zmq.constants.DEALER) - socket.connect(self.path) + with self.socket() as socket: + + # Ping RPCServer with a request. + await socket.send(pickle.dumps(request)) - # Ping RPCServer with a request. - await socket.send(pickle.dumps(request)) + # Await the data from the Server. + data = pickle.loads(await socket.recv()) - # Await the data from the Server. - data = pickle.loads(await socket.recv()) if not isinstance(data, expected_type): # LoRAConfig can be None. if expected_type == LoRAConfig and data is None: pass else: - socket.close() raise ValueError(error_message) - socket.close() - return data async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, error_message: str): """Send one-way RPC request to trigger an action.""" + with self.socket() as socket: + # Ping RPC Server with request. + await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) - # Connect to socket. - socket = self.context.socket(zmq.constants.DEALER) - socket.connect(self.path) - - # Ping RPC Server with request. - await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) - - # Await acknowledgement from RPCServer. - response = pickle.loads(await socket.recv()) + # Await acknowledgement from RPCServer. + response = pickle.loads(await socket.recv()) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - socket.close() raise ValueError(error_message) - socket.close() - return response async def get_tokenizer(self, lora_request: LoRARequest): @@ -180,56 +184,47 @@ async def generate( ) -> AsyncIterator[RequestOutput]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - # Connect to RPC socket for Request-Reply pattern, - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.connect(self.path) - - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart([ - pickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request), - pickle.HIGHEST_PROTOCOL) - ]) - - # Stream back the results from the RPC Server. - while True: - message = await socket.recv() - request_output = pickle.loads(message) - - if isinstance(request_output, Exception): - socket.close() - raise request_output - - if request_output.finished: - break - yield request_output + with self.socket() as socket: + + # Send RPCGenerateRequest to the RPCServer. + await socket.send_multipart([ + pickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request), + pickle.HIGHEST_PROTOCOL) + ]) + + # Stream back the results from the RPC Server. + while True: + message = await socket.recv() + request_output = pickle.loads(message) + + if isinstance(request_output, Exception): + raise request_output + + if request_output.finished: + break + yield request_output - yield request_output - socket.close() + yield request_output async def check_health(self) -> None: """Raise if unhealthy""" - # Connect to socket. - socket = self.context.socket(zmq.constants.DEALER) - socket.connect(self.path) + with self.socket() as socket: - # Ping RPCServer with CHECK_HEALTH request. - await socket.send(pickle.dumps(RPCUtilityRequest.CHECK_HEALTH)) + # Ping RPCServer with CHECK_HEALTH request. + await socket.send(pickle.dumps(RPCUtilityRequest.CHECK_HEALTH)) - # Await the reply from the server. - # TODO: do we need an internal timeout here? - # Or do we expect the external probe to timeout and let this chill? - health_message = pickle.loads(await socket.recv()) - socket.close() + # Await the reply from the server. + # TODO: do we need an internal timeout here? + # Or do we expect the external probe to timeout and let this chill? + health_message = pickle.loads(await socket.recv()) if isinstance(health_message, Exception): raise health_message diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index ae17ccf056b9..eea9e42ea4e4 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -16,9 +16,9 @@ def init_tokenizer_from_configs(model_config: ModelConfig, - scheduler_config: SchedulerConfig, - parallel_config: ParallelConfig, - enable_lora: bool): + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): init_kwargs = dict(tokenizer_id=model_config.tokenizer, enable_lora=enable_lora, max_num_seqs=scheduler_config.max_num_seqs, From 79fcc4459f0da5cd3b6bbb59b8136fa80951f735 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 31 Jul 2024 18:57:51 -0400 Subject: [PATCH 63/80] Logit bias (#395) SUMMARY: * fix issue with logit bias loading --- vllm/entrypoints/openai/protocol.py | 24 ++++++++---------------- vllm/entrypoints/openai/utils.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 16 deletions(-) create mode 100644 vllm/entrypoints/openai/utils.py diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c024bbc07c06..1aa4e5554344 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from functools import partial from typing import Any, Dict, List, Literal, Optional, Union import torch @@ -8,6 +9,7 @@ from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.openai.utils import logit_bias_logits_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -229,14 +231,9 @@ def to_sampling_params(self) -> SamplingParams: f"but token_id must be an integer or string " f"representing an integer") from exc - def logit_bias_logits_processor( - token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors = [logit_bias_logits_processor] + logits_processors = [ + partial(logit_bias_logits_processor, logit_bias) + ] return SamplingParams( n=self.n, @@ -423,14 +420,9 @@ def to_sampling_params(self): f"but token_id must be an integer or string " f"representing an integer") from exc - def logit_bias_logits_processor( - token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors = [logit_bias_logits_processor] + logits_processors = [ + partial(logit_bias_logits_processor, logit_bias) + ] return SamplingParams( n=self.n, diff --git a/vllm/entrypoints/openai/utils.py b/vllm/entrypoints/openai/utils.py new file mode 100644 index 000000000000..08bc4c36a43e --- /dev/null +++ b/vllm/entrypoints/openai/utils.py @@ -0,0 +1,11 @@ +from typing import Dict, List + +import torch + + +def logit_bias_logits_processor(logit_bias: Dict[str, + float], token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits From 4c65f747b8cfad788d9e061ddba5e242b8818832 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 31 Jul 2024 17:17:58 -0600 Subject: [PATCH 64/80] :bug: messed up the revert in the merge commit :( Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 40 +++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0b8d86c51e15..5c8e5c4d76f9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,20 +2,20 @@ import importlib import inspect import re -from argparse import Namespace +import signal from contextlib import asynccontextmanager from http import HTTPStatus from multiprocessing import Process -import signal from typing import AsyncIterator, Set -from fastapi import APIRouter, FastAPI, Request +import fastapi +import uvicorn +from fastapi import APIRouter, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from starlette.routing import Mount -import uvicorn import vllm.envs as envs from vllm.config import ModelConfig @@ -43,7 +43,6 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger -from vllm.server import serve_http from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_port from vllm.version import __version__ as VLLM_VERSION @@ -63,7 +62,7 @@ @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: @@ -135,7 +134,7 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: router = APIRouter() -def mount_metrics(app: FastAPI): +def mount_metrics(app: fastapi.FastAPI): # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics @@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) +def build_app(args): + app = fastapi.FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -334,7 +333,27 @@ async def build_server( ) app.root_path = args.root_path - return app + logger.info("Available routes are:") + for route in app.routes: + if not hasattr(route, 'methods'): + continue + methods = ', '.join(route.methods) + logger.info("Route: %s, Methods: %s", route.path, methods) + + config = uvicorn.Config( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + return uvicorn.Server(config) async def run_server(args, **uvicorn_kwargs) -> None: @@ -379,5 +398,4 @@ def signal_handler() -> None: description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() - asyncio.run(run_server(args)) From 9bc97f1a73ce014fb3f76e9258b2d93b33344925 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:39:44 -0400 Subject: [PATCH 65/80] fix (#396) SUMMARY: * passed clamped --- vllm/entrypoints/openai/logits_processors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 5ba6589fd477..cad750d3ad46 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -72,7 +72,7 @@ def get_logits_processors( logits_processors.append(partial(logit_bias_logits_processor, - logit_bias)) + clamped_logit_bias)) if allowed_token_ids is not None: logits_processors.append( From 4337fe72d70abf96b388a707029d9d0004008da6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 02:07:03 +0000 Subject: [PATCH 66/80] format --- vllm/entrypoints/openai/logits_processors.py | 7 ++++--- vllm/entrypoints/openai/protocol.py | 1 - .../tokenizer_group/ray_tokenizer_group.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index cad750d3ad46..84871fc83ef5 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -39,6 +39,7 @@ def _get_allowed_token_ids_logits_processor( "out-of-vocab token id") return AllowedTokenIdsLogitsProcessor(allowed_token_ids) + def logit_bias_logits_processor(logit_bias: Dict[str, float], token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: @@ -46,6 +47,7 @@ def logit_bias_logits_processor(logit_bias: Dict[str, logits[token_id] += bias return logits + def get_logits_processors( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], allowed_token_ids: Optional[List[int]], @@ -70,9 +72,8 @@ def get_logits_processors( raise ValueError("token_id in logit_bias contains " "out-of-vocab token id") - - logits_processors.append(partial(logit_bias_logits_processor, - clamped_logit_bias)) + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) if allowed_token_ids is not None: logits_processors.append( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 88f511d5db00..205860aa8e72 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,7 +1,6 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from functools import partial from typing import Any, Dict, List, Literal, Optional, Union import torch diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index eebdf7bf644d..79081c04ddc1 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -3,7 +3,7 @@ from typing import List, Optional try: - from ray.exceptions import ActorDiedError + from ray.exceptions import ActorDiedError # type: ignore except ImportError: # For older versions of Ray from ray.exceptions import RayActorError as ActorDiedError # type: ignore From 779d9bd841ea73bd4df13a2091de81899f75aba8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 02:54:27 +0000 Subject: [PATCH 67/80] stash --- vllm/entrypoints/openai/rpc/client.py | 21 ++++++++++---------- vllm/entrypoints/openai/rpc/server.py | 28 +++++++++++++-------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index dd2bccac0e4a..54058ceba045 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,4 +1,4 @@ -import pickle +import cloudpickle from contextlib import contextmanager from typing import Any, AsyncIterator, Mapping, Optional @@ -70,10 +70,10 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, with self.socket() as socket: # Ping RPCServer with a request. - await socket.send(pickle.dumps(request)) + await socket.send(cloudpickle.dumps(request)) # Await the data from the Server. - data = pickle.loads(await socket.recv()) + data = cloudpickle.loads(await socket.recv()) if not isinstance(data, expected_type): # LoRAConfig can be None. @@ -89,10 +89,10 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, """Send one-way RPC request to trigger an action.""" with self.socket() as socket: # Ping RPC Server with request. - await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) + await socket.send(cloudpickle.dumps(request)) # Await acknowledgement from RPCServer. - response = pickle.loads(await socket.recv()) + response = cloudpickle.loads(await socket.recv()) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: raise ValueError(error_message) @@ -188,21 +188,20 @@ async def generate( # Send RPCGenerateRequest to the RPCServer. await socket.send_multipart([ - pickle.dumps( + cloudpickle.dumps( RPCGenerateRequest( inputs=inputs, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request), - pickle.HIGHEST_PROTOCOL) + prompt_adapter_request=prompt_adapter_request)) ]) # Stream back the results from the RPC Server. while True: message = await socket.recv() - request_output = pickle.loads(message) + request_output = cloudpickle.loads(message) if isinstance(request_output, Exception): raise request_output @@ -219,12 +218,12 @@ async def check_health(self) -> None: with self.socket() as socket: # Ping RPCServer with CHECK_HEALTH request. - await socket.send(pickle.dumps(RPCUtilityRequest.CHECK_HEALTH)) + await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)) # Await the reply from the server. # TODO: do we need an internal timeout here? # Or do we expect the external probe to timeout and let this chill? - health_message = pickle.loads(await socket.recv()) + health_message = cloudpickle.loads(await socket.recv()) if isinstance(health_message, Exception): raise health_message diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index ca57295c6996..37a70877bb41 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,5 +1,5 @@ import asyncio -import pickle +import cloudpickle import signal from typing import Any, Coroutine @@ -43,7 +43,7 @@ async def get_model_config(self, identity): await self.socket.send_multipart( [identity, - pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)]) + cloudpickle.dumps(model_config)]) async def get_decoding_config(self, identity): """Send the DecodingConfig""" @@ -51,14 +51,14 @@ async def get_decoding_config(self, identity): await self.socket.send_multipart( [identity, - pickle.dumps(decoding_config, pickle.HIGHEST_PROTOCOL)]) + cloudpickle.dumps(decoding_config)]) async def get_lora_config(self, identity): lora_config = await self.engine.get_lora_config() await self.socket.send_multipart( [identity, - pickle.dumps(lora_config, pickle.HIGHEST_PROTOCOL)]) + cloudpickle.dumps(lora_config)]) async def get_scheduler_config(self, identity): """Send the SchedulerConfig""" @@ -66,7 +66,7 @@ async def get_scheduler_config(self, identity): await self.socket.send_multipart( [identity, - pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + cloudpickle.dumps(parallel_config)]) async def get_parallel_config(self, identity): """Send the ParallelConfig""" @@ -74,7 +74,7 @@ async def get_parallel_config(self, identity): await self.socket.send_multipart( [identity, - pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + cloudpickle.dumps(parallel_config)]) async def do_log_stats(self, identity): """Log stats and confirm success.""" @@ -82,14 +82,14 @@ async def do_log_stats(self, identity): await self.socket.send_multipart([ identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), ]) async def is_server_ready(self, identity): """Notify the client that we are ready.""" await self.socket.send_multipart([ identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), ]) async def abort(self, identity, request: RPCAbortRequest): @@ -100,7 +100,7 @@ async def abort(self, identity, request: RPCAbortRequest): # Send confirmation to the client. await self.socket.send_multipart([ identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), ]) async def generate(self, identity, generate_request: RPCGenerateRequest): @@ -116,30 +116,30 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): async for request_output in results_generator: await self.socket.send_multipart([ identity, - pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) + cloudpickle.dumps(request_output) ]) except Exception as e: ### Notify client of all failures await self.socket.send_multipart( - [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) + [identity, cloudpickle.dumps(e)]) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart([ identity, - pickle.dumps(VLLM_RPC_HEALTHY_STR, pickle.HIGHEST_PROTOCOL) + cloudpickle.dumps(VLLM_RPC_HEALTHY_STR) ]) except Exception as e: await self.socket.send_multipart( - [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) + [identity, cloudpickle.dumps(e)]) def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" - request = pickle.loads(message) + request = cloudpickle.loads(message) if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) From a6044a3552db8edf514f5e99601b9349b00a5e37 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 1 Aug 2024 12:07:19 -0400 Subject: [PATCH 68/80] Fix failed tests (#398) SUMMARY: * hack --- .../outlines_logits_processors.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 1c8f6cccb3e9..de6bf04a16a7 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -26,7 +26,8 @@ from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase - +from lark import Lark +from outlines import grammars class BaseLogitsProcessor: @@ -44,6 +45,23 @@ def __call__(self, input_ids: List[int], last_seq_id = hash(tuple(input_ids[:-1])) self._fsm_state[seq_id] = self._guide.get_next_state( state=self._fsm_state[last_seq_id], token_id=last_token) + else: + # Note: this is a hack. + # Lark pickling does not work properly (silent failure), + # which breaks the RPC (which uses python pickleing). + # We need to find a better solution. + # On the first time this is called, we simply re-create + # the Lark object. + if isinstance(self._guide, CFGGuide): + self._guide.parser = Lark( + self._guide.cfg_string, + parser="lalr", + lexer="contextual", + propagate_positions=False, + maybe_placeholders=False, + regex=True, + import_paths=[grammars.GRAMMAR_PATH], + ) instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) From 0fc8545d5b7d9e4b70dff126cd801a5c426d6731 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 16:12:31 +0000 Subject: [PATCH 69/80] fixed merge conflicts --- vllm/entrypoints/openai/serving_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0ab12362f90c..108f20c29290 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -155,7 +155,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.engine.get_decoding_config() + decoding_config = await self.vllm_backend.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( From 63830910ff94f432ee6ac113b89c95447e994b53 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 16:14:12 +0000 Subject: [PATCH 70/80] updated --- vllm/entrypoints/openai/rpc/client.py | 5 +-- vllm/entrypoints/openai/rpc/server.py | 35 +++++++------------ .../outlines_logits_processors.py | 7 ++-- 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 54058ceba045..08b466086851 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,7 +1,7 @@ -import cloudpickle from contextlib import contextmanager from typing import Any, AsyncIterator, Mapping, Optional +import cloudpickle import zmq import zmq.asyncio @@ -218,7 +218,8 @@ async def check_health(self) -> None: with self.socket() as socket: # Ping RPCServer with CHECK_HEALTH request. - await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)) + await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) + ) # Await the reply from the server. # TODO: do we need an internal timeout here? diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 37a70877bb41..c26cca7099b4 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,8 +1,8 @@ import asyncio -import cloudpickle import signal from typing import Any, Coroutine +import cloudpickle import zmq import zmq.asyncio from typing_extensions import Never @@ -42,39 +42,34 @@ async def get_model_config(self, identity): model_config = await self.engine.get_model_config() await self.socket.send_multipart( - [identity, - cloudpickle.dumps(model_config)]) + [identity, cloudpickle.dumps(model_config)]) async def get_decoding_config(self, identity): """Send the DecodingConfig""" decoding_config = await self.engine.get_decoding_config() await self.socket.send_multipart( - [identity, - cloudpickle.dumps(decoding_config)]) + [identity, cloudpickle.dumps(decoding_config)]) async def get_lora_config(self, identity): lora_config = await self.engine.get_lora_config() await self.socket.send_multipart( - [identity, - cloudpickle.dumps(lora_config)]) + [identity, cloudpickle.dumps(lora_config)]) async def get_scheduler_config(self, identity): """Send the SchedulerConfig""" parallel_config = await self.engine.get_scheduler_config() await self.socket.send_multipart( - [identity, - cloudpickle.dumps(parallel_config)]) + [identity, cloudpickle.dumps(parallel_config)]) async def get_parallel_config(self, identity): """Send the ParallelConfig""" parallel_config = await self.engine.get_parallel_config() await self.socket.send_multipart( - [identity, - cloudpickle.dumps(parallel_config)]) + [identity, cloudpickle.dumps(parallel_config)]) async def do_log_stats(self, identity): """Log stats and confirm success.""" @@ -114,26 +109,20 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): prompt_adapter_request=generate_request.prompt_adapter_request) async for request_output in results_generator: - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(request_output) - ]) + await self.socket.send_multipart( + [identity, cloudpickle.dumps(request_output)]) except Exception as e: ### Notify client of all failures - await self.socket.send_multipart( - [identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) async def check_health(self, identity): try: await self.engine.check_health() - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_HEALTHY_STR) - ]) - except Exception as e: await self.socket.send_multipart( - [identity, cloudpickle.dumps(e)]) + [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index de6bf04a16a7..554dcc0ed43e 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,13 +21,14 @@ from typing import Callable, DefaultDict, Dict, List, Union import torch +from lark import Lark +from outlines import grammars from outlines.caching import cache from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from lark import Lark -from outlines import grammars + class BaseLogitsProcessor: @@ -46,7 +47,7 @@ def __call__(self, input_ids: List[int], self._fsm_state[seq_id] = self._guide.get_next_state( state=self._fsm_state[last_seq_id], token_id=last_token) else: - # Note: this is a hack. + # Note: this is a hack. # Lark pickling does not work properly (silent failure), # which breaks the RPC (which uses python pickleing). # We need to find a better solution. From a09f57fda43ab25289506bed405db965c92a9c6e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 16:28:40 +0000 Subject: [PATCH 71/80] cleaning --- examples/openai_completion_client.py | 7 ++++--- tests/entrypoints/openai/test_completion.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index cf932e67f9a4..58519f978d34 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -14,13 +14,14 @@ model = models.data[0].id # Completion API -stream = True +stream = False completion = client.completions.create( model=model, prompt="A robot may not injure a human being", echo=False, - n=1, - stream=stream) + n=2, + stream=stream, + logprobs=3) print("Completion results:") if stream: diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index eb111afd0d67..50add84087a9 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -119,7 +119,6 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, choice = completion.choices[0] assert len(choice.text) >= 5 assert choice.finish_reason == "length" - print(completion.usage) assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6 + num_virtual_tokens, From 1bdbfcb013349f72213c3f6e049fefcd1cc29423 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 1 Aug 2024 11:56:43 -0600 Subject: [PATCH 72/80] :white_check_mark: add test for multiprocessing flag (#399) Signed-off-by: Joe Runde --- tests/entrypoints/openai/test_basic.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 2c721d9ba760..b3bbd10dfe87 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,9 +1,11 @@ +import socket from http import HTTPStatus import openai import pytest import requests +from vllm import envs from vllm.version import __version__ as VLLM_VERSION from ...utils import RemoteOpenAIServer @@ -59,3 +61,20 @@ async def test_log_metrics(client: openai.AsyncOpenAI): response = requests.get(base_url + "/metrics") assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_fronted_multiprocessing_flag(): + # Build server without the flag to disable multiprocessing + with RemoteOpenAIServer("facebook/opt-125m", []), \ + socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s, \ + pytest.raises(OSError, match="Address already in use"): + # Ensure we see the backend port in use + s.bind(("localhost", envs.VLLM_RPC_PORT)) + + # Build server with the flag to disable multiprocessing + with RemoteOpenAIServer("facebook/opt-125m", + ["--disable-frontend-multiprocessing"]), \ + socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # Ensure the backend port is free -> no multiprocessing is happening + s.bind(("localhost", envs.VLLM_RPC_PORT)) From f3c0f1c7622c8e63cad1449b7cf68e38578244fc Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 1 Aug 2024 13:30:12 -0600 Subject: [PATCH 73/80] :sparkles: pipe tracing flag (#400) (plus rounding out the protocol with an error on `.encode`) --------- Signed-off-by: Joe Runde --- vllm/entrypoints/openai/rpc/__init__.py | 1 + vllm/entrypoints/openai/rpc/client.py | 28 ++++++++++++++++++------- vllm/entrypoints/openai/rpc/server.py | 9 ++++++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 0c055b76fe2a..8a7b12201cab 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -35,6 +35,7 @@ class RPCUtilityRequest(Enum): GET_LORA_CONFIG = 6 DO_LOG_STATS = 7 CHECK_HEALTH = 8 + IS_TRACING_ENABLED = 9 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 08b466086851..0f609ba83424 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -13,7 +13,7 @@ RPCGenerateRequest, RPCUtilityRequest) from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -34,6 +34,7 @@ async def setup(self): # Get the configs. self.model_config = await self._get_model_config_rpc() self.decoding_config = await self._get_decoding_config_rpc() + self.tracing_flag = await self._is_tracing_enabled_rpc() # Create the tokenizer group. # TODO: refactor OAI server to avoid needing this info. @@ -102,15 +103,14 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, async def get_tokenizer(self, lora_request: LoRARequest): return await self.tokenizer.get_lora_tokenizer_async(lora_request) - async def get_decoding_config(self): + async def get_decoding_config(self) -> DecodingConfig: return self.decoding_config - async def get_model_config(self): + async def get_model_config(self) -> ModelConfig: return self.model_config - async def is_tracing_enabled(self): - # TODO: what is this? - return False + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag async def wait_for_server(self): """Wait for the RPCServer to start up.""" @@ -141,7 +141,7 @@ async def _get_parallel_config_rpc(self) -> ParallelConfig: return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_PARALLEL_CONFIG, expected_type=ParallelConfig, - error_message="Could not get ModelConfig from RPC Server") + error_message="Could not get ParallelConfig from RPC Server") async def _get_scheduler_config_rpc(self) -> SchedulerConfig: """Get SchedulerConfig from the RPCServer""" @@ -159,6 +159,15 @@ async def _get_lora_config_rpc(self): expected_type=LoRAConfig, error_message="Could not get LoRAConfig from RPC Server") + async def _is_tracing_enabled_rpc(self) -> ParallelConfig: + """Get is_tracing_enabled flag from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.IS_TRACING_ENABLED, + expected_type=bool, + error_message="Could not get is_tracing_enabled flag from RPC " + "Server") + async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" @@ -232,3 +241,8 @@ async def check_health(self) -> None: if health_message != VLLM_RPC_HEALTHY_STR: raise ValueError("Expected healthy response from backend but got " "f{health_message}") + + async def encode(self, *args, + **kwargs) -> AsyncIterator[EmbeddingRequestOutput]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index c26cca7099b4..f5f785ca8166 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -71,6 +71,13 @@ async def get_parallel_config(self, identity): await self.socket.send_multipart( [identity, cloudpickle.dumps(parallel_config)]) + async def is_tracing_enabled(self, identity): + """Send the is_tracing_enabled flag""" + tracing_flag = await self.engine.is_tracing_enabled() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(tracing_flag)]) + async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() @@ -153,6 +160,8 @@ def _make_handler_coro(self, identity, return self.is_server_ready(identity) elif request == RPCUtilityRequest.CHECK_HEALTH: return self.check_health(identity) + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + return self.is_tracing_enabled(identity) else: raise ValueError(f"Unknown RPCUtilityRequest type: {request}") From 9c415ad2a99bb17a1d3f0581d060650d104e2d97 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 20:31:26 +0000 Subject: [PATCH 74/80] integration tests for old backend --- tests/entrypoints/openai/test_basic.py | 17 - tests/entrypoints/openai/test_disable_mp.py | 715 ++++++++++++++++++++ 2 files changed, 715 insertions(+), 17 deletions(-) create mode 100644 tests/entrypoints/openai/test_disable_mp.py diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index b3bbd10dfe87..4c2d2a1190d4 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -61,20 +61,3 @@ async def test_log_metrics(client: openai.AsyncOpenAI): response = requests.get(base_url + "/metrics") assert response.status_code == HTTPStatus.OK - - -@pytest.mark.asyncio -async def test_fronted_multiprocessing_flag(): - # Build server without the flag to disable multiprocessing - with RemoteOpenAIServer("facebook/opt-125m", []), \ - socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s, \ - pytest.raises(OSError, match="Address already in use"): - # Ensure we see the backend port in use - s.bind(("localhost", envs.VLLM_RPC_PORT)) - - # Build server with the flag to disable multiprocessing - with RemoteOpenAIServer("facebook/opt-125m", - ["--disable-frontend-multiprocessing"]), \ - socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - # Ensure the backend port is free -> no multiprocessing is happening - s.bind(("localhost", envs.VLLM_RPC_PORT)) diff --git a/tests/entrypoints/openai/test_disable_mp.py b/tests/entrypoints/openai/test_disable_mp.py new file mode 100644 index 000000000000..12c805413311 --- /dev/null +++ b/tests/entrypoints/openai/test_disable_mp.py @@ -0,0 +1,715 @@ +""" +Repeat of tests in test_completion.py with the non-mp backend. +""" + +# imports for guided decoding tests +import json +import re +import shutil +from tempfile import TemporaryDirectory +from typing import List + +import jsonschema +import openai # use the official client for correctness check +import pytest +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoTokenizer + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically these adapters use a different base model, +# but we're not testing generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" +# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also +# need to change to match the prompt adapter +PA_NUM_VIRTUAL_TOKENS = 8 + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def zephyr_pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.fixture(scope="module") +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + # pa config + "--enable-prompt-adapter", + "--prompt-adapters", + f"zephyr-pa={zephyr_pa_files}", + f"zephyr-pa2={zephyr_pa_files}", + "--max-prompt-adapters", + "2", + "--max-prompt-adapter-token", + "128", + "--disable-frontend-multiprocessing" + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name,num_virtual_tokens", + [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], +) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, + num_virtual_tokens: int): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, + prompt_tokens=6 + num_virtual_tokens, + total_tokens=11 + num_virtual_tokens) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + + +@pytest.mark.asyncio +async def test_added_lora_tokens(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model="zephyr-lora2", + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should appear in tokenized prompt + assert completion.choices[0].text.startswith("vllm1vllm2vllm3") + + +@pytest.mark.asyncio +async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should not appear in tokenized prompt + assert "vllm" not in completion.choices[0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora and 1 pa hereafter + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary + # for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +async def test_logits_bias(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 5 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + token_id = 1000 + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token_id): 100}, + seed=42, + ) + assert len(completion.choices[0].text) >= 5 + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), + add_special_tokens=False)["input_ids"] + assert all([ + response == expected + for response, expected in zip(response_tokens, expected_tokens) + ]) + + # Test ban + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + ) + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + first_response = completion.choices[0].text + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token): -100 + for token in response_tokens}, + ) + assert first_response != completion.choices[0].text + + +@pytest.mark.asyncio +async def test_allowed_token_ids(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 1 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + allowed_ids = [21555, 21557, 21558] + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + extra_body=dict(allowed_token_ids=allowed_ids), + logprobs=1, + ) + response_tokens = completion.choices[0].logprobs.tokens + assert len(response_tokens) == 1 + assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}", + n=3, + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_json=sample_json_schema, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_regex): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {sample_regex}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict(guided_regex=sample_regex, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + assert re.fullmatch(sample_regex, + completion.choices[i].text) is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_guided_choice): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="The best language for type-safe systems programming is ", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict(guided_choice=sample_guided_choice, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 2 + for i in range(2): + assert completion.choices[i].text in sample_guided_choice + + +@pytest.mark.asyncio +async def test_guided_grammar(client: openai.AsyncOpenAI, + sample_sql_statements): + + completion = await client.completions.create( + model=MODEL_NAME, + prompt=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_grammar=sample_sql_statements)) + + content = completion.choices[0].text + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(content) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") + + assert content.strip() == ground_truth + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema, sample_regex): + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) + + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex", + extra_body=dict(guided_regex=sample_regex, + guided_json=sample_json_schema)) From 62036add2cd30feb1cf668f665ccb3d1553b31d8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 21:13:31 +0000 Subject: [PATCH 75/80] rename --- vllm/engine/protocol.py | 4 +- vllm/entrypoints/openai/api_server.py | 54 +++++++++---------- vllm/entrypoints/openai/rpc/client.py | 2 +- vllm/entrypoints/openai/rpc/server.py | 6 +-- vllm/entrypoints/openai/serving_chat.py | 14 ++--- vllm/entrypoints/openai/serving_completion.py | 16 +++--- vllm/entrypoints/openai/serving_embedding.py | 12 ++--- vllm/entrypoints/openai/serving_engine.py | 8 +-- .../openai/serving_tokenization.py | 10 ++-- 9 files changed, 61 insertions(+), 65 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index b8f8eea44573..fc94ef6662e0 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -15,8 +15,8 @@ @runtime_checkable -class VLLMBackend(Protocol): - """Protocol class for asynchronous vllm backends""" +class AsyncEngineClient(Protocol): + """Protocol class for Clients to AsyncLLMEngine""" @property def is_running(self) -> bool: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 5c8e5c4d76f9..309a3acb5a4d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,7 +21,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.protocol import VLLMBackend +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -34,7 +34,7 @@ EmbeddingRequest, ErrorResponse, TokenizeRequest, TokenizeResponse) -from vllm.entrypoints.openai.rpc.client import RPCClient +from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -54,7 +54,7 @@ openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding openai_serving_tokenization: OpenAIServingTokenization -backend: VLLMBackend +async_engine_client: AsyncEngineClient logger = init_logger('vllm.entrypoints.openai.api_server') @@ -67,7 +67,7 @@ async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await backend.do_log_stats() + await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) @@ -78,14 +78,14 @@ async def _force_log(): @asynccontextmanager -async def build_backend(args) -> AsyncIterator[VLLMBackend]: - # Context manager to handle backend lifecycle +async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: + # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) # Backend itself still global for the silly lil' health handler - global backend + global async_engine_client # First need to determine if this is an embeddings model # (no remote backend for those) @@ -96,16 +96,13 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: seed=0, dtype="float16") if model_config.embedding_mode or args.disable_frontend_multiprocessing: - # local backend - backend = AsyncLLMEngine.from_engine_args( + async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - yield backend - # No cleanup + yield async_engine_client return else: - # remote backend - ## First need to start the backend process + # Start the RPC Server, which has the AsyncLLMEngine. port = get_open_port(envs.VLLM_RPC_PORT) rpc_server_process = Process(target=run_rpc_server, args=(engine_args, @@ -113,19 +110,18 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: port)) rpc_server_process.start() - ## Then build the client for the backend process - backend = RPCClient(port) - await backend.setup() + # Build the RPC Client, which conforms to the AsyncEngineClient protocol. + async_engine_client = AsyncEngineRPCClient(port) + await async_engine_client.setup() try: - yield backend + yield async_engine_client finally: - ## Cleanup: - # Ensure backend process was terminated + # Ensure rpc server process was terminated rpc_server_process.terminate() # Close all open connections to the backend - backend.close() + async_engine_client.close() # Wait for server process to join rpc_server_process.join() @@ -145,7 +141,7 @@ def mount_metrics(app: fastapi.FastAPI): @router.get("/health") async def health() -> Response: """Health check.""" - await backend.check_health() + await async_engine_client.check_health() return Response(status_code=200) @@ -274,7 +270,7 @@ async def authentication(request: Request, call_next): async def build_server( - backend: VLLMBackend, + async_engine_client: AsyncEngineClient, args, **uvicorn_kwargs, ) -> uvicorn.Server: @@ -285,7 +281,7 @@ async def build_server( else: served_model_names = [args.model] - model_config = await backend.get_model_config() + model_config = await async_engine_client.get_model_config() if args.disable_log_requests: request_logger = None @@ -298,7 +294,7 @@ async def build_server( global openai_serving_tokenization openai_serving_chat = OpenAIServingChat( - backend, + async_engine_client, model_config, served_model_names, args.response_role, @@ -309,7 +305,7 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_completion = OpenAIServingCompletion( - backend, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -318,13 +314,13 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_embedding = OpenAIServingEmbedding( - backend, + async_engine_client, model_config, served_model_names, request_logger=request_logger, ) openai_serving_tokenization = OpenAIServingTokenization( - backend, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -361,10 +357,10 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("args: %s", args) shutdown_task = None - async with build_backend(args) as backend: + async with build_async_engine_client(args) as async_engine_client: server = await build_server( - backend, + async_engine_client, args, **uvicorn_kwargs, ) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 0f609ba83424..45bf88b5bf57 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -19,7 +19,7 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -class RPCClient: +class AsyncEngineRPCClient: def __init__(self, port: int): self.context = zmq.asyncio.Context() diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index f5f785ca8166..ef44010ecd65 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -17,7 +17,7 @@ logger = init_logger('vllm.entrypoints.openai.rpc.server') -class RPCServer: +class AsyncEngineRPCServer: def __init__(self, async_engine_args: AsyncEngineArgs, usage_context: UsageContext, port: int): @@ -188,7 +188,7 @@ async def run_server_loop(self): task.add_done_callback(running_tasks.discard) -async def run_server(server: RPCServer): +async def run_server(server: AsyncEngineRPCServer): # Put the server task into the asyncio loop. loop = asyncio.get_running_loop() server_task = loop.create_task(server.run_server_loop()) @@ -212,5 +212,5 @@ def signal_handler() -> None: def run_rpc_server(async_engine_args: AsyncEngineArgs, usage_context: UsageContext, port: int): - server = RPCServer(async_engine_args, usage_context, port) + server = AsyncEngineRPCServer(async_engine_args, usage_context, port) asyncio.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b6ea7b41c8dc..ea875d815e6f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.protocol import VLLMBackend +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing): def __init__( self, - vllm_backend: VLLMBackend, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -50,7 +50,7 @@ def __init__( chat_template: Optional[str], return_tokens_as_token_ids: bool = False, ): - super().__init__(vllm_backend=vllm_backend, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -89,7 +89,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.vllm_backend.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -161,7 +161,7 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = await self.vllm_backend.is_tracing_enabled() + is_tracing_enabled = await self.async_engine_client.is_tracing_enabled() trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -169,7 +169,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.vllm_backend.generate( + result_generator = self.async_engine_client.generate( engine_inputs, sampling_params, request_id, @@ -441,7 +441,7 @@ async def chat_completion_full_generator( async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.vllm_backend.abort(request_id) + await self.async_engine_client.abort(request_id) return self.create_error_response("Client disconnected") final_res = res assert final_res is not None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9ea5da1dad00..20cb46ed863a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.protocol import VLLMBackend +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -42,7 +42,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - vllm_backend: VLLMBackend, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -51,7 +51,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(vllm_backend=vllm_backend, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -91,7 +91,7 @@ async def create_completion(self, request: CompletionRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.vllm_backend.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -119,7 +119,7 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.vllm_backend.is_tracing_enabled( + is_tracing_enabled = await self.async_engine_client.is_tracing_enabled( ) trace_headers = None if is_tracing_enabled: @@ -128,7 +128,7 @@ async def create_completion(self, request: CompletionRequest, raw_request.headers): log_tracing_disabled_warning() - generator = self.vllm_backend.generate( + generator = self.async_engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, @@ -169,7 +169,7 @@ async def create_completion(self, request: CompletionRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.vllm_backend.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res @@ -231,7 +231,7 @@ async def completion_stream_generator( # Abort the request if the client disconnects. if await raw_request.is_disconnected(): - await self.vllm_backend.abort(f"{request_id}-{prompt_idx}") + await self.async_engine_client.abort(f"{request_id}-{prompt_idx}") raise StopAsyncIteration() for output in res.outputs: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 9518c42057cf..52294591ed06 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -6,7 +6,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import VLLMBackend +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - vllm_backend: VLLMBackend, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(vllm_backend=vllm_backend, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -99,7 +99,7 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.vllm_backend.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() @@ -124,7 +124,7 @@ async def create_embedding(self, request: EmbeddingRequest, "Prompt adapter is not supported " "for embedding models") - generator = self.vllm_backend.encode( + generator = self.async_engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, @@ -146,7 +146,7 @@ async def create_embedding(self, request: EmbeddingRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.vllm_backend.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 108f20c29290..df4932d8fe18 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.protocol import VLLMBackend +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -61,7 +61,7 @@ class OpenAIServing: def __init__( self, - vllm_backend: VLLMBackend, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -72,7 +72,7 @@ def __init__( ): super().__init__() - self.vllm_backend = vllm_backend + self.async_engine_client = async_engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -155,7 +155,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.vllm_backend.get_decoding_config() + decoding_config = await self.async_engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index ab6a01570802..c4350881a27a 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -3,7 +3,7 @@ from vllm.config import ModelConfig # yapf conflicts with isort for this block # yapf: disable -from vllm.engine.protocol import VLLMBackend +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -24,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - vllm_backend: VLLMBackend, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -32,7 +32,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(vllm_backend=vllm_backend, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -57,7 +57,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.vllm_backend.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -113,7 +113,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.vllm_backend.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, From a177d87ba9aaac0e5c62fd439e5e4bfb23602800 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 21:14:21 +0000 Subject: [PATCH 76/80] cleaning --- tests/entrypoints/openai/test_basic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 4c2d2a1190d4..2c721d9ba760 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,11 +1,9 @@ -import socket from http import HTTPStatus import openai import pytest import requests -from vllm import envs from vllm.version import __version__ as VLLM_VERSION from ...utils import RemoteOpenAIServer From 9ca3b93893bfcba9c479ee1372eac52ede5a90ea Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 21:15:52 +0000 Subject: [PATCH 77/80] ordering --- vllm/entrypoints/openai/api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 309a3acb5a4d..d8b191113e03 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -49,12 +49,12 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds +async_engine_client: AsyncEngineClient engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding openai_serving_tokenization: OpenAIServingTokenization -async_engine_client: AsyncEngineClient logger = init_logger('vllm.entrypoints.openai.api_server') From f8b5fb1fd8bcda58c3d641e22bba0afed85e17a4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 21:22:07 +0000 Subject: [PATCH 78/80] fix embedding model feedback --- vllm/entrypoints/openai/api_server.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d8b191113e03..1749f27a3846 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -60,6 +60,13 @@ _running_tasks: Set[asyncio.Task] = set() +def model_is_embedding(model_name: str) -> bool: + return ModelConfig(model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, dtype="float16").embedding_mode + @asynccontextmanager async def lifespan(app: fastapi.FastAPI): @@ -87,22 +94,19 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # Backend itself still global for the silly lil' health handler global async_engine_client - # First need to determine if this is an embeddings model - # (no remote backend for those) - model_config = ModelConfig(model=args.model, - tokenizer=args.tokenizer, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="float16") - if model_config.embedding_mode or args.disable_frontend_multiprocessing: + + # If manually triggered or embedding model, use AsyncLLMEngine in process. + # TODO: support embedding model via RPC. + if (model_is_embedding(args.model) or + args.disable_frontend_multiprocessing): async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) yield async_engine_client return + # Otherwise, use the multiprocessing AsyncLLMEngine. else: - # Start the RPC Server, which has the AsyncLLMEngine. + # Start the RPC Server in separate process (holds the AsyncLLMEngine). port = get_open_port(envs.VLLM_RPC_PORT) rpc_server_process = Process(target=run_rpc_server, args=(engine_args, From fca5a7109718c4724c0ccabb2bf45e142de9c09e Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 1 Aug 2024 17:22:54 -0400 Subject: [PATCH 79/80] Update vllm/entrypoints/openai/rpc/server.py Co-authored-by: Simon Mo --- vllm/entrypoints/openai/rpc/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index ef44010ecd65..7a72a6f732c9 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -logger = init_logger('vllm.entrypoints.openai.rpc.server') +logger = init_logger(__name__) class AsyncEngineRPCServer: From 5f07f866b6d2bcca90c5be636b63c8cebedbb3ee Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 1 Aug 2024 21:38:23 +0000 Subject: [PATCH 80/80] format --- vllm/entrypoints/openai/api_server.py | 13 +++++++------ vllm/entrypoints/openai/serving_chat.py | 6 ++++-- vllm/entrypoints/openai/serving_completion.py | 10 ++++++---- vllm/entrypoints/openai/serving_embedding.py | 3 ++- vllm/tracing.py | 2 +- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1749f27a3846..e330ee81f7e4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -60,12 +60,14 @@ _running_tasks: Set[asyncio.Task] = set() + def model_is_embedding(model_name: str) -> bool: return ModelConfig(model=model_name, tokenizer=model_name, tokenizer_mode="auto", trust_remote_code=False, - seed=0, dtype="float16").embedding_mode + seed=0, + dtype="float16").embedding_mode @asynccontextmanager @@ -94,11 +96,10 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # Backend itself still global for the silly lil' health handler global async_engine_client - # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. - if (model_is_embedding(args.model) or - args.disable_frontend_multiprocessing): + if (model_is_embedding(args.model) + or args.disable_frontend_multiprocessing): async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) yield async_engine_client @@ -106,7 +107,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # Otherwise, use the multiprocessing AsyncLLMEngine. else: - # Start the RPC Server in separate process (holds the AsyncLLMEngine). + # Start RPCServer in separate process (holds the AsyncLLMEngine). port = get_open_port(envs.VLLM_RPC_PORT) rpc_server_process = Process(target=run_rpc_server, args=(engine_args, @@ -114,7 +115,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: port)) rpc_server_process.start() - # Build the RPC Client, which conforms to the AsyncEngineClient protocol. + # Build RPCClient, which conforms to AsyncEngineClient Protocol. async_engine_client = AsyncEngineRPCClient(port) await async_engine_client.setup() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ea875d815e6f..ebb1d57fbb9a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -89,7 +89,8 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -161,7 +162,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = await self.async_engine_client.is_tracing_enabled() + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 20cb46ed863a..edc83d83fbba 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -91,7 +91,8 @@ async def create_completion(self, request: CompletionRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -119,8 +120,8 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.async_engine_client.is_tracing_enabled( - ) + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -231,7 +232,8 @@ async def completion_stream_generator( # Abort the request if the client disconnects. if await raw_request.is_disconnected(): - await self.async_engine_client.abort(f"{request_id}-{prompt_idx}") + await self.async_engine_client.abort( + f"{request_id}-{prompt_idx}") raise StopAsyncIteration() for output in res.outputs: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 52294591ed06..e61c82f9a8a6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -99,7 +99,8 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) pooling_params = request.to_pooling_params() diff --git a/vllm/tracing.py b/vllm/tracing.py index ba6732cab68f..9f7bea872678 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -60,7 +60,7 @@ def get_span_exporter(endpoint): OTLPSpanExporter) elif protocol == "http/protobuf": from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) + OTLPSpanExporter) # type: ignore else: raise ValueError( f"Unsupported OTLP protocol '{protocol}' is configured")