Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(checkpoint): TP recomputation communication optimization #275

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/7B_internlm2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
JOB_NAME = "7b_internlm2_train"
model_type="INTERNLM2_PUBLIC"
model_type = "INTERNLM2_PUBLIC"
DO_ALERT = False

VOCAB_SIZE = 92544
Expand Down Expand Up @@ -128,6 +128,7 @@
use_fp32_norm = False
model = dict(
checkpoint=False,
# checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
Expand Down
1 change: 1 addition & 0 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
use_fp32_norm = False
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
# checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
Expand Down
41 changes: 28 additions & 13 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def input_hook(

@abstractmethod
def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False
self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for grad_output when backward.
Expand All @@ -81,7 +81,9 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
pass

@abstractmethod
def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
communication for output when forward.
"""
Expand Down Expand Up @@ -116,7 +118,10 @@ def input_hook(
return _input, DUMMY_HANDLE_CONST

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
tensor parallel should do nothing for grad_output.
Expand All @@ -132,11 +137,13 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T

return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op)

def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all reduce output only for row parallel linear when forward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
return output, DUMMY_HANDLE_CONST

return all_reduce_raw(output, process_group=self._process_group, async_op=async_op)
Expand Down Expand Up @@ -182,12 +189,12 @@ def input_hook(
return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False
self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather grad_output only for row parallel linear when backward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
return grad_output, DUMMY_HANDLE_CONST

return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM)
Expand All @@ -203,11 +210,13 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T
grad_input, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM
)

def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]:
def output_hook(
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
reduce scatter output only for row parallel linear when forward.
"""
if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN:
return output, DUMMY_HANDLE_CONST

return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM)
Expand All @@ -225,7 +234,10 @@ def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True)
self._retain_out_sharded = retain_out_sharded

def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -236,7 +248,7 @@ def grad_output_hook(
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST

def output_hook(
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down Expand Up @@ -266,7 +278,10 @@ def __init__(

# rewrite grad_output communication hook
def grad_output_hook(
self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self,
grad_output: torch.Tensor,
async_op: bool = False,
no_communication: bool = False, # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
split grad_output if retain_out_sharded is False.
Expand All @@ -278,7 +293,7 @@ def grad_output_hook(

# rewrite ouput communication hook
def output_hook(
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613
) -> Tuple[torch.Tensor, AsyncCommHandle]:
"""
all gather output for head layer if retain_out_sharded is False.
Expand Down
7 changes: 7 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,13 @@ def args_sanity_check():
]

if "checkpoint" in model:
if "checkpoint_tp_no_comm" not in model:
gpc.config.model._add_item("checkpoint_tp_no_comm", True)
if model.checkpoint is True:
model.checkpoint = 1
elif model.checkpoint is False:
model.checkpoint = 0
model.checkpoint_tp_no_comm = False
else:
assert (
model.checkpoint >= 0 and model.checkpoint <= 1
Expand Down Expand Up @@ -411,6 +414,10 @@ def args_sanity_check():
gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True
), "only support interleaved pipeline scheduler with overlap"

# when not use tp or sp, checkpoint_tp_no_comm should always be False
if gpc.config.parallel["tensor"]["size"] <= 1 and getattr(gpc.config.model, "checkpoint_tp_no_comm", False):
li126com marked this conversation as resolved.
Show resolved Hide resolved
gpc.config.model.checkpoint_tp_no_comm = False

# monitoring default config
monitor_default_config = {
"alert_address": None, # compatible with old alert config
Expand Down
3 changes: 3 additions & 0 deletions internlm/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module
kwargs["checkpoint"] = float(kwargs.get("checkpoint", False))
kwargs["device"] = get_current_device()

if "checkpoint_tp_no_comm" in kwargs:
kwargs.pop("checkpoint_tp_no_comm")

model_buidler = model_initializer.get_module(module_name=model_type)

if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE):
Expand Down
27 changes: 26 additions & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_output_attr_to_module
from internlm.core.parallel.comm.tensor import _GATHER_DIM
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
Expand All @@ -24,6 +25,7 @@
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -179,6 +181,8 @@ def _forward(self, hidden_states, *args, **kwargs):
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
indexes: the length of index is same as hidden states, which stand for the current position
"""
no_communication = args[4] if len(args) > 4 else False
li126com marked this conversation as resolved.
Show resolved Hide resolved
args = args[:4]

def _dropout_and_norm_attn(_hidden_states):
_dropped = self.dropout1(_hidden_states)
Expand Down Expand Up @@ -211,7 +215,28 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, no_communication=no_communication)

# pad residual
if no_communication and is_using_sequence_parallel() and not is_using_isp():
requires_grad = residual.requires_grad
pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
pad_after = (
gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1
) * residual.shape[_GATHER_DIM]

pad_before_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
pad_after_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)

residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad)

return hidden_states + residual

Expand Down
29 changes: 28 additions & 1 deletion internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.parallel.comm.tensor import _GATHER_DIM
from internlm.initialize.initialize_tensor import (
normal_,
scaled_init_method_normal,
Expand All @@ -24,6 +25,7 @@
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -216,6 +218,8 @@ def _forward(self, hidden_states, residual, *args, **kwargs):
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
indexes: the length of index is same as hidden states, which stand for the current position
"""
no_communication = args[4] if len(args) > 4 else False
args = args[:4]
if self.prenorm:

def _dropout_and_norm_attn(_residual, _hidden_states):
Expand Down Expand Up @@ -255,7 +259,30 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):

if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.feed_forward(hidden_states)
hidden_states = self.feed_forward(hidden_states, no_communication=no_communication)

# pad residual
if no_communication and is_using_sequence_parallel() and not is_using_isp():
requires_grad = residual.requires_grad
li126com marked this conversation as resolved.
Show resolved Hide resolved
pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
pad_after = (
gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1
) * residual.shape[_GATHER_DIM]

pad_before_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
pad_after_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)

residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(
requires_grad
)

return hidden_states + residual
else:
Expand Down
17 changes: 12 additions & 5 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def forward(
bias: Optional[torch.Tensor],
communicator: TPCommunicator,
return_residual=False,
no_communication=False,
):
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.communicator = communicator
ctx.no_communication = no_communication

if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
Expand Down Expand Up @@ -77,7 +79,7 @@ def forward(

# parallel strategy-specific communication callback 2.
# see more details in the communicator for different parallel strategies.
output, _ = communicator.output_hook(output, async_op=False)
output, _ = communicator.output_hook(output, async_op=False, no_communication=no_communication)

saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x
ctx.save_for_backward(saved_x, weight)
Expand All @@ -91,7 +93,9 @@ def backward(ctx, grad_output, *args):

# parallel strategy-specific communication callback 3.
# see more details in the communicator for different parallel strategies.
grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False)
grad_output, _ = communicator.grad_output_hook(
grad_output, no_communication=ctx.no_communication, async_op=False
)
grad_output = grad_output.contiguous()

if ctx.return_residual:
Expand Down Expand Up @@ -264,6 +268,7 @@ def fused_dense_func(
module: Optional[nn.Module] = None,
bias: Optional[torch.Tensor] = None,
return_residual: bool = False,
no_communication=False,
):
if communicator.communication_mode() == "wp":
return WPFusedDenseFunc.apply(
Expand All @@ -281,6 +286,7 @@ def fused_dense_func(
bias,
communicator,
return_residual,
no_communication,
)


Expand Down Expand Up @@ -343,16 +349,16 @@ def __init__(
else:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)

def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622
def forward(self, input: torch.Tensor, no_communication=False) -> torch.Tensor: # pylint: disable=W0622
SolenoidWGT marked this conversation as resolved.
Show resolved Hide resolved
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."

return fused_dense_func(
input,
self.weight,
communicator=self._communicator,
module=self,
bias=self.bias,
no_communication=no_communication,
)


Expand Down Expand Up @@ -465,7 +471,7 @@ def __init__(
self.first_eval_flag = True
self.tmp_weight = None

def forward(self, input): # pylint: disable=W0622
def forward(self, input, no_communication=False): # pylint: disable=W0622
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."

Expand Down Expand Up @@ -496,6 +502,7 @@ def forward(self, input): # pylint: disable=W0622
communicator=self._communicator,
module=self,
bias=self.bias,
no_communication=no_communication,
)


Expand Down
4 changes: 2 additions & 2 deletions internlm/model/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ def __init__(
self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype)
self.w3 = new_linear("w3", in_features, hidden_features, bias, device=device, dtype=dtype)

def forward(self, x):
def forward(self, x, no_communication=False):
if not self.mlp_layer_fusion:
w1_o = self.w1(x)
w3_o = self.w3(x)
else:
fussed_out = self.fused_w1_w3(x)
w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1)
out = self.w2(Silu(w1_o, w3_o))
out = self.w2(Silu(w1_o, w3_o), no_communication=no_communication)
return out


Expand Down
Loading
Loading