Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Coati] Train DPO using PP #6054

Merged
merged 24 commits into from
Oct 11, 2024
Merged

[Coati] Train DPO using PP #6054

merged 24 commits into from
Oct 11, 2024

Conversation

TongLi3701
Copy link
Member

@TongLi3701 TongLi3701 commented Sep 10, 2024

📌 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for us to better distinguish different PRs
  • I have installed pre-commit: pip install pre-commit && pre-commit install

🚨 Issue number

Link this PR to your issue with words like fixed to automatically close the linked issue upon merge

e.g. fixed #1234, closed #1234, resolved #1234

📝 What does this PR do?

Summarize your work here.
if you have any plots/diagrams/screenshots/tables, please attach them here.

💥 Checklist before requesting a review

  • I have linked my PR to an issue (instruction)
  • My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible
  • I have performed a self-review of my code
  • I have added thorough tests.
  • I have added docstrings for all the functions/methods I implemented

⭐️ Do you enjoy contributing to Colossal-AI?

  • 🌝 Yes, I do.
  • 🌚 No, I don't.

Tell us more if you don't enjoy contributing to Colossal-AI.

@TongLi3701 TongLi3701 requested a review from a team as a code owner September 10, 2024 10:39
@TongLi3701 TongLi3701 changed the title Feat/dpo [Coati] Train DPO using PP Sep 10, 2024
Copy link

@RahulVadisetty91 RahulVadisetty91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes and Improvements:

Checkpoint Saving Logic:
Evaluation script save: The script contains several places where checkpoints can be saved This is very useful especially when training for large datasets and more so when using distributed training. Nevertheless, the checkpoint save condition self. The condition that num_train_step > 0 is not necessary and it is better to remove it to make code more clear.

Loss Handling:
Proper to distributed training, the DPO loss is computed and synchronised by defining a corresponding function all_reduce_mean. This guarantees consistent loss values throughout all processes in the applicable techniques for multi-GPU training.

Optimizer and Scheduler Update:
Optimizer updates are done after a given number of steps which is defined as self. accumulation_steps. This technique is beneficial when the large models are trained in the limited hardware resources as it allows the accumulation of the gradients and enhances the memory optimization.

Accumulative Meter:
The accumulative meter also saves summary and average of loss, rewards and the correctness of the overall model. This structure proves rather advantageous when tracking the performance of the model over time especially in the training and evaluation cycles.

Evaluation Logging:
When recording an evaluation, the script writes the result and the identification of the evaluated object into a text file. This one is good for observing the model’s performance over the epochs of training process. However, the logging mechanism has some potentials to be enhanced using more logging frameworks such as the TensorBoard for metrics visualization.

Suggested Additions

Gradient Clipping:
To prevent potential issues related to exploding gradients, especially when working with noisy gradients in large models add gradient clipping (torch. nn. utils. clip_grad_norm_).

Mixed Precision Training:
One should propose to use mixed-precision training with the help of torch. cuda. amp for better performance on today’s GPUs. To achieve this, SplCO requires less time for training through employing reduced precision which however does not compromise on the model accuracy.
More Granular Logging:

The use of log learning rates, gradients, as well as the model weights should also be included. This will make it easier in debugging and even checking the performance in the long training runs.

Asynchronous Checkpoint Saving:
Perform timed-saving of checkpoint, to minimize potential issues such as latency or hitch during the training process especially when the model is large.

Distributed Training Enhancements:

Add support for torch. distributed. in order to get further improvements and for distributed training launching or torchrun is used. This will pave way for scaling in multiple nodes, hence improving the training process.

Usage of the Script:
This script is designed for training reinforcement learning models using DPO (Direct Policy Optimization). It is particularly useful in distributed training environments where resource management is critical. The script handles model checkpoints, synchronizes losses and rewards across multiple processes, and allows gradient accumulation, making it suitable for large-scale AI training tasks.

Copy link

@RahulVadisetty91 RahulVadisetty91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

In GPTLMLoss, you might want to set ignore_index=-100 explicitly for clarity, as mentioned in the comment. Currently, it uses the default setting, but if you want to make it clearer or configurable.

self.loss = nn.CrossEntropyLoss(ignore_index=-100)

@TongLi3701 TongLi3701 requested a review from ver217 October 11, 2024 06:26
@TongLi3701 TongLi3701 merged commit 4c8e85e into hpcaitech:main Oct 11, 2024
6 checks passed
@TongLi3701 TongLi3701 deleted the feat/DPO branch October 11, 2024 11:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants