Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(checkpoint): TP recomputation communication optimization #275

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/7B_internlm2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
JOB_NAME = "7b_internlm2_train"
model_type="INTERNLM2_PUBLIC"
model_type = "INTERNLM2_PUBLIC"
DO_ALERT = False

VOCAB_SIZE = 92544
Expand Down Expand Up @@ -128,6 +128,7 @@
use_fp32_norm = False
model = dict(
checkpoint=False,
# checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
Expand Down
1 change: 1 addition & 0 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
use_fp32_norm = False
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
# checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
Expand Down
1 change: 1 addition & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(self):
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.is_evaluating = False
self.recompute_forward_no_comm = False

@property
def config(self):
Expand Down
41 changes: 28 additions & 13 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def input_hook(

@abstractmethod
def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False
self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for grad_output when backward.
Expand All @@ -81,7 +81,9 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
pass

@abstractmethod
def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for output when forward.
"""
Expand Down Expand Up @@ -116,7 +118,10 @@ def input_hook(
return _input, DUMMY_HANDLE_CONST

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
tensor parallel should do nothing for grad_output.
Expand All @@ -132,11 +137,13 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T

return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op)

def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all reduce output only for row parallel linear when forward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
return output, DUMMY_HANDLE_CONST

return all_reduce_raw(output, process_group=self._process_group, async_op=async_op)
Expand Down Expand Up @@ -182,12 +189,12 @@ def input_hook(
return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False
self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather grad_output only for row parallel linear when backward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
return grad_output, DUMMY_HANDLE_CONST

return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)
Expand All @@ -203,11 +210,13 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
grad_input, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM
)

def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
reduce scatter output only for row parallel linear when forward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
return output, DUMMY_HANDLE_CONST

return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM)
Expand All @@ -225,7 +234,10 @@ def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True)
self._retain_out_sharded = retain_out_sharded

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -236,7 +248,7 @@ def grad_output_hook(
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST

def output_hook(
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down Expand Up @@ -266,7 +278,10 @@ def __init__(

# rewrite grad_output communication hook
def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -278,7 +293,7 @@ def grad_output_hook(

# rewrite ouput communication hook
def output_hook(
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down
11 changes: 11 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,13 @@ def args_sanity_check():
]

if "checkpoint" in model:
if "checkpoint_tp_no_comm" not in model:
gpc.config.model._add_item("checkpoint_tp_no_comm", True)
if model.checkpoint is True:
model.checkpoint = 1
elif model.checkpoint is False:
model.checkpoint = 0
model.checkpoint_tp_no_comm = False
else:
assert (
model.checkpoint >= 0 and model.checkpoint <= 1
Expand Down Expand Up @@ -411,6 +414,14 @@ def args_sanity_check():
gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True
), "only support interleaved pipeline scheduler with overlap"

# when not use tp or sp, checkpoint_tp_no_comm should always be False
if (
gpc.config.parallel["tensor"]["mode"] == "isp"
or gpc.config.parallel["tensor"]["size"] <= 1
or gpc.config.model_type not in ["INTERNLM", "INTERNLM2_PUBLIC"]
) and getattr(gpc.config.model, "checkpoint_tp_no_comm", False):
gpc.config.model.checkpoint_tp_no_comm = False

# monitoring default config
monitor_default_config = {
"alert_address": None, # compatible with old alert config
Expand Down
3 changes: 3 additions & 0 deletions internlm/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module
kwargs["checkpoint"] = float(kwargs.get("checkpoint", False))
kwargs["device"] = get_current_device()

if "checkpoint_tp_no_comm" in kwargs:
kwargs.pop("checkpoint_tp_no_comm")

model_buidler = model_initializer.get_module(module_name=model_type)

if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE):
Expand Down
6 changes: 6 additions & 0 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
internlm1_mha_save_convert,
padding_residual,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -213,6 +215,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):

hidden_states = self.mlp(hidden_states)

# pad residual
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual


Expand Down
7 changes: 7 additions & 0 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
padding_residual,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -255,8 +257,13 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):

if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.feed_forward(hidden_states)

# pad residual
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual
else:
assert residual is None
Expand Down
17 changes: 12 additions & 5 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def forward(
bias: Optional[torch.Tensor],
communicator: TPCommunicator,
return_residual=False,
no_communication=False,
):
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.communicator = communicator
ctx.no_communication = no_communication

if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
Expand Down Expand Up @@ -77,7 +79,7 @@ def forward(

# parallel strategy-specific communication callback 2.
# see more details in the communicator for different parallel strategies.
output, _ = communicator.output_hook(output, async_op=False)
output, _ = communicator.output_hook(output, async_op=False, no_communication=no_communication)

saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x
ctx.save_for_backward(saved_x, weight)
Expand All @@ -91,7 +93,9 @@ def backward(ctx, grad_output, *args):

# parallel strategy-specific communication callback 3.
# see more details in the communicator for different parallel strategies.
grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False)
grad_output, _ = communicator.grad_output_hook(
grad_output, no_communication=ctx.no_communication, async_op=False
)
grad_output = grad_output.contiguous()

if ctx.return_residual:
Expand Down Expand Up @@ -264,6 +268,7 @@ def fused_dense_func(
module: Optional[nn.Module] = None,
bias: Optional[torch.Tensor] = None,
return_residual: bool = False,
no_communication=False,
):
if communicator.communication_mode() == "wp":
return WPFusedDenseFunc.apply(
Expand All @@ -281,6 +286,7 @@ def fused_dense_func(
bias,
communicator,
return_residual,
no_communication,
)


Expand Down Expand Up @@ -343,16 +349,16 @@ def __init__(
else:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)

def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622
def forward(self, input: torch.Tensor, no_communication=False) -> torch.Tensor: # pylint: disable=W0622
SolenoidWGT marked this conversation as resolved.
Show resolved Hide resolved
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."

return fused_dense_func(
input,
self.weight,
communicator=self._communicator,
module=self,
bias=self.bias,
no_communication=no_communication,
)


Expand Down Expand Up @@ -465,7 +471,7 @@ def __init__(
self.first_eval_flag = True
self.tmp_weight = None

def forward(self, input): # pylint: disable=W0622
def forward(self, input, no_communication=False): # pylint: disable=W0622
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."

Expand Down Expand Up @@ -496,6 +502,7 @@ def forward(self, input): # pylint: disable=W0622
communicator=self._communicator,
module=self,
bias=self.bias,
no_communication=no_communication,
)


Expand Down
3 changes: 2 additions & 1 deletion internlm/model/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch import nn

from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.modules.linear import new_linear
from internlm.model.modules.utils import Silu
from internlm.utils.logger import get_logger
Expand Down Expand Up @@ -98,7 +99,7 @@ def forward(self, x):
else:
fussed_out = self.fused_w1_w3(x)
w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1)
out = self.w2(Silu(w1_o, w3_o))
out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.recompute_forward_no_comm)
return out


Expand Down
27 changes: 27 additions & 0 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Any, Dict, List

import torch

from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.parallel.comm.tensor import _GATHER_DIM
from internlm.model.modules.mha import MHA


Expand Down Expand Up @@ -51,3 +56,25 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]:
kwargs["max_seqlen"] = args[3]

return kwargs


def padding_residual(residual):
requires_grad = residual.requires_grad
pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
pad_after = (
gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1
) * residual.shape[_GATHER_DIM]

pad_before_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
pad_after_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad)

return residual
Loading
Loading