diff --git a/examples/huggingface/pytorch/code-generation/quantization/run_generation.py b/examples/huggingface/pytorch/code-generation/quantization/run_generation.py index b3dd9a37471..1b9dd7c40fb 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/code-generation/quantization/run_generation.py @@ -60,7 +60,17 @@ help="Save accuracy results path.") parser.add_argument("--tasks", default="humaneval", type=str, \ help="tasks list for accuracy validation") - +# WeightOnlyQuant config +parser.add_argument("--woq", action="store_true") +parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'AWQ', 'TEQ'], + help="Weight-only parameter.") +parser.add_argument("--woq_compute_dtype", type=str, default="fp32") +parser.add_argument("--woq_weight_dtype", type=str, default="int8", + choices=["int8", "int4_clip", "int4_fullrange", "fp4_e2m1_bnb", "fp4_e2m1", "nf4"]) +parser.add_argument("--woq_group_size", type=int, default=-1) +parser.add_argument("--woq_scheme", default="sym") +parser.add_argument("--woq_enable_mse_search", action="store_true") +parser.add_argument("--woq_enable_full_range", action="store_true") # Harness config parser.add_argument("--n_samples", default=200, type=int) parser.add_argument("--limit", default=None, type=int, help="Limit number of samples to eval") @@ -100,7 +110,7 @@ args = parser.parse_args() -from intel_extension_for_transformers.transformers import AutoModelForCausalLM +from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig user_model = AutoModelForCausalLM.from_pretrained( args.model, torchscript=True @@ -274,6 +284,14 @@ def calib_func(prepared_model): recipes=recipes, example_inputs=example_inputs, ) + elif args.woq: + conf = WeightOnlyQuantConfig( + compute_dtype=args.woq_compute_dtype, + weight_dtype=args.woq_weight_dtype, + group_size=args.woq_group_size, + scheme=args.woq_scheme, + algorithm=args.woq_algo, + ) else: conf = PostTrainingQuantConfig( backend="ipex" if args.ipex else "default", @@ -283,10 +301,10 @@ def calib_func(prepared_model): ) # save config user_model.config.save_pretrained(args.output_dir) - q_model = quantization.fit( - user_model, - conf, - calib_func=calib_func, + q_model = AutoModelForCausalLM.from_pretrained( + args.model, + quantization_config=conf, + use_llm_runtime=False ) q_model.save(args.output_dir)