Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model/linear.py): support norm head for model internlm2 #68

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions internlm/core/communication/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.linear import ISPLinear, ScaleColumnParallelLinear
from internlm.model.linear import BaseScaleColumnParallelLinear, ISPLinear
from internlm.model.utils import all_gather_raw, reduce_scatter_raw
from internlm.utils.common import SchedulerHook

Expand Down Expand Up @@ -220,7 +220,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:

# Important: only works for llama-class models
for _, children in model.named_children():
if isinstance(children, ScaleColumnParallelLinear):
if isinstance(children, BaseScaleColumnParallelLinear):
setattr(children, "isp_name", "head")
self._overlap_states[cid].head.append(children)
elif isinstance(children, Embedding1D):
Expand Down Expand Up @@ -534,9 +534,9 @@ def reduce_scatter(
self.process_group,
op=op,
async_op=True,
memory_pool_allocator=self.memory_pool.allocate_reduce_scatter_memory
if self.enable_memory_pool
else None,
memory_pool_allocator=(
self.memory_pool.allocate_reduce_scatter_memory if self.enable_memory_pool else None
),
)

result, handle = (
Expand Down
73 changes: 63 additions & 10 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
isp_fused_dense_func,
megatron_fused_dense_func,
)
from internlm.utils.logger import get_logger

logger = get_logger(__file__)


class BaseScaleColumnParallelLinear(nn.Linear):
Expand Down Expand Up @@ -59,15 +62,17 @@ class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
ScaleColumnParallelLinear in flash implementation.
"""

def forward(self, input, gather_dim=0): # pylint: disable=W0622
def forward(self, input, gather_dim=0, tp_mode: str = "mtp"): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fused_dense_func(

_fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func
return _fused_func(
input,
weight,
self.bias,
Expand All @@ -77,20 +82,68 @@ def forward(self, input, gather_dim=0): # pylint: disable=W0622
)


class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
class InternLM2ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
"""
ScaleColumnParallelLinear in megatron implementation.
ScaleColumnParallelLinear for InternLM2.

Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缺了norm head

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新1c14206

norm_head (bool): Normalize the output embedding in order to let the calculation of logits not affected by
the norm of embedding. The implementation is referred to baichuan2,
see https://huggingface.co/baichuan-inc/Baichuan2-7B-Base for more information. False by default.
"""

def forward(self, input, gather_dim=0): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
def __init__(
self,
in_features: int,
out_features: int,
process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
norm_head: bool = False,
) -> None:
super().__init__(
in_features, out_features, process_group, bias=bias, device=device, dtype=dtype, weight_scale=weight_scale
)

self.norm_head = norm_head
if self.norm_head:
logger.info("Notice that norm head is enabled to normalize head weight.")
self.first_eval_flag = True
self.tmp_weight = None

def forward(self, input, gather_dim=0, tp_mode: str = "mtp"): # pylint: disable=W0622
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return megatron_fused_dense_func(
if self.norm_head:
if self.training:
if not self.first_eval_flag:
self.first_eval_flag = True
self.tmp_weight = None
# We normalized the output Embedding so that the dot product
# is not affected by the norm of embedding. Ref: https://arxiv.org/pdf/2309.10305.pdf
weight = nn.functional.normalize(weight)
else:
if self.first_eval_flag:
# cache l2 norm of head to accelerate infer.
self.first_eval_flag = False
self.tmp_weight = nn.functional.normalize(weight)

weight = self.tmp_weight

_fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func
return _fused_func(
input,
weight,
self.bias,
Expand All @@ -100,7 +153,7 @@ def forward(self, input, gather_dim=0): # pylint: disable=W0622
)


class RewardModelLinear(ScaleColumnParallelLinear):
class RewardModelLinear(BaseScaleColumnParallelLinear):
"""
RewardModelLinear.
Args:
Expand Down
11 changes: 3 additions & 8 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
MegatronScaleColumnParallelLinear,
RewardModelLinear,
ScaleColumnParallelLinear,
get_mlp_cls,
Expand Down Expand Up @@ -311,11 +310,7 @@ def __init__(
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = (
ScaleColumnParallelLinear
if self.tp_mode in ["mtp", "fsp", "isp"]
else MegatronScaleColumnParallelLinear
)
head_cls = ScaleColumnParallelLinear
if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
Expand Down Expand Up @@ -422,9 +417,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
if hasattr(self, "head"):
# Evaluation
if hidden_states.ndim == 3:
hidden_states = self.head(hidden_states, gather_dim=1)
hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode)
else: # Training
hidden_states = self.head(hidden_states, gather_dim=0)
hidden_states = self.head(hidden_states, gather_dim=0, tp_mode=self.tp_mode)

if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
Expand Down
13 changes: 4 additions & 9 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
RotaryEmbedding,
)
from internlm.model.linear import (
MegatronScaleColumnParallelLinear,
InternLM2ScaleColumnParallelLinear,
RewardModelLinear,
ScaleColumnParallelLinear,
get_linear_cls,
get_mlp_cls,
)
Expand Down Expand Up @@ -851,11 +850,7 @@ def __init__(
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = (
ScaleColumnParallelLinear
if self.tp_mode in ["mtp", "fsp", "isp"]
else MegatronScaleColumnParallelLinear
)
head_cls = InternLM2ScaleColumnParallelLinear

sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)

Expand Down Expand Up @@ -985,9 +980,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
if hasattr(self, "output"):
# Evaluation
if gpc.is_evaluating is True:
hidden_states = self.output(hidden_states, gather_dim=1)
hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode)
else: # Training
hidden_states = self.output(hidden_states, gather_dim=0)
hidden_states = self.output(hidden_states, gather_dim=0, tp_mode=self.tp_mode)

if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
Expand Down
12 changes: 4 additions & 8 deletions internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from internlm.model.embedding import Embedding1D, RotaryEmbedding
from internlm.model.linear import (
MegatronScaleColumnParallelLinear,
RewardModelLinear,
ScaleColumnParallelLinear,
get_linear_cls,
Expand Down Expand Up @@ -836,11 +835,8 @@ def __init__(
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = (
ScaleColumnParallelLinear
if self.tp_mode in ["mtp", "fsp", "isp"]
else MegatronScaleColumnParallelLinear
)
head_cls = ScaleColumnParallelLinear

if first:
if embed_split_hidden:
self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
Expand Down Expand Up @@ -992,9 +988,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
if hasattr(self, "output"):
# Evaluation
if gpc.is_evaluating is True:
hidden_states = self.output(hidden_states, gather_dim=1)
hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode)
else: # Training
hidden_states = self.output(hidden_states, gather_dim=0)
hidden_states = self.output(hidden_states, gather_dim=0, tp_mode=self.tp_mode)

if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
Expand Down
12 changes: 4 additions & 8 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
MegatronScaleColumnParallelLinear,
RewardModelLinear,
ScaleColumnParallelLinear,
get_mlp_cls,
Expand Down Expand Up @@ -334,11 +333,8 @@ def __init__(
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = (
ScaleColumnParallelLinear
if self.tp_mode in ["mtp", "fsp", "isp"]
else MegatronScaleColumnParallelLinear
)
head_cls = ScaleColumnParallelLinear

if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
Expand Down Expand Up @@ -446,9 +442,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
if hasattr(self, "head"):
# Evaluation
if hidden_states.ndim == 3:
hidden_states = self.head(hidden_states, gather_dim=1)
hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode)
else: # Training
hidden_states = self.head(hidden_states, gather_dim=0)
hidden_states = self.head(hidden_states, gather_dim=0, tp_mode=self.tp_mode)

if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
Expand Down
Loading