diff --git a/internlm/core/communication/isp.py b/internlm/core/communication/isp.py index fd917b77..bd744d10 100644 --- a/internlm/core/communication/isp.py +++ b/internlm/core/communication/isp.py @@ -9,7 +9,6 @@ from torch import distributed as dist from torch import nn -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.model.embedding import Embedding1D @@ -26,6 +25,7 @@ class ISPCommModelConfig: dtype: torch.dtype = torch.half device: torch.device = torch.device("cuda") + activation_checkpointing: float = 0.0 module_shapes: Dict[str, torch.Size] = None @@ -131,7 +131,8 @@ def __init__(self) -> None: self.num_blocks: int = 0 self.embedding: List[nn.Module] = [] self.head: List[nn.Module] = [] - self.last_block: nn.Moudle = None + self.ckpt_block_num: int = 0 + self.last_ckpt_block: nn.Module = None self.isp_outs: List[nn.Module] = [] self.isp_modules: List[nn.Module] = [] self.index_to_isp_module: Dict[int, nn.Module] = {} @@ -152,12 +153,10 @@ def __init__( model: Union[nn.Module, nn.ModuleList], model_conf: ISPCommModelConfig, overlap: bool = False, - activation_checkpointing: bool = False, enable_memory_pool: bool = False, process_group: dist.ProcessGroup = None, ) -> None: self.process_group = process_group - self.model_checkpoint = activation_checkpointing self.overlap = overlap self.enable_memory_pool = overlap and enable_memory_pool self.model_conf = model_conf @@ -172,7 +171,8 @@ def __init__( self._num_blocks = None self._head = None self._embedding = None - self._last_block = None + self._ckpt_block_num = None + self._last_ckpt_block = None self._isp_outs = None self._isp_modules = None # key: isp module; value: module global all-gather op handle @@ -222,7 +222,10 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None: elif isinstance(children, Embedding1D): self._overlap_states[cid].embedding.append(children) elif isinstance(children, nn.ModuleList): - self._overlap_states[cid].last_block = children[-1] + self._overlap_states[cid].ckpt_block_num = int(self.model_conf.activation_checkpointing * len(children)) + self._overlap_states[cid].last_ckpt_block = children[ + max(0, self._overlap_states[cid].ckpt_block_num - 1) + ] for idx, block in enumerate(children): self._overlap_states[cid].index_to_isp_module[idx] = [] @@ -335,7 +338,7 @@ def _post_forward_hook_for_embedding(self, *args): # pylint: disable=W0613 def _pre_forward_hook_for_out_proj(self, module: nn.Module, *args): # pylint: disable=W0613 block_index = self._module_to_index[module] - if self.model_checkpoint and self.is_forward is False: + if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False: if block_index - 1 >= 0: self._all_gather_block_weight(block_index - 1) else: @@ -350,13 +353,12 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis self._wait_handle(module) def _pre_forward_hook_for_block(self, *args): # pylint: disable=W0613 - for module in self._index_to_isp_module[self._num_blocks - 1]: + for module in self._index_to_isp_module[self._ckpt_block_num - 1]: self._all_gather_module_weight(module) - self._wait_handle(module) def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 self._clear_handle(module) - if not (self.model_checkpoint and self.is_forward is False): + if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False): self._clear_weight(module) def _post_backward_hook_for_head(self, *args): # pylint: disable=W0613 @@ -377,7 +379,8 @@ def _pre_backward_hook_for_module(self, module: nn.Module, *args): # pylint: di module_index = self._isp_modules.index(module) if module_index - 1 >= 0: next_module = self._isp_modules[module_index - 1] - self._all_gather_module_weight(next_module) + if self._module_to_index[next_module] >= self._ckpt_block_num: + self._all_gather_module_weight(next_module) def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W0613 self._clear_handle(module) @@ -396,12 +399,8 @@ def _register_sync_parameters_hook(self) -> None: for embedding in self._embedding: embedding.register_forward_hook(self._post_forward_hook_for_embedding) - if self.model_checkpoint: - if gpc.is_last_rank(parallel_mode=ParallelMode.PIPELINE): - for head in self._head: - head.register_full_backward_pre_hook(self._pre_backward_hook_for_head) - else: - self._last_block.register_forward_pre_hook(self._pre_forward_hook_for_block) + if self._ckpt_block_num >= 1: + self._last_ckpt_block.register_forward_pre_hook(self._pre_forward_hook_for_block) for out_proj in self._isp_outs: out_proj.register_forward_pre_hook(self._pre_forward_hook_for_out_proj) @@ -414,7 +413,7 @@ def _register_sync_parameters_hook(self) -> None: # 1. register post_backward_hook @head module to prefetch for the last block's last module # 2. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module # 3. register post_backward_hook @isp_module to release resource - if not self.model_checkpoint: + if self._ckpt_block_num < self._num_blocks: for head in self._head: head.register_full_backward_hook(self._post_backward_hook_for_head) @@ -443,7 +442,8 @@ def switch_current_model_chunk(self, chunk_id: int) -> None: self._bias_global_output = self._overlap_states[chunk_id].bias_global_output self._module_to_index = self._overlap_states[chunk_id].module_to_index self._index_to_isp_module = self._overlap_states[chunk_id].index_to_isp_module - self._last_block = self._overlap_states[chunk_id].last_block + self._ckpt_block_num = self._overlap_states[chunk_id].ckpt_block_num + self._last_ckpt_block = self._overlap_states[chunk_id].last_ckpt_block self._head = self._overlap_states[chunk_id].head self._embedding = self._overlap_states[chunk_id].embedding self._num_blocks = self._overlap_states[chunk_id].num_blocks @@ -514,7 +514,7 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None: self._zero_optim = zero_optim def before_forward(self, scheduler, inputs) -> None: - if self._isp_communicator.model_checkpoint: + if self._isp_communicator._ckpt_block_num > 0: self._isp_communicator.is_forward = True # switch model chunk before forward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank @@ -530,7 +530,7 @@ def after_criterion(self, scheduler, loss) -> None: pass def before_backward(self, scheduler, outputs, outputs_grad) -> None: - if self._isp_communicator.model_checkpoint: + if self._isp_communicator._ckpt_block_num > 0: self._isp_communicator.is_forward = False # switch model chunk before backward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index 9a70e1e2..d44eaec9 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -2,6 +2,7 @@ get_scheduler_hooks, get_train_data_loader, get_validation_data_loader, + initialize_isp_communicator, initialize_llm_profile, initialize_model, initialize_optimizer, @@ -17,6 +18,7 @@ "get_validation_data_loader", "initialize_llm_profile", "initialize_model", + "initialize_isp_communicator", "initialize_optimizer", "load_new_batch", "record_current_batch_training_metrics", diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 2ca66be5..62a9d060 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -216,24 +216,7 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) - if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": - isp_communicator = None - else: - isp_communicator = ISPCommunicator( - model, - ISPCommModelConfig( - gpc.config.model.dtype, - get_current_device(), - ), - gpc.config.parallel.weight.overlap, - gpc.config.model.checkpoint, - gpc.config.parallel.weight.memory_pool, - gpc.get_group(ParallelMode.WEIGHT), - ) - # register communicator for isp linear. - ISPLinear.register_communicator(isp_communicator) - - return model, isp_communicator + return model def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): @@ -269,6 +252,36 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): return model +def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): + """ + Initialize communicator for isp tensor parallel mode. + + Args: + model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated. + + Returns: + An isp communicator for managing comp/comm overlap and memory pool. + """ + if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": + isp_communicator = None + else: + isp_communicator = ISPCommunicator( + model, + ISPCommModelConfig( + gpc.config.model.dtype, + get_current_device(), + gpc.config.model.checkpoint, + ), + gpc.config.parallel.weight.overlap, + gpc.config.parallel.weight.memory_pool, + gpc.get_group(ParallelMode.WEIGHT), + ) + # register communicator for isp linear. + ISPLinear.register_communicator(isp_communicator) + + return isp_communicator + + @llm_timeout(func_name="initialize_optimizer") def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None): """ diff --git a/train.py b/train.py index 9620268d..150f5463 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,7 @@ get_scheduler_hooks, get_train_data_loader, get_validation_data_loader, + initialize_isp_communicator, initialize_llm_profile, initialize_model, initialize_optimizer, @@ -96,7 +97,10 @@ def main(args): uniscale_logger = initialize_llm_logger(start_time=current_time) # initialize model - model, isp_communicator = initialize_model() + model = initialize_model() + + # initialize isp communicator + isp_communicator = initialize_isp_communicator(model) with open(args.config, "r") as f: config_lines = f.readlines()