From b54f31b45b673c2fd0aaab250d84bab76d094ab4 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 1 Mar 2024 15:03:07 +0800 Subject: [PATCH 1/2] feat(model/linear.py): support norm head for model internlm2 --- internlm/core/communication/isp.py | 10 ++-- internlm/model/linear.py | 70 ++++++++++++++++++++++++---- internlm/model/modeling_internlm.py | 11 ++--- internlm/model/modeling_internlm2.py | 13 ++---- internlm/model/modeling_llama.py | 12 ++--- internlm/model/modeling_moe.py | 12 ++--- 6 files changed, 80 insertions(+), 48 deletions(-) diff --git a/internlm/core/communication/isp.py b/internlm/core/communication/isp.py index c72c0c63..e7e30bc7 100644 --- a/internlm/core/communication/isp.py +++ b/internlm/core/communication/isp.py @@ -12,7 +12,7 @@ from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.model.embedding import Embedding1D -from internlm.model.linear import ISPLinear, ScaleColumnParallelLinear +from internlm.model.linear import BaseScaleColumnParallelLinear, ISPLinear from internlm.model.utils import all_gather_raw, reduce_scatter_raw from internlm.utils.common import SchedulerHook @@ -220,7 +220,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None: # Important: only works for llama-class models for _, children in model.named_children(): - if isinstance(children, ScaleColumnParallelLinear): + if isinstance(children, BaseScaleColumnParallelLinear): setattr(children, "isp_name", "head") self._overlap_states[cid].head.append(children) elif isinstance(children, Embedding1D): @@ -534,9 +534,9 @@ def reduce_scatter( self.process_group, op=op, async_op=True, - memory_pool_allocator=self.memory_pool.allocate_reduce_scatter_memory - if self.enable_memory_pool - else None, + memory_pool_allocator=( + self.memory_pool.allocate_reduce_scatter_memory if self.enable_memory_pool else None + ), ) result, handle = ( diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 9ce91632..ad353778 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -16,6 +16,9 @@ isp_fused_dense_func, megatron_fused_dense_func, ) +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) class BaseScaleColumnParallelLinear(nn.Linear): @@ -59,7 +62,7 @@ class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): ScaleColumnParallelLinear in flash implementation. """ - def forward(self, input, gather_dim=0): # pylint: disable=W0622 + def forward(self, input, gather_dim=0, tp_mode: str = "mtp"): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. @@ -67,7 +70,9 @@ def forward(self, input, gather_dim=0): # pylint: disable=W0622 weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() else: weight = self.weight - return fused_dense_func( + + _fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func + return _fused_func( input, weight, self.bias, @@ -77,20 +82,65 @@ def forward(self, input, gather_dim=0): # pylint: disable=W0622 ) -class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear): +class InternLM2ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): """ - ScaleColumnParallelLinear in megatron implementation. + ScaleColumnParallelLinear for InternLM2. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. """ - def forward(self, input, gather_dim=0): # pylint: disable=W0622 - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. + def __init__( + self, + in_features: int, + out_features: int, + process_group: Optional[torch.distributed.ProcessGroup], + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_scale: int = 1, + norm_head: bool = False, + ) -> None: + super().__init__( + in_features, out_features, process_group, bias=bias, device=device, dtype=dtype, weight_scale=weight_scale + ) + + self.norm_head = norm_head + if self.norm_head: + logger.info("Notice norm head is used.") + self.first_eval_flag = True + self.tmp_weight = None + + def forward(self, input, gather_dim=0, tp_mode: str = "mtp"): # pylint: disable=W0622 if self.weight_scale != 1: weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() else: weight = self.weight - return megatron_fused_dense_func( + if self.norm_head: + if self.training: + if not self.first_eval_flag: + self.first_eval_flag = True + self.tmp_weight = None + # We normalized the output Embedding so that the dot product + # is not affected by the norm of embedding. Ref: https://arxiv.org/pdf/2309.10305.pdf + weight = nn.functional.normalize(weight) + else: + if self.first_eval_flag: + # cache l2 norm of head to accelerate infer. + self.first_eval_flag = False + self.tmp_weight = nn.functional.normalize(weight) + + weight = self.tmp_weight + + _fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func + return _fused_func( input, weight, self.bias, @@ -100,7 +150,7 @@ def forward(self, input, gather_dim=0): # pylint: disable=W0622 ) -class RewardModelLinear(ScaleColumnParallelLinear): +class RewardModelLinear(BaseScaleColumnParallelLinear): """ RewardModelLinear. Args: diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 1e11b445..85b6da33 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -15,7 +15,6 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D from internlm.model.linear import ( - MegatronScaleColumnParallelLinear, RewardModelLinear, ScaleColumnParallelLinear, get_mlp_cls, @@ -311,11 +310,7 @@ def __init__( if is_reward: head_cls = RewardModelLinear else: - head_cls = ( - ScaleColumnParallelLinear - if self.tp_mode in ["mtp", "fsp", "isp"] - else MegatronScaleColumnParallelLinear - ) + head_cls = ScaleColumnParallelLinear if first: if embed_split_hidden: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -422,9 +417,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N if hasattr(self, "head"): # Evaluation if hidden_states.ndim == 3: - hidden_states = self.head(hidden_states, gather_dim=1) + hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode) else: # Training - hidden_states = self.head(hidden_states, gather_dim=0) + hidden_states = self.head(hidden_states, gather_dim=0, tp_mode=self.tp_mode) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index ad3ae057..beee22dc 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -32,9 +32,8 @@ RotaryEmbedding, ) from internlm.model.linear import ( - MegatronScaleColumnParallelLinear, + InternLM2ScaleColumnParallelLinear, RewardModelLinear, - ScaleColumnParallelLinear, get_linear_cls, get_mlp_cls, ) @@ -851,11 +850,7 @@ def __init__( if is_reward: head_cls = RewardModelLinear else: - head_cls = ( - ScaleColumnParallelLinear - if self.tp_mode in ["mtp", "fsp", "isp"] - else MegatronScaleColumnParallelLinear - ) + head_cls = InternLM2ScaleColumnParallelLinear sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) @@ -985,9 +980,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N if hasattr(self, "output"): # Evaluation if gpc.is_evaluating is True: - hidden_states = self.output(hidden_states, gather_dim=1) + hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) else: # Training - hidden_states = self.output(hidden_states, gather_dim=0) + hidden_states = self.output(hidden_states, gather_dim=0, tp_mode=self.tp_mode) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index 00529796..0ac8d165 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -18,7 +18,6 @@ ) from internlm.model.embedding import Embedding1D, RotaryEmbedding from internlm.model.linear import ( - MegatronScaleColumnParallelLinear, RewardModelLinear, ScaleColumnParallelLinear, get_linear_cls, @@ -836,11 +835,8 @@ def __init__( if is_reward: head_cls = RewardModelLinear else: - head_cls = ( - ScaleColumnParallelLinear - if self.tp_mode in ["mtp", "fsp", "isp"] - else MegatronScaleColumnParallelLinear - ) + head_cls = ScaleColumnParallelLinear + if first: if embed_split_hidden: self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -992,9 +988,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N if hasattr(self, "output"): # Evaluation if gpc.is_evaluating is True: - hidden_states = self.output(hidden_states, gather_dim=1) + hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) else: # Training - hidden_states = self.output(hidden_states, gather_dim=0) + hidden_states = self.output(hidden_states, gather_dim=0, tp_mode=self.tp_mode) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index e386743d..5b830cc9 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -15,7 +15,6 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D from internlm.model.linear import ( - MegatronScaleColumnParallelLinear, RewardModelLinear, ScaleColumnParallelLinear, get_mlp_cls, @@ -334,11 +333,8 @@ def __init__( if is_reward: head_cls = RewardModelLinear else: - head_cls = ( - ScaleColumnParallelLinear - if self.tp_mode in ["mtp", "fsp", "isp"] - else MegatronScaleColumnParallelLinear - ) + head_cls = ScaleColumnParallelLinear + if first: if embed_split_hidden: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -446,9 +442,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N if hasattr(self, "head"): # Evaluation if hidden_states.ndim == 3: - hidden_states = self.head(hidden_states, gather_dim=1) + hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode) else: # Training - hidden_states = self.head(hidden_states, gather_dim=0) + hidden_states = self.head(hidden_states, gather_dim=0, tp_mode=self.tp_mode) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) From 1c14206c838c03fdc318b1939388d4c188929ef7 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 1 Mar 2024 16:19:12 +0800 Subject: [PATCH 2/2] feat(model/linear.py): fix comment of norm head arg --- internlm/model/linear.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index ad353778..271fc48f 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -95,6 +95,9 @@ class InternLM2ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. weight_scale (int): For training stability. 1 by default. + norm_head (bool): Normalize the output embedding in order to let the calculation of logits not affected by + the norm of embedding. The implementation is referred to baichuan2, + see https://huggingface.co/baichuan-inc/Baichuan2-7B-Base for more information. False by default. """ def __init__( @@ -114,7 +117,7 @@ def __init__( self.norm_head = norm_head if self.norm_head: - logger.info("Notice norm head is used.") + logger.info("Notice that norm head is enabled to normalize head weight.") self.first_eval_flag = True self.tmp_weight = None