Skip to content

Commit

Permalink
Merge pull request #4 from huangting4201/feat/isp-communicator-suppor…
Browse files Browse the repository at this point in the history
…t-0.x-activation-ckpt

feat(isp.py): isp communicator support 0.x activation ckpt
  • Loading branch information
huangting4201 authored Jan 29, 2024
2 parents 85dd51f + 8c45118 commit e74f2dd
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 40 deletions.
42 changes: 21 additions & 21 deletions internlm/core/communication/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch import distributed as dist
from torch import nn

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
Expand All @@ -26,6 +25,7 @@ class ISPCommModelConfig:

dtype: torch.dtype = torch.half
device: torch.device = torch.device("cuda")
activation_checkpointing: float = 0.0
module_shapes: Dict[str, torch.Size] = None


Expand Down Expand Up @@ -131,7 +131,8 @@ def __init__(self) -> None:
self.num_blocks: int = 0
self.embedding: List[nn.Module] = []
self.head: List[nn.Module] = []
self.last_block: nn.Moudle = None
self.ckpt_block_num: int = 0
self.last_ckpt_block: nn.Module = None
self.isp_outs: List[nn.Module] = []
self.isp_modules: List[nn.Module] = []
self.index_to_isp_module: Dict[int, nn.Module] = {}
Expand All @@ -152,12 +153,10 @@ def __init__(
model: Union[nn.Module, nn.ModuleList],
model_conf: ISPCommModelConfig,
overlap: bool = False,
activation_checkpointing: bool = False,
enable_memory_pool: bool = False,
process_group: dist.ProcessGroup = None,
) -> None:
self.process_group = process_group
self.model_checkpoint = activation_checkpointing
self.overlap = overlap
self.enable_memory_pool = overlap and enable_memory_pool
self.model_conf = model_conf
Expand All @@ -172,7 +171,8 @@ def __init__(
self._num_blocks = None
self._head = None
self._embedding = None
self._last_block = None
self._ckpt_block_num = None
self._last_ckpt_block = None
self._isp_outs = None
self._isp_modules = None
# key: isp module; value: module global all-gather op handle
Expand Down Expand Up @@ -222,7 +222,10 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
elif isinstance(children, Embedding1D):
self._overlap_states[cid].embedding.append(children)
elif isinstance(children, nn.ModuleList):
self._overlap_states[cid].last_block = children[-1]
self._overlap_states[cid].ckpt_block_num = int(self.model_conf.activation_checkpointing * len(children))
self._overlap_states[cid].last_ckpt_block = children[
max(0, self._overlap_states[cid].ckpt_block_num - 1)
]

for idx, block in enumerate(children):
self._overlap_states[cid].index_to_isp_module[idx] = []
Expand Down Expand Up @@ -335,7 +338,7 @@ def _post_forward_hook_for_embedding(self, *args): # pylint: disable=W0613
def _pre_forward_hook_for_out_proj(self, module: nn.Module, *args): # pylint: disable=W0613
block_index = self._module_to_index[module]

if self.model_checkpoint and self.is_forward is False:
if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False:
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)
else:
Expand All @@ -350,13 +353,12 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
self._wait_handle(module)

def _pre_forward_hook_for_block(self, *args): # pylint: disable=W0613
for module in self._index_to_isp_module[self._num_blocks - 1]:
for module in self._index_to_isp_module[self._ckpt_block_num - 1]:
self._all_gather_module_weight(module)
self._wait_handle(module)

def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
self._clear_handle(module)
if not (self.model_checkpoint and self.is_forward is False):
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
self._clear_weight(module)

def _post_backward_hook_for_head(self, *args): # pylint: disable=W0613
Expand All @@ -377,7 +379,8 @@ def _pre_backward_hook_for_module(self, module: nn.Module, *args): # pylint: di
module_index = self._isp_modules.index(module)
if module_index - 1 >= 0:
next_module = self._isp_modules[module_index - 1]
self._all_gather_module_weight(next_module)
if self._module_to_index[next_module] >= self._ckpt_block_num:
self._all_gather_module_weight(next_module)

def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W0613
self._clear_handle(module)
Expand All @@ -396,12 +399,8 @@ def _register_sync_parameters_hook(self) -> None:
for embedding in self._embedding:
embedding.register_forward_hook(self._post_forward_hook_for_embedding)

if self.model_checkpoint:
if gpc.is_last_rank(parallel_mode=ParallelMode.PIPELINE):
for head in self._head:
head.register_full_backward_pre_hook(self._pre_backward_hook_for_head)
else:
self._last_block.register_forward_pre_hook(self._pre_forward_hook_for_block)
if self._ckpt_block_num >= 1:
self._last_ckpt_block.register_forward_pre_hook(self._pre_forward_hook_for_block)

for out_proj in self._isp_outs:
out_proj.register_forward_pre_hook(self._pre_forward_hook_for_out_proj)
Expand All @@ -414,7 +413,7 @@ def _register_sync_parameters_hook(self) -> None:
# 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @isp_module to release resource
if not self.model_checkpoint:
if self._ckpt_block_num < self._num_blocks:
for head in self._head:
head.register_full_backward_hook(self._post_backward_hook_for_head)

Expand Down Expand Up @@ -443,7 +442,8 @@ def switch_current_model_chunk(self, chunk_id: int) -> None:
self._bias_global_output = self._overlap_states[chunk_id].bias_global_output
self._module_to_index = self._overlap_states[chunk_id].module_to_index
self._index_to_isp_module = self._overlap_states[chunk_id].index_to_isp_module
self._last_block = self._overlap_states[chunk_id].last_block
self._ckpt_block_num = self._overlap_states[chunk_id].ckpt_block_num
self._last_ckpt_block = self._overlap_states[chunk_id].last_ckpt_block
self._head = self._overlap_states[chunk_id].head
self._embedding = self._overlap_states[chunk_id].embedding
self._num_blocks = self._overlap_states[chunk_id].num_blocks
Expand Down Expand Up @@ -514,7 +514,7 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None:
self._zero_optim = zero_optim

def before_forward(self, scheduler, inputs) -> None:
if self._isp_communicator.model_checkpoint:
if self._isp_communicator._ckpt_block_num > 0:
self._isp_communicator.is_forward = True
# switch model chunk before forward
chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank
Expand All @@ -530,7 +530,7 @@ def after_criterion(self, scheduler, loss) -> None:
pass

def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if self._isp_communicator.model_checkpoint:
if self._isp_communicator._ckpt_block_num > 0:
self._isp_communicator.is_forward = False
# switch model chunk before backward
chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank
Expand Down
2 changes: 2 additions & 0 deletions internlm/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
get_scheduler_hooks,
get_train_data_loader,
get_validation_data_loader,
initialize_isp_communicator,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
Expand All @@ -17,6 +18,7 @@
"get_validation_data_loader",
"initialize_llm_profile",
"initialize_model",
"initialize_isp_communicator",
"initialize_optimizer",
"load_new_batch",
"record_current_batch_training_metrics",
Expand Down
49 changes: 31 additions & 18 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,24 +216,7 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f
# if fsdp enabled, wrap the model
model = wrap_FSDP_model(model)

if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
isp_communicator = None
else:
isp_communicator = ISPCommunicator(
model,
ISPCommModelConfig(
gpc.config.model.dtype,
get_current_device(),
),
gpc.config.parallel.weight.overlap,
gpc.config.model.checkpoint,
gpc.config.parallel.weight.memory_pool,
gpc.get_group(ParallelMode.WEIGHT),
)
# register communicator for isp linear.
ISPLinear.register_communicator(isp_communicator)

return model, isp_communicator
return model


def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
Expand Down Expand Up @@ -269,6 +252,36 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
return model


def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize communicator for isp tensor parallel mode.
Args:
model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated.
Returns:
An isp communicator for managing comp/comm overlap and memory pool.
"""
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
isp_communicator = None
else:
isp_communicator = ISPCommunicator(
model,
ISPCommModelConfig(
gpc.config.model.dtype,
get_current_device(),
gpc.config.model.checkpoint,
),
gpc.config.parallel.weight.overlap,
gpc.config.parallel.weight.memory_pool,
gpc.get_group(ParallelMode.WEIGHT),
)
# register communicator for isp linear.
ISPLinear.register_communicator(isp_communicator)

return isp_communicator


@llm_timeout(func_name="initialize_optimizer")
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None):
"""
Expand Down
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_scheduler_hooks,
get_train_data_loader,
get_validation_data_loader,
initialize_isp_communicator,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
Expand Down Expand Up @@ -96,7 +97,10 @@ def main(args):
uniscale_logger = initialize_llm_logger(start_time=current_time)

# initialize model
model, isp_communicator = initialize_model()
model = initialize_model()

# initialize isp communicator
isp_communicator = initialize_isp_communicator(model)

with open(args.config, "r") as f:
config_lines = f.readlines()
Expand Down

0 comments on commit e74f2dd

Please sign in to comment.