Skip to content

Commit

Permalink
extend get_gpt_layer_modelopt_spec to support MoE
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Jun 25, 2024
1 parent 26aef8e commit b9b309f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

Expand All @@ -38,7 +39,7 @@


# Use this spec for Model Optimizer PTQ and TensorRT-LLM export
def get_gpt_layer_modelopt_spec() -> ModuleSpec:
def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
Expand All @@ -65,18 +66,38 @@ def get_gpt_layer_modelopt_spec() -> ModuleSpec:
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
mlp=_get_mlp_module_spec(num_experts=num_experts),
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
**({'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_'} if num_experts is None else {}),
},
),
)


# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec:
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return ModuleSpec(
module=MoELayer,
submodules=(
MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
)
if not moe_grouped_gemm
else None
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True,
"te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm),
"megatron_falcon_gpt": get_falcon_layer_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(),
"modelopt": get_gpt_layer_modelopt_spec(),
"modelopt": get_gpt_layer_modelopt_spec(num_experts),
"te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg),
}
if spec_name not in name_spec_dict:
Expand Down

0 comments on commit b9b309f

Please sign in to comment.