Skip to content

Commit

Permalink
[NeMo-UX] Add token drop callback and optimize mixtral configs (NVIDI…
Browse files Browse the repository at this point in the history
…A#10361)

* add token drop plugin

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* add checks

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* add expert parallel configs

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>

* amend comment

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>

* add comm overlap

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* fix rebase errors

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>

* fix typo

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* add test configs

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* fix

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

---------

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
Signed-off-by: George Armstrong <georgea@nvidia.com>
  • Loading branch information
5 people authored and gwarmstrong committed Sep 19, 2024
1 parent 34041b5 commit 1035c70
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 43 deletions.
3 changes: 1 addition & 2 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ class MixtralConfig(GPTConfig):
# MoE
num_moe_experts: int = 8
moe_aux_loss_coeff: float = 0.01
moe_expert_capacity_factor: float = 1.0
moe_pad_expert_input_to_capacity: bool = True
moe_router_topk: int = 2
moe_router_pre_softmax: bool = True
moe_token_dispatcher_type: str = "alltoall"
moe_router_load_balancing_type: str = 'aux_loss'

init_method_std: float = 0.02
layernorm_epsilon: float = 1e-5
Expand Down
65 changes: 55 additions & 10 deletions nemo/collections/llm/recipes/mixtral_8x22b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x22b"
Expand All @@ -54,14 +56,14 @@ def model() -> run.Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int = 8,
pipeline_parallelism: int = 8,
tensor_parallelism: int = 2,
pipeline_parallelism: int = 4,
pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = 7,
context_parallelism: int = 1,
virtual_pipeline_parallelism: Optional[int] = 14,
context_parallelism: int = 2,
sequence_parallelism: bool = True,
expert_parallelism: int = 1,
num_nodes: int = 8,
expert_parallelism: int = 8,
num_nodes: int = 16,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
Expand Down Expand Up @@ -92,7 +94,7 @@ def trainer(
$ nemo llm pretrain trainer=mixtral_8x22b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=8, num_gpus_per_node=8)
>>> trainer_config = trainer(num_nodes=16, num_gpus_per_node=8)
>>> print(trainer_config)
Note:
Expand Down Expand Up @@ -139,7 +141,7 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None, name: str = "default", num_nodes: int = 16, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mixtral 8x22B model.
Expand All @@ -160,10 +162,10 @@ def pretrain_recipe(
Examples:
CLI usage:
$ nemo llm pretrain --factory mixtral_8x22b
$ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=2, name='my_mixtral_pretrain')"
$ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=16, name='my_mixtral_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=2)
>>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=16)
>>> print(recipe)
"""
return run.Partial(
Expand All @@ -179,6 +181,49 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x22B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x22b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x22b_perf", num_nodes=8)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x22B model.
Expand Down
58 changes: 52 additions & 6 deletions nemo/collections/llm/recipes/mixtral_8x3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x3b"
Expand All @@ -55,14 +57,14 @@ def model() -> run.Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int = 4,
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = True,
expert_parallelism: int = 1,
num_nodes: int = 1,
sequence_parallelism: bool = False,
expert_parallelism: int = 4,
num_nodes: int = 2,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
Expand Down Expand Up @@ -93,7 +95,7 @@ def trainer(
$ nemo llm pretrain trainer=mixtral_8x3b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=8)
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
>>> print(trainer_config)
"""
strategy = run.Config(
Expand Down Expand Up @@ -139,7 +141,7 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mixtral 8x3B model.
Expand Down Expand Up @@ -181,6 +183,50 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x3B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=2, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x3b", num_nodes=4)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure the Hugging Face model resuming for Mixtral 8x3B model.
Expand Down
63 changes: 54 additions & 9 deletions nemo/collections/llm/recipes/mixtral_8x7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x7b"
Expand All @@ -54,14 +56,14 @@ def model() -> run.Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int = 8,
pipeline_parallelism: int = 2,
tensor_parallelism: int = 1,
pipeline_parallelism: int = 4,
pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = None,
virtual_pipeline_parallelism: Optional[int] = 8,
context_parallelism: int = 1,
sequence_parallelism: bool = True,
expert_parallelism: int = 1,
num_nodes: int = 2,
sequence_parallelism: bool = False,
expert_parallelism: int = 8,
num_nodes: int = 8,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
Expand Down Expand Up @@ -138,7 +140,7 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mixtral 8x7B model.
Expand All @@ -159,10 +161,10 @@ def pretrain_recipe(
Examples:
CLI usage:
$ nemo llm pretrain --factory mixtral_8x7b
$ nemo llm pretrain --factory "mixtral_8x7b(num_nodes=2, name='my_mixtral_pretrain')"
$ nemo llm pretrain --factory "mixtral_8x7b(num_nodes=8, name='my_mixtral_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mixtral_8x7b_pretrain", num_nodes=2)
>>> recipe = pretrain_recipe(name="mixtral_8x7b_pretrain", num_nodes=8)
>>> print(recipe)
"""
return run.Partial(
Expand All @@ -178,6 +180,49 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x7B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x3b_perf", num_nodes=8)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x7B model.
Expand Down
55 changes: 55 additions & 0 deletions nemo/lightning/pytorch/callbacks/moe_token_drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytorch_lightning as pl
from megatron.core import ModelParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy


class MegatronTokenDropCallback(Callback):
"""
A PyTorch Lightning callback to enable token drop for MOEs. Token drop improves performance by better
balancing work across experts, but may affect convergence.
Args:
moe_expert_capacity_factor (float): The capacity factor for all experts
moe_pad_expert_input_to_capacity (bool): Pad the input for each expert to the expert capacity lengt
Example:
>>> callback = MegatronCommOverlapCallback()
>>> trainer = Trainer(callbacks=[callback])
"""

def __init__(
self,
moe_expert_capacity_factor: float = 1.0,
moe_pad_expert_input_to_capacity: bool = True,
):

if moe_expert_capacity_factor < 0:
moe_expert_capacity_factor = None
self.moe_expert_capacity_factor = moe_expert_capacity_factor
self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity

def _set_cfgs(self, cfg):
cfg.moe_expert_capacity_factor = self.moe_expert_capacity_factor
cfg.moe_pad_expert_input_to_capacity = self.moe_pad_expert_input_to_capacity

def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
assert isinstance(trainer.strategy, MegatronStrategy), "MegatronTokenDrop requires MegatronStrategy"
if hasattr(trainer.model, "config") and isinstance(trainer.model.config, ModelParallelConfig):
assert trainer.model.config.moe_token_dispatcher_type in [
"alltoall",
"alltoall_seq",
], 'moe_expert_capacity_factor only works with alltoall token dispatcher'
assert trainer.model.config.moe_router_load_balancing_type in [
"aux_loss",
"none",
], 'moe_expert_capacity_factor only works with aux_loss or none load balancing'

if self.moe_pad_expert_input_to_capacity:
if self.moe_expert_capacity_factor is None:
raise ValueError('moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity')

self._set_cfgs(trainer.model.config)
if hasattr(trainer.model, '__io__'):
self._set_cfgs(trainer.model.__io__.config)
Loading

0 comments on commit 1035c70

Please sign in to comment.