Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[Example] Add WOQ to code-generation example (#581)
Browse files Browse the repository at this point in the history
* add woq for code-generation example

Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
  • Loading branch information
violetch24 authored Nov 14, 2023
1 parent a3ee671 commit 65a645f
Showing 1 changed file with 24 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,17 @@
help="Save accuracy results path.")
parser.add_argument("--tasks", default="humaneval", type=str, \
help="tasks list for accuracy validation")

# WeightOnlyQuant config
parser.add_argument("--woq", action="store_true")
parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'AWQ', 'TEQ'],
help="Weight-only parameter.")
parser.add_argument("--woq_compute_dtype", type=str, default="fp32")
parser.add_argument("--woq_weight_dtype", type=str, default="int8",
choices=["int8", "int4_clip", "int4_fullrange", "fp4_e2m1_bnb", "fp4_e2m1", "nf4"])
parser.add_argument("--woq_group_size", type=int, default=-1)
parser.add_argument("--woq_scheme", default="sym")
parser.add_argument("--woq_enable_mse_search", action="store_true")
parser.add_argument("--woq_enable_full_range", action="store_true")
# Harness config
parser.add_argument("--n_samples", default=200, type=int)
parser.add_argument("--limit", default=None, type=int, help="Limit number of samples to eval")
Expand Down Expand Up @@ -100,7 +110,7 @@
args = parser.parse_args()


from intel_extension_for_transformers.transformers import AutoModelForCausalLM
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
torchscript=True
Expand Down Expand Up @@ -274,6 +284,14 @@ def calib_func(prepared_model):
recipes=recipes,
example_inputs=example_inputs,
)
elif args.woq:
conf = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
weight_dtype=args.woq_weight_dtype,
group_size=args.woq_group_size,
scheme=args.woq_scheme,
algorithm=args.woq_algo,
)
else:
conf = PostTrainingQuantConfig(
backend="ipex" if args.ipex else "default",
Expand All @@ -283,10 +301,10 @@ def calib_func(prepared_model):
)
# save config
user_model.config.save_pretrained(args.output_dir)
q_model = quantization.fit(
user_model,
conf,
calib_func=calib_func,
q_model = AutoModelForCausalLM.from_pretrained(
args.model,
quantization_config=conf,
use_llm_runtime=False
)
q_model.save(args.output_dir)

Expand Down

0 comments on commit 65a645f

Please sign in to comment.