diff --git a/docs/weightonlyquant.md b/docs/weightonlyquant.md index 2367ee01966..5ba72bd5741 100644 --- a/docs/weightonlyquant.md +++ b/docs/weightonlyquant.md @@ -67,7 +67,8 @@ input_ids = tokenizer(prompt, return_tensors="pt").input_ids 4-bit/8-bit inference with `WeightOnlyQuantConfig` on CPU device. ```bash from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig -# weight_dtype: int8/int4_fullrange/int4_clip/nf4/fp4_e2m1_bnb/fp4_e2m1 +# weight_dtype: int8/int4_fullrange/int4_clip/nf4/fp4_e2m1_bnb/fp4_e2m1/fp8_e5m2/fp8_e4m3 +# scale_dtype: fp32/fp8, fp8 only used for weight_dtype "fp8_e5m2", "fp8_e4m3" woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange", group_size=32) woq_model = AutoModelForCausalLM.from_pretrained( model_name_or_path, @@ -78,7 +79,7 @@ gen_ids = woq_model.generate(input_ids, max_new_tokens=32, **generate_kwargs) gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) print(gen_text) ``` -4-bit/8-bit inference with Huggingface Transformers `BitsAndBytesConfig` is also supported on CUDA GPU device. +4-bit/8-bit inference with Huggingface Transformers `BitsAndBytesConfig` on CUDA GPU device. ```bash from intel_extension_for_transformers.transformers import AutoModelForCausalLM, BitsAndBytesConfig woq_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4") diff --git a/examples/huggingface/pytorch/code-generation/quantization/README.md b/examples/huggingface/pytorch/code-generation/quantization/README.md index fe7517c24ba..c2b2cac0cb5 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/README.md +++ b/examples/huggingface/pytorch/code-generation/quantization/README.md @@ -1,5 +1,5 @@ # Step-by-Step -We provide the inference benchmarking script `run_generation.py` for Starcoder and CodeLlama models, [bigcode/starcode](https://huggingface.co/bigcode/starcoder), [bigcode/starcodebase](https://huggingface.co/bigcode/starcoderbase), [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) for code generation tasks, the evaluation part(solution execution) for [MultiPL-E](https://github.com/nuprl/MultiPL-E) requires extra dependencies for some programming languages, we provide a `Dockerfile-multiple` with all dependencies, see [Docker](./Dockerfile-multiple) for more details. +We provide the inference benchmarking script `run_generation.py` for Starcoder and CodeLlama models, [bigcode/starcoder](https://huggingface.co/bigcode/starcoder), [bigcode/starcoderbase](https://huggingface.co/bigcode/starcoderbase), [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) for code generation tasks, the evaluation part(solution execution) for [MultiPL-E](https://github.com/nuprl/MultiPL-E) requires extra dependencies for some programming languages, we provide a `Dockerfile-multiple` with all dependencies, see [Docker](./Dockerfile-multiple) for more details. # Prerequisite​ @@ -18,59 +18,130 @@ pip install -r requirements.txt ``` # Run - -## 1. Quantization -``` bash -python run_generation.py \ - --model bigcode/starcoder \ - --output_dir "./saved_results" \ - --sq \ - --alpha 0.7 \ - --calib_iters 500 \ - --calib_batch_size 1 \ - --dataset "mbpp" -``` -``` bash -python run_generation.py \ - --model codellama/CodeLlama-7b-hf \ - --output_dir "./saved_results" \ - --woq \ - --calib_iters 500 \ - --calib_batch_size 1 \ - --dataset "mbpp" -``` - -## 2. Performance - +We provide compression technologies such as `MixedPrecision`, `SmoothQuant` and `WeightOnlyQuant` with `RTN/AWQ/TEQ` algorithms and `BitsandBytes`, `load_in_4bit` and `load_in_8bit` work on CPU device, the followings are command to show how to use it. +## 1. Performance ```bash export KMP_BLOCKTIME=1 export KMP_SETTINGS=1 export KMP_AFFINITY=granularity=fine,compact,1,0 export LD_PRELOAD=${CONDA_PREFIX}/lib/libiomp5.so export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so -# --int8 is used for int8 model +# fp32 +OMP_NUM_THREADS= numactl -m -C python run_generation.py \ + --model bigcode/starcoder \ + --benchmark \ + --batch_size 1 +# mixedprecision OMP_NUM_THREADS= numactl -m -C python run_generation.py \ + --model bigcode/starcoder \ + --mixed_precision \ + --benchmark \ + --batch_size 1 +# smoothquant +# [alternative] --int8 is used for int8 only, --int8_bf16_mixed is used for int8 mixed bfloat16 precision. +python run_generation.py \ --model bigcode/starcoder \ --output_dir "./saved_results" \ + --sq \ + --alpha 0.7 \ + --calib_iters 500 \ + --dataset "mbpp" --int8 \ --benchmark \ --batch_size 1 +# weightonlyquant +OMP_NUM_THREADS= numactl -m -C python run_generation.py \ + --model bigcode/starcoder \ + --woq \ + --benchmark \ + --batch_size 1 +# load_in_4bit +OMP_NUM_THREADS= numactl -m -C python run_generation.py \ + --model bigcode/starcoder \ + --load_in_4bit True \ + --benchmark \ + --batch_size 1 +# load_in_8bit +OMP_NUM_THREADS= numactl -m -C python run_generation.py \ + --model bigcode/starcoder \ + --load_in_8bit True \ + --benchmark \ + --batch_size 1 ``` +## 2. Accuracy -## 3. Accuracy ```bash -# --int8 is used for int8 model +# fp32 python run_generation.py \ --model bigcode/starcoder \ - --output_dir "./saved_results" \ + --accuracy \ + --batch_size 20 \ + --n_samples 20 \ + --allow_code_execution \ + --temperature 0.2 \ + --do_sample \ + --tasks "humaneval" \ +# mixedprecision +python run_generation.py \ + --model bigcode/starcoder \ + --mixed_precision \ + --accuracy \ + --batch_size 20 \ + --n_samples 20 \ + --allow_code_execution \ + --temperature 0.2 \ + --do_sample \ + --tasks "humaneval" \ +# smoothquant +# [alternative] --int8 is used for int8 only, --int8_bf16_mixed is used for int8 mixed bfloat16 precision. +python run_generation.py \ + --model bigcode/starcoder \ + --sq \ + --alpha 1.0 \ --int8 \ + --accuracy \ --batch_size 20 \ + --n_samples 20 \ + --allow_code_execution \ + --temperature 0.2 \ + --do_sample \ + --tasks "humaneval" \ +# weightonlyquant +python run_generation.py \ + --model bigcode/starcoder \ + --woq \ + --woq_weight_dtype "nf4" \ + --accuracy \ + --batch_size 20 \ + --n_samples 20 \ + --allow_code_execution \ + --temperature 0.2 \ + --do_sample \ + --tasks "humaneval" \ +# load_in_4bit +python run_generation.py \ + --model bigcode/starcoder \ + --load_in_4bit True \ + --accuracy \ + --batch_size 20 \ + --n_samples 20 \ + --allow_code_execution \ + --temperature 0.2 \ + --do_sample \ + --tasks "humaneval" \ +# load_in_8bit +python run_generation.py \ + --model bigcode/starcoder \ + --load_in_8bit True \ --accuracy \ + --batch_size 20 \ --n_samples 20 \ --allow_code_execution \ --temperature 0.2 \ - --do_sample + --do_sample \ + --tasks "humaneval" \ ``` + >Note: please follow the [guide](https://huggingface.co/docs/accelerate/usage_guides/ipex) to set up the configuration if `accelerate launch` is used. @@ -122,4 +193,4 @@ docker run -v $(CURDIR):$(CURDIR) \ --int8 --accuracy --tasks multiple-py --batch_size 20 --n_samples 20 --allow_code_execution \ --do_sample --temperature 0.2 --limit 2 -``` \ No newline at end of file +``` diff --git a/examples/huggingface/pytorch/code-generation/quantization/requirements.txt b/examples/huggingface/pytorch/code-generation/quantization/requirements.txt index d1572d2129a..b1c2ab59734 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/requirements.txt +++ b/examples/huggingface/pytorch/code-generation/quantization/requirements.txt @@ -5,6 +5,7 @@ protobuf sentencepiece != 0.1.92 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.1.0+cpu +peft==0.6.2 transformers >= 4.35.0 neural-compressor intel_extension_for_pytorch diff --git a/examples/huggingface/pytorch/code-generation/quantization/run_generation.py b/examples/huggingface/pytorch/code-generation/quantization/run_generation.py index 2a1c8721c4b..c6d3f46a0d4 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/code-generation/quantization/run_generation.py @@ -9,21 +9,16 @@ import numpy as np from itertools import chain from pathlib import Path -from datasets import load_dataset -from torch.nn.functional import pad -from torch.utils.data import DataLoader -from transformers import AutoTokenizer, PretrainedConfig, AutoConfig -import transformers +from transformers import AutoTokenizer, AutoConfig from optimum.utils import NormalizedConfigManager -from optimum.intel.generation.modeling import TSModelForCausalLM from intel_extension_for_transformers.transformers import ( MixedPrecisionConfig, WeightOnlyQuantConfig, SmoothQuantConfig, + BitsAndBytesConfig, ) from intel_extension_for_transformers.transformers import ( AutoModelForCausalLM, - AutoModel, ) parser = argparse.ArgumentParser() @@ -41,9 +36,6 @@ ) parser.add_argument("--output_dir", nargs="?", default="./saved_results") parser.add_argument("--calib_iters", default=32, type=int, help="calibration iters.") -parser.add_argument( - "--calib_batch_size", default=1, type=int, help="calibration batch size." -) parser.add_argument("--int8", action="store_true") parser.add_argument( "--int8_bf16_mixed", @@ -69,6 +61,9 @@ parser.add_argument("--sq", action="store_true") parser.add_argument("--alpha", default="0.5", help="Smooth quant parameter.") # ============WeightOnlyQuant configs============ +parser.add_argument("--bitsandbytes", action="store_true") +parser.add_argument("--load_in_4bit", action="store_true") +parser.add_argument("--load_in_8bit", action="store_true") parser.add_argument("--woq", action="store_true") parser.add_argument( "--woq_algo", @@ -77,17 +72,30 @@ help="Weight-only parameter.", ) parser.add_argument( - "--woq_dtype", + "--woq_weight_dtype", type=str, default="int4_fullrange", - choices=["int8", "int4_clip", "int4_fullrange", "fp4_e2m1_bnb", "fp4_e2m1", "nf4"], + choices=[ + "int8", + "int4_clip", + "int4_fullrange", + "fp4_e2m1_bnb", + "fp4_e2m1", + "nf4", + "fp8_e5m2", + "fp8_e4m3", + ], +) +parser.add_argument( + "--woq_scale_dtype", + type=str, + default="fp32", + choices=["fp32", "fp8"], ) parser.add_argument("--woq_group_size", type=int, default=32) parser.add_argument("--woq_scheme", default="sym") # ============Harness configs============ -parser.add_argument( - "--tasks", default="humaneval", help="Evaluation tasks", choices=["mbpp", "humaneval"] -) +parser.add_argument("--tasks", default=None, help="Evaluation tasks") parser.add_argument("--n_samples", default=200, type=int) parser.add_argument( "--limit", default=None, type=int, help="Limit number of samples to eval" @@ -178,13 +186,21 @@ ) elif args.woq: quantization_config = WeightOnlyQuantConfig( - weight_dtype=args.woq_dtype, + weight_dtype=args.woq_weight_dtype, + scale_dtype=args.woq_scale_dtype, group_size=args.woq_group_size, scheme=args.woq_scheme, algorithm=args.woq_algo, ) # default is A32W4G32 +# bitsandbytes +elif args.bitsandbytes: + # GPU device is need for `load_in_4bit` and `load_in_8bit`. + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + ) - +# get optimized model if quantization_config is not None: user_model = AutoModelForCausalLM.from_pretrained( args.model, @@ -193,15 +209,15 @@ revision=args.revision, use_llm_runtime=False, ) - # save model - if args.sq: - config.save_pretrained(args.output_dir) - user_model.save(args.output_dir) - elif args.mixed_precision: - user_model.config.save_pretrained(args.output_dir) - torch.save( - user_model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin") - ) +elif args.load_in_4bit or args.load_in_8bit: + # CPU device usage is provided by intel-extension-for-transformers. + user_model = AutoModelForCausalLM.from_pretrained( + args.model, + load_in_4bit=args.load_in_4bit, + load_in_8bit=args.load_in_8bit, + revision=args.revision, + use_llm_runtime=False, + ) elif not args.int8 and not args.int8_bf16_mixed: user_model = AutoModelForCausalLM.from_pretrained( args.model, @@ -211,6 +227,15 @@ use_llm_runtime=False, ) +# save model +if args.sq: + config.save_pretrained(args.output_dir) + user_model.save(args.output_dir) +elif args.mixed_precision: + user_model.config.save_pretrained(args.output_dir) + torch.save( + user_model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin") + ) if args.int8 or args.int8_bf16_mixed: # TorchScript model don't attribute generate method, the wrapper is provided. diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py index 84605c8d677..3b5150694d9 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py @@ -35,7 +35,10 @@ help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)", ) parser.add_argument( - "--restore", action="store_true", help="restore ipex quantized model from output_dir/best_configure.json") + "--restore", + action="store_true", + help="restore ipex quantized model from output_dir/best_configure.json", +) parser.add_argument( "--peft_model_id", type=str, default=None, help="model_name_or_path of peft model" ) @@ -61,7 +64,9 @@ # ============SmoothQuant configs============== parser.add_argument("--sq", action="store_true") parser.add_argument("--alpha", default="0.5", help="Smooth quant parameter.") -parser.add_argument("--fallback_add", action="store_true", help="Whether to fallback add ops to FP32") +parser.add_argument( + "--fallback_add", action="store_true", help="Whether to fallback add ops to FP32" +) # ============WeightOnlyQuant configs=============== parser.add_argument("--woq", action="store_true") parser.add_argument( @@ -74,7 +79,22 @@ "--woq_weight_dtype", type=str, default="int8", - choices=["int8", "int4_clip", "int4_fullrange", "fp4_e2m1_bnb", "fp4_e2m1", "nf4"], + choices=[ + "int8", + "int4_clip", + "int4_fullrange", + "fp4_e2m1_bnb", + "fp4_e2m1", + "nf4", + "fp8_e5m2", + "fp8_e4m3", + ], +) +parser.add_argument( + "--woq_scale_dtype", + type=str, + default="fp32", + choices=["fp32", "fp8"], ) parser.add_argument( "--woq_compute_dtype", @@ -165,7 +185,10 @@ else: op_type_dict = {} if args.fallback_add: - op_type_dict["add"] = {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}} + op_type_dict["add"] = { + "weight": {"dtype": ["fp32"]}, + "activation": {"dtype": ["fp32"]}, + } excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"] recipes = { "smooth_quant": True, @@ -183,6 +206,7 @@ elif args.woq: quantization_config = WeightOnlyQuantConfig( compute_dtype=args.woq_compute_dtype, + scale_dtype=args.woq_scale_dtype, weight_dtype=args.woq_weight_dtype, scheme=args.woq_scheme, group_size=args.woq_group_size, @@ -248,9 +272,17 @@ from intel_extension_for_transformers.llm.evaluation.models import ( TSModelCausalLMForITREX, ) + if args.restore: - from intel_extension_for_transformers.transformers.utils.utility import recover_model_from_json - user_model = recover_model_from_json(user_model, os.path.join(args.output_dir, "best_configure.json"), args.trust_remote_code) + from intel_extension_for_transformers.transformers.utils.utility import ( + recover_model_from_json, + ) + + user_model = recover_model_from_json( + user_model, + os.path.join(args.output_dir, "best_configure.json"), + args.trust_remote_code, + ) user_model = TSModelCausalLMForITREX(user_model, config=config) else: user_model = TSModelCausalLMForITREX.from_pretrained( diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h index e0265e25c93..eed0c3bdaad 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h @@ -356,6 +356,7 @@ inline int jblas_dtype_get_f8_quant_mbits(const JBLAS_DTYPE t) { inline float get_mxfp_maxnorm(const JBLAS_DTYPE t, int ebits, int mantissa_bits) { auto emax = std::pow(2, ebits - 1); + if (t == JBLAS_DTYPE::F8_E5M2) emax -= 1; auto max_norm = std::pow(2, emax); if (t != JBLAS_DTYPE::F8_E4M3) { max_norm *= ((std::pow(2, mantissa_bits - 1) - 1) / std::pow(2, mantissa_bits - 2)); diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h index 05b5993684e..fcccd5747d0 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h @@ -870,7 +870,6 @@ int8_t f8_mx_quantize(float v, float shared_exp) { // saturate normals. auto max_norm = utils::get_mxfp_maxnorm(F8_T, ebits, quant_mantissa); v = std::clamp(v, -1 * max_norm, max_norm); - uint32_t* shift_v = reinterpret_cast(&v); // get sign; char* p = reinterpret_cast(&v); @@ -878,7 +877,8 @@ int8_t f8_mx_quantize(float v, float shared_exp) { *shift_v <<= 1; uint8_t store_ebit = (*(p + 3) & 0xFF); store_ebit = store_ebit - 127 + std::pow(2, ebits - 1) - 1; - if (store_ebit > 15) store_ebit = 0; + if (store_ebit > 15 && F8_T == JBLAS_DTYPE::F8_E4M3) store_ebit = 0; + if (store_ebit > 31 && F8_T == JBLAS_DTYPE::F8_E5M2) store_ebit = 0; store_ebit <<= store_mantissa; *shift_v <<= 8; int8_t ox80_shift = -128 >> (store_mantissa - 1); @@ -903,6 +903,7 @@ inline JBLAS_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* shared_exp = std::floor(std::log2(shared_exp)); auto ebits = utils::jblas_dtype_get_f8_ebits(F8_T); auto emax = std::pow(2, ebits - 1); + if (F8_T == JBLAS_DTYPE::F8_E5M2) emax -= 1; shared_exp -= emax; auto scale_max = std::pow(2, 7) - 1; // e8m0 scale type. shared_exp = shared_exp < (-1 * scale_max) ? (-1 * scale_max) : shared_exp; diff --git a/intel_extension_for_transformers/llm/quantization/utils.py b/intel_extension_for_transformers/llm/quantization/utils.py index 4e53cde5520..5a65c6724b2 100644 --- a/intel_extension_for_transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/llm/quantization/utils.py @@ -77,6 +77,16 @@ def get_weight_type_from_config(config): weight_type = "nf4_scalef32" else: raise Exception("scale_dtype only support fp32 now!") + elif config.weight_dtype == "fp8_e5m2": + if config.scale_dtype == "fp8": + weight_type = "fp8e5m2_scalef8" + else: + raise Exception("scale_dtype only support fp8 now!") + elif config.weight_dtype == "fp8_e4m3": + if config.scale_dtype == "fp8": + weight_type = "fp8e4m3_scalef8" + else: + raise Exception("scale_dtype only support fp8 now!") return weight_type @@ -239,38 +249,42 @@ def default_calib_func(model): + "the calibration dataset is NeelNanda/pile-10k," + "batchsize is 1 and calibration iteration is 100." ) - bits = 1 # only for int8 - if config.weight_dtype == "int8": - dtype = "int8" - bits = 8 - elif "int4" in config.weight_dtype: - dtype = "int4" + if config.weight_dtype in ["fp8_e4m3", "fp8_e5m2"]: + return replace_linear(model, None, None, config) else: - dtype = config.weight_dtype - conf = PostTrainingQuantConfig( - approach="weight_only", - op_type_dict={ - ".*":{ - "weight": { - "bits": bits, - "dtype":dtype, - "group_size": config.group_size, # -1 (per-channel) - "scheme": config.scheme, - "algorithm": config.algorithm, + bits = 1 # only for int8 + if config.weight_dtype == "int8": + dtype = "int8" + bits = 8 + elif "int4" in config.weight_dtype: + dtype = "int4" + else: + dtype = config.weight_dtype + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*":{ + "weight": { + "bits": bits, + "dtype":dtype, + "group_size": config.group_size, # -1 (per-channel) + "scheme": config.scheme, + "algorithm": config.algorithm, + }, }, }, - }, - recipes={ - "rtn_args":{"enable_full_range": True if "fullrange" in config.weight_dtype else False, - "enable_mse_search": config.mse_range}, - }, - ) - # TEQ: set calib_func=None, use default training func as calib_func - # RTN: doesn't need calib_func - if config.algorithm in ['TEQ','RTN']: - calib_func=None - inc_model = quantization.fit(model, - conf, - calib_func=calib_func, - calib_dataloader=calib_dataloader) - return replace_linear(inc_model.model, None, None, config) + recipes={ + "rtn_args":{"enable_full_range": True if "fullrange" in config.weight_dtype else False, + "enable_mse_search": config.mse_range}, + }, + ) + # TEQ: set calib_func=None, use default training func as calib_func + # RTN: doesn't need calib_func + if config.algorithm in ['TEQ','RTN']: + calib_func=None + inc_model = quantization.fit(model, + conf, + calib_func=calib_func, + calib_dataloader=calib_dataloader) + return replace_linear(inc_model.model, None, None, config) + diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index a3266e90ab3..14fff572ac6 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -33,11 +33,10 @@ def __init__( llm_int8_skip_modules=None, compute_dtype=None, weight_dtype=None, - scale_dtype="fp32", # Now only fp32 + scale_dtype="fp32", mse_range=False, use_double_quant=False, double_quant_dtype="int8", # reserve for double quant - double_quant_scale_dtype="fp32", # reserve for double quant group_size=32, scheme="sym", algorithm="RTN", @@ -54,11 +53,10 @@ def __init__( llm_int8_skip_modules if llm_int8_skip_modules else [] ) self.weight_dtype = weight_dtype - self.scale_dtype = scale_dtype self.mse_range = mse_range self.use_double_quant = use_double_quant self.double_quant_dtype = double_quant_dtype - self.double_quant_scale_dtype = double_quant_scale_dtype + self.scale_dtype = scale_dtype self.scheme = scheme self.algorithm = algorithm self.group_size = group_size @@ -100,14 +98,19 @@ def post_init(self): "nf4", "fp4_e2m1_bnb", "fp4_e2m1", + "fp8_e5m2", + "fp8_e4m3", + ]: raise ValueError( f"weight_dtype must be a string in " - f"'int8', 'int4_fullrange', 'int4_clip', 'nf4', 'fp4_e2m1_bnb', 'fp4_e2m1'" + f"'int8', 'int4_fullrange', 'int4_clip', 'nf4', 'fp4_e2m1_bnb', 'fp4_e2m1', 'fp8_e5m2, fp8_e4m3'" ) - if self.scale_dtype not in ["fp32"]: - raise ValueError("scale_dtype must be a string in 'fp32'") + if self.scale_dtype not in ["fp32", "fp8"]: + raise ValueError( + f"scale_dtype must be a string in 'fp32', 'fp8' " + f"and fp8 only used for weight_dtype 'fp8_e5m2', 'fp8_e4m3'") if not isinstance(self.mse_range, bool): raise ValueError("mse_range must be a boolean") @@ -118,8 +121,8 @@ def post_init(self): if self.use_double_quant and not isinstance(self.double_quant_dtype, str): raise ValueError("double_quant_dtype must be a string") - if self.use_double_quant and not isinstance(self.double_quant_scale_dtype, str): - raise ValueError("double_quant_scale_dtype must be a string") + if self.use_double_quant and not isinstance(self.scale_dtype, str): + raise ValueError("scale_dtype must be a string") if not isinstance(self.group_size, int): raise ValueError("group_size must be a int") diff --git a/tests/CI/test_quantization.py b/tests/CI/test_quantization.py index e61a2a6e676..b64bbaa65b1 100644 --- a/tests/CI/test_quantization.py +++ b/tests/CI/test_quantization.py @@ -365,6 +365,16 @@ def test_quantization_for_llm(self): ) output = woq_model(dummy_input) self.assertTrue(isclose(float(output[0][0][0][0]), -6.6008710861206055, rel_tol=1e-04)) + # fp8 + woq_config = WeightOnlyQuantConfig(weight_dtype="fp8_e5m2", scale_dtype="fp8") + woq_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, quantization_config=woq_config, use_llm_runtime=False + ) + output = woq_model(dummy_input) + self.assertTrue( + isclose(float(output[0][0][0][0]), -6.790275573730469, rel_tol=1e-04) + ) + # amp amp_config = MixedPrecisionConfig() amp_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,