Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
recover scale_dtype and remove double_quant_scale_dtype
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <chang1.wang@intel.com>
  • Loading branch information
changwangss committed Dec 13, 2023
1 parent 7a8bb9d commit 294f334
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/weightonlyquant.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
```bash
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
# weight_dtype: int8/int4_fullrange/int4_clip/nf4/fp4_e2m1_bnb/fp4_e2m1/fp8_e5m2/fp8_e4m3
# double_quant_scale_dtype: fp32/fp8, fp8 only used for weight_dtype "fp8_e5m2", "fp8_e4m3"
# scale_dtype: fp32/fp8, fp8 only used for weight_dtype "fp8_e5m2", "fp8_e4m3"
woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange", group_size=32)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
elif args.woq:
quantization_config = WeightOnlyQuantConfig(
weight_dtype=args.woq_weight_dtype,
double_quant_scale_dtype=args.woq_scale_dtype,
scale_dtype=args.woq_scale_dtype,
group_size=args.woq_group_size,
scheme=args.woq_scheme,
algorithm=args.woq_algo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
elif args.woq:
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
double_quant_scale_dtype=args.woq_scale_dtype,
scale_dtype=args.woq_scale_dtype,
weight_dtype=args.woq_weight_dtype,
scheme=args.woq_scheme,
group_size=args.woq_group_size,
Expand Down
32 changes: 16 additions & 16 deletions intel_extension_for_transformers/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,45 +48,45 @@ def replace_linear(model, modules_to_not_convert=None, current_key_name=None, qu

def get_weight_type_from_config(config):
if config.weight_dtype == "int8":
if config.double_quant_scale_dtype == "fp32":
if config.scale_dtype == "fp32":
weight_type = "s8_scalef32"
else:
raise Exception("double_quant_scale_dtype only support fp32 now!")
raise Exception("scale_dtype only support fp32 now!")
elif config.weight_dtype == "int4_fullrange":
if config.double_quant_scale_dtype == "fp32":
if config.scale_dtype == "fp32":
weight_type = "s4fullrange_scalef32"
else:
raise Exception("double_quant_scale_dtype only support fp32 now!")
raise Exception("scale_dtype only support fp32 now!")
elif config.weight_dtype == "int4_clip":
if config.double_quant_scale_dtype == "fp32":
if config.scale_dtype == "fp32":
weight_type = "s4clip_scalef32"
else:
raise Exception("double_quant_scale_dtype only support fp32 now!")
raise Exception("scale_dtype only support fp32 now!")
elif config.weight_dtype == "fp4_e2m1_bnb":
if config.double_quant_scale_dtype == "fp32":
if config.scale_dtype == "fp32":
weight_type = "fp4bnb_scalef32"
else:
raise Exception("double_quant_scale_dtype only support fp32 now!")
raise Exception("scale_dtype only support fp32 now!")
elif config.weight_dtype == "fp4_e2m1":
if config.double_quant_scale_dtype == "fp32":
if config.scale_dtype == "fp32":
weight_type = "fp4e2m1_scalef32"
else:
raise Exception("double_quant_scale_dtype only support fp32 now!")
raise Exception("scale_dtype only support fp32 now!")
elif config.weight_dtype == "nf4":
if config.double_quant_scale_dtype == "fp32":
if config.scale_dtype == "fp32":
weight_type = "nf4_scalef32"
else:
raise Exception("double_quant_scale_dtype only support fp32 now!")
raise Exception("scale_dtype only support fp32 now!")
elif config.weight_dtype == "fp8_e5m2":
if config.double_quant_scale_dtype == "fp8":
if config.scale_dtype == "fp8":
weight_type = "fp8e5m2_scalef8"
else:
raise Exception("double_quant_scale_dtype only support fp8 now!")
raise Exception("scale_dtype only support fp8 now!")
elif config.weight_dtype == "fp8_e4m3":
if config.double_quant_scale_dtype == "fp8":
if config.scale_dtype == "fp8":
weight_type = "fp8e4m3_scalef8"
else:
raise Exception("double_quant_scale_dtype only support fp8 now!")
raise Exception("scale_dtype only support fp8 now!")
return weight_type


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Argument description of WeightOnlyQuantConfig:
| weight_dtype | String | Data type of quantized weight: int4/int8 (default int4) |
| alg | String | Quantization algorithm: sym/asym (default sym) |
| group_size | Int | Group size: Int (default: 32) |
| double_quant_scale_dtype | String | Data type of scales: fp32/bf16 (default fp32) |
| scale_dtype | String | Data type of scales: fp32/bf16 (default fp32) |
| use_ggml | Bool | Enable ggml for quantization and inference (default: False) |
| use_quant | Bool | Determine whether or not the model will be quantized. (default: True) |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
weight_dtype=quantization_config.weight_dtype,
alg=quantization_config.scheme,
group_size=quantization_config.group_size,
scale_dtype=quantization_config.double_quant_scale_dtype,
scale_dtype=quantization_config.scale_dtype,
compute_dtype=quantization_config.compute_dtype,
use_ggml=quantization_config.use_ggml,
use_quant=quantization_config.use_quant,
Expand Down
16 changes: 8 additions & 8 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def __init__(
llm_int8_skip_modules=None,
compute_dtype=None,
weight_dtype=None,
scale_dtype="fp32"
mse_range=False,
use_double_quant=False,
double_quant_dtype="int8", # reserve for double quant
double_quant_scale_dtype="fp32", # reserve for double quant
group_size=32,
scheme="sym",
algorithm="RTN",
Expand All @@ -56,7 +56,7 @@ def __init__(
self.mse_range = mse_range
self.use_double_quant = use_double_quant
self.double_quant_dtype = double_quant_dtype
self.double_quant_scale_dtype = double_quant_scale_dtype
self.scale_dtype = scale_dtype
self.scheme = scheme
self.algorithm = algorithm
self.group_size = group_size
Expand Down Expand Up @@ -107,9 +107,9 @@ def post_init(self):
f"'int8', 'int4_fullrange', 'int4_clip', 'nf4', 'fp4_e2m1_bnb', 'fp4_e2m1', 'fp8_e5m2, fp8_e4m3'"
)

if self.double_quant_scale_dtype not in ["fp32", "fp8"]:
if self.scale_dtype not in ["fp32", "fp8"]:
raise ValueError(
f"double_quant_scale_dtype must be a string in 'fp32', 'fp8' "
f"scale_dtype must be a string in 'fp32', 'fp8' "
f"and fp8 only used for weight_dtype 'fp8_e5m2', 'fp8_e4m3'")

if not isinstance(self.mse_range, bool):
Expand All @@ -121,8 +121,8 @@ def post_init(self):
if self.use_double_quant and not isinstance(self.double_quant_dtype, str):
raise ValueError("double_quant_dtype must be a string")

if self.use_double_quant and not isinstance(self.double_quant_scale_dtype, str):
raise ValueError("double_quant_scale_dtype must be a string")
if self.use_double_quant and not isinstance(self.scale_dtype, str):
raise ValueError("scale_dtype must be a string")

if not isinstance(self.group_size, int):
raise ValueError("group_size must be a int")
Expand Down Expand Up @@ -150,8 +150,8 @@ def post_init_runtime(self):
elif self.weight_dtype not in ["int4", "int8"]:
raise ValueError(f"weight_dtype must be 'int4', 'int8'.")

if self.double_quant_scale_dtype not in ["fp32", "fp16"]:
raise ValueError("double_quant_scale_dtype must be 'fp32', 'fp16'.")
if self.scale_dtype not in ["fp32", "fp16"]:
raise ValueError("scale_dtype must be 'fp32', 'fp16'.")

if self.group_size not in [-1, 32, 128]:
raise ValueError("group_size must be an integer in [-1, 32, 128]")
Expand Down

0 comments on commit 294f334

Please sign in to comment.