Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Jul 9, 2024
1 parent 4b41bb5 commit 97cb5d1
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 7 deletions.
1 change: 0 additions & 1 deletion internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(self):
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.is_evaluating = False
self.recompute_forward_no_comm = False

@property
def config(self):
Expand Down
2 changes: 2 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ def args_sanity_check():
"torch.tf32",
]

gpc.config._add_item("recompute_forward_no_comm", False)

if "checkpoint" in model:
if "checkpoint_tp_no_comm" not in model:
gpc.config.model._add_item("checkpoint_tp_no_comm", True)
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
hidden_states = self.mlp(hidden_states)

# pad residual
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
if gpc.config.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
hidden_states = self.feed_forward(hidden_states)

# pad residual
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
if gpc.config.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, x):
else:
fussed_out = self.fused_w1_w3(x)
w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1)
out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.recompute_forward_no_comm)
out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.config.recompute_forward_no_comm)
return out


Expand Down
7 changes: 4 additions & 3 deletions internlm/solver/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def recompute_forward_context(args, no_communication):
handle = None
try:
# Set True when entering the context
if no_communication:
gpc.recompute_forward_no_comm = True
if no_communication and hasattr(gpc.config, "recompute_forward_no_comm"):
gpc.config.recompute_forward_no_comm = True
if is_using_sequence_parallel():
# overlap all_gather
grad_output = args[0]
Expand All @@ -58,7 +58,8 @@ def recompute_forward_context(args, no_communication):
yield
finally:
# Set False when exiting the context
gpc.recompute_forward_no_comm = False
if hasattr(gpc.config, "recompute_forward_no_comm"):
gpc.config.recompute_forward_no_comm = False

if handle:
handle.wait()
Expand Down

0 comments on commit 97cb5d1

Please sign in to comment.