Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CLI tools #52

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

On the server side, run one of the following commands:
vLLM OpenAI API server
python -m vllm.entrypoints.openai.api_server \
--model <your_model> --swap-space 16 \
vllm serve <your_model> \
--swap-space 16 \
--disable-log-requests

(TGI backend)
Expand All @@ -17,7 +17,7 @@
--dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000

when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
Expand All @@ -44,6 +44,11 @@
except ImportError:
from backend_request_func import get_tokenizer

try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser


@dataclass
class BenchmarkMetrics:
Expand Down Expand Up @@ -72,7 +77,6 @@ def sample_sharegpt_requests(
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
Expand Down Expand Up @@ -191,6 +195,7 @@ async def get_request(
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue

# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
Expand All @@ -214,7 +219,7 @@ def calculate_metrics(
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note: this may inflate the output token count slightly
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
Expand Down Expand Up @@ -511,7 +516,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)

```{argparse}
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:func: create_parser_for_docs
:prog: -m vllm.entrypoints.openai.api_server
```

Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,9 @@ def _read_requirements(filename: str) -> List[str]:
},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data,
entry_points={
"console_scripts": [
"vllm=vllm.scripts:main",
],
},
)
50 changes: 50 additions & 0 deletions tests/tgis/test_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path

import pytest
from huggingface_hub.utils import LocalEntryNotFoundError

from vllm.tgis_utils.hub import (convert_files, download_weights, weight_files,
weight_hub_files)


def test_convert_files():
model_id = "bigscience/bloom-560m"
local_pt_files = download_weights(model_id, extension=".bin")
local_pt_files = [Path(p) for p in local_pt_files]
local_st_files = [
p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors"
for p in local_pt_files
]
convert_files(local_pt_files, local_st_files, discard_names=[])

found_st_files = weight_files(model_id)

assert all([str(p) in found_st_files for p in local_st_files])


def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]


def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [
f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)
]


def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", ".errors")
assert filenames == []


def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files


def test_weight_files_error():
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")
6 changes: 4 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port
from vllm.utils import FlexibleArgumentParser, get_open_port

# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
Expand Down Expand Up @@ -74,7 +74,9 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:

cli_args = cli_args + ["--port", str(get_open_port())]

parser = make_arg_parser()
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
self.host = str(args.host or 'localhost')
self.port = int(args.port)
Expand Down
82 changes: 53 additions & 29 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import fastapi
import uvicorn
from fastapi import Request
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand All @@ -30,10 +30,14 @@
from vllm.logger import init_logger
from vllm.tgis_utils.args import add_tgis_args, postprocess_tgis_args
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION

TIMEOUT_KEEP_ALIVE = 5 # seconds

logger = init_logger(__name__)
engine: AsyncLLMEngine
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
Expand Down Expand Up @@ -67,50 +71,35 @@ async def _force_log():
logger.info("gRPC server stopped")


app = fastapi.FastAPI(lifespan=lifespan)


def parse_args():
parser = make_arg_parser()
parser = add_tgis_args(parser)
parsed_args = parser.parse_args()
parsed_args = postprocess_tgis_args(parsed_args)
return parsed_args

router = APIRouter()

# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)

router.routes.append(route)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)


@app.get("/health")
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)


@app.get("/v1/models")
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())


@app.get("/version")
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)


@app.post("/v1/chat/completions")
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
Expand All @@ -126,7 +115,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return JSONResponse(content=generator.model_dump())


@app.post("/v1/completions")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
Expand All @@ -140,7 +129,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


@app.post("/v1/embeddings")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
Expand All @@ -151,8 +140,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


if __name__ == "__main__":
args = parse_args()
def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path

app.add_middleware(
CORSMiddleware,
Expand All @@ -162,6 +153,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
allow_headers=args.allowed_headers,
)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)

if token := envs.VLLM_API_KEY or args.api_key:

@app.middleware("http")
Expand All @@ -187,6 +184,12 @@ async def authentication(request: Request, call_next):
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")

return app


def run_server(args, llm_engine=None):
app = build_app(args)

logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

Expand All @@ -195,6 +198,8 @@ async def authentication(request: Request, call_next):
else:
served_model_names = [args.model]

global engine, engine_args

engine_args = AsyncEngineArgs.from_cli_args(args)

# Enforce pixel values as image input type for vision language models
Expand All @@ -206,8 +211,9 @@ async def authentication(request: Request, call_next):
"Only --image-input-type 'pixel_values' is supported for serving "
"vision language models with the vLLM API server.")

engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))

event_loop: Optional[asyncio.AbstractEventLoop]
try:
Expand All @@ -223,6 +229,11 @@ async def authentication(request: Request, call_next):
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())

global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global async_llm_engine

openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
Expand All @@ -247,3 +258,16 @@ async def authentication(request: Request, call_next):
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)


if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
parser = add_tgis_args(parser)
args = parser.parse_args()
args = postprocess_tgis_args(args)

run_server(args)
12 changes: 8 additions & 4 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.tgis_utils.args import EnvVarArgumentParser
from vllm.utils import FlexibleArgumentParser


class LoRAParserAction(argparse.Action):
Expand All @@ -23,9 +23,7 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, lora_list)


def make_arg_parser():
parser = EnvVarArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
Expand Down Expand Up @@ -114,3 +112,9 @@ def make_arg_parser():

parser = AsyncEngineArgs.add_cli_args(parser)
return parser


def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)
Loading
Loading