Skip to content

Commit

Permalink
change
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Jul 9, 2024
1 parent f8002fb commit 8951edb
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 36 deletions.
8 changes: 5 additions & 3 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,11 @@ def args_sanity_check():
), "only support interleaved pipeline scheduler with overlap"

# when not use tp or sp, checkpoint_tp_no_comm should always be False
if (gpc.config.parallel["tensor"]["mode"] == "isp" or gpc.config.parallel["tensor"]["size"] <= 1) and getattr(
gpc.config.model, "checkpoint_tp_no_comm", False
):
if (
gpc.config.parallel["tensor"]["mode"] == "isp"
or gpc.config.parallel["tensor"]["size"] <= 1
or gpc.config.model_type not in ["INTERNLM", "INTERNLM2_PUBLIC"]
) and getattr(gpc.config.model, "checkpoint_tp_no_comm", False):
gpc.config.model.checkpoint_tp_no_comm = False

# monitoring default config
Expand Down
6 changes: 2 additions & 4 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

no_communication = gpc.recompute_forward_no_comm

hidden_states = self.mlp(hidden_states, no_communication=no_communication)
hidden_states = self.mlp(hidden_states)

# pad residual
if no_communication and is_using_sequence_parallel():
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual
Expand Down
6 changes: 2 additions & 4 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

no_communication = gpc.recompute_forward_no_comm

hidden_states = self.feed_forward(hidden_states, no_communication=no_communication)
hidden_states = self.feed_forward(hidden_states)

# pad residual
if no_communication and is_using_sequence_parallel():
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual
Expand Down
5 changes: 3 additions & 2 deletions internlm/model/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch import nn

from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.modules.linear import new_linear
from internlm.model.modules.utils import Silu
from internlm.utils.logger import get_logger
Expand Down Expand Up @@ -91,14 +92,14 @@ def __init__(
self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype)
self.w3 = new_linear("w3", in_features, hidden_features, bias, device=device, dtype=dtype)

def forward(self, x, no_communication=False):
def forward(self, x):
if not self.mlp_layer_fusion:
w1_o = self.w1(x)
w3_o = self.w3(x)
else:
fussed_out = self.fused_w1_w3(x)
w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1)
out = self.w2(Silu(w1_o, w3_o), no_communication=no_communication)
out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.recompute_forward_no_comm)
return out


Expand Down
55 changes: 32 additions & 23 deletions internlm/solver/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-

import weakref
from contextlib import contextmanager

import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
Expand Down Expand Up @@ -41,6 +42,30 @@ def copy_to_device(obj, device):
return obj


@contextmanager
def recompute_forward_context(args, no_communication):
handle = None
try:
# Set True when entering the context
if no_communication:
gpc.recompute_forward_no_comm = True
if is_using_sequence_parallel():
# overlap all_gather
grad_output = args[0]
grad_output, handle = all_gather_raw(
grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM
)
yield
finally:
# Set False when exiting the context
gpc.recompute_forward_no_comm = False

if handle:
handle.wait()
args = list(args)
args[0] = grad_output


class CheckpointFunction(torch.autograd.Function):
"""
Checkpoint Function
Expand Down Expand Up @@ -132,29 +157,13 @@ def backward(ctx, *args):

detached_inputs = detach_variable(tuple(inputs))

handle = None
if no_communication:
gpc.recompute_forward_no_comm = True
if is_using_sequence_parallel():
grad_output = args[0]
grad_output, handle = all_gather_raw(
grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM
)

if ctx.had_autocast_in_fwd:
with torch.enable_grad(), internlm_accelerator.amp.autocast():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

if gpc.recompute_forward_no_comm:
gpc.recompute_forward_no_comm = False

if handle:
handle.wait()
args = list(args)
args[0] = grad_output
with recompute_forward_context(args, no_communication):
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), internlm_accelerator.amp.autocast():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
Expand Down

0 comments on commit 8951edb

Please sign in to comment.