Skip to content

Commit

Permalink
Recipe tuning for mixtral, nemotron4
Browse files Browse the repository at this point in the history
Signed-off-by: Guyue Huang <guyueh@login-preos01.a51.clusters.nvidia.com>
  • Loading branch information
Guyue Huang committed Jan 7, 2025
1 parent 65aad76 commit 4829c33
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class MixtralConfig(GPTConfig):
rotary_base: float = 1000000.0
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16

apply_rope_fusion: bool = True
bias_activation_fusion: bool = True

@dataclass
class MixtralConfig8x3B(MixtralConfig):
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/llm/recipes/nemotron4_15b.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192
)
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

Expand Down Expand Up @@ -202,6 +205,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192,
)
)
return recipe
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/llm/recipes/nemotron4_340b.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096
)
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

Expand Down Expand Up @@ -209,6 +212,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=22,
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
Expand Down
14 changes: 14 additions & 0 deletions nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,17 @@ class TransformerLayerTPOverlapCfg:
proj_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True),
fc2_fprop=RingExchangeOverlapCfg(num_sm=1, set_sm_margin=True),
)

# Nemotron 340B
userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096 = TransformerLayerTPOverlapCfg(
qkv_dgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False),
qkv_wgrad=BulkOverlapCfg(num_sm=32, cga_size=2, set_sm_margin=False),
fc1_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False),
fc1_wgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False),
qkv_fprop=RingExchangeOverlapCfg(aggregate=False),
proj_dgrad=RingExchangeOverlapCfg(aggregate=False),
fc1_fprop=RingExchangeOverlapCfg(aggregate=False),
fc2_dgrad=RingExchangeOverlapCfg(aggregate=False),
proj_fprop=PipelineOverlapCfg(num_sm=32, cga_size=2, num_splits=2, set_sm_margin=True, fp8_buf=True),
fc2_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True),
)

0 comments on commit 4829c33

Please sign in to comment.