-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Coati] Train DPO using PP #6054
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes and Improvements:
Checkpoint Saving Logic:
Evaluation script save: The script contains several places where checkpoints can be saved This is very useful especially when training for large datasets and more so when using distributed training. Nevertheless, the checkpoint save condition self. The condition that num_train_step > 0 is not necessary and it is better to remove it to make code more clear.
Loss Handling:
Proper to distributed training, the DPO loss is computed and synchronised by defining a corresponding function all_reduce_mean. This guarantees consistent loss values throughout all processes in the applicable techniques for multi-GPU training.
Optimizer and Scheduler Update:
Optimizer updates are done after a given number of steps which is defined as self. accumulation_steps. This technique is beneficial when the large models are trained in the limited hardware resources as it allows the accumulation of the gradients and enhances the memory optimization.
Accumulative Meter:
The accumulative meter also saves summary and average of loss, rewards and the correctness of the overall model. This structure proves rather advantageous when tracking the performance of the model over time especially in the training and evaluation cycles.
Evaluation Logging:
When recording an evaluation, the script writes the result and the identification of the evaluated object into a text file. This one is good for observing the model’s performance over the epochs of training process. However, the logging mechanism has some potentials to be enhanced using more logging frameworks such as the TensorBoard for metrics visualization.
Suggested Additions
Gradient Clipping:
To prevent potential issues related to exploding gradients, especially when working with noisy gradients in large models add gradient clipping (torch. nn. utils. clip_grad_norm_).
Mixed Precision Training:
One should propose to use mixed-precision training with the help of torch. cuda. amp for better performance on today’s GPUs. To achieve this, SplCO requires less time for training through employing reduced precision which however does not compromise on the model accuracy.
More Granular Logging:
The use of log learning rates, gradients, as well as the model weights should also be included. This will make it easier in debugging and even checking the performance in the long training runs.
Asynchronous Checkpoint Saving:
Perform timed-saving of checkpoint, to minimize potential issues such as latency or hitch during the training process especially when the model is large.
Distributed Training Enhancements:
Add support for torch. distributed. in order to get further improvements and for distributed training launching or torchrun is used. This will pave way for scaling in multiple nodes, hence improving the training process.
Usage of the Script:
This script is designed for training reinforcement learning models using DPO (Direct Policy Optimization). It is particularly useful in distributed training environments where resource management is critical. The script handles model checkpoints, synchronizes losses and rewards across multiple processes, and allows gradient accumulation, making it suitable for large-scale AI training tasks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📌 Checklist before creating the PR
[doc/gemini/tensor/...]: A concise description
pip install pre-commit && pre-commit install
🚨 Issue number
📝 What does this PR do?
💥 Checklist before requesting a review
⭐️ Do you enjoy contributing to Colossal-AI?
Tell us more if you don't enjoy contributing to Colossal-AI.