Skip to content

Commit

Permalink
Lints
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
  • Loading branch information
rafvasq committed Jun 27, 2024
1 parent d807096 commit e1608f1
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 29 deletions.
5 changes: 1 addition & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ def __init__(self, cli_args: List[str], *, wait_url: str,
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
[
"vllm", "serve",
*cli_args
],
["vllm", "serve", *cli_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
Expand Down
15 changes: 7 additions & 8 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def build_app(args):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)

if token := envs.VLLM_API_KEY or args.api_key:

Expand All @@ -182,10 +183,10 @@ async def authentication(request: Request, call_next):
else:
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")

return app


def run_server(args, llm_engine=None):
app = build_app(args)

Expand All @@ -211,9 +212,8 @@ def run_server(args, llm_engine=None):
"vision language models with the vLLM API server.")

engine = (llm_engine
if llm_engine is not None
else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))

event_loop: Optional[asyncio.AbstractEventLoop]
try:
Expand All @@ -228,7 +228,7 @@ def run_server(args, llm_engine=None):
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())

global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
Expand Down Expand Up @@ -272,4 +272,3 @@ def run_server(args, llm_engine=None):
args = postprocess_tgis_args(args)

run_server(args)

5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.tgis_utils.args import EnvVarArgumentParser


class LoRAParserAction(argparse.Action):
Expand Down Expand Up @@ -114,6 +113,8 @@ def make_arg_parser(
parser = AsyncEngineArgs.add_cli_args(parser)
return parser


def create_parser_for_docs() -> argparse.ArgumentParser:
parser_for_docs = argparse.ArgumentParser(prog="-m vllm.entrypoints.openai.api_server")
parser_for_docs = argparse.ArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)
22 changes: 12 additions & 10 deletions vllm/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def tgis_cli(args: argparse.Namespace) -> None:
registrer_signal_handlers()

if args.command == "download-weights":
download_weights(args.model_name, args.revision, args.token, args.extension, args.auto_convert)
download_weights(args.model_name, args.revision, args.token,
args.extension, args.auto_convert)
elif args.command == "convert-to-safetensors":
convert_to_safetensors(args.model_name, args.revision)
elif args.command == "convert-to-fast-tokenizer":
convert_to_fast_tokenizer(args.model_name, args.revision, args.output_path)
convert_to_fast_tokenizer(args.model_name, args.revision,
args.output_path)


def complete(model_name: str, client: OpenAI) -> None:
Expand Down Expand Up @@ -310,28 +312,28 @@ def main():
download_weights_parser.add_argument("--revision")
download_weights_parser.add_argument("--token")
download_weights_parser.add_argument("--extension", default=".safetensors")
download_weights_parser.add_argument("--auto_convert" , default=True)
download_weights_parser.add_argument("--auto_convert", default=True)
download_weights_parser.set_defaults(dispatch_function=tgis_cli,
command="download-weights")
command="download-weights")

convert_to_safetensors_parser = subparsers.add_parser(
"convert-to-safetensors",
help=("Convert model weights to safetensors"),
usage="vllm convert-to-safetensors <model_name> [options]")
convert_to_safetensors_parser.add_argument("model_name")
convert_to_safetensors_parser.add_argument("--revision")
convert_to_safetensors_parser.set_defaults(dispatch_function=tgis_cli,
command="convert-to-safetensors")
convert_to_safetensors_parser.set_defaults(
dispatch_function=tgis_cli, command="convert-to-safetensors")

convert_to_fast_tokenizer_parser = subparsers.add_parser(
"convert-to-fast-tokenizer",
help=("Convert to fast tokenizer"),
usage="vllm convert-to-fast-tokenizer <model_name> [options]")
convert_to_fast_tokenizer_parser.add_argument("model_name")
convert_to_fast_tokenizer_parser.add_argument("--revision")
convert_to_fast_tokenizer_parser.add_argument("--output_path")
convert_to_fast_tokenizer_parser.set_defaults(dispatch_function=tgis_cli,
command="convert-to-fast-tokenizer")
convert_to_fast_tokenizer_parser.set_defaults(
dispatch_function=tgis_cli, command="convert-to-fast-tokenizer")

args = parser.parse_args()
# One of the sub commands should be executed.
Expand Down
10 changes: 5 additions & 5 deletions vllm/tgis_utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def convert_files(pt_files: List[Path],
logger.warning("No pytorch .bin weight files found to convert")
return

logger.info(f"Converting {N} pytorch .bin files to .safetensors...")
logger.info("Converting %d pytorch .bin files to .safetensors...", N)

for i, (pt_file, sf_file) in enumerate(pairs):
logger.info(f'Converting: [{i + 1}/{N}] "{pt_file.name}"')
file_count = (i + 1) / N
logger.info('Converting: [%d] "$s"', file_count, pt_file.name)
start = datetime.datetime.now()
convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now() - start
logger.info(
f'Converted: [{i + 1}/{N}] "{sf_file.name}" -- Took: {elapsed}')

logger.info('Converted: [%d] "%s" -- Took: %d', file_count,
sf_file.name, elapsed)

0 comments on commit e1608f1

Please sign in to comment.