Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a feature that will stop the training/eval process after reaching some max_steps #428

Merged
merged 3 commits into from
Apr 9, 2024

Conversation

wukaixingxp
Copy link
Contributor

Adding a feature that will stop the training/eval process after reaching some max_steps. By default, this feature is disabled and the max_train_step and max_eval_step are set to 0. User can pass arguments to set max_train_step or max_eval_step
to a positive integer so that the training or eval process will be stop when reaching the max_step limit. This PR will introduce two more arguments in the config file and thus the testing functions test_gradient_accumulation and test_save_to_json has been modified.

test log with max_train_step=6, max_eval_step=2

~/work/llama-recipes (feature/add_max_steps)]$ torchrun --nnodes 1 --nproc_per_node 4  recipes/finetuning/finetuning.py --max_train_step 6 --max_eval_step 2 --use_peft --peft_method lora  --model_name meta-llama/Llama-2-7b-chat-hf --enable_fsdp --use_fast_kernels --dist_checkpoint_root_folder /home/kaiwu/work/llama2-7b/ --dist_checkpoint_folder /home/kaiwu/work/finetune-output
[2024-04-08 15:46:25,374] torch.distributed.run: [WARNING] 
[2024-04-08 15:46:25,374] torch.distributed.run: [WARNING] *****************************************
[2024-04-08 15:46:25,374] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-04-08 15:46:25,374] torch.distributed.run: [WARNING] *****************************************
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:485: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:485: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:485: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:485: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.57s/it]
--> Model meta-llama/Llama-2-7b-chat-hf

--> meta-llama/Llama-2-7b-chat-hf has 6738.415616 Million params

trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199
bFloat16 enabled for mixed precision - using bfSixteen policy
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.59s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.57s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.60s/it]
trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199
trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199
trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Parameter 'function'=<function get_preprocessed_samsum.<locals>.apply_prompt_template at 0x7f2af7ba9e10> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-1c80317fa3b1799d.arrow
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-bdd640fb06671ad1.arrow
Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Parameter 'function'=<function get_preprocessed_samsum.<locals>.apply_prompt_template at 0x7feb86989e10> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-1c80317fa3b1799d.arrow
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-bdd640fb06671ad1.arrow
Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Parameter 'function'=<function get_preprocessed_samsum.<locals>.apply_prompt_template at 0x7fb55cd99f30> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-1c80317fa3b1799d.arrow
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-bdd640fb06671ad1.arrow
--> Training Set Length = 14732
Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Parameter 'function'=<function get_preprocessed_samsum.<locals>.apply_prompt_template at 0x7f1c1f3a1e10> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-1c80317fa3b1799d.arrow
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-bdd640fb06671ad1.arrow
Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-3eb13b9046685257.arrow
Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-23b8c1e9392456de.arrow
Preprocessing dataset:   0%|                                                                                                            | 0/14732 [00:00<?, ?it/s]Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-3eb13b9046685257.arrow
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-23b8c1e9392456de.arrow
--> Validation Set Length = 818
Preprocessing dataset:   0%|                                                                                                            | 0/14732 [00:00<?, ?it/s]Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-3eb13b9046685257.arrow
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-23b8c1e9392456de.arrow
Preprocessing dataset:   4%|████                                                                                            | 616/14732 [00:00<00:02, 6154.09it/s]Reusing dataset samsum (/home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-3eb13b9046685257.arrow
Loading cached processed dataset at /home/kaiwu/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-23b8c1e9392456de.arrow
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 5964.66it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 5955.82it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 6072.75it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:02<00:00, 6058.17it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 5870.09it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 5810.93it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 5721.43it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 5878.17it/s]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                   | 0/48 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                   | 0/48 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                   | 0/48 [00:00<?, ?it/s]NCCL version 2.19.3+cuda12.1
Training Epoch: 1/3, step 5/48 completed (loss: 1.6506224870681763):  12%|███████▏                                                 | 6/48 [00:27<03:09,  4.52s/it]
Training Epoch: 1/3, step 5/48 completed (loss: 1.578351378440857):  12%|███████▎                                                  | 6/48 [00:27<03:09,  4.52s/it]
Training Epoch: 1/3, step 5/48 completed (loss: 1.608504295349121):  12%|███████▎                                                  | 6/48 [00:27<03:10,  4.54s/it]
Training Epoch: 1/3, step 5/48 completed (loss: 1.5797399282455444):  12%|███████▏                                                 | 6/48 [00:27<03:10,  4.54s/it]
Max CUDA memory allocated was 17 GB
Max CUDA memory reserved was 21 GB
Peak active CUDA memory was 18 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 8 GB
evaluating Epoch:  18%|███████████████████▋                                                                                        | 2/11 [00:00<00:04,  2.17it/s]max eval steps reached, stopping evaluation, total_eval_steps:  2
evaluating Epoch:  18%|███████████████████▋                                                                                        | 2/11 [00:01<00:05,  1.79it/s]
evaluating Epoch:  18%|███████████████████▋                                                                                        | 2/11 [00:01<00:04,  1.88it/s]
evaluating Epoch:  18%|███████████████████▋                                                                                        | 2/11 [00:00<00:04,  2.16it/s]
evaluating Epoch:  18%|███████████████████▋                                                                                        | 2/11 [00:00<00:04,  2.15it/s]
 eval_ppl=tensor(1.3011, device='cuda:0') eval_epoch_loss=tensor(0.2632, device='cuda:0')
we are about to save the PEFT modules
PEFT modules are saved in PATH/to/save/PEFT/model directory
best eval loss on epoch 1 is 0.26322659850120544
Epoch 1: train_perplexity=1.2784, train_epoch_loss=0.2456, epoch time 27.707596888765693s
max training steps reached, stopping training, total_train_steps:  6
Key: avg_train_prep, Value: 1.2784380912780762
Key: avg_train_loss, Value: 0.24563908576965332
Key: avg_eval_prep, Value: 1.301121473312378
Key: avg_eval_loss, Value: 0.26322659850120544
Key: avg_epoch_time, Value: 27.707596888765693
Key: avg_checkpoint_time, Value: 1.0763725992292166

Copy link
Contributor

@HamidShojanazeri HamidShojanazeri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wukaixingxp for the PR! just left a minor comment

epoch_times = []
checkpoint_times = []
results = {}
best_val_loss = float("inf")
total_train_steps = 0
for epoch in range(train_config.num_epochs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wukaixingxp I wonder if we should remove some of the redundancy, maybe something like?

max_steps_reached = False  # Flag to indicate max training steps reached

for epoch in range(train_config.num_epochs):
    if max_steps_reached:
        break  # Stop starting a new epoch if max steps reached

    epoch_start_time = time.perf_counter()
    with MemoryTrace():  
        model.train()
        total_loss = 0.0
        total_length = len(train_dataloader) // gradient_accumulation_steps
        pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)

        for step, batch in enumerate(train_dataloader):
            total_train_steps += 1

            # Check if max training steps reached
            if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
                print(f"Max training steps reached at step {total_train_steps}. Stopping training.")
                max_steps_reached = True
                break  # Exit batch loop

if max_steps_reached and not train_config.enable_fsdp or local_rank == 0:
    print("Training stopped after reaching the maximum number of steps.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed my code accordingly, thanks for the feedback.

@HamidShojanazeri HamidShojanazeri merged commit aaa9e2c into main Apr 9, 2024
3 checks passed
@wukaixingxp wukaixingxp deleted the feature/add_max_steps branch April 11, 2024 02:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants