Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/refactor partition strategy #13

Merged
Merged
Changes from 4 commits
Commits
Show all changes
191 commits
Select commit Hold shift + click to select a range
10aa63f
support optimized sp
yingtongxiong Oct 7, 2023
e5a2909
Merge remote-tracking branch 'upstream/develop' into feat/deepspeed_sp
yingtongxiong Oct 7, 2023
bf475b6
debug
yingtongxiong Oct 8, 2023
bd4af3a
modify the all2all
yingtongxiong Oct 8, 2023
189a313
support fstp and refactor code
yingtongxiong Oct 9, 2023
21c1a7f
support evaluation with fstp
yingtongxiong Oct 9, 2023
949431f
modify the config
yingtongxiong Oct 9, 2023
0fa1083
Merge remote-tracking branch 'upstream/develop' into feat/fstp
yingtongxiong Oct 9, 2023
54e5616
remove useless code for no-pp
yingtongxiong Oct 9, 2023
144731c
fix evaluation bug in pp
yingtongxiong Oct 9, 2023
ef9e7cc
modify the config
yingtongxiong Oct 9, 2023
5d39c33
restore train.py
yingtongxiong Oct 9, 2023
29df765
refactor code
yingtongxiong Oct 9, 2023
f191853
fix lint
yingtongxiong Oct 9, 2023
007e58a
merge upstream develop
yingtongxiong Oct 9, 2023
a8dea63
fix the ci incompatible in config
yingtongxiong Oct 9, 2023
1b7935d
merge upstream develop
yingtongxiong Oct 9, 2023
dd67ab9
merge develop
yingtongxiong Oct 9, 2023
db63754
fix lint
yingtongxiong Oct 9, 2023
5fb6d99
feat(configs/7B_sft.py): update parallel config comment
huangting4201 Oct 10, 2023
0fac845
overlap grad_input computation and grad_weight reduce_scatter
yingtongxiong Oct 10, 2023
c94be64
merge origin
yingtongxiong Oct 10, 2023
792b066
communication overlap
yingtongxiong Oct 11, 2023
5fd5a8a
support fine-grained overlap
yingtongxiong Oct 11, 2023
d0b1346
feat(model/linear.py): support block allgather overlap
huangting4201 Oct 12, 2023
d0f0c22
feat(model/linear.py): change pre backward from wqkv to block
huangting4201 Oct 13, 2023
82204ee
support hybrid overlap
yingtongxiong Oct 16, 2023
0d1fa03
feat(model/linear.py): set block 0 full weight
huangting4201 Oct 16, 2023
d1af0d6
feat(model/linear.py): block-grained backward
huangting4201 Oct 17, 2023
229cc5c
impl reduce scatter async
Oct 17, 2023
4e99a7f
feat(train/training_internlm.py): remove abnormal tgs when calculatin…
huangting4201 Oct 17, 2023
6682f5d
fix reduce scatter async bug
Oct 17, 2023
b51cf4e
Merge branch 'feat/fstp' of github.com:yingtongxiong/InternLM into fe…
Oct 17, 2023
6408b94
support fine grained
yingtongxiong Oct 17, 2023
a5c6e45
Merge branch 'feat/fstp' of https://github.com/yingtongxiong/InternLM…
yingtongxiong Oct 17, 2023
5c38cb6
add head overlap
yingtongxiong Oct 17, 2023
5abe519
remove full weight for block 0
yingtongxiong Oct 17, 2023
16ef7b7
add test
yingtongxiong Oct 17, 2023
a5aeab2
memory profiling test
yingtongxiong Oct 17, 2023
4742271
add memory pool
yingtongxiong Oct 19, 2023
ed72327
support reduce scatter memory pool
yingtongxiong Oct 20, 2023
815a584
feat(model/linear.py): remove useless code
huangting4201 Oct 20, 2023
95488d8
update optimizer accumulate grad impl when fstp
Oct 20, 2023
d91a5d9
feat(initialize/launch.py): refactor config for fstp
huangting4201 Oct 20, 2023
3c69254
feat(optimizer/hybrid_zero_optim.py): resolve conflicts
huangting4201 Oct 20, 2023
eac382a
feat(optimizer/hybrid_zero_optim.py): fix lint error
huangting4201 Oct 20, 2023
2acf9b8
feat(utils/gputest.py): fix lint error
huangting4201 Oct 20, 2023
f22e5b3
Merge pull request #4 from yingtongxiong/fstp/refactor-config
yingtongxiong Oct 20, 2023
dcd89ed
refactor linear
yingtongxiong Oct 20, 2023
1804d01
merge reduce-scatter
yingtongxiong Oct 20, 2023
85ad917
feat(model/overlap_handler.py): refactor overlap hook handle
huangting4201 Oct 20, 2023
b20f47a
feat(model/overlap_handler.py): move handler to gpc
huangting4201 Oct 23, 2023
e7f9f1d
feat(model/overlap_handler.py): optimize reduce scatter mem pool
huangting4201 Oct 23, 2023
f6a5086
support bias
yingtongxiong Oct 23, 2023
0d693cf
feat(model/overlap_handler.py): fix lint error
huangting4201 Oct 23, 2023
03cc7f9
feat(model/overlap_handler.py): fix lint error
huangting4201 Oct 23, 2023
9cf1ff0
feat(solver/optimizer/hybrid_zero_optim.py): minor update
huangting4201 Oct 23, 2023
b2c1a70
feat(train/training_internlm.py): fix lint error
huangting4201 Oct 23, 2023
b48687a
Merge pull request #5 from yingtongxiong/fstp/refactor-hook-handle
huangting4201 Oct 23, 2023
0996c47
fix accumulate grads bug
Oct 23, 2023
97dcefc
support model activation checkpoint
yingtongxiong Oct 24, 2023
5d83136
feat(model/overlap_handler.py): fix head post backward hook when acti…
huangting4201 Oct 24, 2023
262de4b
support tflops computation and generate test py files
yingtongxiong Oct 24, 2023
0d3592a
Merge branch 'feat/fstp_refactor' of https://github.com/yingtongxiong…
yingtongxiong Oct 24, 2023
41cfa1a
feat(model/overlap_handler.py): fix overlap handler None bug
huangting4201 Oct 24, 2023
0bac166
add test
yingtongxiong Oct 25, 2023
918dff7
reset moe
yingtongxiong Oct 25, 2023
363275b
add memory print
yingtongxiong Oct 25, 2023
985465c
merge upstream
yingtongxiong Oct 25, 2023
cc20fa2
reset print memory
yingtongxiong Oct 25, 2023
d831ddc
modify the config
yingtongxiong Oct 26, 2023
1aae39b
Merge remote-tracking branch 'upstream/develop' into feat/fstp_refactor
yingtongxiong Oct 26, 2023
cbd4f04
add synchronize
yingtongxiong Oct 26, 2023
3253cbf
add a new get_tflops_func
mwiacx Oct 26, 2023
4d83e10
Merge branch 'feat/fstp_refactor' of https://github.com/yingtongxiong…
yingtongxiong Oct 26, 2023
8aefb74
add flash tflops
yingtongxiong Oct 26, 2023
aa3840f
fix some bugs
yingtongxiong Oct 26, 2023
3778c66
feat(model/overlap_handler.py): fix overlap hander to support pp(non-…
huangting4201 Oct 27, 2023
bc5a85c
Merge pull request #6 from yingtongxiong/fstp/overlap-support-pp
yingtongxiong Oct 27, 2023
4c1cd5d
fix async reduce scatter
mwiacx Oct 31, 2023
6b84325
fix(optimizer/hybrid_zero_optim.py): remove redundant _accum_grad_buc…
huangting4201 Oct 31, 2023
b3def4c
fix(optimizer/hybrid_zero_optim.py): add reduce_scatter_overlap switch
huangting4201 Oct 31, 2023
10b5056
fix all-gather overlap the model_checkpoint is 0
yingtongxiong Nov 1, 2023
4851291
fix(optimizer/hybrid_zero_optim.py): fix bucket size full judge condi…
huangting4201 Nov 2, 2023
5a18b3b
fix(model/overlap_handler.py): fix last block hook when pp with activ…
huangting4201 Nov 2, 2023
9b1265c
modify the sp allreduce and support tf32 for fstp linear
yingtongxiong Nov 6, 2023
c517ec5
feat(model/overlap_handler.py): delete reduce_scatter_overlap switch
huangting4201 Nov 6, 2023
7c6d293
reset the sp allreduce in optimizer
yingtongxiong Nov 6, 2023
b80e6cd
merge origin
yingtongxiong Nov 6, 2023
b5e4d04
fix conflicts
yingtongxiong Nov 6, 2023
7475439
feat(model/overlap_handler.py): add memory_pool switch and refactor o…
huangting4201 Nov 13, 2023
3c07423
feat(model/overlap_handler.py): release weight
huangting4201 Nov 14, 2023
a1fd877
fix(train.py): clear memory pool before optim step
huangting4201 Nov 15, 2023
a80fcf8
feat(model): refactor weight and os and data patition strategy
huangting4201 Nov 28, 2023
cab9abd
fix(training_internlm.py): fix loss accuracy(optim init and seed set)
huangting4201 Nov 29, 2023
d3ee3ef
fix(model): reset embedding and head
huangting4201 Nov 30, 2023
6cd271c
fix(model): fix process group error
huangting4201 Dec 1, 2023
0817b8c
fix(model): fix FSTP linear Torch process group
huangting4201 Dec 1, 2023
1b7d2dc
fix(overlap_handler.py): release module post backward when model ckpt is
huangting4201 Dec 7, 2023
fd5a144
feat(model): embedding and head use sp group and refactor parameter g…
huangting4201 Dec 11, 2023
ac72710
feat(model): modify grad norm compute func
huangting4201 Dec 12, 2023
76be8c2
fix(model/utils.py): fix fstp linear reduce scatter sum->avg
huangting4201 Dec 14, 2023
d30aecd
feat(core/context): support pp for initializing isp/msp/fsp process g…
huangting4201 Dec 19, 2023
e9cd521
feat(model): refactor model and optimizer for msp/fsp/isp
huangting4201 Dec 20, 2023
e0cafb0
fix(overlap_handler.py): fix hook error and param group split
huangting4201 Dec 21, 2023
7974a32
fix(overlap_handler.py): fix clear weight error when activation ckpt …
huangting4201 Dec 22, 2023
3361350
fix(parallel_context.py): fix seed mode when TENSOR parallel
huangting4201 Dec 25, 2023
9b22258
feat(*) refactor fstp handler
mwiacx Dec 26, 2023
8e3196b
Merge branch 'feat/refactor-partition-strategy' into feat/refactor-fs…
mwiacx Dec 26, 2023
fe6fed7
feat(*): fix bug
mwiacx Dec 28, 2023
a80fbe3
fix(train/utils.py): fix zp size cheak and embed_param group
huangting4201 Jan 12, 2024
c01f015
Merge branch 'feat/refactor-partition-strategy' into feat/refactor-fs…
mwiacx Jan 12, 2024
1aebcd9
fix(model/util): force to pass communictor
mwiacx Jan 12, 2024
917ab0d
fix(model/utils.py): fix param set
huangting4201 Jan 12, 2024
b77787f
fix(hybrid_zero_optim.py): fix reduce scatter error when wp_size=1
huangting4201 Jan 12, 2024
594d61d
feat(model_checkpoint.py): model and optimizer save/load ckpt adapt t…
huangting4201 Jan 15, 2024
d87d9f9
Merge pull request #9 from yingtongxiong/feat/refactor-fstp-handler
huangting4201 Jan 15, 2024
e4d1ff8
fix(model_checkpoint.py): fix dp/zo size check
huangting4201 Jan 16, 2024
f2f88a7
support sequence parallel for moe
blankde Dec 27, 2023
6e012b1
modify expert groups
blankde Jan 17, 2024
18e6e78
feat(isp): support interleaved pipeline parallel scheduler
mwiacx Jan 17, 2024
55ebba0
add moe group
blankde Jan 17, 2024
ab039d5
fix(isp.py): fix comment
huangting4201 Jan 17, 2024
c113443
Merge pull request #1 from huangting4201/feat/support-interleaved-pp-…
huangting4201 Jan 17, 2024
8347ab4
feat(model): remove useless debug print
huangting4201 Jan 17, 2024
7ed1109
feat(model): fix lint error
huangting4201 Jan 17, 2024
ba254e3
merge huangting/feat/refactor-partition-strategy
blankde Jan 18, 2024
ccc2108
refactor code
blankde Jan 18, 2024
71543b3
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 18, 2024
fac2b20
refactor code
blankde Jan 18, 2024
a83a94f
merge huangting/feat/refactor-partition-strategy
blankde Jan 18, 2024
05fa04a
feat(multi_head_attention.py): set bias=True
huangting4201 Jan 19, 2024
91bd3f9
fix bugs
blankde Jan 19, 2024
20f6b36
support moe checkpoint
blankde Jan 19, 2024
7cdeea8
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
f959781
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
e873668
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
b99a642
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
7ac53bf
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
bb5835e
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
d5872e7
fix(tests): fix ci test error
huangting4201 Jan 22, 2024
18da3fc
Merge branch 'feat/refactor-partition-strategy' of https://github.com…
blankde Jan 22, 2024
0aebd2c
update moe config file
blankde Jan 22, 2024
15610f6
adapt grad profiling
JiaoPL Jan 22, 2024
b007c43
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 22, 2024
c8b100e
fix(communication/isp.py): fix bias switch for mem pool
huangting4201 Jan 22, 2024
e1676f0
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 22, 2024
c606bb5
fix(model/utils.py): fix boolean value ambiguous error
huangting4201 Jan 22, 2024
d646f91
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 22, 2024
70a17d6
test grad profiling with mtp,msp,fsp,isp
JiaoPL Jan 22, 2024
4e9b276
feat(training_internlm.py): update initialize_model func to adapt to …
huangting4201 Jan 22, 2024
32df5ad
feat(training_internlm.py): move get_scheduler_hooks from train.py to…
huangting4201 Jan 22, 2024
d388ddc
feat(model): fix dict has no attri mode error
huangting4201 Jan 23, 2024
8e1b619
feat(training_internlm.py): move use_fp32_norm config to gpc.config
huangting4201 Jan 23, 2024
fd349f1
Merge branch 'feat/refactor-partition-strategy' of https://github.com…
blankde Jan 23, 2024
978cea8
feat(version): update internevo version and torch verion
huangting4201 Jan 24, 2024
d5fe8fe
feat(context/parallel_context.py): set default parallel size in paral…
huangting4201 Jan 24, 2024
0c8e0cf
Merge pull request #2 from blankde/feat/support_moe_for_isp
huangting4201 Jan 24, 2024
1d64a22
feat(format): fix ci lint check error
huangting4201 Jan 24, 2024
b0c6a20
feat(format): fix ci lint check error
huangting4201 Jan 24, 2024
571d83c
feat(format): fix ci lint check error
huangting4201 Jan 24, 2024
83517ca
feat(evaluation.py): fix evaluation error when msp/fsp with pp
huangting4201 Jan 24, 2024
48aca7f
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 25, 2024
0ec9b67
fix moe param groups
JiaoPL Jan 25, 2024
aa388b5
modify the distributedAttention for different data pack mode
yingtongxiong Jan 25, 2024
34b9479
feat(model/multi_head_attention.py): fix return output
huangting4201 Jan 25, 2024
10309b8
feat(utils/evaluation.py): rename gpc.evaluation to gpc.is_evaluating
huangting4201 Jan 25, 2024
4c8324a
feat(multi_head_attention.py): rename gpc.evaluation to gpc.is_evalua…
huangting4201 Jan 25, 2024
1b64785
Merge pull request #3 from JiaoPL/feat/adapt_grad_norm_profiling_for_…
huangting4201 Jan 25, 2024
f186a75
feat(communication/isp.py): refactor isp communicator to adapt to dif…
huangting4201 Jan 25, 2024
f064880
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 26, 2024
3d7402d
fix(tests): fix ci test error
huangting4201 Jan 26, 2024
8170641
fix(tests): fix ci pipeline test error
huangting4201 Jan 26, 2024
85dd51f
feat(utils/common.py): remove func get_megatron_flops_2
huangting4201 Jan 26, 2024
971c8eb
feat(communication/isp.py): isp communicator support 0.x activation ckpt
huangting4201 Jan 29, 2024
6853bab
feat(train/training_internlm.py): move isp init to func initialize_is…
huangting4201 Jan 29, 2024
8c45118
feat(communication/isp.py): fix prefetch last ckpt block wait handle
huangting4201 Jan 29, 2024
e74f2dd
Merge pull request #4 from huangting4201/feat/isp-communicator-suppor…
huangting4201 Jan 29, 2024
011edcf
feat(utils/parallel.py): add func is_using_isp
huangting4201 Jan 29, 2024
f02523e
fix(tests): fix ci tests error
huangting4201 Jan 29, 2024
23ab67f
feat(model/modeling_llama.py): update model llama
huangting4201 Jan 30, 2024
f11422e
feat(model/utils.py): simplify code
huangting4201 Jan 30, 2024
4a27957
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 30, 2024
8e1ee6f
feat(model/linear.py): update FeedForward class to internlm2
huangting4201 Jan 30, 2024
b5f9ada
Merge pull request #5 from huangting4201/feat/support-feedforwardv2-ckpt
huangting4201 Jan 30, 2024
d7928a6
fix(parallel_context.py): fix private repo ci tests error
huangting4201 Jan 30, 2024
1960dc0
feat(parallel_context.py): set zero1 parallel size >= 1
huangting4201 Jan 30, 2024
52ace84
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 30, 2024
62a665d
feat(tests): add e2e test case for isp and enable pytorch expandable_…
huangting4201 Jan 31, 2024
e91acb4
feat(doc): update doc torch and flashattn version
huangting4201 Jan 31, 2024
2e4f749
Merge branch 'develop' into feat/refactor-partition-strategy
sunpengsdu Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions internlm/core/communication/isp.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
from torch import distributed as dist
from torch import nn

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
@@ -26,6 +25,7 @@ class ISPCommModelConfig:

dtype: torch.dtype = torch.half
device: torch.device = torch.device("cuda")
activation_checkpointing: float = 0.0
module_shapes: Dict[str, torch.Size] = None


@@ -131,7 +131,8 @@ def __init__(self) -> None:
self.num_blocks: int = 0
self.embedding: List[nn.Module] = []
self.head: List[nn.Module] = []
self.last_block: nn.Moudle = None
self.ckpt_block_num: int = 0
self.last_ckpt_block: nn.Module = None
self.isp_outs: List[nn.Module] = []
self.isp_modules: List[nn.Module] = []
self.index_to_isp_module: Dict[int, nn.Module] = {}
@@ -152,12 +153,10 @@ def __init__(
model: Union[nn.Module, nn.ModuleList],
model_conf: ISPCommModelConfig,
overlap: bool = False,
activation_checkpointing: bool = False,
enable_memory_pool: bool = False,
process_group: dist.ProcessGroup = None,
) -> None:
self.process_group = process_group
self.model_checkpoint = activation_checkpointing
self.overlap = overlap
self.enable_memory_pool = overlap and enable_memory_pool
self.model_conf = model_conf
@@ -172,7 +171,8 @@ def __init__(
self._num_blocks = None
self._head = None
self._embedding = None
self._last_block = None
self._ckpt_block_num = None
self._last_ckpt_block = None
self._isp_outs = None
self._isp_modules = None
# key: isp module; value: module global all-gather op handle
@@ -222,7 +222,10 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
elif isinstance(children, Embedding1D):
self._overlap_states[cid].embedding.append(children)
elif isinstance(children, nn.ModuleList):
self._overlap_states[cid].last_block = children[-1]
self._overlap_states[cid].ckpt_block_num = int(self.model_conf.activation_checkpointing * len(children))
self._overlap_states[cid].last_ckpt_block = children[
max(0, self._overlap_states[cid].ckpt_block_num - 1)
]

for idx, block in enumerate(children):
self._overlap_states[cid].index_to_isp_module[idx] = []
@@ -335,7 +338,7 @@ def _post_forward_hook_for_embedding(self, *args): # pylint: disable=W0613
def _pre_forward_hook_for_out_proj(self, module: nn.Module, *args): # pylint: disable=W0613
block_index = self._module_to_index[module]

if self.model_checkpoint and self.is_forward is False:
if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False:
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)
else:
@@ -350,13 +353,12 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
self._wait_handle(module)

def _pre_forward_hook_for_block(self, *args): # pylint: disable=W0613
for module in self._index_to_isp_module[self._num_blocks - 1]:
for module in self._index_to_isp_module[self._ckpt_block_num - 1]:
self._all_gather_module_weight(module)
self._wait_handle(module)

def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
self._clear_handle(module)
if not (self.model_checkpoint and self.is_forward is False):
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
self._clear_weight(module)

def _post_backward_hook_for_head(self, *args): # pylint: disable=W0613
@@ -377,7 +379,8 @@ def _pre_backward_hook_for_module(self, module: nn.Module, *args): # pylint: di
module_index = self._isp_modules.index(module)
if module_index - 1 >= 0:
next_module = self._isp_modules[module_index - 1]
self._all_gather_module_weight(next_module)
if self._module_to_index[next_module] >= self._ckpt_block_num:
self._all_gather_module_weight(next_module)

def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W0613
self._clear_handle(module)
@@ -396,12 +399,8 @@ def _register_sync_parameters_hook(self) -> None:
for embedding in self._embedding:
embedding.register_forward_hook(self._post_forward_hook_for_embedding)

if self.model_checkpoint:
if gpc.is_last_rank(parallel_mode=ParallelMode.PIPELINE):
for head in self._head:
head.register_full_backward_pre_hook(self._pre_backward_hook_for_head)
else:
self._last_block.register_forward_pre_hook(self._pre_forward_hook_for_block)
if self._ckpt_block_num >= 1:
self._last_ckpt_block.register_forward_pre_hook(self._pre_forward_hook_for_block)

for out_proj in self._isp_outs:
out_proj.register_forward_pre_hook(self._pre_forward_hook_for_out_proj)
@@ -414,7 +413,7 @@ def _register_sync_parameters_hook(self) -> None:
# 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @isp_module to release resource
if not self.model_checkpoint:
if self._ckpt_block_num < self._num_blocks:
for head in self._head:
head.register_full_backward_hook(self._post_backward_hook_for_head)

@@ -443,7 +442,8 @@ def switch_current_model_chunk(self, chunk_id: int) -> None:
self._bias_global_output = self._overlap_states[chunk_id].bias_global_output
self._module_to_index = self._overlap_states[chunk_id].module_to_index
self._index_to_isp_module = self._overlap_states[chunk_id].index_to_isp_module
self._last_block = self._overlap_states[chunk_id].last_block
self._ckpt_block_num = self._overlap_states[chunk_id].ckpt_block_num
self._last_ckpt_block = self._overlap_states[chunk_id].last_ckpt_block
self._head = self._overlap_states[chunk_id].head
self._embedding = self._overlap_states[chunk_id].embedding
self._num_blocks = self._overlap_states[chunk_id].num_blocks
@@ -514,7 +514,7 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None:
self._zero_optim = zero_optim

def before_forward(self, scheduler, inputs) -> None:
if self._isp_communicator.model_checkpoint:
if self._isp_communicator._ckpt_block_num > 0:
self._isp_communicator.is_forward = True
# switch model chunk before forward
chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank
@@ -530,7 +530,7 @@ def after_criterion(self, scheduler, loss) -> None:
pass

def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if self._isp_communicator.model_checkpoint:
if self._isp_communicator._ckpt_block_num > 0:
self._isp_communicator.is_forward = False
# switch model chunk before backward
chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank
2 changes: 2 additions & 0 deletions internlm/train/__init__.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
get_scheduler_hooks,
get_train_data_loader,
get_validation_data_loader,
initialize_isp_communicator,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
@@ -17,6 +18,7 @@
"get_validation_data_loader",
"initialize_llm_profile",
"initialize_model",
"initialize_isp_communicator",
"initialize_optimizer",
"load_new_batch",
"record_current_batch_training_metrics",
49 changes: 31 additions & 18 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
@@ -216,24 +216,7 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f
# if fsdp enabled, wrap the model
model = wrap_FSDP_model(model)

if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
isp_communicator = None
else:
isp_communicator = ISPCommunicator(
model,
ISPCommModelConfig(
gpc.config.model.dtype,
get_current_device(),
),
gpc.config.parallel.weight.overlap,
gpc.config.model.checkpoint,
gpc.config.parallel.weight.memory_pool,
gpc.get_group(ParallelMode.WEIGHT),
)
# register communicator for isp linear.
ISPLinear.register_communicator(isp_communicator)

return model, isp_communicator
return model


def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
@@ -269,6 +252,36 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
return model


def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize communicator for isp tensor parallel mode.

Args:
model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated.

Returns:
An isp communicator for managing comp/comm overlap and memory pool.
"""
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
isp_communicator = None
else:
isp_communicator = ISPCommunicator(
model,
ISPCommModelConfig(
gpc.config.model.dtype,
get_current_device(),
gpc.config.model.checkpoint,
),
gpc.config.parallel.weight.overlap,
gpc.config.parallel.weight.memory_pool,
gpc.get_group(ParallelMode.WEIGHT),
)
# register communicator for isp linear.
ISPLinear.register_communicator(isp_communicator)

return isp_communicator


@llm_timeout(func_name="initialize_optimizer")
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None):
"""
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
get_scheduler_hooks,
get_train_data_loader,
get_validation_data_loader,
initialize_isp_communicator,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
@@ -96,7 +97,10 @@ def main(args):
uniscale_logger = initialize_llm_logger(start_time=current_time)

# initialize model
model, isp_communicator = initialize_model()
model = initialize_model()

# initialize isp communicator
isp_communicator = initialize_isp_communicator(model)

with open(args.config, "r") as f:
config_lines = f.readlines()
Loading