Commit d882d18 1 parent 95c21e3 commit d882d18 Copy full SHA for d882d18
File tree 4 files changed +7
-93
lines changed
4 files changed +7
-93
lines changed Load Diff This file was deleted.
Original file line number Diff line number Diff line change
1
+ ../ ../ ../ applications / Colossal - LLaMA - 2 / colossal_llama2 / utils / flash_attention_patch .py
Original file line number Diff line number Diff line change 3
3
from contextlib import nullcontext
4
4
5
5
import torch
6
- from attn import SUPPORT_FLASH , replace_xformers
6
+ from attn import replace_with_flash_attention
7
7
from data_utils import RandomDataset
8
8
from model_utils import format_numel_str , get_model_numel
9
9
from performance_evaluator import PerformanceEvaluator
@@ -188,8 +188,7 @@ def empty_init():
188
188
model .gradient_checkpointing_enable ()
189
189
190
190
if args .xformers :
191
- assert SUPPORT_FLASH , "Use flash attention while xfomers is not installed"
192
- replace_xformers (model )
191
+ replace_with_flash_attention (model )
193
192
194
193
model_numel = get_model_numel (model )
195
194
coordinator .print_on_master (f"Model params: { format_numel_str (model_numel )} " )
Original file line number Diff line number Diff line change 9
9
import torch
10
10
import torch .distributed as dist
11
11
import torch .nn as nn
12
- from attn import SUPPORT_XFORMERS , replace_xformers
12
+ from attn import replace_with_flash_attention
13
13
from data_utils import load_json , prepare_dataloader , save_json
14
14
from datasets import load_dataset
15
15
from torch .optim import Optimizer
@@ -219,8 +219,7 @@ def main():
219
219
if args .grad_checkpoint :
220
220
model .gradient_checkpointing_enable ()
221
221
if args .flash_attention :
222
- assert SUPPORT_XFORMERS , "Use flash attention while xfomers is not installed"
223
- replace_xformers (model )
222
+ replace_with_flash_attention (model )
224
223
225
224
model_numel = get_model_numel (model )
226
225
coordinator .print_on_master (f"Model params: { format_numel_str (model_numel )} " )
Original file line number Diff line number Diff line change 8
8
import torch
9
9
import torch .distributed as dist
10
10
import torch .nn as nn
11
- from attn import SUPPORT_XFORMERS , replace_xformers
11
+ from attn import replace_with_flash_attention
12
12
from data_utils import load_json , prepare_dataloader , save_json
13
13
from datasets import load_dataset
14
14
from torch .optim import Optimizer
@@ -238,8 +238,7 @@ def main():
238
238
if args .grad_checkpoint :
239
239
model .gradient_checkpointing_enable ()
240
240
if args .flash_attention :
241
- assert SUPPORT_XFORMERS , "Use flash attention while xfomers is not installed"
242
- replace_xformers (model )
241
+ replace_with_flash_attention (model )
243
242
244
243
model_numel = get_model_numel (model )
245
244
coordinator .print_on_master (f"Model params: { format_numel_str (model_numel )} " )
You can’t perform that action at this time.
0 commit comments