Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Add token drop callback and optimize mixtral configs #10361

Merged
merged 21 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ class MixtralConfig(GPTConfig):
# MoE
num_moe_experts: int = 8
moe_aux_loss_coeff: float = 0.01
moe_expert_capacity_factor: float = 1.0
moe_pad_expert_input_to_capacity: bool = True
moe_router_topk: int = 2
moe_router_pre_softmax: bool = True
moe_token_dispatcher_type: str = "alltoall"
moe_router_load_balancing_type: str = 'aux_loss'

init_method_std: float = 0.02
layernorm_epsilon: float = 1e-5
Expand Down
65 changes: 55 additions & 10 deletions nemo/collections/llm/recipes/mixtral_8x22b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x22b"
Expand All @@ -54,14 +56,14 @@ def model() -> run.Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int = 8,
pipeline_parallelism: int = 8,
tensor_parallelism: int = 2,
pipeline_parallelism: int = 4,
pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = 7,
context_parallelism: int = 1,
virtual_pipeline_parallelism: Optional[int] = 14,
context_parallelism: int = 2,
sequence_parallelism: bool = True,
expert_parallelism: int = 1,
num_nodes: int = 8,
expert_parallelism: int = 8,
num_nodes: int = 16,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
Expand Down Expand Up @@ -92,7 +94,7 @@ def trainer(
$ nemo llm pretrain trainer=mixtral_8x22b ...

Python API usage:
>>> trainer_config = trainer(num_nodes=8, num_gpus_per_node=8)
>>> trainer_config = trainer(num_nodes=16, num_gpus_per_node=8)
>>> print(trainer_config)

Note:
Expand Down Expand Up @@ -139,7 +141,7 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None, name: str = "default", num_nodes: int = 16, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mixtral 8x22B model.
Expand All @@ -160,10 +162,10 @@ def pretrain_recipe(
Examples:
CLI usage:
$ nemo llm pretrain --factory mixtral_8x22b
$ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=2, name='my_mixtral_pretrain')"
$ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=16, name='my_mixtral_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=2)
>>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=16)
>>> print(recipe)
"""
return run.Partial(
Expand All @@ -179,6 +181,49 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x22B model.

This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.

Returns:
run.Partial: Partial configuration for performance-optimized pre-training.

Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x22b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x22b_perf", num_nodes=8)
>>> print(recipe)

Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x22B model.
Expand Down
58 changes: 52 additions & 6 deletions nemo/collections/llm/recipes/mixtral_8x3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x3b"
Expand All @@ -55,14 +57,14 @@ def model() -> run.Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int = 4,
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = True,
expert_parallelism: int = 1,
num_nodes: int = 1,
sequence_parallelism: bool = False,
expert_parallelism: int = 4,
num_nodes: int = 2,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
Expand Down Expand Up @@ -93,7 +95,7 @@ def trainer(
$ nemo llm pretrain trainer=mixtral_8x3b ...

Python API usage:
>>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=8)
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
>>> print(trainer_config)
"""
strategy = run.Config(
Expand Down Expand Up @@ -139,7 +141,7 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mixtral 8x3B model.
Expand Down Expand Up @@ -181,6 +183,50 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x3B model.

This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.

Returns:
run.Partial: Partial configuration for performance-optimized pre-training.

Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=2, name='perf_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x3b", num_nodes=4)
>>> print(recipe)

Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure the Hugging Face model resuming for Mixtral 8x3B model.
Expand Down
63 changes: 54 additions & 9 deletions nemo/collections/llm/recipes/mixtral_8x7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x7b"
Expand All @@ -54,14 +56,14 @@ def model() -> run.Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int = 8,
pipeline_parallelism: int = 2,
tensor_parallelism: int = 1,
pipeline_parallelism: int = 4,
pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = None,
virtual_pipeline_parallelism: Optional[int] = 8,
context_parallelism: int = 1,
sequence_parallelism: bool = True,
expert_parallelism: int = 1,
num_nodes: int = 2,
sequence_parallelism: bool = False,
expert_parallelism: int = 8,
num_nodes: int = 8,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
Expand Down Expand Up @@ -138,7 +140,7 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mixtral 8x7B model.
Expand All @@ -159,10 +161,10 @@ def pretrain_recipe(
Examples:
CLI usage:
$ nemo llm pretrain --factory mixtral_8x7b
$ nemo llm pretrain --factory "mixtral_8x7b(num_nodes=2, name='my_mixtral_pretrain')"
$ nemo llm pretrain --factory "mixtral_8x7b(num_nodes=8, name='my_mixtral_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe(name="mixtral_8x7b_pretrain", num_nodes=2)
>>> recipe = pretrain_recipe(name="mixtral_8x7b_pretrain", num_nodes=8)
>>> print(recipe)
"""
return run.Partial(
Expand All @@ -178,6 +180,49 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x7B model.

This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.

Returns:
run.Partial: Partial configuration for performance-optimized pre-training.

Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x3b_perf", num_nodes=8)
>>> print(recipe)

Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x7B model.
Expand Down
55 changes: 55 additions & 0 deletions nemo/lightning/pytorch/callbacks/moe_token_drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytorch_lightning as pl
from megatron.core import ModelParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy


class MegatronTokenDropCallback(Callback):
"""
A PyTorch Lightning callback to enable token drop for MOEs. Token drop improves performance by better
balancing work across experts, but may affect convergence.

Args:
moe_expert_capacity_factor (float): The capacity factor for all experts
moe_pad_expert_input_to_capacity (bool): Pad the input for each expert to the expert capacity lengt

Example:
>>> callback = MegatronCommOverlapCallback()
>>> trainer = Trainer(callbacks=[callback])
"""

def __init__(
self,
moe_expert_capacity_factor: float = 1.0,
moe_pad_expert_input_to_capacity: bool = True,
):

if moe_expert_capacity_factor < 0:
moe_expert_capacity_factor = None
self.moe_expert_capacity_factor = moe_expert_capacity_factor
self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity

def _set_cfgs(self, cfg):
cfg.moe_expert_capacity_factor = self.moe_expert_capacity_factor
cfg.moe_pad_expert_input_to_capacity = self.moe_pad_expert_input_to_capacity

def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
assert isinstance(trainer.strategy, MegatronStrategy), "MegatronTokenDrop requires MegatronStrategy"
if hasattr(trainer.model, "config") and isinstance(trainer.model.config, ModelParallelConfig):
assert trainer.model.config.moe_token_dispatcher_type in [
"alltoall",
"alltoall_seq",
], 'moe_expert_capacity_factor only works with alltoall token dispatcher'
assert trainer.model.config.moe_router_load_balancing_type in [
"aux_loss",
"none",
], 'moe_expert_capacity_factor only works with aux_loss or none load balancing'

if self.moe_pad_expert_input_to_capacity:
if self.moe_expert_capacity_factor is None:
raise ValueError('moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity')

self._set_cfgs(trainer.model.config)
if hasattr(trainer.model, '__io__'):
self._set_cfgs(trainer.model.__io__.config)
Loading
Loading