Skip to content

Commit

Permalink
Initialize distributed environment in convert_hf_mixtral_to_nemo.py.
Browse files Browse the repository at this point in the history
Why do we need this instead of just running it with python3?
`get_tensor_and_expert_parallel_world_size` requires
a distributed environment and is called when initializing Mixtral's MoE layers.

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Jan 19, 2024
1 parent c66aec6 commit 2f67223
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions scripts/nlp_language_modeling/convert_hf_mixtral_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
r"""
Conversion script to convert Huggingface Mixtral checkpoints into NeMo checkpoint.
Example to run this conversion script:
python3 convert_hf_mixtral_to_nemo.py \
--in-file <path_to_mixtral_checkpoints_folder> \
--out-file <path_to_output_nemo_file> \
[--fast-swiglu\
"""
torchrun --nproc_per_node=1 convert_hf_mixtral_to_nemo.py \
--in-file <path_to_mixtral_checkpoints_folder> \
--out-file <path_to_output_nemo_file>
You must call this from a distributed environment, see also comments in __main__."""

import json
import os
from argparse import ArgumentParser
from collections import OrderedDict

import megatron.core.parallel_state as parallel_state
import torch
import torch.distributed
import torch.nn
from omegaconf import OmegaConf
from pytorch_lightning.core.saving import _load_state as ptl_load_state
Expand All @@ -42,7 +43,7 @@
PipelineMixedPrecisionPlugin,
)
from nemo.utils import logging

from nemo.utils.app_state import AppState

def get_args():
parser = ArgumentParser()
Expand Down Expand Up @@ -339,5 +340,15 @@ def convert(args):


if __name__ == '__main__':
# @akoumparouli: `get_tensor_and_expert_parallel_world_size` requires
# a distributed environment and is called when initializing Mixtral's MoE layers.
# TODO(akoumparouli): remove distributed environment requirement
assert torch.distributed.is_available(), "Please run this with `torchrun --nproc_per_node=1`"
torch.distributed.init_process_group()
assert torch.distributed.is_initialized(), "Distributed group failed to initialize"
parallel_state.initialize_model_parallel()
# nlp_overrides
app_state = AppState()
app_state.data_parallel_rank = 0
args = get_args()
convert(args)

0 comments on commit 2f67223

Please sign in to comment.