Skip to content

Commit

Permalink
feat(training_internlm.py): update initialize_model func to adapt to …
Browse files Browse the repository at this point in the history
…private repo
  • Loading branch information
huangting4201 committed Jan 22, 2024
1 parent c606bb5 commit 4e9b276
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from internlm.core.context import global_context as gpc
from internlm.core.context.random import set_mode
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
Expand Down Expand Up @@ -81,7 +81,17 @@
logger = get_logger(__file__)


def set_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]):
def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]):
if not isinstance(model, nn.ModuleList):
model = [model]

for _chunk in model:
for _, module in _chunk.named_modules():
if isinstance(module, (RMSNorm, nn.LayerNorm)) and gpc.config.model.get("use_fp32_norm", False):
set_fp32_attr_to_module(module)


def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]):
def _check_module(module):
# layer_norm
if isinstance(module, (RMSNorm, nn.LayerNorm)):
Expand Down Expand Up @@ -111,6 +121,7 @@ def _check_module(module):
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model

# set param parallel attribute
for name, module in _chunk.named_modules():
_check_module(module)

Expand All @@ -124,16 +135,23 @@ def _check_module(module):


@llm_timeout(func_name="initialize_model")
def initialize_model():
def initialize_model(pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None):
"""
Initialize model with Automatic Mixed Precision.
Returns:
torch.nn.Module:
The neural network model to be trained or evaluated.
"""

if pre_process_func:
pre_process_output = pre_process_func()
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
if post_process_func:
post_process_func(pre_process_output)

# should be set before NaiveAMPModel
set_fp32_attr_for_model(model)

if isinstance(model, nn.ModuleList):
model = nn.ModuleList(
[
Expand All @@ -154,7 +172,7 @@ def initialize_model():
sync_buffer=False,
)

set_attr_for_param_groups(model)
set_parallel_attr_for_param_groups(model)

# This sync is very important, cause the model weights kept in optimizer are copied
# from the origin parameters in the memory, so we should make sure the dp sync
Expand Down

0 comments on commit 4e9b276

Please sign in to comment.