diff --git a/scripts/nlp_language_modeling/convert_hf_mixtral_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_mixtral_to_nemo.py index a1295570a1257..18734a3ffdcb4 100644 --- a/scripts/nlp_language_modeling/convert_hf_mixtral_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_mixtral_to_nemo.py @@ -42,7 +42,7 @@ PipelineMixedPrecisionPlugin, ) from nemo.utils import logging - +import megatron.core.parallel_state as parallel_state def get_args(): parser = ArgumentParser() @@ -340,4 +340,5 @@ def convert(args): if __name__ == '__main__': args = get_args() + parallel_state.set_cpu_expert_model_parallel_world_size(1) convert(args) diff --git a/scripts/nlp_language_modeling/convert_nemo_mixtral_to_hf.py b/scripts/nlp_language_modeling/convert_nemo_mixtral_to_hf.py index d247652fe3e0f..095b8e82caf93 100644 --- a/scripts/nlp_language_modeling/convert_nemo_mixtral_to_hf.py +++ b/scripts/nlp_language_modeling/convert_nemo_mixtral_to_hf.py @@ -32,7 +32,7 @@ from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.utils import logging - +import megatron.core.parallel_state as parallel_state def get_args(): parser = ArgumentParser() @@ -231,6 +231,7 @@ def convert(in_file, precision=None) -> None: if __name__ == '__main__': args = get_args() + parallel_state.set_cpu_expert_model_parallel_world_size(1) hf_state_dict, nemo_config = convert(args.in_file, args.precision) config = load_config(args.hf_model_name, nemo_config)