diff --git a/scripts/nlp_language_modeling/convert_mistral_7b_to_nemo.py b/scripts/nlp_language_modeling/convert_mistral_7b_to_nemo.py index 25c5596ec8ee7..8fccc3ef16c9c 100644 --- a/scripts/nlp_language_modeling/convert_mistral_7b_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_mistral_7b_to_nemo.py @@ -21,21 +21,19 @@ [--fast-swiglu\ """ +import json import os from argparse import ArgumentParser from collections import OrderedDict import torch +import torch.nn from omegaconf import OmegaConf from pytorch_lightning.core.saving import _load_state as ptl_load_state from pytorch_lightning.trainer.trainer import Trainer - -import torch.nn -import json from sentencepiece import SentencePieceProcessor from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel - from nemo.collections.nlp.parts.nlp_overrides import ( GradScaler, MegatronHalfPrecisionPlugin, @@ -46,7 +44,6 @@ from nemo.utils import logging - def get_args(): parser = ArgumentParser() parser.add_argument( @@ -133,6 +130,7 @@ def load_config(mistral_config, tokenizer_path): return nemo_config + def load_mistral_ckpt(dir): params_file = os.path.join(dir, 'params.json') assert os.path.exists(params_file) @@ -148,6 +146,7 @@ def load_mistral_ckpt(dir): assert tokenizer.get_piece_size() == model_args['vocab_size'] return model_args, ckpt, tokenizer + def convert(args): logging.info(f"loading checkpoint {args.in_file}") @@ -229,8 +228,6 @@ def convert(args): if mcore_gpt: assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.' - - for l in range(int(num_layers)): print(f"converting layer {l}") old_tensor_shape = ckpt[f'layers.{l}.attention.wq.weight'].size() @@ -279,7 +276,6 @@ def convert(args): mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight' checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight) - # LayerNorm input_ln_weight = ckpt[f'layers.{l}.attention_norm.weight'] @@ -332,6 +328,7 @@ def convert(args): model.save_to(args.out_file) logging.info(f'NeMo model saved to: {args.out_file}') + if __name__ == '__main__': args = get_args() - convert(args) \ No newline at end of file + convert(args)