-
Notifications
You must be signed in to change notification settings - Fork 318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[misc] fix: gradient accumulation in seq balance and modify default vllm log level #141
Conversation
@@ -125,6 +125,7 @@ def __init__(self, config: DictConfig, role: str): | |||
self.config.actor.ppo_micro_batch_size //= (self.device_mesh.shape[0] // | |||
self.ulysses_sequence_parallel_size) | |||
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size | |||
assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also have to encore self.config.actor.ppo_mini_batch_size >= n_gpus * self.config.actor.ppo_micro_batch_size_per_gpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary? The mini_batch_size here is already normalized, then if self.config.actor.ppo_mini_batch_size < n_gpus * self.config.actor.ppo_micro_batch_size_per_gpu, the above line will not get 0
…llm log level (volcengine#141) - Previous gradient accumulation value is computed by micro_batch_size, which is wrong when using dynamic_bsz - Fix ci script to avoid overlooking this issue - Change vLLM state log default value to True to disable log. - We will check the `self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0` after normalization in fsdp_workers instead of in dp_actor and dp_critic.
…llm log level (volcengine#141) - Previous gradient accumulation value is computed by micro_batch_size, which is wrong when using dynamic_bsz - Fix ci script to avoid overlooking this issue - Change vLLM state log default value to True to disable log. - We will check the `self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0` after normalization in fsdp_workers instead of in dp_actor and dp_critic.
self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0
after normalization in fsdp_workers instead of in dp_actor and dp_critic.