Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/load model with torch dtype auto #663 after cookbook refactor #871

Merged

Conversation

IgorKasianenko
Copy link
Contributor

What does this PR do?

This PR loads a model with torch_dtype=auto instead of bfloat16 when we do not specify train_config.use_fp16.
For llama models this will not make a difference as their default dtype is bfloat16

Fixes # (issue)
#656 (kind of)

Feature/Issue validation/testing

Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • torchrun --nnodes 1 --nproc_per_node 8 ./recipes/quickstart/finetuning/finetuning.py --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --enable_fsdp --max_train_step=2 --batch_size_training 1 --batching_strategy packing --dataset samsum_dataset --save_model False --context_length 4096 --fsdp_config.pure _bf16 --fsdp_config.optimizer anyprecision --samsum_dataset.trust_remote_code 1

Logs

W0203 15:55:32.585000 720518 site-packages/torch/distributed/run.py:793] 
W0203 15:55:32.585000 720518 site-packages/torch/distributed/run.py:793] *****************************************
W0203 15:55:32.585000 720518 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0203 15:55:32.585000 720518 site-packages/torch/distributed/run.py:793] *****************************************
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Warning: fsdp_config does not accept parameter: fsdp_config.pure
Warning: fsdp_config does not accept parameter: fsdp_config.pure
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Warning: fsdp_config does not accept parameter: fsdp_config.pure
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Warning: fsdp_config does not accept parameter: fsdp_config.pure
Warning: fsdp_config does not accept parameter: fsdp_config.pure
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Warning: fsdp_config does not accept parameter: fsdp_config.pure
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/ubuntu/anaconda3/envs/recipe/lib/python3.12/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Warning: fsdp_config does not accept parameter: fsdp_config.pure
Warning: fsdp_config does not accept parameter: fsdp_config.pure
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail

...


evaluating Epoch: 100%|�[32m██████████�[0m| 4/4 [00:18<00:00,  4.69s/it]
Starting epoch 1/3Starting epoch 1/3Starting epoch 1/3
Starting epoch 1/3

train_config.max_train_step: 2
train_config.max_train_step: 2train_config.max_train_step: 2
train_config.max_train_step: 2


Starting epoch 1/3
train_config.max_train_step: 2
Starting epoch 1/3
train_config.max_train_step: 2
Starting epoch 1/3
train_config.max_train_step: 2
 eval_ppl=tensor(47375.3711, device='cuda:0') eval_epoch_loss=tensor(10.7659, device='cuda:0')
best eval loss on epoch 1 is 10.765857696533203
Epoch 1: train_perplexity=1.1558, train_epoch_loss=0.1448, epoch time 46.755609701387584s
Starting epoch 1/3
train_config.max_train_step: 2
training params are saved in /home/ubuntu/projects/llama-cookbook-663/PATH/to/save/FSDP/model/fine-tuned-meta-llama/Meta-Llama-3-8B-Instruct/train_params.yaml
Key: avg_train_prep, Value: 1.1558226346969604
Key: avg_train_loss, Value: 0.1448122262954712
Key: avg_eval_prep, Value: 47375.37109375
Key: avg_eval_loss, Value: 10.765857696533203
Key: avg_epoch_time, Value: 46.755609701387584
Key: avg_checkpoint_time, Value: 5.895271897315979e-07
[rank0]:[W203 16:03:24.135413049 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.

Thanks for contributing 🎉!

…scribe the various use_fp16 and pure_bf16 options
@IgorKasianenko IgorKasianenko requested a review from mreso February 3, 2025 16:07
@IgorKasianenko IgorKasianenko changed the title Use auto instead of bf16 to determine model dtype; Add comments to de… https://github.com/meta-llama/llama-cookbook/pull/663 Feb 3, 2025
@IgorKasianenko IgorKasianenko changed the title https://github.com/meta-llama/llama-cookbook/pull/663 Fix/load model with torch dtype auto #663 after cookbook refactor Feb 3, 2025
Copy link
Contributor

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@IgorKasianenko IgorKasianenko merged commit 341539a into meta-llama:main Feb 5, 2025
4 checks passed
@IgorKasianenko IgorKasianenko deleted the fix/663-after-cookbook-refactor branch February 5, 2025 12:49
@IgorKasianenko IgorKasianenko self-assigned this Feb 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants