diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index b2924b9e8463e..b625f92d77d38 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -2,8 +2,8 @@ On the server side, run one of the following commands: vLLM OpenAI API server - python -m vllm.entrypoints.openai.api_server \ - --model --swap-space 16 \ + vllm serve \ + --swap-space 16 \ --disable-log-requests (TGI backend) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 6248d84683753..092c3c6cb9a3d 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -109,7 +109,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) ```{argparse} :module: vllm.entrypoints.openai.cli_args -:func: make_arg_parser +:func: create_parser_for_docs :prog: -m vllm.entrypoints.openai.api_server ``` diff --git a/setup.py b/setup.py index 067ad13fed71b..2ecde92311f40 100644 --- a/setup.py +++ b/setup.py @@ -459,4 +459,9 @@ def _read_requirements(filename: str) -> List[str]: }, cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, package_data=package_data, + entry_points={ + "console_scripts": [ + "vllm=vllm.scripts:main", + ], + }, ) diff --git a/tests/utils.py b/tests/utils.py index ad4d097b0e8ed..c1382a4a981a1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import get_open_port, is_hip +from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip if is_hip(): from amdsmi import (amdsmi_get_gpu_vram_usage, @@ -103,7 +103,9 @@ def __init__(self, cli_args = cli_args + ["--port", str(get_open_port())] - parser = make_arg_parser() + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) args = parser.parse_args(cli_args) self.host = str(args.host or 'localhost') self.port = int(args.port) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6cba356c47063..45c634b4a2991 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -8,7 +8,7 @@ import fastapi import uvicorn -from fastapi import Request +from fastapi import APIRouter, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -35,10 +35,14 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds +logger = init_logger(__name__) +engine: AsyncLLMEngine +engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding @@ -64,35 +68,23 @@ async def _force_log(): yield -app = fastapi.FastAPI(lifespan=lifespan) - - -def parse_args(): - parser = make_arg_parser() - return parser.parse_args() - +router = APIRouter() # Add prometheus asgi middleware to route /metrics requests route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics route.path_regex = re.compile('^/metrics(?P.*)$') -app.routes.append(route) - - -@app.exception_handler(RequestValidationError) -async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) - return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) +router.routes.append(route) -@app.get("/health") +@router.get("/health") async def health() -> Response: """Health check.""" await openai_serving_chat.engine.check_health() return Response(status_code=200) -@app.post("/tokenize") +@router.post("/tokenize") async def tokenize(request: TokenizeRequest): generator = await openai_serving_completion.create_tokenize(request) if isinstance(generator, ErrorResponse): @@ -103,7 +95,7 @@ async def tokenize(request: TokenizeRequest): return JSONResponse(content=generator.model_dump()) -@app.post("/detokenize") +@router.post("/detokenize") async def detokenize(request: DetokenizeRequest): generator = await openai_serving_completion.create_detokenize(request) if isinstance(generator, ErrorResponse): @@ -114,19 +106,19 @@ async def detokenize(request: DetokenizeRequest): return JSONResponse(content=generator.model_dump()) -@app.get("/v1/models") +@router.get("/v1/models") async def show_available_models(): models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) -@app.get("/version") +@router.get("/version") async def show_version(): ver = {"version": VLLM_VERSION} return JSONResponse(content=ver) -@app.post("/v1/chat/completions") +@router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): generator = await openai_serving_chat.create_chat_completion( @@ -142,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return JSONResponse(content=generator.model_dump()) -@app.post("/v1/completions") +@router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): generator = await openai_serving_completion.create_completion( request, raw_request) @@ -156,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -@app.post("/v1/embeddings") +@router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): generator = await openai_serving_embedding.create_embedding( request, raw_request) @@ -167,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -if __name__ == "__main__": - args = parse_args() +def build_app(args): + app = fastapi.FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path app.add_middleware( CORSMiddleware, @@ -178,6 +172,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): allow_headers=args.allowed_headers, ) + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + err = openai_serving_chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + if token := envs.VLLM_API_KEY or args.api_key: @app.middleware("http") @@ -203,6 +203,12 @@ async def authentication(request: Request, call_next): raise ValueError(f"Invalid middleware {middleware}. " f"Must be a function or a class.") + return app + + +def run_server(args, llm_engine=None): + app = build_app(args) + logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) @@ -211,10 +217,12 @@ async def authentication(request: Request, call_next): else: served_model_names = [args.model] - engine_args = AsyncEngineArgs.from_cli_args(args) + global engine, engine_args - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) event_loop: Optional[asyncio.AbstractEventLoop] try: @@ -230,6 +238,10 @@ async def authentication(request: Request, call_next): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) + global openai_serving_chat + global openai_serving_completion + global openai_serving_embedding + openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, args.response_role, @@ -258,3 +270,13 @@ async def authentication(request: Request, call_next): ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs) + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + run_server(args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 81c474ecc808a..f841633b572a9 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -34,9 +34,7 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, adapter_list) -def make_arg_parser(): - parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--host", type=nullable_str, default=None, @@ -133,3 +131,9 @@ def make_arg_parser(): parser = AsyncEngineArgs.add_cli_args(parser) return parser + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/vllm/scripts.py b/vllm/scripts.py new file mode 100644 index 0000000000000..3f334be925ee8 --- /dev/null +++ b/vllm/scripts.py @@ -0,0 +1,154 @@ +# The CLI entrypoint to vLLM. +import argparse +import os +import signal +import sys +from typing import Optional + +from openai import OpenAI + +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils import FlexibleArgumentParser + + +def registrer_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def serve(args: argparse.Namespace) -> None: + # EngineArgs expects the model name to be passed as --model. + args.model = args.model_tag + + run_server(args) + + +def interactive_cli(args: argparse.Namespace) -> None: + registrer_signal_handlers() + + base_url = args.url + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") + openai_client = OpenAI(api_key=api_key, base_url=base_url) + + if args.model_name: + model_name = args.model_name + else: + available_models = openai_client.models.list() + model_name = available_models.data[0].id + + print(f"Using model: {model_name}") + + if args.command == "complete": + complete(model_name, openai_client) + elif args.command == "chat": + chat(args.system_prompt, model_name, openai_client) + + +def complete(model_name: str, client: OpenAI) -> None: + print("Please enter prompt to complete:") + while True: + input_prompt = input("> ") + + completion = client.completions.create(model=model_name, + prompt=input_prompt) + output = completion.choices[0].text + print(output) + + +def chat(system_prompt: Optional[str], model_name: str, + client: OpenAI) -> None: + conversation = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + input_message = input("> ") + message = {"role": "user", "content": input_message} + conversation.append(message) + + chat_completion = client.chat.completions.create(model=model_name, + messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) + print(output) + + +def _add_query_options( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "--url", + type=str, + default="http://localhost:8000/v1", + help="url of the running OpenAI-Compatible RESTful API server") + parser.add_argument( + "--model-name", + type=str, + default=None, + help=("The model name used in prompt completion, default to " + "the first model in list models API call.")) + parser.add_argument( + "--api-key", + type=str, + default=None, + help=( + "API key for OpenAI services. If provided, this api key " + "will overwrite the api key obtained through environment variables." + )) + return parser + + +def main(): + parser = FlexibleArgumentParser(description="vLLM CLI") + subparsers = parser.add_subparsers(required=True) + + serve_parser = subparsers.add_parser( + "serve", + help="Start the vLLM OpenAI Compatible API server", + usage="vllm serve [options]") + serve_parser.add_argument("model_tag", + type=str, + help="The model tag to serve") + serve_parser = make_arg_parser(serve_parser) + serve_parser.set_defaults(dispatch_function=serve) + + complete_parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " + "via the running API server"), + usage="vllm complete [options]") + _add_query_options(complete_parser) + complete_parser.set_defaults(dispatch_function=interactive_cli, + command="complete") + + chat_parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server", + usage="vllm chat [options]") + _add_query_options(chat_parser) + chat_parser.add_argument( + "--system-prompt", + type=str, + default=None, + help=("The system prompt to be added to the chat template, " + "used for models that support system prompts.")) + chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") + + args = parser.parse_args() + # One of the sub commands should be executed. + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main()