From 126f91835b4e9a6fe656e4be5a7d31511766f313 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Sat, 23 Mar 2024 09:14:48 -0700 Subject: [PATCH] fp32 as default data type because fp16 not fully supported (#2597) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2597 fp32 as default data type because fp16 not fully supported Reviewed By: JacobSzwejbka Differential Revision: D55258223 fbshipit-source-id: fee91743aa05f1c2e38d451c2bc146b2f7a31ff0 --- examples/models/llama2/export_llama_lib.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 4f0edf6b55..55e06480ae 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -290,9 +290,9 @@ def build_args_parser() -> argparse.ArgumentParser: ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}" parser = argparse.ArgumentParser() parser.add_argument("-o", "--output-dir", default=".", help="output directory") - parser.add_argument( - "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file" - ) + # parser.add_argument( + # "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file" + # ) parser.add_argument( "-E", "--embedding-quantize", @@ -396,8 +396,10 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "-d", "--dtype-override", - default=None, - help="Override the dtype of the model (default is the checkpoint dtype). Options: fp16, fp32", + default="fp32", + type=str, + choices=["fp32"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", ) parser.add_argument( @@ -495,7 +497,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: # source transforms transforms = [] - if args.quantized_ckpt or args.quantization_mode: + if args.quantization_mode: modelname = f"{modelname}_q" transforms.append( partial(