Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add download_kwargs for load model (#302) #399

Merged
merged 1 commit into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def from_pretrained(
trust_remote_code=True,
safetensors=True,
device_map=None,
download_kwargs=None,
**model_init_kwargs,
) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(
Expand All @@ -63,6 +64,7 @@ def from_pretrained(
trust_remote_code=trust_remote_code,
safetensors=safetensors,
device_map=device_map,
download_kwargs=download_kwargs,
**model_init_kwargs,
)

Expand All @@ -80,6 +82,7 @@ def from_quantized(
safetensors=True,
device_map="balanced",
offload_folder=None,
download_kwargs=None,
**config_kwargs,
) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
Expand All @@ -104,5 +107,6 @@ def from_quantized(
safetensors=safetensors,
device_map=device_map,
offload_folder=offload_folder,
download_kwargs=download_kwargs,
**config_kwargs,
)
27 changes: 24 additions & 3 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def from_pretrained(
"A device map that will be passed onto the model loading method from transformers."
),
] = None,
download_kwargs: Annotated[
Dict, Doc("Used for configure download model"),
] = None,
**model_init_kwargs: Annotated[
Dict,
Doc(
Expand All @@ -300,7 +303,9 @@ def from_pretrained(
"""A method for initialization of pretrained models, usually in FP16."""
# Get weights path and quant config
model_weights_path, config, quant_config = self._load_config(
self, model_path, "", safetensors, trust_remote_code=trust_remote_code
self, model_path, "", safetensors,
trust_remote_code=trust_remote_code,
download_kwargs=download_kwargs
)

target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
Expand Down Expand Up @@ -383,6 +388,9 @@ def from_quantized(
str,
Doc("The folder ot offload the model to."),
] = None,
download_kwargs: Annotated[
Dict, Doc("Used for configure download model"),
] = None,
**config_kwargs: Annotated[
Dict,
Doc(
Expand All @@ -399,6 +407,7 @@ def from_quantized(
safetensors,
trust_remote_code,
max_seq_len=max_seq_len,
download_kwargs=download_kwargs,
**config_kwargs,
)

Expand Down Expand Up @@ -470,6 +479,7 @@ def _load_config(
safetensors=True,
trust_remote_code=True,
max_seq_len=4096,
download_kwargs=None,
**config_kwargs,
):
# [STEP 1] Download model if path is not a directory
Expand All @@ -479,8 +489,19 @@ def _load_config(
ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
else:
ignore_patterns.append("*.safetensors*")

model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)

if download_kwargs is None:
download_kwargs = {}

if "ignore_patterns" in download_kwargs:
download_kwargs_ignore_patterns = download_kwargs.pop("ignore_patterns")

if isinstance(download_kwargs_ignore_patterns, str):
ignore_patterns.append(download_kwargs_ignore_patterns)
elif isinstance(download_kwargs_ignore_patterns, list):
ignore_patterns.extend(download_kwargs_ignore_patterns)

model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns, **download_kwargs)

if model_filename != "":
model_weights_path = model_path + f"/{model_filename}"
Expand Down