Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed investigation on A100 for flash attention on/off #128

Closed
arnocandel opened this issue May 12, 2023 · 2 comments
Closed

Speed investigation on A100 for flash attention on/off #128

arnocandel opened this issue May 12, 2023 · 2 comments

Comments

@arnocandel
Copy link
Member

arnocandel commented May 12, 2023

Torch==2.0.0

8-bit with flash-attn package:

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True

3%|██▉ | 10/377 [08:44<5:20:03, 52.33s/it]

16-bit with defaults:

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=False --train_8bit=False
0%|▎ | 1/377 [00:21<2:16:32, 21.79s/it]

Torch==2.0.1

8-bit with flash-attn package:

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True
1%|█▏ | 4/377 [03:33<5:29:20, 52.98s/it]

16-bit with flash-attn package:

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True --train_8bit=False
0%|▎ | 1/377 [00:19<2:04:22, 19.85s/it]

16-bit with default (flash attention):

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=False --train_8bit=False
2%|██ | 7/377 [02:31<2:14:38, 21.83s/it]

16-bit with default, but disabled flash attention:

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ git diff
diff --git a/finetune.py b/finetune.py
index e112138..4477fb2 100644
--- a/finetune.py
+++ b/finetune.py
@@ -643,7 +643,7 @@ def train(
         model = torch.compile(model)
         # WIP (not generally replacing layers until pytorch 2.1)
         if not llama_flash_attn:
-            torch.backends.cuda.enable_flash_sdp(True)
+            torch.backends.cuda.enable_flash_sdp(False)
 
     if gpus > 1 and not ddp:
         assert trainer.is_model_parallel

1%|█▍ | 5/377 [01:50<2:15:45, 21.90s/it]

16-bit with default, but disabled flash attention and no torch.compile():

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ git diff
diff --git a/finetune.py b/finetune.py
index e112138..2ea48d4 100644
--- a/finetune.py
+++ b/finetune.py
@@ -640,10 +640,10 @@ def train(
     ).__get__(model, type(model))
 
     if torch.__version__ >= "2" and sys.platform != "win32":
-        model = torch.compile(model)
+        # model = torch.compile(model)
         # WIP (not generally replacing layers until pytorch 2.1)
         if not llama_flash_attn:
-            torch.backends.cuda.enable_flash_sdp(True)
+            torch.backends.cuda.enable_flash_sdp(False)
 
     if gpus > 1 and not ddp:
         assert trainer.is_model_parallel

1%|▉ | 3/377 [01:06<2:16:33, 21.91s/it]

@arnocandel
Copy link
Member Author

#31 #86

@arnocandel arnocandel changed the title Confirm Flash attn Speed investigation on A100 for flash attention on/off May 12, 2023
@arnocandel
Copy link
Member Author

so nothing really matters, if can do 16-bit, then do 16-bit. Otherwise do 8-bit. No Flash attention specifics needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant