diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 100a5df33a22..2405cdccc5bc 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -229,8 +229,13 @@ def empty_init(): init_kwargs["empty_init"] = False with init_ctx: - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs) - + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=True, + **init_kwargs, + attn_implementation="flash_attention_2", + torch_dtype=torch.float16, + ) if args.grad_checkpoint: model.gradient_checkpointing_enable() if config.model_type == "chatglm":