Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama flash attn #86

Merged
merged 15 commits into from
Apr 28, 2023
Merged

Llama flash attn #86

merged 15 commits into from
Apr 28, 2023

Conversation

arnocandel
Copy link
Member

@arnocandel arnocandel commented Apr 26, 2023

  • Compare with quantized versions, e.g. https://huggingface.co/elinas/alpaca-30b-lora-int4
  • Compare with existing 30B loras: https://huggingface.co/serpdotai/llama-oasst-lora-30B
  • Run with 8-bit to see memory/speed difference
  • Run with larger lora
  • Run with more layers using lora, not just attention
  • Run without fast attention to see how fast or how much memory for case when fits (e.g. smaller cutoff length). Indirectly so far, on 20B neox without fast attention. Even with batching setup to be efficient, neox 24x slower for 512 cutoff.
  • Separate out optional requirements, since flash attention only works for A100+ OOTB

Note:

Requires A100+ to work OOTB without patching transformers or using nightly torch to avoid errors.

Would need to patch transformers lm-sys/FastChat#581 (comment)

We could also try bettertransformers wrapper, which wraps HF models into BT models, using native torch version of flash attention (they say slower but more memory efficient):

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#:~:text=Scaled%20dot%20product%20attention%20attempts,for%20enabling%20and%20disabling%20implementations.

That might require pytorch nightly to fix a related bug where they should have disabled fast attention for some head sizes if don't have sm80. But we can then at least test on A6000/4090. But should then all work on A100 and use flash attention to some degree without limit on head sizes.

See also:
pytorch/pytorch#99105
pytorch/pytorch#98771
pytorch/pytorch#98140
huggingface/transformers#18439
lm-sys/FastChat#459
pytorch/pytorch#94883

@arnocandel
Copy link
Member Author

arnocandel commented Apr 26, 2023

python finetune.py --base_model=decapoda-research/llama-7b-hf --batch_size=100 --micro_batch_size=1 --train_8bit=False --num_epochs=0.001 --val_set_size=0 --cutoff_len=512 --llama_flash_attn=True
  File "/nfs4/llm/h2o-llm/env/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/nfs4/llm/h2o-llm/env/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/nfs4/llm/h2o-llm/env/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 78, in backward
    _flash_attn_backward(
  File "/nfs4/llm/h2o-llm/env/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 44, in _flash_attn_backward
    _, _, _, softmax_d = flash_attn_cuda.bwd(
TypeError: bwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: torch.Tensor, arg6: torch.Tensor, arg7: torch.Tensor, arg8: torch.Tensor, arg9: torch.Tensor, arg10: torch.Tensor, arg11: int, arg12: int, arg13: float, arg14: float, arg15: bool, arg16: bool, arg17: int, arg18: Optional[torch.Generator], arg19: Optional[torch.Tensor]) -> List[torch.Tensor]

@arnocandel arnocandel marked this pull request as draft April 26, 2023 07:10
@arnocandel
Copy link
Member Author

fails with same error on 48GB card without 8-bit and without LoRA.
CUDA_VISIBLE_DEVICES=0 torchrun finetune.py --base_model=decapoda-research/llama-7b-hf --batch_size=100 --micro_batch_size=1 --train_8bit=False --num_epochs=0.001 --val_set_size=0 --cutoff_len=512 --llama_flash_attn=True --lora_r=0

…tention is done, so early on.

NOTE: flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit.  Then when pip installing flash attention do:

CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn
@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 26, 2023

On 24GB board can run 7B 16-bit with 512 cutoff and it uses 92% of memory.

CUDA_VISIBLE_DEVICES=0 python finetune.py --base_model=decapoda-research/llama-7b-hf --data_path=h2oai/openassistant_oasst1_h2ogpt --prompt_type=plain --micro_batch_size=1 --batch_size=100 --cutoff_len=512 --num_epochs=0.001 --val_set_size=0 --save_steps=1000000  --eval_steps=100000 --train_8bit=False --output_dir=foo103

See also:
lm-sys/FastChat#581
Alternative: lm-sys/FastChat#177
pytorch/pytorch#99105

RuntimeError: Expected is_sm80 || is_sm90 to be true, but got false. seems to be because llama has large head and pytorch team didn't add conditions for head size on whether flash attention would be used, but nightly has that fix. But for llama wrapper patch stuff, we need A100+.

BetterTransformer has built-in support for flash attention now via torch updates, but only for certain parts of the operations. We can try: https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2

To run:

CUDA_VISIBLE_DEVICES=0 python finetune.py --base_model=decapoda-research/llama-7b-hf --data_path=h2oai/openassistant_oasst1_h2ogpt --prompt_type=plain --micro_batch_size=1 --batch_size=100 --cutoff_len=512 --num_epochs=0.001 --val_set_size=0 --save_steps=1000000  --eval_steps=100000 --train_8bit=False --output_dir=foo104 --llama_flash_attn=True

@pseudotensor
Copy link
Collaborator

Training works fine:

image

^M  4%|▎         | 1294/36234 [11:25<5:08:51,  1.89it/s]^M  4%|▎         | 1295/36234 [11:25<5:08:37,  1.89it/s]^M                                                      ^M{'loss': 1.0692, 'learning_rate': 0.0002901367133447722, 'epoch': 0.21}
^M  4%|▎         | 1295/36234 [11:25<5:08:37,  1.89it/s]^M  4%|▎         | 1296/36234 [11:26<5:13:42,  1.86it/s]^M                                                      ^M{'loss': 1.2193, 'learning_rate': 0.00029012841091492775, 'epoch': 0.21}
^M  4%|▎         | 1296/36234 [11:26<5:13:42,  1.86it/s]^M  4%|▎         | 1297/36234 [11:26<5:08:50,  1.89it/s]^M                                                      ^M{'loss': 1.4059, 'learning_rate': 0.00029012010848508327, 'epoch': 0.21}
^M  4%|▎         | 1297/36234 [11:26<5:08:50,  1.89it/s]^M  4%|▎         | 1298/36234 [11:27<5:06:25,  1.90it/s]^M                                                      ^M{'loss': 1.2362, 'learning_rate': 0.0002901118060552388, 'epoch': 0.21}
^M  4%|▎         | 1298/36234 [11:27<5:06:25,  1.90it/s]^M  4%|▎         | 1299/36234 [11:27<5:07:58,  1.89it/s]^M                                                      ^M{'loss': 1.1192, 'learning_rate': 0.0002901035036253943, 'epoch': 0.22}
^M  4%|▎         | 1299/36234 [11:27<5:07:58,  1.89it/s]^M  4%|▎         | 1300/36234 [11:28<5:15:15,  1.85it/s]^M                                                      ^M{'loss': 1.3157, 'learning_rate': 0.0002900952011955499, 'epoch': 0.22}
^M  4%|▎         | 1300/36234 [11:28<5:15:15,  1.85it/s]^M  4%|▎         | 1301/36234 [11:28<5:15:37,  1.84it/s]^M                                                      ^M{'loss': 1.5309, 'learning_rate': 0.0002900868987657054, 'epoch': 0.22}
^M  4%|▎         | 1301/36234 [11:28<5:15:37,  1.84it/s]^M  4%|▎         | 1302/36234 [11:29<5:05:14,  1.91it/s]^M                                                      ^M{'loss': 1.266, 'learning_rate': 0.00029007859633586094, 'epoch': 0.22}
^M  4%|▎         | 1302/36234 [11:29<5:05:14,  1.91it/s]^M  4%|▎         | 1303/36234 [11:29<5:12:01,  1.87it/s]^M                                                      ^M{'loss': 1.4601, 'learning_rate': 0.00029007029390601646, 'epoch': 0.22}
^M  4%|▎         | 1303/36234 [11:29<5:12:01,  1.87it/s]^M  4%|▎         | 1304/36234 [11:30<5:09:24,  1.88it/s]^M                                                      ^M{'loss': 1.0876, 'learning_rate': 0.000290061991476172, 'epoch': 0.22}
^M  4%|▎         | 1304/36234 [11:30<5:09:24,  1.88it/s]^M  4%|▎         | 1305/36234 [11:31<5:07:03,  1.90it/s]^M                                                      ^M{'loss': 1.4668, 'learning_rate': 0.0002900536890463275, 'epoch': 0.22}
^M  4%|▎         | 1305/36234 [11:31<5:07:03,  1.90it/s]^M  4%|▎         | 1306/36234 [11:31<5:05:29,  1.91it/s]^M                                                      ^M{'loss': 1.1907, 'learning_rate': 0.0002900453866164831, 'epoch': 0.22}
^M  4%|▎         | 1306/36234 [11:31<5:05:29,  1.91it/s]^M  4%|▎         | 1307/36234 [11:32<5:06:48,  1.90it/s]^M                                                      ^M{'loss': 1.3131, 'learning_rate': 0.0002900370841866386, 'epoch': 0.22}
^M  4%|▎         | 1307/36234 [11:32<5:06:48,  1.90it/s]^M  4%|▎         | 1308/36234 [11:32<5:03:21,  1.92it/s]^M                                                      ^M{'loss': 1.1192, 'learning_rate': 0.00029002878175679413, 'epoch': 0.22}
^M  4%|▎         | 1308/36234 [11:32<5:03:21,  1.92it/s]^M  4%|▎         | 1309/36234 [11:33<5:06:37,  1.90it/s]^M                                                      ^M{'loss': 1.4406, 'learning_rate': 0.00029002047932694965, 'epoch': 0.22}
^M  4%|▎         | 1309/36234 [11:33<5:06:37,  1.90it/s]^M  4%|▎         | 1310/36234 [11:33<5:07:38,  1.89it/s]^M                                                      ^M{'loss': 1.3113, 'learning_rate': 0.0002900121768971052, 'epoch': 0.22}
^M  4%|▎         | 1310/36234 [11:33<5:07:38,  1.89it/s]^M  4%|▎         | 1311/36234 [11:34<5:07:12,  1.89it/s]^M                                                      ^M{'loss': 1.3578, 'learning_rate': 0.00029000387446726075, 'epoch': 0.22}
^M  4%|▎         | 1311/36234 [11:34<5:07:12,  1.89it/s]^M  4%|▎         | 1312/36234 [11:34<5:10:10,  1.88it/s]^M                                                      ^M{'loss': 1.1698, 'learning_rate': 0.0002899955720374162, 'epoch': 0.22}
^M  4%|▎         | 1312/36234 [11:34<5:10:10,  1.88it/s]^M  4%|▎         | 1313/36234 [11:35<5:08:21,  1.89it/s]^M                                                      ^M{'loss': 1.4657, 'learning_rate': 0.0002899872696075718, 'epoch': 0.22}
^M  4%|▎         | 1313/36234 [11:35<5:08:21,  1.89it/s]^M  4%|▎         | 1314/36234 [11:35<5:08:30,  1.89it/s]^M                                                      ^M{'loss': 1.2712, 'learning_rate': 0.0002899789671777273, 'epoch': 0.22}
^M  4%|▎         | 1314/36234 [11:35<5:08:30,  1.89it/s]^M  4%|▎         | 1315/36234 [11:36<5:10:14,  1.88it/s]^M                                                      ^M{'loss': 1.2937, 'learning_rate': 0.00028997066474788284, 'epoch': 0.22}
^M  4%|▎         | 1315/36234 [11:36<5:10:14,  1.88it/s]^M  4%|▎         | 1316/36234 [11:36<5:12:26,  1.86it/s]^M                                                      ^M{'loss': 1.3456, 'learning_rate': 0.0002899623623180384, 'epoch': 0.22}
^M  4%|▎         | 1316/36234 [11:36<5:12:26,  1.86it/s]^M  4%|▎         | 1317/36234 [11:37<5:10:26,  1.87it/s]^M                                                      ^M{'loss': 1.2501, 'learning_rate': 0.0002899540598881939, 'epoch': 0.22}
^M  4%|▎         | 1317/36234 [11:37<5:10:26,  1.87it/s]^M  4%|▎         | 1318/36234 [11:37<5:12:34,  1.86it/s]^M                                                      ^M{'loss': 1.5545, 'learning_rate': 0.00028994575745834946, 'epoch': 0.22}
^M  4%|▎         | 1318/36234 [11:37<5:12:34,  1.86it/s]^M  4%|▎         | 1319/36234 [11:38<5:06:46,  1.90it/s]^M                                                      ^M{'loss': 1.3465, 'learning_rate': 0.000289937455028505, 'epoch': 0.22}
^M  4%|▎         | 1319/36234 [11:38<5:06:46,  1.90it/s]^M  4%|▎         | 1320/36234 [11:38<5:06:11,  1.90it/s]^M                                                      ^M{'loss': 1.3056, 'learning_rate': 0.0002899291525986605, 'epoch': 0.22}
^M  4%|▎         | 1320/36234 [11:38<5:06:11,  1.90it/s]^M  4%|▎         | 1321/36234 [11:39<5:06:08,  1.90it/s]^M                                                      ^M{'loss': 1.0685, 'learning_rate': 0.00028992085016881603, 'epoch': 0.22}
^M  4%|▎         | 1321/36234 [11:39<5:06:08,  1.90it/s]^M  4%|▎         | 1322/36234 [11:40<5:06:07,  1.90it/s]^M                                                      ^M{'loss': 1.0574, 'learning_rate': 0.00028991254773897155, 'epoch': 0.22}
^M  4%|▎         | 1322/36234 [11:40<5:06:07,  1.90it/s]^M  4%|▎         | 1323/36234 [11:40<5:05:41,  1.90it/s]^M                                                      ^M{'loss': 1.3323, 'learning_rate': 0.00028990424530912713, 'epoch': 0.22}
^M  4%|▎         | 1323/36234 [11:40<5:05:41,  1.90it/s]^M  4%|▎         | 1324/36234 [11:41<5:05:31,  1.90it/s]^M                                                      ^M{'loss': 1.2647, 'learning_rate': 0.00028989594287928265, 'epoch': 0.22}
^M  4%|▎         | 1324/36234 [11:41<5:05:31,  1.90it/s]^M  4%|▎         | 1325/36234 [11:41<5:07:59,  1.89it/s]^M                                                      ^M{'loss': 1.1907, 'learning_rate': 0.0002898876404494382, 'epoch': 0.22}
^M  4%|▎         | 1325/36234 [11:41<5:07:59,  1.89it/s]^M  4%|▎         | 1326/36234 [11:42<5:10:47,  1.87it/s]^M                                                      ^M{'loss': 1.5984, 'learning_rate': 0.0002898793380195937, 'epoch': 0.22}
^M  4%|▎         | 1326/36234 [11:42<5:10:47,  1.87it/s]^M  4%|▎         | 1327/36234 [11:42<5:11:57,  1.86it/s]^M                                                      ^M{'loss': 1.367, 'learning_rate': 0.0002898710355897492, 'epoch': 0.22}
^M  4%|▎         | 1327/36234 [11:42<5:11:57,  1.86it/s]^M  4%|▎         | 1328/36234 [11:43<5:36:23,  1.73it/s]^M                                                      ^M{'loss': 1.2906, 'learning_rate': 0.00028986273315990475, 'epoch': 0.22}
^M  4%|▎         | 1328/36234 [11:43<5:36:23,  1.73it/s]^M  4%|▎         | 1329/36234 [11:43<5:24:26,  1.79it/s]^M                                                      ^M{'loss': 1.2658, 'learning_rate': 0.0002898544307300603, 'epoch': 0.22}
^M  4%|▎         | 1329/36234 [11:43<5:24:26,  1.79it/s]^M  4%|▎         | 1330/36234 [11:44<5:23:59,  1.80it/s]^M                                                      ^M{'loss': 1.4504, 'learning_rate': 0.00028984612830021585, 'epoch': 0.22}
^M  4%|▎         | 1330/36234 [11:44<5:23:59,  1.80it/s]^M  4%|▎         | 1331/36234 [11:44<5:19:36,  1.82it/s]^M                                                      ^M{'loss': 1.2514, 'learning_rate': 0.00028983782587037137, 'epoch': 0.22}
^M  4%|▎         | 1331/36234 [11:44<5:19:36,  1.82it/s]^M  4%|▎         | 1332/36234 [11:45<5:26:05,  1.78it/s]^M                                                      ^M{'loss': 1.456, 'learning_rate': 0.0002898295234405269, 'epoch': 0.22}
^M  4%|▎         | 1332/36234 [11:45<5:26:05,  1.78it/s]^M  4%|▎         | 1333/36234 [11:46<5:23:04,  1.80it/s]^M                                                      ^M{'loss': 1.3635, 'learning_rate': 0.0002898212210106824, 'epoch': 0.22}
^M  4%|▎         | 1333/36234 [11:46<5:23:04,  1.80it/s]^M  4%|▎         | 1334/36234 [11:46<5:19:58,  1.82it/s]^M                                                      ^M{'loss': 1.2169, 'learning_rate': 0.000289812918580838, 'epoch': 0.22}
^M  4%|▎         | 1334/36234 [11:46<5:19:58,  1.82it/s]^M  4%|▎         | 1335/36234 [11:47<5:21:39,  1.81it/s]^M                                                      ^M{'loss': 1.4433, 'learning_rate': 0.00028980461615099346, 'epoch': 0.22}
^M  4%|▎         | 1335/36234 [11:47<5:21:39,  1.81it/s]^M  4%|▎         | 1336/36234 [11:47<6:00:39,  1.61it/s]^M                                                      ^M{'loss': 1.1642, 'learning_rate': 0.00028979631372114904, 'epoch': 0.22}
^M  4%|▎         | 1336/36234 [11:47<6:00:39,  1.61it/s]^M  4%|▎         | 1337/36234 [11:48<5:45:39,  1.68it/s]^M                                                      ^M{'loss': 1.2496, 'learning_rate': 0.00028978801129130456, 'epoch': 0.22}
^M  4%|▎         | 1337/36234 [11:48<5:45:39,  1.68it/s]^M  4%|▎         | 1338/36234 [11:49<5:32:23,  1.75it/s]^M                                                      ^M{'loss': 1.3304, 'learning_rate': 0.0002897797088614601, 'epoch': 0.22}
^M  4%|▎         | 1338/36234 [11:49<5:32:23,  1.75it/s]^M  4%|▎         | 1339/36234 [11:49<5:25:37,  1.79it/s]^M                                                      ^M{'loss': 1.2467, 'learning_rate': 0.00028977140643161566, 'epoch': 0.22}
^M  4%|▎         | 1339/36234 [11:49<5:25:37,  1.79it/s]^M  4%|▎         | 1340/36234 [11:50<5:22:43,  1.80it/s]^M                                                      ^M{'loss': 1.2184, 'learning_rate': 0.00028976310400177113, 'epoch': 0.22}
^M  4%|▎         | 1340/36234 [11:50<5:22:43,  1.80it/s]^M  4%|▎         | 1341/36234 [11:50<5:17:04,  1.83it/s]^M                                                      ^M{'loss': 1.1651, 'learning_rate': 0.0002897548015719267, 'epoch': 0.22}
^M  4%|▎         | 1341/36234 [11:50<5:17:04,  1.83it/s]^M  4%|▎         | 1342/36234 [11:51<5:11:55,  1.86it/s]^M                                                      ^M{'loss': 1.181, 'learning_rate': 0.0002897464991420822, 'epoch': 0.22}

Will be done with 6 epochs in 5 hours.

Training 20B was 12s/iteration, this 30B llama is 2iteration/s (note the inversion). So 24x faster for 50% larger model.

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 26, 2023

No crash during checkpoint at save_steps=2000

(h2ollm) ubuntu@cloudvm:~/h2o-llm$ ls -alrt llama-30b-hf.h2oaiopenassistant_oasst1_h2ogpt.6.0_epochs.31eef248d53c9f39e51c60b8b030c1e3cafc34b0.llama30b_1
total 24
drwxrwxr-x 43 ubuntu ubuntu 12288 Apr 26 21:08 ..
drwxrwxr-x  3 ubuntu ubuntu  4096 Apr 26 21:08 runs
drwxrwxr-x  4 ubuntu ubuntu  4096 Apr 26 21:21 .
drwxrwxr-x  2 ubuntu ubuntu  4096 Apr 26 21:21 checkpoint-1500
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ ls -alrt llama-30b-hf.h2oaiopenassistant_oasst1_h2ogpt.6.0_epochs.31eef248d53c9f39e51c60b8b030c1e3cafc34b0.llama30b_1/checkpoint-1500/
total 150876
drwxrwxr-x 4 ubuntu ubuntu      4096 Apr 26 21:21 ..
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_7.pth
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_6.pth
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_5.pth
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_4.pth
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_3.pth
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_2.pth
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_1.pth
-rw-rw-r-- 1 ubuntu ubuntu  51204365 Apr 26 21:21 pytorch_model.bin
-rw-rw-r-- 1 ubuntu ubuntu      3899 Apr 26 21:21 training_args.bin
-rw-rw-r-- 1 ubuntu ubuntu    499723 Apr 26 21:21 tokenizer.model
-rw-rw-r-- 1 ubuntu ubuntu       715 Apr 26 21:21 tokenizer_config.json
-rw-rw-r-- 1 ubuntu ubuntu       423 Apr 26 21:21 special_tokens_map.json
-rw-rw-r-- 1 ubuntu ubuntu       627 Apr 26 21:21 scheduler.pt
-rw-rw-r-- 1 ubuntu ubuntu       557 Apr 26 21:21 scaler.pt
-rw-rw-r-- 1 ubuntu ubuntu 102437061 Apr 26 21:21 optimizer.pt
-rw-rw-r-- 1 ubuntu ubuntu    180399 Apr 26 21:21 trainer_state.json
-rw-rw-r-- 1 ubuntu ubuntu     14583 Apr 26 21:21 rng_state_0.pth
drwxrwxr-x 2 ubuntu ubuntu      4096 Apr 26 21:21 .
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ 

No crash with lora 16 either with 0.01 epochs. So all good.

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 26, 2023

Non-test run on 31eef24

The below is content of llama30b.sh, ran as: (nohup ./go_llama30b.sh &> llama30b_5.log &)

torchrun --nproc_per_node=8 finetune.py --base_model=decapoda-research/llama-30b-hf --data_path=h2oai/openassistant_oasst1_h2ogpt  --micro_batch_size=1 --batch_size=8 --cutoff_len=512 --num_epochs=8.0 --val_set_size=0 --eval_steps=100000 --save_steps=6000 --prompt_type=plain --save_code=True --train_8bit=False --run_id=llama30b_5 --llama_flash_attn=True --lora_r=16

To ensure roughly checkpoint every epoch with:

diff --git a/finetune.py b/finetune.py
index a22bbe1..27b92af 100644
--- a/finetune.py
+++ b/finetune.py
@@ -625,7 +625,7 @@ def train(
             eval_steps=eval_steps if val_set_size > 0 else None,
             save_steps=save_steps,
             output_dir=output_dir,
-            save_total_limit=3,
+            save_total_limit=20,
             load_best_model_at_end=True if val_set_size > 0 else False,
             ddp_find_unused_parameters=False if ddp else None,
             group_by_length=group_by_length,

Roughly 8 hours:

^M  0%|          | 0/48312 [00:00<?, ?it/s]^M  0%|          | 1/48312 [00:01<17:08:02,  1.28s/it]^M                                                    ^M{'loss': 1.8545, 'learning_rate': 0.0, 'epoch': 0.0}
^M  0%|          | 1/48312 [00:01<17:08:02,  1.28s/it]^M  0%|          | 2/48312 [00:01<10:52:27,  1.23it/s]^M                                                    ^M{'loss': 1.6515, 'learning_rate': 0.0, 'epoch': 0.0}
^M  0%|          | 2/48312 [00:01<10:52:27,  1.23it/s]^M  0%|          | 3/48312 [00:03<14:51:12,  1.11s/it]^M                                                    ^M{'loss': 1.694, 'learning_rate': 2.9999999999999997e-06, 'epoch': 0.0}
^M  0%|          | 3/48312 [00:03<14:51:12,  1.11s/it]^M  0%|          | 4/48312 [00:03<11:32:17,  1.16it/s]^M                                                    ^M{'loss': 1.7847, 'learning_rate': 2.9999999999999997e-06, 'epoch': 0.0}
^M  0%|          | 4/48312 [00:03<11:32:17,  1.16it/s]^M  0%|          | 5/48312 [00:04<9:57:41,  1.35it/s] ^M                                                   ^M{'loss': 2.0736, 'learning_rate': 2.9999999999999997e-06, 'epoch': 0.0}
^M  0%|          | 5/48312 [00:04<9:57:41,  1.35it/s]^M  0%|          | 6/48312 [00:04<8:52:16,  1.51it/s]^M                                                   ^M{'loss': 2.2253, 'learning_rate': 5.999999999999999e-06, 'epoch': 0.0}
^M  0%|          | 6/48312 [00:04<8:52:16,  1.51it/s]^M  0%|          | 7/48312 [00:05<8:16:00,  1.62it/s]^M                                                   ^M{'loss': 1.9398, 'learning_rate': 8.999999999999999e-06, 'epoch': 0.0}
^M  0%|          | 7/48312 [00:05<8:16:00,  1.62it/s]^M  0%|          | 8/48312 [00:05<7:49:23,  1.72it/s]^M                                                   ^M{'loss': 2.3214, 'learning_rate': 1.1999999999999999e-05, 'epoch': 0.0}
^M  0%|          | 8/48312 [00:05<7:49:23,  1.72it/s]^M  0%|          | 9/48312 [00:06<7:27:18,  1.80it/s]^M                                                   ^M{'loss': 2.3765, 'learning_rate': 1.4999999999999999e-05, 'epoch': 0.0}
^M  0%|          | 9/48312 [00:06<7:27:18,  1.80it/s]^M  0%|          | 10/48312 [00:06<7:14:16,  1.85it/s]^M                                                    ^M{'loss': 1.3629, 'learning_rate': 1.7999999999999997e-05, 'epoch': 0.0}
^M  0%|          | 10/48312 [00:06<7:14:16,  1.85it/s]^M  0%|          | 11/48312 [00:07<6:58:37,  1.92it/s]^M                                                    ^M{'loss': 1.5485, 'learning_rate': 2.1e-05, 'epoch': 0.0}
^M  0%|          | 11/48312 [00:07<6:58:37,  1.92it/s]^M  0%|          | 12/48312 [00:07<6:47:36,  1.97it/s]^M                                                    ^M{'loss': 2.0413, 'learning_rate': 2.3999999999999997e-05, 'epoch': 0.0}
^M  0%|          | 12/48312 [00:07<6:47:36,  1.97it/s]^M  0%|          | 13/48312 [00:08<6:38:28,  2.02it/s]^M                                                    ^M{'loss': 2.1327, 'learning_rate': 2.6999999999999996e-05, 'epoch': 0.0}
^M  0%|          | 13/48312 [00:08<6:38:28,  2.02it/s]^M  0%|          | 14/48312 [00:08<6:34:41,  2.04it/s]^M                                                    ^M{'loss': 1.811, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.0}
^M  0%|          | 14/48312 [00:08<6:34:41,  2.04it/s]^M  0%|          | 15/48312 [00:09<6:37:01,  2.03it/s]^M                                                    ^M{'loss': 1.8994, 'learning_rate': 3.2999999999999996e-05, 'epoch': 0.0}
^M  0%|          | 15/48312 [00:09<6:37:01,  2.03it/s]^M  0%|          | 16/48312 [00:09<6:37:35,  2.02it/s]^M                                                    ^M{'loss': 2.0232, 'learning_rate': 3.5999999999999994e-05, 'epoch': 0.0}
^M  0%|          | 16/48312 [00:09<6:37:35,  2.02it/s]^M  0%|          | 17/48312 [00:10<6:33:40,  2.04it/s]^M                                                    ^M{'loss': 1.9286, 'learning_rate': 3.9e-05, 'epoch': 0.0}
^M  0%|          | 17/48312 [00:10<6:33:40,  2.04it/s]^M  0%|          | 18/48312 [00:10<6:32:53,  2.05it/s]^M                                                    ^M{'loss': 2.0756, 'learning_rate': 4.2e-05, 'epoch': 0.0}
^M  0%|          | 18/48312 [00:10<6:32:53,  2.05it/s]^M  0%|          | 19/48312 [00:11<6:46:53,  1.98it/s]^M                                                    ^M{'loss': 2.1375, 'learning_rate': 4.4999999999999996e-05, 'epoch': 0.0}
^M  0%|          | 19/48312 [00:11<6:46:53,  1.98it/s]^M  0%|          | 20/48312 [00:11<7:02:11,  1.91it/s]^M                                                    ^M{'loss': 2.8463, 'learning_rate': 4.7999999999999994e-05, 'epoch': 0.0}
^M  0%|          | 20/48312 [00:11<7:02:11,  1.91it/s]^M  0%|          | 21/48312 [00:12<7:04:05,  1.90it/s]^M                                                    ^M{'loss': 2.1112, 'learning_rate': 5.1e-05, 'epoch': 0.0}
^M  0%|          | 21/48312 [00:12<7:04:05,  1.90it/s]^M  0%|          | 22/48312 [00:12<7:05:01,  1.89it/s]^M                                                    ^M{'loss': 1.9269, 'learning_rate': 5.399999999999999e-05, 'epoch': 0.0}
^M  0%|          | 22/48312 [00:12<7:05:01,  1.89it/s]

image

image

@pseudotensor pseudotensor self-assigned this Apr 26, 2023
@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 27, 2023

First epoch:

tar cvzf snap_llama1.tgz llama-* runs* llama*.log run_llama*

1 epoch:

image

2 epochs:

image

3-4 epochs:

image

4-5 epochs:

image

8 epochs:

image

dump of state:

jon@mr-dl10:~$ ls -alrth /home/jon/snap_llama2.tgz 
-rwxr-xr-x 1 jon jon 2.3G Apr 26 22:28 /home/jon/snap_llama2.tgz*

@pseudotensor
Copy link
Collaborator

This OOMs:

# https://huggingface.co/serpdotai/llama-oasst-lora-30B                                                                                                                                                                                                     
torchrun --nproc_per_node=8 finetune.py --base_model=decapoda-research/llama-30b-hf --data_path=h2oai/openassistant_oasst1_h2ogpt  --micro_batch_size=1 --batch_size=8 --cutoff_len=2048 --num_epochs=8.0 --val_set_size=0 --eval_steps=100000 --save_steps=6000 --prompt_type=plain --save_code=True --train_8bit=False --run_id=llama30b_6 --llama_flash_attn=True --lora_r=64 --lora_target_modules="['q_proj', 'k_proj', 'v_proj', 'o_proj']" --learning_rate=2e-4 --lora_alpha=32

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 27, 2023

So far not OOMing with just 512. ETA 8 hours.

torchrun --nproc_per_node=8 finetune.py --base_model=decapoda-research/llama-30b-hf --data_path=h2oai/openassistant_oasst1_h2ogpt  --micro_batch_size=1 --batch_size=8 --cutoff_len=512 --num_epochs=8.0 --val_set_size=0 --eval_steps=100000 --save_steps=6000 --prompt_type=plain --save_code=True --train_8bit=False --run_id=llama30b_7 --llama_flash_attn=True --lora_r=64 --lora_target_modules="['q_proj', 'k_proj', 'v_proj', 'o_proj']" --learning_rate=2e-4 --lora_alpha=32
^M  0%|          | 72/48312 [00:45<8:02:39,  1.67it/s]^M  0%|          | 73/48312 [00:46<8:17:06,  1.62it/s]^M                                                    ^M{'loss': 1.4825, 'learning_rate': 0.00013600000000000003, 'epoch': 0.01}
^M  0%|          | 73/48312 [00:46<8:17:06,  1.62it/s]^M  0%|          | 74/48312 [00:47<8:06:45,  1.65it/s]^M                                                    ^M{'loss': 1.5193, 'learning_rate': 0.000138, 'epoch': 0.01}
^M  0%|          | 74/48312 [00:47<8:06:45,  1.65it/s]^M  0%|          | 75/48312 [00:47<8:08:09,  1.65it/s]^M                                                    ^M{'loss': 1.4011, 'learning_rate': 0.00014, 'epoch': 0.01}
^M  0%|          | 75/48312 [00:47<8:08:09,  1.65it/s]^M  0%|          | 76/48312 [00:48<8:19:07,  1.61it/s]^M                                                    ^M{'loss': 1.443, 'learning_rate': 0.000142, 'epoch': 0.01}
^M  0%|          | 76/48312 [00:48<8:19:07,  1.61it/s]^M  0%|          | 77/48312 [00:48<8:19:23,  1.61it/s]^M                                                    ^M{'loss': 1.5454, 'learning_rate': 0.000144, 'epoch': 0.01}
^M  0%|          | 77/48312 [00:49<8:19:23,  1.61it/s]^M  0%|          | 78/48312 [00:49<8:18:39,  1.61it/s]^M                                                    ^M{'loss': 1.7086, 'learning_rate': 0.000146, 'epoch': 0.01}
^M  0%|          | 78/48312 [00:49<8:18:39,  1.61it/s]^M  0%|          | 79/48312 [00:50<8:01:41,  1.67it/s]^M                                                    ^M{'loss': 1.5335, 'learning_rate': 0.000148, 'epoch': 0.01}
^M  0%|          | 79/48312 [00:50<8:01:41,  1.67it/s]^M  0%|          | 80/48312 [00:50<8:06:45,  1.65it/s]^M                                                    ^M{'loss': 1.4207, 'learning_rate': 0.00015000000000000001, 'epoch': 0.01}
^M  0%|          | 80/48312 [00:50<8:06:45,  1.65it/s]^M  0%|          | 81/48312 [00:51<7:59:44,  1.68it/s]^M                                                    ^M{'loss': 1.6293, 'learning_rate': 0.000152, 'epoch': 0.01}
^M  0%|          | 81/48312 [00:51<7:59:44,  1.68it/s]^M  0%|          | 82/48312 [00:51<7:59:14,  1.68it/s]^M                                                    ^M{'loss': 1.5173, 'learning_rate': 0.000154, 'epoch': 0.01}
^M  0%|          | 82/48312 [00:51<7:59:14,  1.68it/s]^M  0%|          | 83/48312 [00:52<7:54:28,  1.69it/s]^M                                                    ^M{'loss': 1.1987, 'learning_rate': 0.00015600000000000002, 'epoch': 0.01}
^M  0%|          | 83/48312 [00:52<7:54:28,  1.69it/s]^M  0%|          | 84/48312 [00:53<7:56:55,  1.69it/s]^M                                                    ^M{'loss': 1.3694, 'learning_rate': 0.00015800000000000002, 'epoch': 0.01}
^M  0%|          | 84/48312 [00:53<7:56:55,  1.69it/s]^M  0%|          | 85/48312 [00:53<8:01:50,  1.67it/s]^M                                                    ^M{'loss': 1.3214, 'learning_rate': 0.00016, 'epoch': 0.01}
^M  0%|          | 85/48312 [00:53<8:01:50,  1.67it/s]^M  0%|          | 86/48312 [00:54<8:00:41,  1.67it/s]^M                                                    ^M{'loss': 1.5173, 'learning_rate': 0.000162, 'epoch': 0.01}
^M  0%|          | 86/48312 [00:54<8:00:41,  1.67it/s]^M  0%|          | 87/48312 [00:54<8:02:37,  1.67it/s]^M                                                    ^M{'loss': 1.3231, 'learning_rate': 0.000164, 'epoch': 0.01}

image

Much lower loss:

image

image

jon@mr-dl10:~$ ls -alrth /home/jon/snap_llama3.tgz
-rwxr-xr-x 1 jon jon 19G Apr 27 09:50 /home/jon/snap_llama3.tgz*
jon@mr-dl10:~$ 

_7 case with eval:

image

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 27, 2023

8-bit will take 30 hours. Won't run this one.

torchrun --nproc_per_node=8 finetune.py --base_model=decapoda-research/llama-30b-hf --data_path=h2oai/openassistant_oasst1_h2ogpt  --micro_batch_size=1 --batch_size=8 --cutoff_len=512 --num_epochs=8.0 --val_set_size=0 --eval_steps=100000 --save_steps=6000 --prompt_type=plain --save_code=True --train_8bit=True --run_id=llama30b_8 --llama_flash_attn=True --lora_r=64 --lora_target_modules="['q_proj', 'k_proj', 'v_proj', 'o_proj']" --learning_rate=2e-4 --lora_alpha=32

Just model is 42% of 80GB:

image

^M  0%|          | 221/48312 [08:28<30:20:57,  2.27s/it]^M  0%|          | 222/48312 [08:30<30:14:28,  2.26s/it]^M                                                      ^M{'loss': 1.4584, 'learning_rate': 0.0001995187920019912, 'epoch': 0.04}
^M  0%|          | 222/48312 [08:30<30:14:28,  2.26s/it]^M  0%|          | 223/48312 [08:32<30:17:07,  2.27s/it]^M                                                      ^M{'loss': 1.6878, 'learning_rate': 0.0001995146436571808, 'epoch': 0.04}
^M  0%|          | 223/48312 [08:32<30:17:07,  2.27s/it]^M  0%|          | 224/48312 [08:35<30:37:46,  2.29s/it]^M                                                      ^M{'loss': 1.3812, 'learning_rate': 0.00019951049531237038, 'epoch': 0.04}
^M  0%|          | 224/48312 [08:35<30:37:46,  2.29s/it]^M  0%|          | 225/48312 [08:37<30:30:40,  2.28s/it]^M                                                      ^M{'loss': 1.2646, 'learning_rate': 0.00019950634696755995, 'epoch': 0.04}
^M  0%|          | 225/48312 [08:37<30:30:40,  2.28s/it]^M  0%|          | 226/48312 [08:39<30:08:36,  2.26s/it]^M                                                      ^M{'loss': 1.5205, 'learning_rate': 0.00019950219862274952, 'epoch': 0.04}
^M  0%|          | 226/48312 [08:39<30:08:36,  2.26s/it]^M  0%|          | 227/48312 [08:42<30:31:06,  2.28s/it]^M                                                      ^M{'loss': 1.3945, 'learning_rate': 0.0001994980502779391, 'epoch': 0.04}
^M  0%|          | 227/48312 [08:42<30:31:06,  2.28s/it]^M  0%|          | 228/48312 [08:44<30:19:39,  2.27s/it]^M                                                      ^M{'loss': 1.3554, 'learning_rate': 0.00019949390193312868, 'epoch': 0.04}
^M  0%|          | 228/48312 [08:44<30:19:39,  2.27s/it]^M  0%|          | 229/48312 [08:46<30:05:57,  2.25s/it]^M                                                      ^M{'loss': 1.456, 'learning_rate': 0.00019948975358831828, 'epoch': 0.04}
^M  0%|          | 229/48312 [08:46<30:05:57,  2.25s/it]
(END)

image

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 27, 2023

8-bit with larger batch leads to more efficiency for GPUs:

torchrun --nproc_per_node=8 finetune.py --base_model=decapoda-research/llama-30b-hf --data_path=h2oai/openassistant_oasst1_h2ogpt  --micro_batch_size=32 --batch_size=256 --cutoff_len=512 --num_epochs=8.0 --val_set_size=0 --eval_steps=100000 --save_steps=225 --prompt_type=plain --save_code=True --train_8bit=True --run_id=llama30b_11 --llama_flash_attn=True --lora_r=64 --lora_target_modules="['q_proj', 'k_proj', 'v_proj', 'o_proj']" --learning_rate=2e-4 --lora_alpha=32
TensorBoardCallback
^M  0%|          | 0/1512 [00:00<?, ?it/s]^M  0%|          | 1/1512 [00:21<8:53:50, 21.20s/it]^M                                                  ^M{'loss': 1.6806, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.01}
^M  0%|          | 1/1512 [00:21<8:53:50, 21.20s/it]^M  0%|          | 2/1512 [00:40<8:32:39, 20.37s/it]^M                                                  ^M{'loss': 1.7186, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.01}
^M  0%|          | 2/1512 [00:41<8:32:39, 20.37s/it]^M  0%|          | 3/1512 [01:00<8:25:28, 20.10s/it]^M                                                  ^M{'loss': 1.6805, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.02}
^M  0%|          | 3/1512 [01:00<8:25:28, 20.10s/it]^M  0%|          | 4/1512 [01:20<8:22:11, 19.98s/it]^M                                                  ^M{'loss': 1.6612, 'learning_rate': 6e-06, 'epoch': 0.02}
^M  0%|          | 4/1512 [01:20<8:22:11, 19.98s/it]^M  0%|          | 5/1512 [01:40<8:21:17, 19.96s/it]^M                                                  ^M{'loss': 1.6556, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.03}
^M  0%|          | 5/1512 [01:40<8:21:17, 19.96s/it]^M  0%|          | 6/1512 [02:00<8:19:50, 19.91s/it]^M                                                  ^M{'loss': 1.6891, 'learning_rate': 1e-05, 'epoch': 0.03}
^M  0%|          | 6/1512 [02:00<8:19:50, 19.91s/it]^M  0%|          | 7/1512 [02:20<8:18:23, 19.87s/it]^M                                                  ^M{'loss': 1.659, 'learning_rate': 1.2e-05, 'epoch': 0.04}
^M  0%|          | 7/1512 [02:20<8:18:23, 19.87s/it]^M  1%|          | 8/1512 [02:39<8:16:47, 19.82s/it]^M                                                  ^M{'loss': 1.7286, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.04}
^M  1%|          | 8/1512 [02:39<8:16:47, 19.82s/it]
(END)

image

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 27, 2023

OIGw/OASST mixed in with pure OASST at equal level. ETA about 22 hours for 2 epochs, which is about equivalent to 20 epochs for pure OASST for the OASST part of the data. So checkpoints are every 0.25 epochs in order to perform later eval and choose best model.

torchrun --nproc_per_node=8 finetune.py --base_model=decapoda-research/llama-30b-hf  --micro_batch_size=1 --batch_size=8 --cutoff_len=512 --num_epochs=2.0 --val_set_size=0 --eval_steps=100000 --save_steps=17000 --save_total_limit=20 --prompt_type=plain --save_code=True --train_8bit=False --run_id=llama30b_16 --llama_flash_attn=True --lora_r=64 --lora_target_modules="['q_proj', 'k_proj', 'v_proj', 'o_proj']" --learning_rate=2e-4 --lora_alpha=32 --drop_truncations=True --data_path=h2oai/h2ogpt-oig-oasst1-instruct-cleaned-v2 --data_mix_in_path=h2oai/openassistant_oasst1_h2ogpt --data_mix_in_factor=1.0 --data_mix_in_prompt_type='plain' --data_mix_in_col_dict="{'input': 'input'}" 
^M  0%|          | 60/137906 [00:37<23:10:56,  1.65it/s]^M  0%|          | 61/137906 [00:37<22:57:24,  1.67it/s]^M                                                      ^M{'loss': 2.0783, 'learning_rate': 0.00010600000000000002, 'epoch': 0.0}
^M  0%|          | 61/137906 [00:38<22:57:24,  1.67it/s]^M  0%|          | 62/137906 [00:38<22:03:21,  1.74it/s]^M                                                      ^M{'loss': 1.8075, 'learning_rate': 0.00010800000000000001, 'epoch': 0.0}
^M  0%|          | 62/137906 [00:38<22:03:21,  1.74it/s]^M  0%|          | 63/137906 [00:39<22:16:29,  1.72it/s]^M                                                      ^M{'loss': 1.5789, 'learning_rate': 0.00011000000000000002, 'epoch': 0.0}
^M  0%|          | 63/137906 [00:39<22:16:29,  1.72it/s]^M  0%|          | 64/137906 [00:39<22:28:35,  1.70it/s]^M                                                      ^M{'loss': 1.596, 'learning_rate': 0.00011200000000000001, 'epoch': 0.0}
^M  0%|          | 64/137906 [00:39<22:28:35,  1.70it/s]^M  0%|          | 65/137906 [00:40<22:40:52,  1.69it/s]^M                                                      ^M{'loss': 1.8478, 'learning_rate': 0.00011399999999999999, 'epoch': 0.0}
^M  0%|          | 65/137906 [00:40<22:40:52,  1.69it/s]^M  0%|          | 66/137906 [00:40<22:40:08,  1.69it/s]^M                                                      ^M{'loss': 1.8958, 'learning_rate': 0.000116, 'epoch': 0.0}
^M  0%|          | 66/137906 [00:40<22:40:08,  1.69it/s]^M  0%|          | 67/137906 [00:41<22:23:11,  1.71it/s]^M                                                      ^M{'loss': 1.583, 'learning_rate': 0.000118, 'epoch': 0.0}
^M  0%|          | 67/137906 [00:41<22:23:11,  1.71it/s]^M  0%|          | 68/137906 [00:42<22:16:46,  1.72it/s]^M                                                      ^M{'loss': 1.8608, 'learning_rate': 0.00012, 'epoch': 0.0}
^M  0%|          | 68/137906 [00:42<22:16:46,  1.72it/s]^M  0%|          | 69/137906 [00:42<22:33:59,  1.70it/s]^M                                                      ^M{'loss': 1.6374, 'learning_rate': 0.000122, 'epoch': 0.0}
^M  0%|          | 69/137906 [00:42<22:33:59,  1.70it/s]^M  0%|          | 70/137906 [00:43<22:25:44,  1.71it/s]^M                                                      ^M{'loss': 1.4426, 'learning_rate': 0.000124, 'epoch': 0.0}
^M  0%|          | 70/137906 [00:43<22:25:44,  1.71it/s]^M  0%|          | 71/137906 [00:43<23:34:56,  1.62it/s]^M                                                      ^M{'loss': 1.3956, 'learning_rate': 0.000126, 'epoch': 0.0}
^M  0%|          | 71/137906 [00:43<23:34:56,  1.62it/s]^M  0%|          | 72/137906 [00:44<23:23:10,  1.64it/s]^M                                                      ^M{'loss': 1.4945, 'learning_rate': 0.00012800000000000002, 'epoch': 0.0}
^M  0%|          | 72/137906 [00:44<23:23:10,  1.64it/s]^M  0%|          | 73/137906 [00:45<23:13:48,  1.65it/s]^M                                                      ^M{'loss': 1.3996, 'learning_rate': 0.00013000000000000002, 'epoch': 0.0}
^M  0%|          | 73/137906 [00:45<23:13:48,  1.65it/s]^M  0%|          | 74/137906 [00:45<22:41:00,  1.69it/s]^M                                                      ^M{'loss': 1.5406, 'learning_rate': 0.000132, 'epoch': 0.0}
^M  0%|          | 74/137906 [00:45<22:41:00,  1.69it/s]^M  0%|          | 75/137906 [00:46<22:21:34,  1.71it/s]^M                                                      ^M{'loss': 1.7863, 'learning_rate': 0.000134, 'epoch': 0.0}
^M  0%|          | 75/137906 [00:46<22:21:34,  1.71it/s]

image

Early logs:
llama30b_16.log.early.zip

At about 1 epoch (yellow line):

image

1.6 epochs:

image

At exactly 2 epochs when it would be done, hit odd error:

^M100%|█████████▉| 137844/137918 [22:16:42<00:45,  1.64it/s]^M100%|█████████▉| 137845/137918 [22:16:43<00:43,  1.67it/s]^M                                                          ^M{'loss': 0.4093, 'learning_rate': 2.046176841921955e-07, 'epoch': 2.0}
^M100%|█████████▉| 137845/137918 [22:16:43<00:43,  1.67it/s]^M100%|█████████▉| 137846/137918 [22:16:43<00:42,  1.70it/s]^M                                                          ^M{'loss': 0.6818, 'learning_rate': 2.0316649494260547e-07, 'epoch': 2.>
^M100%|█████████▉| 137846/137918 [22:16:44<00:42,  1.70it/s][E ProcessGroupNCCL.cpp:828] [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686933, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1802547 milliseconds before timing >
Traceback (most recent call last):
  File "/home/ubuntu/h2o-llm/finetune.py", line 982, in <module>
    fire.Fire(train)
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/ubuntu/h2o-llm/finetune.py", line 634, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/transformers/trainer.py", line 1662, in train
    return inner_training_loop(
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/transformers/trainer.py", line 1929, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/transformers/trainer.py", line 2709, in training_step
    self.scaler.scale(loss).backward()
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: NCCL communicator was aborted on rank 5.  Original reason for failure was: [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686933, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1802547 milliseconds before timing >
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686934, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1803371 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686933, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1802547 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686935, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1805818 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686934, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1806319 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686933, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1807550 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686934, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1807761 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4686935, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1808107 milliseconds before timing out.
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 90643 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 90644 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 90645 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 90646 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 90647 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 90649 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 90650 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 5 (pid: 90648) of binary: /home/ubuntu/miniconda3/envs/h2ollm/bin/python
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/h2ollm/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.0.0', 'console_scripts', 'torchrun')())
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/miniconda3/envs/h2ollm/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
======================================================

image

jon@mr-dl10:~$ ls -alrt /home/jon/snap_llama5.tgz 
-rwxr-xr-x 1 jon jon 38992525917 Apr 28 21:51 /home/jon/snap_llama5.tgz*
jon@mr-dl10:~$ 

@pseudotensor pseudotensor marked this pull request as ready for review April 27, 2023 23:19
@pseudotensor pseudotensor self-requested a review April 27, 2023 23:20
@arnocandel arnocandel merged commit 9e0f08f into main Apr 28, 2023
@pseudotensor pseudotensor mentioned this pull request Apr 29, 2023
12 tasks
@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 29, 2023

llama 7B:

16, 32, 64 batch size OOMs, 16 almost survived but died soon. 32 and 64 died early

torchrun --nproc_per_node=8 finetune.py --base_model=decapoda-research/llama-7b-hf  --micro_batch_size=8 --batch_size=64 --cutoff_len=512 --num_epochs=10.0 --val_set_size=0 --eval_steps=100000 --save_steps=2125 --save_total_limit=20 --prompt_type=plain --save_code=True --train_8bit=False --run_id=llama7b_5 --llama_flash_attn=True --lora_r=64 --lora_target_modules="['q_proj', 'k_proj', 'v_proj', 'o_proj']" --learning_rate=2e-4 --lora_alpha=32 --drop_truncations=True --data_path=h2oai/h2ogpt-oig-oasst1-instruct-cleaned-v2 --data_mix_in_path=h2oai/openassistant_oasst1_h2ogpt --data_mix_in_factor=1.0 --data_mix_in_prompt_type='plain' --data_mix_in_col_dict="{'input': 'input'}" 

image

epoch ~8:

^M 79%|███████▉  | 67875/86170 [15:39:12<4:10:45,  1.22it/s]^M 79%|███████▉  | 67876/86170 [15:39:13<4:07:13,  1.23it/s]^M                                                          ^M{'loss': 0.47, 'learning_rate': 4.2570001161845015e-05, 'epoch': 7.88}
^M 79%|███████▉  | 67876/86170 [15:39:13<4:07:13,  1.23it/s]^M 79%|███████▉  | 67877/86170 [15:39:14<4:11:15,  1.21it/s]^M                                                          ^M{'loss': 0.4129, 'learning_rate': 4.256767747182526e-05, 'epoch': 7.8>
^M 79%|███████▉  | 67877/86170 [15:39:14<4:11:15,  1.21it/s]^M 79%|███████▉  | 67878/86170 [15:39:14<4:01:42,  1.26it/s]^M                                                          ^M{'loss': 0.5061, 'learning_rate': 4.256535378180551e-05, 'epoch': 7.8>
^M 79%|███████▉  | 67878/86170 [15:39:14<4:01:42,  1.26it/s]^M 79%|███████▉  | 67879/86170 [15:39:15<4:05:49,  1.24it/s]^M                                                          ^M{'loss': 0.6733, 'learning_rate': 4.256303009178576e-05, 'epoch': 7.8>
^M 79%|███████▉  | 67879/86170 [15:39:15<4:05:49,  1.24it/s]^M 79%|███████▉  | 67880/86170 [15:39:16<4:03:18,  1.25it/s]^M                                                          ^M{'loss': 0.4663, 'learning_rate': 4.2560706401766006e-05, 'epoch': 7.>
^M 79%|███████▉  | 67880/86170 [15:39:16<4:03:18,  1.25it/s]^M 79%|███████▉  | 67881/86170 [15:39:17<4:02:37,  1.26it/s]^M                                                          ^M{'loss': 0.6107, 'learning_rate': 4.2558382711746254e-05, 'epoch': 7.>
^M 79%|███████▉  | 67881/86170 [15:39:17<4:02:37,  1.26it/s]^M 79%|███████▉  | 67882/86170 [15:39:17<4:01:41,  1.26it/s]^M                                                          ^M{'loss': 0.4612, 'learning_rate': 4.25560590217265e-05, 'epoch': 7.88}
^M 79%|███████▉  | 67882/86170 [15:39:17<4:01:41,  1.26it/s]^M 79%|███████▉  | 67883/86170 [15:39:18<4:05:28,  1.24it/s]^M                                                          ^M{'loss': 0.5064, 'learning_rate': 4.255373533170675e-05, 'epoch': 7.8>
^M 79%|███████▉  | 67883/86170 [15:39:18<4:05:28,  1.24it/s]^M 79%|███████▉  | 67884/86170 [15:39:19<3:58:23,  1.28it/s]^M                                                          ^M{'loss': 0.4071, 'learning_rate': 4.2551411641687e-05, 'epoch': 7.88}
^M 79%|███████▉  | 67884/86170 [15:39:19<3:58:23,  1.28it/s]^M 79%|███████▉  | 67885/86170 [15:39:20<4:01:25,  1.26it/s]^M                                                          ^M{'loss': 0.3507, 'learning_rate': 4.254908795166725e-05, 'epoch': 7.8>
^M 79%|███████▉  | 67885/86170 [15:39:20<4:01:25,  1.26it/s]^M 79%|███████▉  | 67886/86170 [15:39:21<4:04:58,  1.24it/s]^M                                                          ^M{'loss': 0.4495, 'learning_rate': 4.25467642616475e-05, 'epoch': 7.88}
^M 79%|███████▉  | 67886/86170 [15:39:21<4:04:58,  1.24it/s]
(END)

image

image

Purple one, about loss=0.5 by end.

jon@mr-dl10:~$ scp ubuntu@176.56.197.2:h2o-llm/snap_llama6_llama7b.tgz .

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 29, 2023

For llama 30B _7:

default generate eval:

image

score_llama30B_jon7h.log
score_llama30B_jon7g.log

left is GPT-3.5 turbo, right is 30B llama _7:

image

image

image

image

image

image

image

image

image

image

image

image

image

more creative:

(alpaca) jon@gpu:/data/jon/h2o-llm$ python generate.py --base_model=decapoda-research/llama-30b-hf --prompt_type=human_bot --lora_weights=/data/jon/snap_llama3/llama-30b-hf.h2oaiopenassistant_oasst1_h2ogpt.8.0_epochs.31eef248d53c9f39e51c60b8b030c1e3cafc34b0.llama30b_7/ --gradio=False --infer_devices=False --eval_sharegpt_prompts_only=100 --eval_sharegpt_as_output=False --num_beams=1 --temperature=0.5 --do_sample=True --top_p=0.8 --top_k=80

image

Still much worse than GPT3.5 turbo:

image

@pseudotensor
Copy link
Collaborator

llama sensitive to non-helpful training data in oasst raw data:

BTW, I checked the OASST/OIG data, to see if there's reason for lack of response, and the types of things I worried about are in there, e.g.
<bot>: I'm not sure what you are asking. 
is same as what I got:
I'm sorry, I'm not sure what you're asking for here.
It definitely must be ruining the model to some extent. (edited) 
[1:29](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682929780060539)
But too often I disagree with the reward model score, so really need OpenAI eval etc.
[1:30](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682929840546369)
Another one in data:
<bot>: I'm sorry, I cannot perform this task as I am an AI language model and do not have access 
what I got:
I apologize, but I cannot perform the task you have requested. 
Yet it does know the answer if you raise temperature etc.  So it's just a push in wrong direction from training data. (edited) 
[1:32](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682929968000309)
For non-conversational tasks, e.g. summarization, prompt engineering helps alot as I showed w.r.t. the samsum dataset.  Issue with that dataset is it's NC for some reason, from samsung.
[1:35](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682930134086999)
Another one in data:
<bot>: Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t
in sharegpt eval (as others above):
I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?
same refusal even if not true, it could have answered.
Basically the model is alot dumber because it semi-randomly will to some degree answer as if it can't answer.  But not because of what it knows, but just because some semi-random sequence triggered it. (edited) 


jon
:spiral_calendar_pad:  [2:01 AM](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682931695174349)
For a case it's asked to rephrase, it gives this poor response:
I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.
But there was nothing difficult for it, it can do it in other retries. (edited) 





[2:04](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682931884793549)
I tried to give a context prompt, but it only made things worse for the reward score, i.e. went from 0.55 -> 0.44 with this context prompt: https://github.com/h2oai/h2ogpt/pull/100
(I merged anyways, but set to False for now)
I would swear the prompt should have helped, but it only hurt. (edited) 


jon
:spiral_calendar_pad:  [2:17 AM](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682932663637359)
Another one I got just now from retry of shareGPT one:
I'm sorry, but as an AI language model, 
and in data there is:
<bot>: I'm sorry, but as an AI language model I
Just too much unhelpful stuff in data.
I know it's random, because I got great answer at first, and just did regenerate and got this answer.  All parameters are same. (edited) 
[2:21](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682932865887939)
Yet another one:
I do not understand your question. Can you please try to make it clearer?
and in data:
<bot>: I'm sorry, I didn't quite understand your question
If I regenerate (sometimes requires up to 3 times), the answers are amazing.  Just gets stuck in that bad data trap.
I think llama is particularly more sensitive to such unhelpful training data, in that it more often seems to be doing it than 20B that's just sort of dull w.r.t. training data. (edited) 


jon
:spiral_calendar_pad:  [2:48 AM](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682934516766679)
Yet another, a regenerate of one particular problematic question (even if perfectly fine) gave:
I'm sorry, but I don't understand your question. Could you please rephrase it?
and in data:
<bot>: I'm sorry, I didn't quite understand your question, could you please rephrase it?
It's just parroting back prior interactions, for no good reason. (edited) 
[2:50](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682934658527219)
I have more unique ones, won't write, you get point.


jon
:spiral_calendar_pad:  [2:56 AM](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682934982582229)
Response score 99.8% :slightly_smiling_face: (edited) 
image.png
 
image.png


[2:59](https://h2oai.slack.com/archives/D4HQWBQ5Q/p1682935152546489)
A way to fix, perhaps overfit to reward model score, is to filter the instruct data through the reward model, since very fast.  Only keep human/bot pairs where bot response has a score higher than (say) 50%.  It's a way of doing RLHF through choice of data instead of choice of objective.
Since it's just removing data, not making up new data, it's kinda fair.  It's unlikely to bias the model towards heavier hallucinations.
Could help give a boost to 20B as well.
Of course, just going by the tags in OASST data is ok, but OIG data has no tags, but RLHF model effectively gives it a tag.
I find it odd, because you already did this with add_deberta_grade .  Is this not used in any of our data?  I ran existing human bot pairs in the data (those I found above) through that reward model, and it gives a very low score (< 0.2 you have as threshold).  I'm not sure h2oGPT.cleaned.graded.human_bot.parquet is/was ever used.  But that would probably have fixed everything I'm seeing.
I'm confident llama at least would get near GPT3.5 levels without that bad data.  The 20B is less affected, as it doesn't use those patterns even if in the data. (edited)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants