Skip to content

Commit

Permalink
[test] enable flash attention for benchmark by default
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Jun 25, 2024
1 parent eecf9a0 commit a6b602e
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,13 @@ def empty_init():
init_kwargs["empty_init"] = False

with init_ctx:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)

model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
**init_kwargs,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if config.model_type == "chatglm":
Expand Down

0 comments on commit a6b602e

Please sign in to comment.