diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index d5f124f4..b6db7b22 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -31,7 +31,7 @@ ) from internlm.core.context import global_context as gpc from internlm.core.context.random import set_mode -from internlm.core.naive_amp import NaiveAMPModel +from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module from internlm.core.trainer import TrainState from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn @@ -81,7 +81,17 @@ logger = get_logger(__file__) -def set_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): +def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + for _, module in _chunk.named_modules(): + if isinstance(module, (RMSNorm, nn.LayerNorm)) and gpc.config.model.get("use_fp32_norm", False): + set_fp32_attr_to_module(module) + + +def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): def _check_module(module): # layer_norm if isinstance(module, (RMSNorm, nn.LayerNorm)): @@ -111,6 +121,7 @@ def _check_module(module): if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model + # set param parallel attribute for name, module in _chunk.named_modules(): _check_module(module) @@ -124,7 +135,7 @@ def _check_module(module): @llm_timeout(func_name="initialize_model") -def initialize_model(): +def initialize_model(pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None): """ Initialize model with Automatic Mixed Precision. @@ -132,8 +143,15 @@ def initialize_model(): torch.nn.Module: The neural network model to be trained or evaluated. """ - + if pre_process_func: + pre_process_output = pre_process_func() model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) + if post_process_func: + post_process_func(pre_process_output) + + # should be set before NaiveAMPModel + set_fp32_attr_for_model(model) + if isinstance(model, nn.ModuleList): model = nn.ModuleList( [ @@ -154,7 +172,7 @@ def initialize_model(): sync_buffer=False, ) - set_attr_for_param_groups(model) + set_parallel_attr_for_param_groups(model) # This sync is very important, cause the model weights kept in optimizer are copied # from the origin parameters in the memory, so we should make sure the dp sync