Skip to content

Commit

Permalink
Squash 3125
Browse files Browse the repository at this point in the history
  • Loading branch information
joerunde committed Apr 2, 2024
1 parent 5d9fb86 commit b59b8d2
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 34 deletions.
25 changes: 23 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ def __init__(
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
local_files_only: bool = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir
self.local_files_only = local_files_only
self.load_format = load_format
self.seed = seed
self.revision = revision
Expand All @@ -110,6 +112,10 @@ def __init__(
from modelscope.hub.snapshot_download import snapshot_download

if not os.path.exists(model):
if self.local_files_only:
raise ValueError(
f"Unable to find cached ModelScope model for {model} "
f"with local_files_only==True")
model_path = snapshot_download(model_id=model,
cache_dir=download_dir,
revision=revision)
Expand All @@ -119,8 +125,23 @@ def __init__(
self.download_dir = model_path
self.tokenizer = model_path

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
elif self.local_files_only:
# TODO: fully support local_files_only propagation through
# each model class's load_weights function
#
# For places where we don't propagate local_files_only, modify
# the env var...
os.environ['HF_HUB_OFFLINE'] = "1"
# and monkey patch...
import huggingface_hub
huggingface_hub.constants.HF_HUB_OFFLINE = True

self.hf_config = get_config(self.model,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
cache_dir=download_dir,
revision=revision,
code_revision=code_revision)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
Expand Down
10 changes: 9 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Optional, Tuple

from huggingface_hub.constants import HF_HUB_OFFLINE

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, TokenizerPoolConfig,
VisionLanguageConfig)
Expand All @@ -17,6 +19,7 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
local_files_only: bool = HF_HUB_OFFLINE # checks TRANSFORMERS_OFFLINE too
load_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
Expand Down Expand Up @@ -124,6 +127,11 @@ def add_cli_args(
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument(
'--local-files-only',
action='store_true',
default=EngineArgs.local_files_only,
help='disable downloads and only look at local files')
parser.add_argument(
'--load-format',
type=str,
Expand Down Expand Up @@ -395,7 +403,7 @@ def create_engine_configs(
self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs)
self.max_logprobs, self.local_files_only)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, "
f"local_files_only={model_config.local_files_only}, "
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce="
Expand Down Expand Up @@ -231,6 +232,8 @@ def _init_tokenizer(self, **tokenizer_init_kwargs):
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
local_files_only=self.model_config.local_files_only,
cache_dir=self.model_config.download_dir,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
Expand Down
5 changes: 4 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ async def _post_init(self):
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
trust_remote_code=engine_model_config.trust_remote_code,
local_files_only=engine_model_config.local_files_only,
cache_dir=engine_model_config.download_dir,
)

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
Expand Down
84 changes: 62 additions & 22 deletions vllm/model_executor/weight_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for downloading and initializing model weights."""
import contextlib
import fnmatch
import glob
import hashlib
Expand All @@ -9,8 +10,11 @@

import filelock
import numpy as np
import requests
import torch
from huggingface_hub import HfFileSystem, snapshot_download
from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import OfflineModeIsEnabled, RevisionNotFoundError
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm

Expand Down Expand Up @@ -143,14 +147,11 @@ def prepare_hf_model_weights(
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
use_safetensors = False
# Determine the format of weights to load
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
Expand All @@ -162,29 +163,68 @@ def prepare_hf_model_weights(
if fall_back_to_pt:
allow_patterns += ["*.pt"]

if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break

logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
# Find the model weights to load:
# - check if pointing at a local directory
# - download weights from HuggingFace Hub (including a newer revision if it
# exists)
# - discover weights in the local HuggingFace Hub cache (fallback to this if
# download fails)
if os.path.isdir(model_name_or_path):
hf_folder = model_name_or_path
else:
# If there is an error downloading from the HF API, we'll fallback to
# loading from the local cache
local_files_only = False
if HF_HUB_OFFLINE:
local_files_only = True
else:
try:
# Before we download we check the available files
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path,
detail=False,
revision=revision)

# Depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break

logger.info(f"Using model weights format {allow_patterns}")
except (
requests.exceptions.SSLError,
requests.exceptions.ProxyError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
OfflineModeIsEnabled,
RevisionNotFoundError,
FileNotFoundError,
requests.HTTPError,
) as error:
# If querying the repo fails (eg. Network is down / HF Hub is
# down / HF Hub returns access error / or HF_HUB_OFFLINE=1), see
# if we can fallback to load from locally cached files instead
# of crashing
logger.warning(f"Error in call to HF Hub: {error}. "
f"Attempting to load from local cache instead.")
local_files_only = True

# Use file lock to prevent multiple processes from downloading the same
# model weights at the same time. If we fallback to local files only,
# we don't need the lock, but we still use snapshot_download to resolve
# the path to the model files in the cache
with get_lock(model_name_or_path, cache_dir
) if not local_files_only else contextlib.nullcontext():
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=local_files_only,
tqdm_class=Disabledtqdm,
revision=revision)
else:
hf_folder = model_name_or_path

use_safetensors = False
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
Expand Down
27 changes: 19 additions & 8 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@
}


def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
def get_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
cache_dir: Optional[str] = None,
local_files_only: bool = False,
) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
code_revision=code_revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
Expand All @@ -37,9 +44,13 @@ def get_config(model: str,
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
config = config_class.from_pretrained(
model,
revision=revision,
code_revision=code_revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
return config


Expand Down

0 comments on commit b59b8d2

Please sign in to comment.