Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>
  • Loading branch information
hsiehjackson committed May 8, 2024
1 parent 76f341e commit fe02284
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ exp_manager:

model:
model_parent: 'gpt' # gpt, t5, bert
model_architecture: 'transformer'
model_name: 'gpt' # gpt, t5, bert
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
micro_batch_size: 4 # limited by GPU memory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nemo.utils import logging

try:
from megatron.core import parallel_state
from megatron.core import ModelParallelConfig, parallel_state
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import init_method_normal, scaled_init_method_normal
from megatron.core.models.gpt import GPTModel
Expand Down Expand Up @@ -63,8 +63,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

MegatronBaseModel.__init__(self, cfg, trainer)
self.model_parent = cfg.get('model_parent')
self.model_architecture = cfg.get('model_architecture')
self.model_name = cfg.get('model_name')
self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)
self.general_config = self.build_general_config()
self.architecture_config = self.build_architecture_config()
self.model_config = self.build_model_config()

Expand Down Expand Up @@ -93,72 +95,88 @@ def list_available_models(self):
return None

def model_provider_func(self, pre_process, post_process, add_encoder=False, add_decoder=False):
config = self.architecture_config
config.__dict__.update(self.general_config.__dict__)
match self.model_name:
case 'gpt':
model = GPTModel(
config=self.architecture_config,
config=config,
pre_process=pre_process,
post_process=post_process,
**self.model_config,
)
case 't5':
model = T5Model(
config=self.architecture_config,
config=config,
pre_process=pre_process,
post_process=post_process,
**self.model_config,
)
case 'bert':
model = BertModel(
config=self.architecture_config,
config=config,
pre_process=pre_process,
post_process=post_process,
**self.model_config,
)
case _:
raise NotImplementedError(f'{self.model_name} model architecture is not supported')
return model
raise NotImplementedError(f'{self.model_name} model name is not supported')

def build_architecture_config(self):
# create a dictionary copy of the model config
return model

def build_general_config(self):
cfg_cli = OmegaConf.to_container(self.cfg, resolve=True)
cfg_cli_convert = {
# general config
'fp16': self.torch_dtype == torch.float16 and self.megatron_amp_O2,
'bf16': self.torch_dtype == torch.bfloat16 and self.megatron_amp_O2,
'params_dtype': self.torch_dtype if self.torch_dtype in [torch.bfloat16, torch.float16] and self.megatron_amp_O2 else torch.float32,
'timers': self.megatron_timers,
'async_tensor_model_parallel_allreduce': cfg_cli.get('tensor_model_parallel_world_size', 1) > 1 and not cfg_cli.get('sequence_parallel', False),
'pipeline_dtype': self.torch_dtype,
'grad_scale_func': self.trainer.precision_plugin.scaler.scale if self.trainer.precision in ["16", "16-mixed"] else None,
'grad_scale_func': self.trainer.precision_plugin.scaler.scale if self.trainer.precision in [16, "16", "16-mixed"] else None,
'enable_autocast': self.torch_dtype in [torch.bfloat16, torch.float16] and not self.megatron_amp_O2,
'autocast_dtype': self.torch_dtype,
# transformer config
'activation_func': activation_to_func(cfg_cli.get('activation', 'gelu')),
'gated_linear_unit': cfg_cli.get('activation', 'gelu').endswith('glu'),
'init_method': init_method_normal(cfg_cli.get('init_method_std', 0.02)),
'output_layer_init_method': scaled_init_method_normal(cfg_cli.get('init_method_std', 0.02), num_layers=cfg_cli.get('num_layers', 1)) \
if cfg_cli.get('use_scaled_init_method', True) else init_method_normal(cfg_cli.get('init_method_std', 0.02)),
'apply_query_key_layer_scaling': cfg_cli.get('apply_query_key_layer_scaling', False) and self.trainer.precision in [16, '16', '16-mixed'],
'attention_softmax_in_fp32': cfg_cli.get('attention_softmax_in_fp32', True)or cfg_cli.get('apply_query_key_layer_scaling', False) ,
}

# create a dict to store the architecture config arguments
architecture_config_dict = {}
for field in fields(TransformerConfig):
config_dict = {}
for field in fields(ModelParallelConfig):
if field.name in cfg_cli:
architecture_config_dict[field.name] = cfg_cli[field.name]
config_dict[field.name] = cfg_cli[field.name]
elif field.name in cfg_cli_convert:
architecture_config_dict[field.name] = cfg_cli_convert[field.name]
config_dict[field.name] = cfg_cli_convert[field.name]
else:
logging.warning(
f"The model: {self} does not have the argument: {field.name} in its cfg. "
f"Add this key to cfg to make to make it configurable."
)
logging.warning(f"[General] {self} does not have the general argument: {field.name} in its cfg.")

architecture_config = TransformerConfig(**architecture_config_dict)
return architecture_config
return ModelParallelConfig(**config_dict)

def build_architecture_config(self):
# create a dictionary copy of the model config
cfg_cli = OmegaConf.to_container(self.cfg, resolve=True)
match self.model_architecture:
case 'transformer':
cfg_cli_convert = {
'activation_func': activation_to_func(cfg_cli.get('activation', 'gelu')),
'gated_linear_unit': cfg_cli.get('activation', 'gelu').endswith('glu'),
'init_method': init_method_normal(cfg_cli.get('init_method_std', 0.02)),
'output_layer_init_method': scaled_init_method_normal(cfg_cli.get('init_method_std', 0.02), num_layers=cfg_cli.get('num_layers', 1)) \
if cfg_cli.get('use_scaled_init_method', True) else init_method_normal(cfg_cli.get('init_method_std', 0.02)),
'apply_query_key_layer_scaling': cfg_cli.get('apply_query_key_layer_scaling', False) and self.trainer.precision in [16, '16', '16-mixed'],
'attention_softmax_in_fp32': cfg_cli.get('attention_softmax_in_fp32', True)or cfg_cli.get('apply_query_key_layer_scaling', False) ,
}
architecture_config = TransformerConfig
case _:
raise NotImplementedError(f'{self.model_architecture} model architecture is not supported')

config_dict = {}
for field in fields(architecture_config):
if field.name in cfg_cli:
config_dict[field.name] = cfg_cli[field.name]
elif field.name in cfg_cli_convert:
config_dict[field.name] = cfg_cli_convert[field.name]
else:
logging.warning(f"[Architecture] {self} does not have the argument: {field.name} in its cfg.")

return architecture_config(**config_dict)

def build_model_config(self):
# create a dictionary copy of the model config
cfg = OmegaConf.to_container(self.cfg, resolve=True)
Expand Down Expand Up @@ -209,15 +227,15 @@ def build_model_config(self):
'add_binary_head': True,
'num_tokentypes': 0 if cfg.get('add_binary_head') == False else 2,
}

case _:
raise NotImplementedError(f'{self.model_name} model name is not supported')

for key in model_config:
if key in cfg:
model_config[key] = cfg[key]
else:
logging.warning(
f"The model: {self} does not have the argument: {key} in its cfg. "
f"Add this key to cfg to make to make it configurable."
)
logging.warning(f"[Model] {self} does not have the argument: {key} in its cfg.")

return model_config

def build_transformer_config(self):
Expand Down

0 comments on commit fe02284

Please sign in to comment.