From 3b056775f46ec54b8bd56d71a5c325f2367c88e6 Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Wed, 17 Jul 2024 15:27:23 -0400 Subject: [PATCH 1/2] Add tgis tools Signed-off-by: Rafael Vasquez Co-authored-by: Prashant Gupta --- tests/tgis/__init__.py | 0 tests/tgis/test_hub.py | 50 +++++ vllm/entrypoints/openai/api_server.py | 29 +++ vllm/scripts.py | 190 ++++++++++++++++++ vllm/tgis_utils/hub.py | 270 ++++++++++++++++++++++++++ 5 files changed, 539 insertions(+) create mode 100644 tests/tgis/__init__.py create mode 100644 tests/tgis/test_hub.py create mode 100644 vllm/tgis_utils/hub.py diff --git a/tests/tgis/__init__.py b/tests/tgis/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tgis/test_hub.py b/tests/tgis/test_hub.py new file mode 100644 index 000000000000..5ecde4bc67e5 --- /dev/null +++ b/tests/tgis/test_hub.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import pytest +from huggingface_hub.utils import LocalEntryNotFoundError + +from vllm.tgis_utils.hub import (convert_files, download_weights, weight_files, + weight_hub_files) + + +def test_convert_files(): + model_id = "bigscience/bloom-560m" + local_pt_files = download_weights(model_id, extension=".bin") + local_pt_files = [Path(p) for p in local_pt_files] + local_st_files = [ + p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" + for p in local_pt_files + ] + convert_files(local_pt_files, local_st_files, discard_names=[]) + + found_st_files = weight_files(model_id) + + assert all([str(p) in found_st_files for p in local_st_files]) + + +def test_weight_hub_files(): + filenames = weight_hub_files("bigscience/bloom-560m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [ + f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73) + ] + + +def test_weight_hub_files_empty(): + filenames = weight_hub_files("bigscience/bloom", ".errors") + assert filenames == [] + + +def test_download_weights(): + files = download_weights("bigscience/bloom-560m") + local_files = weight_files("bigscience/bloom-560m") + assert files == local_files + + +def test_weight_files_error(): + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased") \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 45c634b4a299..857c7d641db7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,6 +18,7 @@ import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.grpc.grpc_server import start_grpc_server from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block # yapf: disable @@ -34,6 +35,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.logger import init_logger +from vllm.tgis_utils.args import add_tgis_args, postprocess_tgis_args from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION @@ -46,6 +48,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding +async_llm_engine: AsyncLLMEngine logger = init_logger('vllm.entrypoints.openai.api_server') @@ -65,8 +68,15 @@ async def _force_log(): _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) + grpc_server = await start_grpc_server(async_llm_engine, args) + yield + logger.info("Gracefully stopping gRPC server") + await grpc_server.stop(30) #TODO configurable grace + await grpc_server.wait_for_termination() + logger.info("gRPC server stopped") + router = APIRouter() @@ -220,6 +230,16 @@ def run_server(args, llm_engine=None): global engine, engine_args engine_args = AsyncEngineArgs.from_cli_args(args) + + # Enforce pixel values as image input type for vision language models + # when serving with API server + if engine_args.image_input_type is not None and \ + engine_args.image_input_type.upper() != "PIXEL_VALUES": + raise ValueError( + f"Invalid image_input_type: {engine_args.image_input_type}. " + "Only --image-input-type 'pixel_values' is supported for serving " + "vision language models with the vLLM API server.") + engine = (llm_engine if llm_engine is not None else AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) @@ -241,6 +261,7 @@ def run_server(args, llm_engine=None): global openai_serving_chat global openai_serving_completion global openai_serving_embedding + global async_llm_engine openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, @@ -252,6 +273,11 @@ def run_server(args, llm_engine=None): args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) + + # 🌶️🌶️🌶️ Sets the engine for the TGIS gRPC server. + # Do not delete on merge conflicts! + async_llm_engine = engine + app.root_path = args.root_path logger.info("Available routes are:") @@ -278,5 +304,8 @@ def run_server(args, llm_engine=None): parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) + parser = add_tgis_args(parser) args = parser.parse_args() + args = postprocess_tgis_args(args) + run_server(args) diff --git a/vllm/scripts.py b/vllm/scripts.py index 3f334be925ee..c1ba1c910ff8 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -3,6 +3,7 @@ import os import signal import sys +from pathlib import Path from typing import Optional from openai import OpenAI @@ -49,6 +50,19 @@ def interactive_cli(args: argparse.Namespace) -> None: chat(args.system_prompt, model_name, openai_client) +def tgis_cli(args: argparse.Namespace) -> None: + registrer_signal_handlers() + + if args.command == "download-weights": + download_weights(args.model_name, args.revision, args.token, + args.extension, args.auto_convert) + elif args.command == "convert-to-safetensors": + convert_to_safetensors(args.model_name, args.revision) + elif args.command == "convert-to-fast-tokenizer": + convert_to_fast_tokenizer(args.model_name, args.revision, + args.output_path) + + def complete(model_name: str, client: OpenAI) -> None: print("Please enter prompt to complete:") while True: @@ -82,6 +96,151 @@ def chat(system_prompt: Optional[str], model_name: str, print(output) +def download_weights( + model_name: str, + revision: Optional[str] = None, + token: Optional[str] = None, + extension: str = ".safetensors", + auto_convert: bool = True, +) -> None: + from vllm.tgis_utils import hub + + print(extension) + meta_exts = [".json", ".py", ".model", ".md"] + + extensions = extension.split(",") + + if len(extensions) == 1 and extensions[0] not in meta_exts: + extensions.extend(meta_exts) + + files = hub.download_weights(model_name, + extensions, + revision=revision, + auth_token=token) + + if auto_convert and ".safetensors" in extensions: + if not hub.local_weight_files(hub.get_model_path(model_name, revision), + ".safetensors"): + if ".bin" not in extensions: + print(".safetensors weights not found, \ + downloading pytorch weights to convert...") + hub.download_weights(model_name, + ".bin", + revision=revision, + auth_token=token) + + print(".safetensors weights not found, \ + converting from pytorch weights...") + convert_to_safetensors(model_name, revision) + elif not any(f.endswith(".safetensors") for f in files): + print(".safetensors weights not found on hub, \ + but were found locally. Remove them first to re-convert") + if auto_convert: + convert_to_fast_tokenizer(model_name, revision) + + +def convert_to_safetensors( + model_name: str, + revision: Optional[str] = None, +): + from vllm.tgis_utils import hub + + # Get local pytorch file paths + model_path = hub.get_model_path(model_name, revision) + local_pt_files = hub.local_weight_files(model_path, ".bin") + local_pt_index_files = hub.local_index_files(model_path, ".bin") + if len(local_pt_index_files) > 1: + print( + f"Found more than one .bin.index.json file: {local_pt_index_files}" + ) + return + if not local_pt_files: + print("No pytorch .bin files found to convert") + return + + local_pt_files = [Path(f) for f in local_pt_files] + local_pt_index_file = local_pt_index_files[ + 0] if local_pt_index_files else None + + # Safetensors final filenames + local_st_files = [ + p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" + for p in local_pt_files + ] + + if any(os.path.exists(p) for p in local_st_files): + print("Existing .safetensors weights found, \ + remove them first to reconvert") + return + + try: + import transformers + + config = transformers.AutoConfig.from_pretrained( + model_name, + revision=revision, + ) + architecture = config.architectures[0] + + class_ = getattr(transformers, architecture) + + # Name for this variable depends on transformers version + discard_names = getattr(class_, "_tied_weights_keys", []) + discard_names.extend( + getattr(class_, "_keys_to_ignore_on_load_missing", [])) + + except Exception: + discard_names = [] + + if local_pt_index_file: + local_pt_index_file = Path(local_pt_index_file) + st_prefix = local_pt_index_file.stem.removeprefix( + "pytorch_").removesuffix(".bin.index") + local_st_index_file = (local_pt_index_file.parent / + f"{st_prefix}.safetensors.index.json") + + if os.path.exists(local_st_index_file): + print("Existing .safetensors.index.json file found, \ + remove it first to reconvert") + return + + hub.convert_index_file(local_pt_index_file, local_st_index_file, + local_pt_files, local_st_files) + + # Convert pytorch weights to safetensors + hub.convert_files(local_pt_files, local_st_files, discard_names) + + +def convert_to_fast_tokenizer( + model_name: str, + revision: Optional[str] = None, + output_path: Optional[str] = None, +): + from vllm.tgis_utils import hub + + # Check for existing "tokenizer.json" + model_path = hub.get_model_path(model_name, revision) + + if os.path.exists(os.path.join(model_path, "tokenizer.json")): + print(f"Model {model_name} already has a fast tokenizer") + return + + if output_path is not None: + if not os.path.isdir(output_path): + print(f"Output path {output_path} must exist and be a directory") + return + else: + output_path = model_path + + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, + revision=revision) + tokenizer.save_pretrained(output_path) + + print(f"Saved tokenizer to {output_path}") + + def _add_query_options( parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( @@ -142,6 +301,37 @@ def main(): "used for models that support system prompts.")) chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") + download_weights_parser = subparsers.add_parser( + "download-weights", + help=("Download the weights of a given model"), + usage="vllm download-weights [options]") + download_weights_parser.add_argument("model_name") + download_weights_parser.add_argument("--revision") + download_weights_parser.add_argument("--token") + download_weights_parser.add_argument("--extension", default=".safetensors") + download_weights_parser.add_argument("--auto_convert", default=True) + download_weights_parser.set_defaults(dispatch_function=tgis_cli, + command="download-weights") + + convert_to_safetensors_parser = subparsers.add_parser( + "convert-to-safetensors", + help=("Convert model weights to safetensors"), + usage="vllm convert-to-safetensors [options]") + convert_to_safetensors_parser.add_argument("model_name") + convert_to_safetensors_parser.add_argument("--revision") + convert_to_safetensors_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-safetensors") + + convert_to_fast_tokenizer_parser = subparsers.add_parser( + "convert-to-fast-tokenizer", + help=("Convert to fast tokenizer"), + usage="vllm convert-to-fast-tokenizer [options]") + convert_to_fast_tokenizer_parser.add_argument("model_name") + convert_to_fast_tokenizer_parser.add_argument("--revision") + convert_to_fast_tokenizer_parser.add_argument("--output_path") + convert_to_fast_tokenizer_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-fast-tokenizer") + args = parser.parse_args() # One of the sub commands should be executed. if hasattr(args, "dispatch_function"): diff --git a/vllm/tgis_utils/hub.py b/vllm/tgis_utils/hub.py new file mode 100644 index 000000000000..4361b189fdea --- /dev/null +++ b/vllm/tgis_utils/hub.py @@ -0,0 +1,270 @@ +import concurrent +import datetime +import glob +import json +import logging +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from pathlib import Path +from typing import Dict, List, Optional + +import torch +from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache +from huggingface_hub.utils import LocalEntryNotFoundError +from safetensors.torch import (_find_shared_tensors, _is_complete, load_file, + save_file) +from tqdm import tqdm + +TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE") == "true" +logger = logging.getLogger(__name__) + + +def weight_hub_files(model_name, + extension=".safetensors", + revision=None, + auth_token=None): + """Get the safetensors filenames on the hub""" + exts = [extension] if isinstance(extension, str) else extension + api = HfApi() + info = api.model_info(model_name, revision=revision, token=auth_token) + filenames = [ + s.rfilename for s in info.siblings if any( + s.rfilename.endswith(ext) and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename and "args" not in s.rfilename + and "training" not in s.rfilename for ext in exts) + ] + return filenames + + +def weight_files(model_name, extension=".safetensors", revision=None): + """Get the local safetensors filenames""" + filenames = weight_hub_files(model_name, extension) + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache(model_name, + filename=filename, + revision=revision) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_name} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `vllm \ + download-weights {model_name}` first.") + files.append(cache_file) + + return files + + +def download_weights(model_name, + extension=".safetensors", + revision=None, + auth_token=None): + """Download the safetensors files from the hub""" + filenames = weight_hub_files(model_name, + extension, + revision=revision, + auth_token=auth_token) + + download_function = partial( + hf_hub_download, + repo_id=model_name, + local_files_only=False, + revision=revision, + token=auth_token, + ) + + print(f"Downloading {len(filenames)} files for model {model_name}") + executor = ThreadPoolExecutor(max_workers=5) + futures = [ + executor.submit(download_function, filename=filename) + for filename in filenames + ] + files = [ + future.result() + for future in tqdm(concurrent.futures.as_completed(futures), + total=len(futures)) + ] + + return files + + +def get_model_path(model_name: str, revision: Optional[str] = None): + """Get path to model dir in local huggingface hub (model) cache""" + config_file = "config.json" + err = None + try: + config_path = try_to_load_from_cache( + model_name, + config_file, + cache_dir=os.getenv("TRANSFORMERS_CACHE" + ), # will fall back to HUGGINGFACE_HUB_CACHE + revision=revision, + ) + if config_path is not None: + return config_path.removesuffix(f"/{config_file}") + except ValueError as e: + err = e + + if os.path.isfile(f"{model_name}/{config_file}"): + return model_name # Just treat the model name as an explicit model path + + if err is not None: + raise err + + raise ValueError( + f"Weights not found in local cache for model {model_name}") + + +def local_weight_files(model_path: str, extension=".safetensors"): + """Get the local safetensors filenames""" + ext = "" if extension is None else extension + return glob.glob(f"{model_path}/*{ext}") + + +def local_index_files(model_path: str, extension=".safetensors"): + """Get the local .index.json filename""" + ext = "" if extension is None else extension + return glob.glob(f"{model_path}/*{ext}.index.json") + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + # _find_shared_tensors returns a list of sets of names of tensors that + # have the same data, including sets with one element that aren't shared + if len(shared) == 1: + continue + + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError(f"Error while trying to find names to remove \ + to save state dict, but found no suitable name to \ + keep for saving amongst: {shared}. None is covering \ + the entire storage.Refusing to save/load the model \ + since you could be storing much more \ + memory than needed. Please refer to\ + https://huggingface.co/docs/safetensors/torch_shared_tensors \ + for more information. \ + Or open an issue.") + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): + """ + Convert a pytorch file to a safetensors file + This will remove duplicate tensors from the file. + + Unfortunately, this might not respect *transformers* convention. + Forcing us to check for potentially different keys during load when looking + for specific tensors (making tensor sharing explicit). + """ + loaded = torch.load(pt_file, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_file) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_index_file(source_file: Path, dest_file: Path, + pt_files: List[Path], sf_files: List[Path]): + weight_file_map = {s.name: d.name for s, d in zip(pt_files, sf_files)} + + logger.info( + "Converting pytorch .bin.index.json files to .safetensors.index.json") + with open(source_file, "r") as f: + index = json.load(f) + + index["weight_map"] = { + k: weight_file_map[v] + for k, v in index["weight_map"].items() + } + + with open(dest_file, "w") as f: + json.dump(index, f, indent=4) + + +def convert_files(pt_files: List[Path], + sf_files: List[Path], + discard_names: List[str] = None): + assert len(pt_files) == len(sf_files) + + # Filter non-inference files + pairs = [ + p for p in zip(pt_files, sf_files) if not any(s in p[0].name for s in [ + "arguments", + "args", + "training", + "optimizer", + "scheduler", + "index", + ]) + ] + + N = len(pairs) + + if N == 0: + logger.warning("No pytorch .bin weight files found to convert") + return + + logger.info("Converting %d pytorch .bin files to .safetensors...", N) + + for i, (pt_file, sf_file) in enumerate(pairs): + file_count = (i + 1) / N + logger.info('Converting: [%d] "$s"', file_count, pt_file.name) + start = datetime.datetime.now() + convert_file(pt_file, sf_file, discard_names) + elapsed = datetime.datetime.now() - start + logger.info('Converted: [%d] "%s" -- Took: %d', file_count, + sf_file.name, elapsed) From a9f34f2e616fef7722406ca731a4b69db9882489 Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Fri, 19 Jul 2024 14:54:47 -0400 Subject: [PATCH 2/2] Separate extra commands Signed-off-by: Rafael Vasquez --- vllm/entrypoints/openai/api_server.py | 29 ----- vllm/scripts.py | 159 +------------------------- vllm/tgis_utils/scripts.py | 96 ++++++++++++++++ 3 files changed, 97 insertions(+), 187 deletions(-) create mode 100644 vllm/tgis_utils/scripts.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 857c7d641db7..45c634b4a299 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,7 +18,6 @@ import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.grpc.grpc_server import start_grpc_server from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block # yapf: disable @@ -35,7 +34,6 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.logger import init_logger -from vllm.tgis_utils.args import add_tgis_args, postprocess_tgis_args from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION @@ -48,7 +46,6 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding -async_llm_engine: AsyncLLMEngine logger = init_logger('vllm.entrypoints.openai.api_server') @@ -68,15 +65,8 @@ async def _force_log(): _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) - grpc_server = await start_grpc_server(async_llm_engine, args) - yield - logger.info("Gracefully stopping gRPC server") - await grpc_server.stop(30) #TODO configurable grace - await grpc_server.wait_for_termination() - logger.info("gRPC server stopped") - router = APIRouter() @@ -230,16 +220,6 @@ def run_server(args, llm_engine=None): global engine, engine_args engine_args = AsyncEngineArgs.from_cli_args(args) - - # Enforce pixel values as image input type for vision language models - # when serving with API server - if engine_args.image_input_type is not None and \ - engine_args.image_input_type.upper() != "PIXEL_VALUES": - raise ValueError( - f"Invalid image_input_type: {engine_args.image_input_type}. " - "Only --image-input-type 'pixel_values' is supported for serving " - "vision language models with the vLLM API server.") - engine = (llm_engine if llm_engine is not None else AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) @@ -261,7 +241,6 @@ def run_server(args, llm_engine=None): global openai_serving_chat global openai_serving_completion global openai_serving_embedding - global async_llm_engine openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, @@ -273,11 +252,6 @@ def run_server(args, llm_engine=None): args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) - - # 🌶️🌶️🌶️ Sets the engine for the TGIS gRPC server. - # Do not delete on merge conflicts! - async_llm_engine = engine - app.root_path = args.root_path logger.info("Available routes are:") @@ -304,8 +278,5 @@ def run_server(args, llm_engine=None): parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) - parser = add_tgis_args(parser) args = parser.parse_args() - args = postprocess_tgis_args(args) - run_server(args) diff --git a/vllm/scripts.py b/vllm/scripts.py index c1ba1c910ff8..f2ee45abc042 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -12,6 +12,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import FlexibleArgumentParser +from vllm.tgis_utils.scripts import tgis_cli def registrer_signal_handlers(): @@ -50,19 +51,6 @@ def interactive_cli(args: argparse.Namespace) -> None: chat(args.system_prompt, model_name, openai_client) -def tgis_cli(args: argparse.Namespace) -> None: - registrer_signal_handlers() - - if args.command == "download-weights": - download_weights(args.model_name, args.revision, args.token, - args.extension, args.auto_convert) - elif args.command == "convert-to-safetensors": - convert_to_safetensors(args.model_name, args.revision) - elif args.command == "convert-to-fast-tokenizer": - convert_to_fast_tokenizer(args.model_name, args.revision, - args.output_path) - - def complete(model_name: str, client: OpenAI) -> None: print("Please enter prompt to complete:") while True: @@ -96,151 +84,6 @@ def chat(system_prompt: Optional[str], model_name: str, print(output) -def download_weights( - model_name: str, - revision: Optional[str] = None, - token: Optional[str] = None, - extension: str = ".safetensors", - auto_convert: bool = True, -) -> None: - from vllm.tgis_utils import hub - - print(extension) - meta_exts = [".json", ".py", ".model", ".md"] - - extensions = extension.split(",") - - if len(extensions) == 1 and extensions[0] not in meta_exts: - extensions.extend(meta_exts) - - files = hub.download_weights(model_name, - extensions, - revision=revision, - auth_token=token) - - if auto_convert and ".safetensors" in extensions: - if not hub.local_weight_files(hub.get_model_path(model_name, revision), - ".safetensors"): - if ".bin" not in extensions: - print(".safetensors weights not found, \ - downloading pytorch weights to convert...") - hub.download_weights(model_name, - ".bin", - revision=revision, - auth_token=token) - - print(".safetensors weights not found, \ - converting from pytorch weights...") - convert_to_safetensors(model_name, revision) - elif not any(f.endswith(".safetensors") for f in files): - print(".safetensors weights not found on hub, \ - but were found locally. Remove them first to re-convert") - if auto_convert: - convert_to_fast_tokenizer(model_name, revision) - - -def convert_to_safetensors( - model_name: str, - revision: Optional[str] = None, -): - from vllm.tgis_utils import hub - - # Get local pytorch file paths - model_path = hub.get_model_path(model_name, revision) - local_pt_files = hub.local_weight_files(model_path, ".bin") - local_pt_index_files = hub.local_index_files(model_path, ".bin") - if len(local_pt_index_files) > 1: - print( - f"Found more than one .bin.index.json file: {local_pt_index_files}" - ) - return - if not local_pt_files: - print("No pytorch .bin files found to convert") - return - - local_pt_files = [Path(f) for f in local_pt_files] - local_pt_index_file = local_pt_index_files[ - 0] if local_pt_index_files else None - - # Safetensors final filenames - local_st_files = [ - p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" - for p in local_pt_files - ] - - if any(os.path.exists(p) for p in local_st_files): - print("Existing .safetensors weights found, \ - remove them first to reconvert") - return - - try: - import transformers - - config = transformers.AutoConfig.from_pretrained( - model_name, - revision=revision, - ) - architecture = config.architectures[0] - - class_ = getattr(transformers, architecture) - - # Name for this variable depends on transformers version - discard_names = getattr(class_, "_tied_weights_keys", []) - discard_names.extend( - getattr(class_, "_keys_to_ignore_on_load_missing", [])) - - except Exception: - discard_names = [] - - if local_pt_index_file: - local_pt_index_file = Path(local_pt_index_file) - st_prefix = local_pt_index_file.stem.removeprefix( - "pytorch_").removesuffix(".bin.index") - local_st_index_file = (local_pt_index_file.parent / - f"{st_prefix}.safetensors.index.json") - - if os.path.exists(local_st_index_file): - print("Existing .safetensors.index.json file found, \ - remove it first to reconvert") - return - - hub.convert_index_file(local_pt_index_file, local_st_index_file, - local_pt_files, local_st_files) - - # Convert pytorch weights to safetensors - hub.convert_files(local_pt_files, local_st_files, discard_names) - - -def convert_to_fast_tokenizer( - model_name: str, - revision: Optional[str] = None, - output_path: Optional[str] = None, -): - from vllm.tgis_utils import hub - - # Check for existing "tokenizer.json" - model_path = hub.get_model_path(model_name, revision) - - if os.path.exists(os.path.join(model_path, "tokenizer.json")): - print(f"Model {model_name} already has a fast tokenizer") - return - - if output_path is not None: - if not os.path.isdir(output_path): - print(f"Output path {output_path} must exist and be a directory") - return - else: - output_path = model_path - - import transformers - - tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, - revision=revision) - tokenizer.save_pretrained(output_path) - - print(f"Saved tokenizer to {output_path}") - - def _add_query_options( parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( diff --git a/vllm/tgis_utils/scripts.py b/vllm/tgis_utils/scripts.py new file mode 100644 index 000000000000..e1c56e9ec8c9 --- /dev/null +++ b/vllm/tgis_utils/scripts.py @@ -0,0 +1,96 @@ +# The CLI entrypoint to vLLM. +import argparse +import os +import signal +import sys +from pathlib import Path +from typing import Optional + +from vllm.model_executor.model_loader.weight_utils import convert_bin_to_safetensor_file +from vllm.scripts import registrer_signal_handlers + + +def tgis_cli(args: argparse.Namespace) -> None: + registrer_signal_handlers() + + if args.command == "download-weights": + download_weights(args.model_name, args.revision, args.token, + args.extension, args.auto_convert) + elif args.command == "convert-to-safetensors": + convert_bin_to_safetensor_file(args.model_name, args.revision) + elif args.command == "convert-to-fast-tokenizer": + convert_to_fast_tokenizer(args.model_name, args.revision, + args.output_path) + + +def download_weights( + model_name: str, + revision: Optional[str] = None, + token: Optional[str] = None, + extension: str = ".safetensors", + auto_convert: bool = True, +) -> None: + from vllm.tgis_utils import hub + + print(extension) + meta_exts = [".json", ".py", ".model", ".md"] + + extensions = extension.split(",") + + if len(extensions) == 1 and extensions[0] not in meta_exts: + extensions.extend(meta_exts) + + files = hub.download_weights(model_name, + extensions, + revision=revision, + auth_token=token) + + if auto_convert and ".safetensors" in extensions: + if not hub.local_weight_files(hub.get_model_path(model_name, revision), + ".safetensors"): + if ".bin" not in extensions: + print(".safetensors weights not found, \ + downloading pytorch weights to convert...") + hub.download_weights(model_name, + ".bin", + revision=revision, + auth_token=token) + + print(".safetensors weights not found, \ + converting from pytorch weights...") + convert_bin_to_safetensor_file(model_name, revision) + elif not any(f.endswith(".safetensors") for f in files): + print(".safetensors weights not found on hub, \ + but were found locally. Remove them first to re-convert") + if auto_convert: + convert_to_fast_tokenizer(model_name, revision) + + +def convert_to_fast_tokenizer( + model_name: str, + revision: Optional[str] = None, + output_path: Optional[str] = None, +): + from vllm.tgis_utils import hub + + # Check for existing "tokenizer.json" + model_path = hub.get_model_path(model_name, revision) + + if os.path.exists(os.path.join(model_path, "tokenizer.json")): + print(f"Model {model_name} already has a fast tokenizer") + return + + if output_path is not None: + if not os.path.isdir(output_path): + print(f"Output path {output_path} must exist and be a directory") + return + else: + output_path = model_path + + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, + revision=revision) + tokenizer.save_pretrained(output_path) + + print(f"Saved tokenizer to {output_path}")