Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor MPT-7B fixes and creation script update #6982

Merged
merged 3 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def _build_tokenizer(self):
merges_file=self.register_artifact("tokenizer.merge_file", self._cfg.tokenizer.get('merge_file', None)),
use_fast=self.cfg.tokenizer.get('use_fast', False),
delimiter=self.cfg.tokenizer.get('delimiter', None),
special_tokens=self.cfg.tokenizer.get('special_tokens', None),
legacy=legacy,
)

Expand Down
44 changes: 33 additions & 11 deletions scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,19 @@
TP/PP values you want:
NeMo/examples/nlp/language_modeling/megatron_change_num_partitions.py

* Please note: when using the above script, you MUST also pass the `-–megatron_legacy` flag
Failure to do this will result in a corrupt model! *

This script also requires a baseline config file from which to override default parameters.
You can specify the location of this file using the -c argument. You can use any Nemo config
file which is appropriate, but in the default case, we highly recommend you use the following:
NeMo/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml


Here is an example usage command:

```python
python scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py -i /path/to/mpt_7b -o /path/to/save
python scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py -c /path/to/megatron_gpt_config.yaml -i /path/to/mpt_7b -o /path/to/save
```

"""
Expand All @@ -49,6 +57,7 @@

import pytorch_lightning as pl
import torch
import yaml
from omegaconf import OmegaConf

from nemo.collections.nlp.models.language_modeling.megatron import GPTModel
Expand All @@ -60,6 +69,9 @@
parser.add_argument(
'-i', '--input', required=True, type=str, help='path to the two MPT-7B .bin weight files from HuggingFace'
)
parser.add_argument(
'-c', '--config', required=True, type=str, help='the path to the megatron_gpt_config.yaml file'
)
parser.add_argument(
'-o', '--output', required=False, default=None, type=str, help='path to dir where to store output .nemo file'
)
Expand All @@ -71,22 +83,37 @@
logging.critical(f'Input directory [ {args.input} ] does not exist or cannot be found. Aborting.')
exit(255)

model_dict = {
'micro_batch_size': 4,
'global_batch_size': 8,
if not os.path.exists(args.config):
logging.critical(f'Path to config file [ {args.config} ] does not exist or cannot be found. Aborting.')
exit(255)

with open(args.config, 'r', encoding='utf_8') as fr:
orig_cfg = yaml.safe_load(fr)

model_dict = orig_cfg['model']
if 'tokenizer' in model_dict:
del model_dict['tokenizer']
if 'data' in model_dict:
del model_dict['data']

override_model_dict = {
'micro_batch_size': 1,
'global_batch_size': 4,
'rampup_batch_size': None,
'tensor_model_parallel_size': 1,
'pipeline_model_parallel_size': 1,
'virtual_pipeline_model_parallel_size': None,
'megatron_amp_O2': True,
'transformer_engine': False,
'use_cpu_initialization': True,
'use_cpu_initialization': False,
'hidden_size': 4096,
'encoder_seq_length': 2048,
'max_position_embeddings': 2048,
'num_layers': 32,
'num_attention_heads': 32,
'ffn_hidden_size': 4 * 4096,
'precision': 'bf16',
'layernorm_epsilon': 1e-5,
'pre_process': True,
'post_process': True,
'num_tokentypes': 0,
Expand Down Expand Up @@ -114,11 +141,6 @@
'type': 'EleutherAI/gpt-neox-20b',
'use_fast': True,
}
optim_dict = {
'name': 'fused_adam',
'lr': 2e-4,
'weight_decay': 0.01,
}
trainer_dict = {
'devices': 1,
'num_nodes': 1,
Expand All @@ -139,8 +161,8 @@
'enable_model_summary': False,
}

model_dict.update(override_model_dict)
model_dict['tokenizer'] = tokeniser_dict
model_dict['optim'] = optim_dict

omega_cfg = OmegaConf.create(model_dict)

Expand Down
Loading