Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Fix WOQ huggingface model loading (#1400)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss authored Mar 21, 2024
1 parent 62b1e88 commit 01b1a44
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def recover_export_model(model, current_key_name=None):
Return optimum format model.
"""
from ..llm.quantization.nn.modules import QuantizedLinearQBits

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
Expand Down Expand Up @@ -194,8 +195,13 @@ def save_low_bit(
)
return

if self.quantization_config.weight_dtype not in \
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]:
if self.quantization_config.weight_dtype not in [
"fp8_e5m2",
"fp8_e4m3",
"nf4",
"fp4",
"int4_fullrange",
]:
convert_model_to_public(self)
os.makedirs(save_directory, exist_ok=True)
# use transformers original `save_pretrained` function
Expand Down Expand Up @@ -336,7 +342,27 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
return_unused_kwargs=True,
**kwargs,
)
if hasattr(config, "quantization_config"):

if kwargs.get("use_llm_runtime", None) is not None:
use_neural_speed = kwargs.pop("use_llm_runtime", True) and not use_xpu
logger.warning(
"use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead."
)
elif kwargs.get("use_neural_speed", None) is not None:
use_neural_speed = kwargs.pop("use_neural_speed", True) and not use_xpu
else:
if hasattr(config, "model_type") == False:
logger.error(
"Can't get the model_type. Please check the correct model_type"
)
exit(0)

if config.model_type in cls.model_type_list and not use_xpu:
use_neural_speed = True
else:
use_neural_speed = False

if hasattr(config, "quantization_config") and not use_neural_speed:
if config.quantization_config is None:
logger.warning(
"Quantization_config loading failed. If you want to load saved "
Expand Down Expand Up @@ -369,26 +395,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"Saved low bit model loading failed, please check your model."
)
exit(0)
if kwargs.get("use_llm_runtime", None) is not None:
use_neural_speed = kwargs.pop("use_llm_runtime", True) and not use_xpu
logger.warning(
"use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead."
)
elif kwargs.get("use_neural_speed", None) is not None:
use_neural_speed = kwargs.pop("use_neural_speed", True) and not use_xpu
else:
if hasattr(config, "model_type") == False:
logger.error(
"Can't get the model_type. Please check the correct model_type"
)
exit(0)

if config.model_type in cls.model_type_list and not use_xpu:
logger.info("Using Neural Speed...")
use_neural_speed = True
else:
logger.info("Using Pytorch...")
use_neural_speed = False

import intel_extension_for_transformers.transformers.modeling.modeling_map

Expand Down Expand Up @@ -437,7 +443,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if quantization_config is None:
if use_neural_speed:
# use wnf4_sfp32_cfp32_g32_sym by default
quantization_config = RtnConfig(compute_dtype="fp32", weight_dtype="nf4")
quantization_config = RtnConfig(
compute_dtype="fp32", weight_dtype="nf4"
)
else:
quantization_config = RtnConfig(
bits=4,
Expand Down Expand Up @@ -502,7 +510,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
):
logger.info("Applying Weight Only Quantization.")
if use_neural_speed:
logger.info("Using LLM runtime.")
logger.info("Using Neural Speed.")
quantization_config.post_init_runtime()
from neural_speed import Model

Expand Down Expand Up @@ -966,6 +974,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
kwargs["torch_dtype"] = "auto"
config = kwargs.pop("config", None)
quantization_config = config.quantization_config

if quantization_config["quant_method"] == "rtn":
quantization_config = RtnConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "awq":
Expand All @@ -976,7 +985,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config = GPTQConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "autoround":
quantization_config = AutoRoundConfig.from_dict(quantization_config)

assert (
quantization_config is not None
), "Detect this model is not a low-bit model."
Expand Down Expand Up @@ -1170,8 +1178,13 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = model_class(config, *model_args, **kwargs)
else:
model = model_class(config, *model_args, **kwargs)
if config.quantization_config["weight_dtype"] not in \
["fp8_e5m2", "fp8_e4m3", "fp4", "nf4", "int4_fullrange"]:
if config.quantization_config["weight_dtype"] not in [
"fp8_e5m2",
"fp8_e4m3",
"fp4",
"nf4",
"int4_fullrange",
]:
model = build_woq_model(model, quantization_config)
else:
model = replace_linear(
Expand Down Expand Up @@ -1221,8 +1234,12 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if config.quantization_config["weight_dtype"] not in \
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4" "int4_fullrange"]:
if config.quantization_config["weight_dtype"] not in [
"fp8_e5m2",
"fp8_e4m3",
"nf4",
"fp4" "int4_fullrange",
]:
model = replace_linear(
model,
quantization_config=quantization_config,
Expand Down
2 changes: 1 addition & 1 deletion tests/CI/test_weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_auto_model_saving_loading(self):
module_list.append(name)
self.assertTrue(len(module_list) > 0)
model.save_pretrained(self.workspace, safe_serialization=False)
loaded_model = AutoModelForCausalLM.from_pretrained(self.workspace)
loaded_model = AutoModelForCausalLM.from_pretrained(self.workspace, use_neural_speed=False)
for name, module in loaded_model.named_modules():
if isinstance(module, QuantizedLinearQBits):
module_list.append(name)
Expand Down

0 comments on commit 01b1a44

Please sign in to comment.