Skip to content

Commit

Permalink
import in function to fix test
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Cui <chcui@nvidia.com>
  • Loading branch information
cuichenx authored Aug 16, 2024
1 parent ff29c45 commit 391ef23
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets
from nemo.collections.nlp.data.language_modeling.megatron.gpt_fim_dataset import GPTFIMDataset, GPTFIMDatasetConfig
from nemo.collections.nlp.models.language_modeling.megatron.falcon.falcon_spec import get_falcon_layer_spec
from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_spec import get_gemma2_layer_spec
from nemo.collections.nlp.models.language_modeling.megatron.gpt_full_te_layer_autocast_spec import (
get_gpt_full_te_layer_autocast_spec,
)
Expand Down Expand Up @@ -155,6 +154,8 @@ def mcore_supports_moe() -> bool:

## TODO: This function will not work if TE is not installed
def get_specs(spec_name, transformer_config=None, use_te=True, hyena_cfg: Dict = None):
from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_spec import get_gemma2_layer_spec

# else cases for backwards compatibility with neva
num_experts = transformer_config.num_moe_experts if transformer_config else None
moe_grouped_gemm = transformer_config.moe_grouped_gemm if transformer_config else False
Expand Down

0 comments on commit 391ef23

Please sign in to comment.