From b9b309f5568887523e6f3e1d57aba58bd6262858 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 25 Jun 2024 05:18:37 -0700 Subject: [PATCH] extend get_gpt_layer_modelopt_spec to support MoE Signed-off-by: Alexandros Koumparoulis --- .../megatron/gpt_layer_modelopt_spec.py | 39 ++++++++++++++----- .../language_modeling/megatron_gpt_model.py | 2 +- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py index f9ba58736cbd3..d4ea6bfcf0941 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py @@ -21,6 +21,7 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules + from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules @@ -38,7 +39,7 @@ # Use this spec for Model Optimizer PTQ and TensorRT-LLM export -def get_gpt_layer_modelopt_spec() -> ModuleSpec: +def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec: """Mix the native spec with TENorm. This is essentially the native local spec except for the layernorm implementation @@ -65,18 +66,38 @@ def get_gpt_layer_modelopt_spec() -> ModuleSpec: ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, - ), - ), + mlp=_get_mlp_module_spec(num_experts=num_experts), mlp_bda=get_bias_dropout_add, # Map TE-layernorm-fusion keys back sharded_state_dict_keys_map={ 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', - 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + **({'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_'} if num_experts is None else {}), }, ), ) + + +# Helper function to get module spec for MLP/MoE +def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec: + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ), + ) + else: + # Mixture of experts with modules in megatron core. + return ModuleSpec( + module=MoELayer, + submodules=( + MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ) + if not moe_grouped_gemm + else None + ), + ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f603e853cb103..fc57b208f1140 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -155,7 +155,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, "te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm), "megatron_falcon_gpt": get_falcon_layer_spec(), "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(), - "modelopt": get_gpt_layer_modelopt_spec(), + "modelopt": get_gpt_layer_modelopt_spec(num_experts), "te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg), } if spec_name not in name_spec_dict: