diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index d06e22fc2180..c101656ce68f 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -210,20 +210,10 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: It may not be suitable for all hardware configurations or use cases. """ - # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically - # by MegatronCommOverlapCallback. They are added here for user's knowledge. - # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. - # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else - # each PP stage launches independently as needed. - recipe.trainer.callbacks.extend( [ run.Config(MegatronTokenDropCallback), - run.Config( - MegatronCommOverlapCallback, - overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing. - align_param_gather=True, - ), + run.Config(MegatronCommOverlapCallback), ] ) recipe.trainer.strategy.expert_model_parallel_size = 1