Skip to content

Commit d882d18

Browse files
authored
[example] reuse flash attn patch (#5400)
1 parent 95c21e3 commit d882d18

File tree

4 files changed

+7
-93
lines changed

4 files changed

+7
-93
lines changed

examples/language/llama2/attn.py

-84
This file was deleted.

examples/language/llama2/attn.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py

examples/language/llama2/benchmark.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import nullcontext
44

55
import torch
6-
from attn import SUPPORT_FLASH, replace_xformers
6+
from attn import replace_with_flash_attention
77
from data_utils import RandomDataset
88
from model_utils import format_numel_str, get_model_numel
99
from performance_evaluator import PerformanceEvaluator
@@ -188,8 +188,7 @@ def empty_init():
188188
model.gradient_checkpointing_enable()
189189

190190
if args.xformers:
191-
assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed"
192-
replace_xformers(model)
191+
replace_with_flash_attention(model)
193192

194193
model_numel = get_model_numel(model)
195194
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")

examples/language/llama2/finetune.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.distributed as dist
1111
import torch.nn as nn
12-
from attn import SUPPORT_XFORMERS, replace_xformers
12+
from attn import replace_with_flash_attention
1313
from data_utils import load_json, prepare_dataloader, save_json
1414
from datasets import load_dataset
1515
from torch.optim import Optimizer
@@ -219,8 +219,7 @@ def main():
219219
if args.grad_checkpoint:
220220
model.gradient_checkpointing_enable()
221221
if args.flash_attention:
222-
assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
223-
replace_xformers(model)
222+
replace_with_flash_attention(model)
224223

225224
model_numel = get_model_numel(model)
226225
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")

examples/language/llama2/pretrain.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import torch.distributed as dist
1010
import torch.nn as nn
11-
from attn import SUPPORT_XFORMERS, replace_xformers
11+
from attn import replace_with_flash_attention
1212
from data_utils import load_json, prepare_dataloader, save_json
1313
from datasets import load_dataset
1414
from torch.optim import Optimizer
@@ -238,8 +238,7 @@ def main():
238238
if args.grad_checkpoint:
239239
model.gradient_checkpointing_enable()
240240
if args.flash_attention:
241-
assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
242-
replace_xformers(model)
241+
replace_with_flash_attention(model)
243242

244243
model_numel = get_model_numel(model)
245244
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")

0 commit comments

Comments
 (0)