diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index cba99f1faefd..66ce449616cf 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -15,11 +15,11 @@ from nemo.lightning.megatron_parallel import MaskedTokenLossReduction from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule -HAVE_TE=True +HAVE_TE = True try: import transformer_engine except (ImportError, ModuleNotFoundError): - HAVE_TE=False + HAVE_TE = False if TYPE_CHECKING: from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel @@ -82,12 +82,14 @@ def local_layer_spec(config: "GPTConfig") -> ModuleSpec: num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm ) + def default_layer_spec(config: "GPTConfig") -> ModuleSpec: if HAVE_TE: return transformer_engine_layer_spec(config) else: return local_layer_spec(config) + @dataclass class GPTConfig(TransformerConfig, io.IOMixin): # From megatron.core.models.gpt.gpt_model.GPTModel