diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 4f0edf6b55a..55e06480aec 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -290,9 +290,9 @@ def build_args_parser() -> argparse.ArgumentParser: ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}" parser = argparse.ArgumentParser() parser.add_argument("-o", "--output-dir", default=".", help="output directory") - parser.add_argument( - "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file" - ) + # parser.add_argument( + # "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file" + # ) parser.add_argument( "-E", "--embedding-quantize", @@ -396,8 +396,10 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "-d", "--dtype-override", - default=None, - help="Override the dtype of the model (default is the checkpoint dtype). Options: fp16, fp32", + default="fp32", + type=str, + choices=["fp32"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", ) parser.add_argument( @@ -495,7 +497,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: # source transforms transforms = [] - if args.quantized_ckpt or args.quantization_mode: + if args.quantization_mode: modelname = f"{modelname}_q" transforms.append( partial(