From 1035c706a2263b72d13e3e19cab5abfe7dec2de2 Mon Sep 17 00:00:00 2001 From: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com> Date: Mon, 16 Sep 2024 12:46:59 -0400 Subject: [PATCH] [NeMo-UX] Add token drop callback and optimize mixtral configs (#10361) * add token drop plugin Signed-off-by: Jimmy Zhang * add checks Signed-off-by: Jimmy Zhang * add expert parallel configs Signed-off-by: Jimmy Zhang * Apply isort and black reformatting Signed-off-by: JimmyZhang12 * amend comment Signed-off-by: Jimmy Zhang * Apply isort and black reformatting Signed-off-by: JimmyZhang12 * add comm overlap Signed-off-by: Jimmy Zhang * fix rebase errors Signed-off-by: Jimmy Zhang * Apply isort and black reformatting Signed-off-by: JimmyZhang12 * fix typo Signed-off-by: Jimmy Zhang * add test configs Signed-off-by: Jimmy Zhang * fix Signed-off-by: Jimmy Zhang --------- Signed-off-by: Jimmy Zhang Signed-off-by: JimmyZhang12 Co-authored-by: Jimmy Zhang Co-authored-by: JimmyZhang12 Co-authored-by: Pablo Garay Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Signed-off-by: George Armstrong --- nemo/collections/llm/gpt/model/mixtral.py | 3 +- nemo/collections/llm/recipes/mixtral_8x22b.py | 65 ++++++++++++++++--- nemo/collections/llm/recipes/mixtral_8x3b.py | 58 +++++++++++++++-- nemo/collections/llm/recipes/mixtral_8x7b.py | 63 +++++++++++++++--- .../pytorch/callbacks/moe_token_drop.py | 55 ++++++++++++++++ .../llm/recipes/test_mixtral_8x22b.py | 12 ++-- .../llm/recipes/test_mixtral_8x3b.py | 8 +-- .../llm/recipes/test_mixtral_8x7b.py | 12 ++-- 8 files changed, 233 insertions(+), 43 deletions(-) create mode 100644 nemo/lightning/pytorch/callbacks/moe_token_drop.py diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index bc255ae8fb87..bb3dc0068ca3 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -57,11 +57,10 @@ class MixtralConfig(GPTConfig): # MoE num_moe_experts: int = 8 moe_aux_loss_coeff: float = 0.01 - moe_expert_capacity_factor: float = 1.0 - moe_pad_expert_input_to_capacity: bool = True moe_router_topk: int = 2 moe_router_pre_softmax: bool = True moe_token_dispatcher_type: str = "alltoall" + moe_router_load_balancing_type: str = 'aux_loss' init_method_std: float = 0.02 layernorm_epsilon: float = 1e-5 diff --git a/nemo/collections/llm/recipes/mixtral_8x22b.py b/nemo/collections/llm/recipes/mixtral_8x22b.py index 2320c89dfd2c..82f7cae23dba 100644 --- a/nemo/collections/llm/recipes/mixtral_8x22b.py +++ b/nemo/collections/llm/recipes/mixtral_8x22b.py @@ -29,6 +29,8 @@ from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback from nemo.utils.exp_manager import TimingCallback NAME = "mixtral_8x22b" @@ -54,14 +56,14 @@ def model() -> run.Config[pl.LightningModule]: def trainer( - tensor_parallelism: int = 8, - pipeline_parallelism: int = 8, + tensor_parallelism: int = 2, + pipeline_parallelism: int = 4, pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, - virtual_pipeline_parallelism: Optional[int] = 7, - context_parallelism: int = 1, + virtual_pipeline_parallelism: Optional[int] = 14, + context_parallelism: int = 2, sequence_parallelism: bool = True, - expert_parallelism: int = 1, - num_nodes: int = 8, + expert_parallelism: int = 8, + num_nodes: int = 16, num_gpus_per_node: int = 8, max_steps: int = 1168251, callbacks: Optional[list[run.Config[Callback]]] = None, @@ -92,7 +94,7 @@ def trainer( $ nemo llm pretrain trainer=mixtral_8x22b ... Python API usage: - >>> trainer_config = trainer(num_nodes=8, num_gpus_per_node=8) + >>> trainer_config = trainer(num_nodes=16, num_gpus_per_node=8) >>> print(trainer_config) Note: @@ -139,7 +141,7 @@ def trainer( @run.cli.factory(target=pretrain, name=NAME) def pretrain_recipe( - dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain + dir: Optional[str] = None, name: str = "default", num_nodes: int = 16, num_gpus_per_node: int = 8, fn=pretrain ) -> run.Partial: """ Create a pre-training recipe for Mixtral 8x22B model. @@ -160,10 +162,10 @@ def pretrain_recipe( Examples: CLI usage: $ nemo llm pretrain --factory mixtral_8x22b - $ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=2, name='my_mixtral_pretrain')" + $ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=16, name='my_mixtral_pretrain')" Python API usage: - >>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=2) + >>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=16) >>> print(recipe) """ return run.Partial( @@ -179,6 +181,49 @@ def pretrain_recipe( ) +@run.cli.factory(target=pretrain, name=NAME + "_performance") +def pretrain_recipe_performance( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Mixtral 8x22B model. + + This recipe enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory "mixtral_8x22b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe_performance(name="mixtral_8x22b_perf", num_nodes=8) + >>> print(recipe) + + Note: + Use this recipe with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn) + recipe.trainer.callbacks.extend( + [ + run.Config(MegatronTokenDropCallback), + run.Config(MegatronCommOverlapCallback), + ] + ) + + return recipe + + def hf_resume() -> run.Config[nl.AutoResume]: """ Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x22B model. diff --git a/nemo/collections/llm/recipes/mixtral_8x3b.py b/nemo/collections/llm/recipes/mixtral_8x3b.py index 9a70e8d952a3..ca5b4e35039f 100644 --- a/nemo/collections/llm/recipes/mixtral_8x3b.py +++ b/nemo/collections/llm/recipes/mixtral_8x3b.py @@ -30,6 +30,8 @@ from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback from nemo.utils.exp_manager import TimingCallback NAME = "mixtral_8x3b" @@ -55,14 +57,14 @@ def model() -> run.Config[pl.LightningModule]: def trainer( - tensor_parallelism: int = 4, + tensor_parallelism: int = 1, pipeline_parallelism: int = 1, pipeline_parallelism_type: Optional[torch.dtype] = None, virtual_pipeline_parallelism: Optional[int] = None, context_parallelism: int = 1, - sequence_parallelism: bool = True, - expert_parallelism: int = 1, - num_nodes: int = 1, + sequence_parallelism: bool = False, + expert_parallelism: int = 4, + num_nodes: int = 2, num_gpus_per_node: int = 8, max_steps: int = 1168251, callbacks: Optional[list[run.Config[Callback]]] = None, @@ -93,7 +95,7 @@ def trainer( $ nemo llm pretrain trainer=mixtral_8x3b ... Python API usage: - >>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=8) + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) >>> print(trainer_config) """ strategy = run.Config( @@ -139,7 +141,7 @@ def trainer( @run.cli.factory(target=pretrain, name=NAME) def pretrain_recipe( - dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain + dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain ) -> run.Partial: """ Create a pre-training recipe for Mixtral 8x3B model. @@ -181,6 +183,50 @@ def pretrain_recipe( ) +@run.cli.factory(target=pretrain, name=NAME + "_performance") +def pretrain_recipe_performance( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Mixtral 8x3B model. + + This recipe enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=2, name='perf_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe_performance(name="mixtral_8x3b", num_nodes=4) + >>> print(recipe) + + Note: + Use this recipe with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn) + + recipe.trainer.callbacks.extend( + [ + run.Config(MegatronTokenDropCallback), + run.Config(MegatronCommOverlapCallback), + ] + ) + + return recipe + + def hf_resume() -> run.Config[nl.AutoResume]: """ Configure the Hugging Face model resuming for Mixtral 8x3B model. diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index 4c88c10a55a1..9000c66c3445 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -29,6 +29,8 @@ from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback from nemo.utils.exp_manager import TimingCallback NAME = "mixtral_8x7b" @@ -54,14 +56,14 @@ def model() -> run.Config[pl.LightningModule]: def trainer( - tensor_parallelism: int = 8, - pipeline_parallelism: int = 2, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 4, pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, - virtual_pipeline_parallelism: Optional[int] = None, + virtual_pipeline_parallelism: Optional[int] = 8, context_parallelism: int = 1, - sequence_parallelism: bool = True, - expert_parallelism: int = 1, - num_nodes: int = 2, + sequence_parallelism: bool = False, + expert_parallelism: int = 8, + num_nodes: int = 8, num_gpus_per_node: int = 8, max_steps: int = 1168251, callbacks: Optional[list[run.Config[Callback]]] = None, @@ -138,7 +140,7 @@ def trainer( @run.cli.factory(target=pretrain, name=NAME) def pretrain_recipe( - dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain + dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain ) -> run.Partial: """ Create a pre-training recipe for Mixtral 8x7B model. @@ -159,10 +161,10 @@ def pretrain_recipe( Examples: CLI usage: $ nemo llm pretrain --factory mixtral_8x7b - $ nemo llm pretrain --factory "mixtral_8x7b(num_nodes=2, name='my_mixtral_pretrain')" + $ nemo llm pretrain --factory "mixtral_8x7b(num_nodes=8, name='my_mixtral_pretrain')" Python API usage: - >>> recipe = pretrain_recipe(name="mixtral_8x7b_pretrain", num_nodes=2) + >>> recipe = pretrain_recipe(name="mixtral_8x7b_pretrain", num_nodes=8) >>> print(recipe) """ return run.Partial( @@ -178,6 +180,49 @@ def pretrain_recipe( ) +@run.cli.factory(target=pretrain, name=NAME + "_performance") +def pretrain_recipe_performance( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Mixtral 8x7B model. + + This recipe enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe_performance(name="mixtral_8x3b_perf", num_nodes=8) + >>> print(recipe) + + Note: + Use this recipe with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn) + recipe.trainer.callbacks.extend( + [ + run.Config(MegatronTokenDropCallback), + run.Config(MegatronCommOverlapCallback), + ] + ) + + return recipe + + def hf_resume() -> run.Config[nl.AutoResume]: """ Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x7B model. diff --git a/nemo/lightning/pytorch/callbacks/moe_token_drop.py b/nemo/lightning/pytorch/callbacks/moe_token_drop.py new file mode 100644 index 000000000000..fc2aea84f3c1 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/moe_token_drop.py @@ -0,0 +1,55 @@ +import pytorch_lightning as pl +from megatron.core import ModelParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy + + +class MegatronTokenDropCallback(Callback): + """ + A PyTorch Lightning callback to enable token drop for MOEs. Token drop improves performance by better + balancing work across experts, but may affect convergence. + + Args: + moe_expert_capacity_factor (float): The capacity factor for all experts + moe_pad_expert_input_to_capacity (bool): Pad the input for each expert to the expert capacity lengt + + Example: + >>> callback = MegatronCommOverlapCallback() + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__( + self, + moe_expert_capacity_factor: float = 1.0, + moe_pad_expert_input_to_capacity: bool = True, + ): + + if moe_expert_capacity_factor < 0: + moe_expert_capacity_factor = None + self.moe_expert_capacity_factor = moe_expert_capacity_factor + self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity + + def _set_cfgs(self, cfg): + cfg.moe_expert_capacity_factor = self.moe_expert_capacity_factor + cfg.moe_pad_expert_input_to_capacity = self.moe_pad_expert_input_to_capacity + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + assert isinstance(trainer.strategy, MegatronStrategy), "MegatronTokenDrop requires MegatronStrategy" + if hasattr(trainer.model, "config") and isinstance(trainer.model.config, ModelParallelConfig): + assert trainer.model.config.moe_token_dispatcher_type in [ + "alltoall", + "alltoall_seq", + ], 'moe_expert_capacity_factor only works with alltoall token dispatcher' + assert trainer.model.config.moe_router_load_balancing_type in [ + "aux_loss", + "none", + ], 'moe_expert_capacity_factor only works with aux_loss or none load balancing' + + if self.moe_pad_expert_input_to_capacity: + if self.moe_expert_capacity_factor is None: + raise ValueError('moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity') + + self._set_cfgs(trainer.model.config) + if hasattr(trainer.model, '__io__'): + self._set_cfgs(trainer.model.__io__.config) diff --git a/tests/collections/llm/recipes/test_mixtral_8x22b.py b/tests/collections/llm/recipes/test_mixtral_8x22b.py index f2891408c6d6..3f855721e14f 100644 --- a/tests/collections/llm/recipes/test_mixtral_8x22b.py +++ b/tests/collections/llm/recipes/test_mixtral_8x22b.py @@ -30,18 +30,18 @@ def test_trainer(self, recipe_module): assert trainer_config.__fn_or_cls__ == Trainer assert trainer_config.accelerator == "gpu" assert trainer_config.devices == 8 - assert trainer_config.num_nodes == 8 + assert trainer_config.num_nodes == 16 # Check strategy configuration assert isinstance(trainer_config.strategy, run.Config) assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" - assert trainer_config.strategy.tensor_model_parallel_size == 8 - assert trainer_config.strategy.pipeline_model_parallel_size == 8 + assert trainer_config.strategy.tensor_model_parallel_size == 2 + assert trainer_config.strategy.pipeline_model_parallel_size == 4 assert trainer_config.strategy.pipeline_dtype == torch.bfloat16 - assert trainer_config.strategy.virtual_pipeline_model_parallel_size == 7 - assert trainer_config.strategy.context_parallel_size == 1 + assert trainer_config.strategy.virtual_pipeline_model_parallel_size == 14 + assert trainer_config.strategy.context_parallel_size == 2 assert trainer_config.strategy.sequence_parallel is True - assert trainer_config.strategy.expert_model_parallel_size == 1 + assert trainer_config.strategy.expert_model_parallel_size == 8 # Check DDP configuration assert isinstance(trainer_config.strategy.ddp, run.Config) diff --git a/tests/collections/llm/recipes/test_mixtral_8x3b.py b/tests/collections/llm/recipes/test_mixtral_8x3b.py index 949246c54c2a..238fec74e0e1 100644 --- a/tests/collections/llm/recipes/test_mixtral_8x3b.py +++ b/tests/collections/llm/recipes/test_mixtral_8x3b.py @@ -28,18 +28,18 @@ def test_trainer(self, recipe_module): assert trainer_config.__fn_or_cls__ == Trainer assert trainer_config.accelerator == "gpu" assert trainer_config.devices == 8 - assert trainer_config.num_nodes == 1 + assert trainer_config.num_nodes == 2 # Check strategy configuration assert isinstance(trainer_config.strategy, run.Config) assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" - assert trainer_config.strategy.tensor_model_parallel_size == 4 + assert trainer_config.strategy.tensor_model_parallel_size == 1 assert trainer_config.strategy.pipeline_model_parallel_size == 1 assert trainer_config.strategy.pipeline_dtype is None assert trainer_config.strategy.virtual_pipeline_model_parallel_size is None assert trainer_config.strategy.context_parallel_size == 1 - assert trainer_config.strategy.sequence_parallel is True - assert trainer_config.strategy.expert_model_parallel_size == 1 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.expert_model_parallel_size == 4 def test_pretrain_recipe(self, recipe_module): recipe = recipe_module.pretrain_recipe() diff --git a/tests/collections/llm/recipes/test_mixtral_8x7b.py b/tests/collections/llm/recipes/test_mixtral_8x7b.py index ff8e2ee0724e..75003891930d 100644 --- a/tests/collections/llm/recipes/test_mixtral_8x7b.py +++ b/tests/collections/llm/recipes/test_mixtral_8x7b.py @@ -30,18 +30,18 @@ def test_trainer(self, recipe_module): assert trainer_config.__fn_or_cls__ == Trainer assert trainer_config.accelerator == "gpu" assert trainer_config.devices == 8 - assert trainer_config.num_nodes == 2 + assert trainer_config.num_nodes == 8 # Check strategy configuration assert isinstance(trainer_config.strategy, run.Config) assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" - assert trainer_config.strategy.tensor_model_parallel_size == 8 - assert trainer_config.strategy.pipeline_model_parallel_size == 2 + assert trainer_config.strategy.tensor_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_model_parallel_size == 4 assert trainer_config.strategy.pipeline_dtype == torch.bfloat16 - assert trainer_config.strategy.virtual_pipeline_model_parallel_size is None + assert trainer_config.strategy.virtual_pipeline_model_parallel_size == 8 assert trainer_config.strategy.context_parallel_size == 1 - assert trainer_config.strategy.sequence_parallel is True - assert trainer_config.strategy.expert_model_parallel_size == 1 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.expert_model_parallel_size == 8 # Check DDP configuration assert isinstance(trainer_config.strategy.ddp, run.Config)