-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] vLLM CLI for serving and querying OpenAI compatible server #5090
Changes from 24 commits
533dfa2
fa90277
b6f06fa
3c09138
01b0fef
8d13d0a
e4004e9
d9606e4
60d58cb
dd031b5
fdea667
1979d18
73ed451
5aa70b6
0aff304
1e4e891
1c617b9
5c8250b
09103b6
ae60142
807d97f
09aa92f
f9dde03
00f84dd
6f60716
cbd8d8e
310f473
4913116
edef04f
9e19be7
563ec6d
824b5d9
3dd1b75
e93d59a
53b6d1e
8cf2257
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import argparse | ||
import asyncio | ||
import importlib | ||
import inspect | ||
|
@@ -8,7 +9,7 @@ | |
|
||
import fastapi | ||
import uvicorn | ||
from fastapi import Request | ||
from fastapi import APIRouter, Request | ||
from fastapi.exceptions import RequestValidationError | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from fastapi.responses import JSONResponse, Response, StreamingResponse | ||
|
@@ -32,6 +33,9 @@ | |
|
||
TIMEOUT_KEEP_ALIVE = 5 # seconds | ||
|
||
logger = init_logger(__name__) | ||
engine: AsyncLLMEngine | ||
engine_args: AsyncEngineArgs | ||
openai_serving_chat: OpenAIServingChat | ||
openai_serving_completion: OpenAIServingCompletion | ||
openai_serving_embedding: OpenAIServingEmbedding | ||
|
@@ -57,47 +61,35 @@ async def _force_log(): | |
yield | ||
|
||
|
||
app = fastapi.FastAPI(lifespan=lifespan) | ||
|
||
|
||
def parse_args(): | ||
parser = make_arg_parser() | ||
return parser.parse_args() | ||
|
||
router = APIRouter() | ||
|
||
# Add prometheus asgi middleware to route /metrics requests | ||
route = Mount("/metrics", make_asgi_app()) | ||
# Workaround for 307 Redirect for /metrics | ||
route.path_regex = re.compile('^/metrics(?P<path>.*)$') | ||
app.routes.append(route) | ||
|
||
router.routes.append(route) | ||
|
||
@app.exception_handler(RequestValidationError) | ||
async def validation_exception_handler(_, exc): | ||
err = openai_serving_chat.create_error_response(message=str(exc)) | ||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) | ||
|
||
|
||
@app.get("/health") | ||
@router.get("/health") | ||
async def health() -> Response: | ||
"""Health check.""" | ||
await openai_serving_chat.engine.check_health() | ||
return Response(status_code=200) | ||
|
||
|
||
@app.get("/v1/models") | ||
@router.get("/v1/models") | ||
async def show_available_models(): | ||
models = await openai_serving_chat.show_available_models() | ||
return JSONResponse(content=models.model_dump()) | ||
|
||
|
||
@app.get("/version") | ||
@router.get("/version") | ||
async def show_version(): | ||
ver = {"version": vllm.__version__} | ||
return JSONResponse(content=ver) | ||
|
||
|
||
@app.post("/v1/chat/completions") | ||
@router.post("/v1/chat/completions") | ||
async def create_chat_completion(request: ChatCompletionRequest, | ||
raw_request: Request): | ||
generator = await openai_serving_chat.create_chat_completion( | ||
|
@@ -113,7 +105,7 @@ async def create_chat_completion(request: ChatCompletionRequest, | |
return JSONResponse(content=generator.model_dump()) | ||
|
||
|
||
@app.post("/v1/completions") | ||
@router.post("/v1/completions") | ||
async def create_completion(request: CompletionRequest, raw_request: Request): | ||
generator = await openai_serving_completion.create_completion( | ||
request, raw_request) | ||
|
@@ -127,7 +119,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): | |
return JSONResponse(content=generator.model_dump()) | ||
|
||
|
||
@app.post("/v1/embeddings") | ||
@router.post("/v1/embeddings") | ||
async def create_embedding(request: EmbeddingRequest, raw_request: Request): | ||
generator = await openai_serving_embedding.create_embedding( | ||
request, raw_request) | ||
|
@@ -138,8 +130,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): | |
return JSONResponse(content=generator.model_dump()) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
def build_app(args): | ||
app = fastapi.FastAPI(lifespan=lifespan) | ||
app.include_router(router) | ||
app.root_path = args.root_path | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
|
@@ -149,6 +143,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): | |
allow_headers=args.allowed_headers, | ||
) | ||
|
||
@app.exception_handler(RequestValidationError) | ||
async def validation_exception_handler(_, exc): | ||
err = openai_serving_chat.create_error_response(message=str(exc)) | ||
return JSONResponse(err.model_dump(), | ||
status_code=HTTPStatus.BAD_REQUEST) | ||
|
||
if token := envs.VLLM_API_KEY or args.api_key: | ||
|
||
@app.middleware("http") | ||
|
@@ -174,6 +174,12 @@ async def authentication(request: Request, call_next): | |
raise ValueError(f"Invalid middleware {middleware}. " | ||
f"Must be a function or a class.") | ||
|
||
return app | ||
|
||
|
||
def run_server(args): | ||
app = build_app(args) | ||
|
||
logger.info("vLLM API server version %s", vllm.__version__) | ||
logger.info("args: %s", args) | ||
|
||
|
@@ -182,6 +188,8 @@ async def authentication(request: Request, call_next): | |
else: | ||
served_model_names = [args.model] | ||
|
||
global engine, engine_args | ||
|
||
engine_args = AsyncEngineArgs.from_cli_args(args) | ||
engine = AsyncLLMEngine.from_engine_args( | ||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) | ||
|
@@ -200,6 +208,10 @@ async def authentication(request: Request, call_next): | |
# When using single vLLM without engine_use_ray | ||
model_config = asyncio.run(engine.get_model_config()) | ||
|
||
global openai_serving_chat | ||
global openai_serving_completion | ||
global openai_serving_embedding | ||
|
||
openai_serving_chat = OpenAIServingChat(engine, model_config, | ||
served_model_names, | ||
args.response_role, | ||
|
@@ -219,3 +231,13 @@ async def authentication(request: Request, call_next): | |
ssl_certfile=args.ssl_certfile, | ||
ssl_ca_certs=args.ssl_ca_certs, | ||
ssl_cert_reqs=args.ssl_cert_reqs) | ||
|
||
|
||
if __name__ == "__main__": | ||
# NOTE(simon): | ||
# This section should be in sync with vllm/scripts.py for CLI entrypoints. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this note true? They seem to be different? (also in this case, should we have a common main method to share?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In sync in their usage of |
||
parser = argparse.ArgumentParser( | ||
description="vLLM OpenAI-Compatible RESTful API server.") | ||
parser = make_arg_parser(parser) | ||
args = parser.parse_args() | ||
run_server(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# The CLI entrypoint to vLLM. | ||
import argparse | ||
import os | ||
import signal | ||
import sys | ||
from typing import Optional | ||
|
||
from openai import OpenAI | ||
|
||
from vllm.entrypoints.openai.api_server import run_server | ||
from vllm.entrypoints.openai.cli_args import make_arg_parser | ||
|
||
|
||
def registrer_signal_handlers(): | ||
|
||
def signal_handler(sig, frame): | ||
sys.exit(0) | ||
|
||
signal.signal(signal.SIGINT, signal_handler) | ||
signal.signal(signal.SIGTSTP, signal_handler) | ||
|
||
|
||
def serve(args: argparse.Namespace) -> None: | ||
# EngineArgs expects the model name to be passed as --model. | ||
if args.model is not None and args.model == args.model_tag: | ||
raise ValueError( | ||
"The --model argument is not supported for the serve command. " | ||
"Use positional argument [model_tag] instead.") | ||
args.model = args.model_tag | ||
|
||
run_server(args) | ||
|
||
|
||
def interactive_cli(args: argparse.Namespace) -> None: | ||
registrer_signal_handlers() | ||
|
||
base_url = args.url | ||
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") | ||
openai_client = OpenAI(api_key=api_key, base_url=base_url) | ||
|
||
if args.model_name: | ||
model_name = args.model_name | ||
else: | ||
available_models = openai_client.models.list() | ||
model_name = available_models.data[0].id | ||
|
||
print(f"Using model: {model_name}") | ||
|
||
if args.command == "complete": | ||
complete(model_name, openai_client) | ||
elif args.command == "chat": | ||
chat(args.system_prompt, model_name, openai_client) | ||
|
||
|
||
def complete(model_name: str, client: OpenAI) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: do we support setting some sampling params? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Sang's question is whether or not we should support setting sampling params through
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, it would require checking field by field whether any sampling params are provided to override the default value. |
||
print("Please enter prompt to complete:") | ||
while True: | ||
input_prompt = input("> ") | ||
|
||
completion = client.completions.create(model=model_name, | ||
prompt=input_prompt) | ||
output = completion.choices[0].text | ||
print(output) | ||
|
||
|
||
def chat(system_prompt: Optional[str], model_name: str, | ||
client: OpenAI) -> None: | ||
conversation = [] | ||
if system_prompt is not None: | ||
conversation.append({"role": "system", "content": system_prompt}) | ||
|
||
print("Please enter a message for the chat model:") | ||
while True: | ||
input_message = input("> ") | ||
message = {"role": "user", "content": input_message} | ||
conversation.append(message) | ||
|
||
chat_completion = client.chat.completions.create(model=model_name, | ||
messages=conversation) | ||
|
||
response_message = chat_completion.choices[0].message | ||
output = response_message.content | ||
|
||
conversation.append(response_message) | ||
print(output) | ||
|
||
|
||
def _add_query_options( | ||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | ||
parser.add_argument( | ||
"--url", | ||
type=str, | ||
default="http://localhost:8000/v1", | ||
help="url of the running OpenAI-Compatible RESTful API server") | ||
parser.add_argument( | ||
"--model-name", | ||
type=str, | ||
default=None, | ||
help=("The model name used in prompt completion, default to " | ||
"the first model in list models API call.")) | ||
parser.add_argument( | ||
"--api-key", | ||
type=str, | ||
default=None, | ||
help=( | ||
"API key for OpenAI services. If provided, this api key " | ||
"will overwrite the api key obtained through environment variables." | ||
)) | ||
return parser | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="vLLM CLI") | ||
subparsers = parser.add_subparsers(required=True) | ||
|
||
serve_parser = subparsers.add_parser( | ||
"serve", | ||
help="Start the vLLM OpenAI Compatible API server", | ||
usage="vllm serve <model_tag> [options]") | ||
serve_parser.add_argument("model_tag", | ||
type=str, | ||
help="The model tag to serve") | ||
serve_parser = make_arg_parser(serve_parser) | ||
serve_parser.set_defaults(dispatch_function=serve) | ||
|
||
complete_parser = subparsers.add_parser( | ||
"complete", | ||
help=("Generate text completions based on the given prompt " | ||
"via the running API server"), | ||
usage="vllm complete [options]") | ||
_add_query_options(complete_parser) | ||
complete_parser.set_defaults(dispatch_function=interactive_cli, | ||
command="complete") | ||
|
||
chat_parser = subparsers.add_parser( | ||
"chat", | ||
help="Generate chat completions via the running API server", | ||
usage="vllm chat [options]") | ||
_add_query_options(chat_parser) | ||
chat_parser.add_argument( | ||
"--system-prompt", | ||
type=str, | ||
default=None, | ||
help=("The system prompt to be added to the chat template, " | ||
"used for models that support system prompts.")) | ||
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") | ||
Comment on lines
+123
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add some documentation on how to use these (can be in a separate PR) if we're planning to release this feature. |
||
|
||
args = parser.parse_args() | ||
# One of the sub commands should be executed. | ||
if hasattr(args, "dispatch_function"): | ||
args.dispatch_function(args) | ||
else: | ||
parser.print_help() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to make
engine
as an optional arg to this function?This can help external applications reuse the llm engine and attach other API interfaces (like grpc) to the same llm engine. To be used with the other suggestion of changing line 204 to:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, this would be useful.