Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(checkpoint): TP recomputation communication optimization #275

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/7B_internlm2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
JOB_NAME = "7b_internlm2_train"
model_type="INTERNLM2_PUBLIC"
model_type = "INTERNLM2_PUBLIC"
DO_ALERT = False

VOCAB_SIZE = 92544
Expand Down Expand Up @@ -128,6 +128,7 @@
use_fp32_norm = False
model = dict(
checkpoint=False,
# checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
Expand Down
1 change: 1 addition & 0 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
use_fp32_norm = False
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
# checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
Expand Down
1 change: 1 addition & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(self):
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.is_evaluating = False
self.recompute_forward_no_comm = False

@property
def config(self):
Expand Down
77 changes: 62 additions & 15 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def input_hook(

@abstractmethod
def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False
self,
grad_output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for grad_output when backward.
Expand All @@ -81,7 +83,11 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
pass

@abstractmethod
def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self,
output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for output when forward.
"""
Expand All @@ -93,13 +99,14 @@ class TensorParallelCommunicator(TPCommunicator):
tensor parallel communicator for linear
"""

def __init__(self, process_group: dist.ProcessGroup, role: LinearRole) -> None:
def __init__(self, process_group: dist.ProcessGroup, role: LinearRole, last_block_layer=False) -> None:
assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}"

self._process_group = process_group
self._role = role

self._save_total_input = False
self.last_block_layer = last_block_layer

def save_total_input(self) -> bool:
return self._save_total_input
Expand All @@ -116,7 +123,9 @@ def input_hook(
return _input, DUMMY_HANDLE_CONST

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
tensor parallel should do nothing for grad_output.
Expand All @@ -132,11 +141,19 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T

return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op)

def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self,
output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all reduce output only for row parallel linear when forward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if (
(self.last_block_layer and gpc.recompute_forward_no_comm)
or dist.get_world_size(self._process_group) <= 1
or self._role == LinearRole.COLUMN
):
return output, DUMMY_HANDLE_CONST

return all_reduce_raw(output, process_group=self._process_group, async_op=async_op)
Expand All @@ -148,14 +165,20 @@ class SequenceParallelCommunicator(TPCommunicator):
"""

def __init__(
self, process_group: dist.ProcessGroup, role: LinearRole, save_total_input_as_activation: bool = False
self,
process_group: dist.ProcessGroup,
role: LinearRole,
save_total_input_as_activation: bool = False,
last_block_layer=False,
) -> None:
assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}"

self._process_group = process_group
self._role = role

self._save_total_input = save_total_input_as_activation
self.last_block_layer = last_block_layer
self.no_communication = False

def save_total_input(self) -> bool:
return self._save_total_input
Expand All @@ -182,12 +205,19 @@ def input_hook(
return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False
self,
grad_output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather grad_output only for row parallel linear when backward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if (
(self.last_block_layer and self.no_communication)
or dist.get_world_size(self._process_group) <= 1
or self._role == LinearRole.COLUMN
):
self.no_communication = False
return grad_output, DUMMY_HANDLE_CONST

return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)
Expand All @@ -203,11 +233,20 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
grad_input, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM
)

def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self,
output: torch.Tensor,
async_op: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
reduce scatter output only for row parallel linear when forward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
self.no_communication = gpc.recompute_forward_no_comm
if (
(self.last_block_layer and self.no_communication)
or dist.get_world_size(self._process_group) <= 1
or self._role == LinearRole.COLUMN
):
return output, DUMMY_HANDLE_CONST

return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM)
Expand All @@ -225,7 +264,9 @@ def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True)
self._retain_out_sharded = retain_out_sharded

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -236,7 +277,9 @@ def grad_output_hook(
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST

def output_hook(
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
output: torch.Tensor,
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down Expand Up @@ -266,7 +309,9 @@ def __init__(

# rewrite grad_output communication hook
def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -278,7 +323,9 @@ def grad_output_hook(

# rewrite ouput communication hook
def output_hook(
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
output: torch.Tensor,
async_op: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down
11 changes: 11 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,13 @@ def args_sanity_check():
]

if "checkpoint" in model:
if "checkpoint_tp_no_comm" not in model:
gpc.config.model._add_item("checkpoint_tp_no_comm", True)
if model.checkpoint is True:
model.checkpoint = 1
elif model.checkpoint is False:
model.checkpoint = 0
model.checkpoint_tp_no_comm = False
else:
assert (
model.checkpoint >= 0 and model.checkpoint <= 1
Expand Down Expand Up @@ -411,6 +414,14 @@ def args_sanity_check():
gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True
), "only support interleaved pipeline scheduler with overlap"

# when not use tp or sp, checkpoint_tp_no_comm should always be False
if (
gpc.config.parallel["tensor"]["mode"] == "isp"
or gpc.config.parallel["tensor"]["size"] <= 1
or gpc.config.model_type not in ["INTERNLM", "INTERNLM2_PUBLIC"]
) and getattr(gpc.config.model, "checkpoint_tp_no_comm", False):
gpc.config.model.checkpoint_tp_no_comm = False

# monitoring default config
monitor_default_config = {
"alert_address": None, # compatible with old alert config
Expand Down
3 changes: 3 additions & 0 deletions internlm/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module
kwargs["checkpoint"] = float(kwargs.get("checkpoint", False))
kwargs["device"] = get_current_device()

if "checkpoint_tp_no_comm" in kwargs:
kwargs.pop("checkpoint_tp_no_comm")

model_buidler = model_initializer.get_module(module_name=model_type)

if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE):
Expand Down
6 changes: 6 additions & 0 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
internlm1_mha_save_convert,
padding_residual,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -213,6 +215,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):

hidden_states = self.mlp(hidden_states)

# pad residual
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual


Expand Down
7 changes: 7 additions & 0 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
padding_residual,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -255,8 +257,13 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):

if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.feed_forward(hidden_states)

# pad residual
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual
else:
assert residual is None
Expand Down
6 changes: 4 additions & 2 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ def __init__(
def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."

return fused_dense_func(
input,
self.weight,
Expand Down Expand Up @@ -589,14 +588,17 @@ def new_linear(
dtype,
)
elif split_mode == "row":
return RowParallelLinear(
linear = RowParallelLinear(
in_features,
out_features,
bias,
multiple_of,
device,
dtype,
)
if name == "w2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样会不会很hardcode?

setattr(linear, "last_block_layer", True)
return linear
else:
err_msg = (
f"Parallel strategies for linear is unsupported, which is named as {name}.\n"
Expand Down
21 changes: 21 additions & 0 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any, Dict, List

import torch

from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.modules.mha import MHA


Expand Down Expand Up @@ -51,3 +55,20 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]:
kwargs["max_seqlen"] = args[3]

return kwargs


def padding_residual(residual):
requires_grad = residual.requires_grad
_GATHER_DIM = 1
total_size = gpc.get_world_size(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
zero_padding_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], total_size, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
start_idx = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
end_idx = start_idx + residual.shape[_GATHER_DIM]
zero_padding_tensor[:, start_idx:end_idx, :] = residual
residual = zero_padding_tensor.requires_grad_(requires_grad)

return residual
Loading
Loading