Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Fixed some issue
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng Penghui <penghui.cheng@intel.com>
  • Loading branch information
PenghuiCheng committed Mar 11, 2024
1 parent 3c140a2 commit ba0c692
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers.generation import GenerationConfig
import intel_extension_for_pytorch as ipex
from intel_extension_for_transformers.llm.utils.generation import _beam_search, _greedy_search
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, RtnConfig, GPTQConfig
from intel_extension_for_transformers.llm.quantization.utils import convert_dtype_str2torch
from transformers.utils import check_min_version

Expand Down Expand Up @@ -76,12 +76,7 @@
"--gptq_nsamples", type=int, default=128, help="Number of calibration data samples."
)
parser.add_argument(
"--gptq_use_max_length",
action="store_true",
help="Set all sequence length to be same length of args.gptq_pad_max_length",
)
parser.add_argument(
"--gptq_pad_max_length",
"--max_input_length",
type=int,
default=2048,
help="Calibration dataset sequence max length, this should align with your model config",
Expand Down Expand Up @@ -118,26 +113,25 @@
quantization_config = None
if args.woq:
if args.woq_algo == "GPTQ":
algorithm_args = {
"act_order": False,
"percdamp": args.gptq_percdamp,
"block_size": args.gptq_block_size,
"nsamples": args.gptq_nsamples,
"use_max_length": args.gptq_use_max_length,
"pad_max_length": args.gptq_pad_max_length,
}
quantization_config = WeightOnlyQuantConfig(
quantization_config = GPTQConfig(
tokenizer=tokenizer,
dataset=args.dataset,
bits=args.bits,
desc_act=args.desc_act,
damp_percent=args.gptq_percdamp,
sym=True if args.woq_scheme == "sym" else False,
blocksize=args.gptq_block_size,
nsamples=args.gptq_nsamples,
static_groups=args.static_groups,
group_size=args.woq_group_size,
max_input_length=args.max_input_length,
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.woq_dtype,
scheme=args.woq_scheme,
group_size=args.woq_group_size,
algorithm=args.woq_algo,
tokenizer=tokenizer,
algorithm_args=algorithm_args,
calib_iters=args.calib_iters,
)
else:
quantization_config = WeightOnlyQuantConfig(
quantization_config = RtnConfig(
compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype,
group_size=args.woq_group_size, scale_dtype=args.compute_dtype
) #default is A16W4G16
Expand Down
5 changes: 2 additions & 3 deletions intel_extension_for_transformers/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,8 @@ def default_calib_func(model):
if orig_dtype != torch.float32:
q_model.to(dtype=orig_dtype)

config.low_bit_model = True
config.tokenizer = None
q_model.config.quantize_config = config.to_dict()
# config.tokenizer = None
# q_model.config.quantization_config = config.to_dict()
return q_model.to(device)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def save_low_bit(
commit_message=commit_message,
token=kwargs.get("token"),
)
self.quantization_config.low_bit_model = True
self.quantization_config.save_pretrained(save_directory, **kwargs)


Expand Down Expand Up @@ -209,29 +208,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
use_xpu = (True if device_map == torch.device("xpu") or device_map == "xpu" else False)

config = kwargs.pop("config", None)
trust_remote_code = kwargs.get("trust_remote_code", None)

if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
config, _ = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
return_unused_kwargs=True,
**kwargs,

)
if hasattr(config, "quantize_config"):
if config.quantize_config is None or config.quantize_config["low_bit_model"] != True:
if hasattr(config, "quantization_config"):
if config.quantization_config is None:
logger.warning("Quantization_config loading failed. If you want to load saved "
"low bit model, please check your quantizate_config.json.")
else:
logger.info(
"quantization_config: {}".format(
config.quantize_config
config.quantization_config
)
)
try:
kwargs["device_map"] = \
config.quantize_config["device"] if "device" in config.quantize_config.keys() else "auto"
kwargs["quantize_config"] = config.quantize_config
model = cls.load_low_bit(pretrained_model_name_or_path, *model_args, **kwargs)
config.quantization_config["device"] if "device" in config.quantization_config.keys() else "auto"
model = cls.load_low_bit(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
logger.info("Saved low bit model loading successfully. Other input args "
"will be ignored.")
return model
Expand Down Expand Up @@ -265,19 +263,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if isinstance(quantization_config, BitsAndBytesConfig):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
quantization_config=quantization_config,
*model_args,
config=config,
quantization_config=quantization_config,
**kwargs,
)
return model
if load_in_8bit or load_in_4bit:
if (is_accelerate_available() and is_bitsandbytes_available() and not use_cpu and not use_xpu):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
quantization_config=quantization_config,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
*model_args,
**kwargs,
)
logger.info("WeightOnlyQuant bitsandbytes done.")
Expand Down Expand Up @@ -328,7 +328,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"will fall to traditional load method with higher memory consumption."
)
kwargs["low_cpu_mem_usage"] = False
model = cls.ORIG_MODEL.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model = cls.ORIG_MODEL.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
model.config.update({"low_cpu_mem_usage": False})
model = model.to("cpu")
model.config.update({"device": "cpu"})
Expand Down Expand Up @@ -360,13 +360,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if use_xpu:
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
kwargs["low_cpu_mem_usage"] = True
kwargs["device_map"] = "auto"
kwargs["device_map"] = "cpu"
try:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
torchscript=True
if quantization_config.quant_method.value in ["teq", "awq"] and not use_xpu else False,
*model_args,
config=config,
**kwargs,
)
model.config.update({"low_cpu_mem_usage": True})
Expand All @@ -376,27 +375,25 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
kwargs["low_cpu_mem_usage"] = False
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
torchscript=True
if quantization_config.quant_method.value in ["teq", "awq"] and not use_xpu else False,
*model_args,
config=config,
**kwargs,
)
model.config.update({"low_cpu_mem_usage": False})
else:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
torchscript=True
if quantization_config.quant_method.value in ["teq", "awq"] and not use_xpu else False,
*model_args,
config=config,
**kwargs,
)
model.eval()

quantization_config.update(kwargs={"device": "cpu"})
quantization_config.update(**{"device": "cpu"})
if use_xpu:
import intel_extension_for_pytorch
assert hasattr(torch, "xpu") and torch.xpu.is_available(), "There is no xpu device in this system!"
quantization_config.update(kwargs={"device": "xpu"})
quantization_config.update(**{"device": "xpu"})
if (not torch.cuda.is_available() or device_map == "cpu"
or device_map == torch.device("cpu")) and model.config.model_type == "chatglm":
model = model.float()
Expand All @@ -405,6 +402,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
elif use_xpu:
quantization_config.post_init_xpu()
model = convert_to_quantized_model(model, quantization_config, device=device_map)
quantization_config.tokenizer = None
model.config.quantization_config = quantization_config

# add quantization_config and save_low_bit to pretrained model dynamically
model.device_map = device_map
Expand All @@ -420,11 +419,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
assert (ipex.__version__ >= "2.2.0+cpu"), "Please use Intel Extension for PyTorch >=2.2.0+cpu."
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
low_cpu_mem_usage=True,
torch_dtype=torch.float,
torchscript=True,
use_cache=True,
*model_args,
**kwargs,
)

Expand Down Expand Up @@ -653,7 +652,9 @@ def calib_func(model):
)
logger.info("SmoothQuant done.")
else:
model = cls.ORIG_MODEL.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
if (not torch.cuda.is_available() or device_map == "cpu"
or device_map == torch.device("cpu")) and model.config.model_type == "chatglm":
model = model.float()
Expand Down Expand Up @@ -695,7 +696,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Autofactory
kwargs_orig = copy.deepcopy(kwargs)
# modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
trust_remote_code = kwargs.get("trust_remote_code", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Maybe needed when extract_local_archive_file
subfolder = kwargs.get("subfolder", "")
variant = kwargs.get("variant", None)
Expand Down Expand Up @@ -726,25 +727,20 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# if torch_dtype=auto was passed here, ensure to pass it on
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"
quantize_config = kwargs.pop("quantize_config")
if quantize_config["quant_method"] == "rtn":
quantization_config = RtnConfig.from_dict(quantize_config)
elif quantize_config["quant_method"] == "awq":
quantization_config = AwqConfig.from_dict(quantize_config)
elif quantize_config["quant_method"] == "teq":
quantization_config = TeqConfig.from_dict(quantize_config)
elif quantize_config["quant_method"] == "gptq":
quantization_config = GPTQConfig.from_dict(quantize_config)
elif quantize_config["quant_method"] == "autoround":
quantization_config = AutoRoundConfig.from_dict(quantize_config)
config = kwargs.pop("config", None)
quantization_config = config.quantization_config
if quantization_config["quant_method"] == "rtn":
quantization_config = RtnConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "awq":
quantization_config = AwqConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "teq":
quantization_config = TeqConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "gptq":
quantization_config = GPTQConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "autoround":
quantization_config = AutoRoundConfig.from_dict(quantization_config)

assert (quantization_config is not None), "Detect this model is not a low-bit model."
kwargs["trust_remote_code"] = trust_remote_code
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
**kwargs,
)

if commit_hash is None:
if not isinstance(config, PretrainedConfig):
Expand All @@ -768,8 +764,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
else:
commit_hash = getattr(config, "_commit_hash", None)

config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
low_cpu_mem_usage = config_dict.pop("low_cpu_mem_usage", True)
low_cpu_mem_usage = config.low_cpu_mem_usage

has_remote_code = (hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map)

Expand Down
Loading

0 comments on commit ba0c692

Please sign in to comment.