-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
r""" | ||
Conversion script to convert NeMo Mistral-7B checkpoints into HuggingFace checkpoint. | ||
Example to run this conversion script: | ||
python3 convert_nemo_mistral_7b_to_hf.py \ | ||
--in-file <path_to_nemo_checkpoints_folder> \ | ||
--out-file <path_to_output_hf_file> | ||
""" | ||
|
||
from argparse import ArgumentParser | ||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.nn | ||
from pytorch_lightning.trainer.trainer import Trainer | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel | ||
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy | ||
from nemo.utils import logging | ||
|
||
|
||
def get_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument("--in-file", type=str, default=None, required=True, help="Path to NeMo Mistral-7B checkpoint") | ||
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output HF checkpoint.") | ||
parser.add_argument('--hf-model-name', type=str, default="mistralai/Mistral-7B-v0.1", help="Name of HF checkpoint") | ||
parser.add_argument("--precision", type=str, default="32", help="Model precision") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def load_config(hf_model_name, nemo_config): | ||
hf_config = AutoConfig.from_pretrained(hf_model_name) | ||
# SWA; nemo_config.window_size is list [left-bound, right-bound] | ||
hf_config.sliding_window = nemo_config.window_size[0] | ||
hf_config.max_position_embeddings = nemo_config.encoder_seq_length | ||
hf_config.num_hidden_layers = nemo_config.num_layers | ||
hf_config.hidden_size = nemo_config.hidden_size | ||
hf_config.intermediate_size = nemo_config.ffn_hidden_size | ||
hf_config.num_attention_heads = nemo_config.num_attention_heads | ||
hf_config.max_position_embeddings = nemo_config.max_position_embeddings | ||
hf_config.initializer_range = nemo_config.init_method_std | ||
hf_config.rms_norm_eps = nemo_config.layernorm_epsilon | ||
hf_config.num_key_value_heads = nemo_config.num_query_groups | ||
if nemo_config.activation == 'fast-swiglu': | ||
hf_config.activation = 'silu' | ||
else: | ||
logging.warning(f"Got unknown activation function {nemo_config.activation}") | ||
|
||
hf_config.rope_theta = nemo_config['rotary_base'] | ||
return hf_config | ||
|
||
|
||
def convert(in_file, precision=None) -> None: | ||
""" | ||
Convert NeMo checkpoint to HF checkpoint | ||
""" | ||
|
||
logging.info(f'Loading NeMo checkpoint from: {in_file}') | ||
|
||
dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) | ||
model_config = MegatronGPTModel.restore_from(in_file, trainer=dummy_trainer, return_config=True) | ||
model_config.tensor_model_parallel_size = 1 | ||
model_config.pipeline_model_parallel_size = 1 | ||
cpu_only = True | ||
if cpu_only: | ||
map_location = torch.device('cpu') | ||
model_config.use_cpu_initialization = True | ||
else: | ||
map_location = None | ||
Check warning Code scanning / CodeQL Unreachable code Warning
This statement is unreachable.
|
||
|
||
if cpu_only: | ||
logging.info("******** Loading model on CPU. This will take a significant amount of time.") | ||
model = MegatronGPTModel.restore_from( | ||
in_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location | ||
) | ||
ckpt = model.state_dict() | ||
nemo_config = model.cfg | ||
|
||
mcore_gpt = nemo_config.mcore_gpt | ||
hidden_size = nemo_config.hidden_size | ||
head_num = nemo_config.num_attention_heads | ||
head_size = hidden_size // head_num | ||
Check warning Code scanning / CodeQL Variable defined multiple times Warning
This assignment to 'head_size' is unnecessary as it is
redefined Error loading related location Loading |
||
num_layers = nemo_config.num_layers | ||
Check warning Code scanning / CodeQL Variable defined multiple times Warning
This assignment to 'num_layers' is unnecessary as it is
redefined Error loading related location Loading |
||
|
||
if precision is None: | ||
precision = model.cfg.precision | ||
if precision in [32, "32"]: | ||
dtype = torch.float32 | ||
elif precision in [16, "16", "16-mixed"]: | ||
dtype = torch.float16 | ||
elif precision in ["bf16", "bf16-mixed"]: | ||
dtype = torch.bfloat16 | ||
else: | ||
logging.warning(f"Precision string {precision} is not recognized, falling back to fp32") | ||
dtype = torch.float32 # fallback | ||
param_to_weights = lambda param: param.to(dtype) | ||
|
||
state_dict = OrderedDict() | ||
|
||
hf_embed_weight_name = f'model.embed_tokens.weight' | ||
if mcore_gpt: | ||
embed_weights_base_name = f'model.embedding.word_embeddings.weight' | ||
else: | ||
embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight' | ||
state_dict[hf_embed_weight_name] = param_to_weights(ckpt[embed_weights_base_name]) | ||
|
||
if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num: | ||
num_query_groups = head_num | ||
else: | ||
num_query_groups = nemo_config.num_query_groups | ||
assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups' | ||
if mcore_gpt: | ||
assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.' | ||
|
||
hidden_size = model.cfg.hidden_size | ||
head_num = model.cfg.num_attention_heads | ||
num_layers = model.cfg.num_layers | ||
num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B | ||
|
||
head_size = hidden_size // head_num | ||
heads_per_group = head_num // num_query_groups | ||
qkv_total_dim = head_num + 2 * num_query_groups | ||
|
||
# Embedding | ||
embed_weight = model.state_dict()[f'model.embedding.word_embeddings.weight'] | ||
embed_weights_base_name = f'model.embed_tokens.weight' | ||
state_dict[embed_weights_base_name] = param_to_weights(embed_weight) | ||
|
||
for l in range(int(num_layers)): | ||
print(f"converting layer {l}") | ||
|
||
qkv_weights = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'] | ||
qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) | ||
|
||
q_slice = torch.cat( | ||
[ | ||
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) | ||
for i in range(num_query_groups) | ||
] | ||
) | ||
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) | ||
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) | ||
|
||
for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: | ||
weight_name = f'model.layers.{l}.self_attn.{name}.weight' | ||
state_dict[weight_name] = param_to_weights(qkv_weights[slice].reshape(-1, hidden_size)) | ||
|
||
# attention dense | ||
hf_o_weight_name = f'model.layers.{l}.self_attn.o_proj.weight' | ||
if mcore_gpt: | ||
o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight' | ||
else: | ||
o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight' | ||
state_dict[hf_o_weight_name] = param_to_weights(ckpt[o_weight_base_name]) | ||
|
||
# # MLP | ||
if mcore_gpt: | ||
mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight' | ||
else: | ||
raise Exception("not implemented") | ||
gate_proj_weight, up_proj_weight = torch.chunk(ckpt[mlp_down_base_name], 2, dim=0) | ||
hf_gate_proj_name = f'model.layers.{l}.mlp.gate_proj.weight' | ||
hf_up_proj_name = f'model.layers.{l}.mlp.up_proj.weight' | ||
state_dict[hf_gate_proj_name] = param_to_weights(gate_proj_weight) | ||
state_dict[hf_up_proj_name] = param_to_weights(up_proj_weight) | ||
|
||
hf_mlp_up_weight_name = f'model.layers.{l}.mlp.down_proj.weight' | ||
if mcore_gpt: | ||
mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight' | ||
else: | ||
raise Exception("not implemented") | ||
state_dict[hf_mlp_up_weight_name] = param_to_weights(ckpt[mlp_up_base_name]) | ||
|
||
# LayerNorm | ||
hf_input_ln_weight_name = f'model.layers.{l}.input_layernorm.weight' | ||
if mcore_gpt: | ||
input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' | ||
else: | ||
input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight' | ||
Check warning Code scanning / CodeQL Unreachable code Warning
This statement is unreachable.
|
||
state_dict[hf_input_ln_weight_name] = param_to_weights(ckpt[input_ln_base_name]) | ||
|
||
hf_post_attn_ln_weight_name = f'model.layers.{l}.post_attention_layernorm.weight' | ||
if mcore_gpt: | ||
post_attn_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' | ||
else: | ||
post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight' | ||
Check warning Code scanning / CodeQL Unreachable code Warning
This statement is unreachable.
|
||
state_dict[hf_post_attn_ln_weight_name] = param_to_weights(ckpt[post_attn_ln_base_name]) | ||
|
||
hf_final_ln_weight_name = 'model.norm.weight' | ||
if mcore_gpt: | ||
final_ln_base_name = 'model.decoder.final_layernorm.weight' | ||
else: | ||
final_ln_base_name = 'model.language_model.encoder.final_layernorm.weight' | ||
state_dict[hf_final_ln_weight_name] = param_to_weights(ckpt[final_ln_base_name]) | ||
|
||
hf_output_layer_weight_name = 'lm_head.weight' | ||
if mcore_gpt: | ||
output_layer_base_name = 'model.output_layer.weight' | ||
else: | ||
output_layer_base_name = 'model.language_model.output_layer.weight' | ||
state_dict[hf_output_layer_weight_name] = param_to_weights(ckpt[output_layer_base_name]) | ||
return state_dict, nemo_config | ||
|
||
|
||
if __name__ == '__main__': | ||
args = get_args() | ||
hf_state_dict, nemo_config = convert(args.in_file, args.precision) | ||
|
||
config = load_config(args.hf_model_name, nemo_config) | ||
model = AutoModelForCausalLM.from_config(config) | ||
model.load_state_dict(hf_state_dict) | ||
model.save_pretrained(args.out_file) | ||
hf_tokenizer = AutoTokenizer.from_pretrained(args.hf_model_name) | ||
hf_tokenizer.save_pretrained(args.out_file) | ||
logging.info(f'HF checkpoint saved to: {args.out_file}') |