From c6b42f9ec063acf5c6da9d813380da95bfe0aa1d Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Thu, 1 Feb 2024 15:15:45 +0800 Subject: [PATCH] reconstruct embedding rotary related code and fix conflict --- internlm/model/embedding.py | 359 ++++++++++++++-------------- internlm/model/linear.py | 17 +- internlm/model/utils.py | 143 ----------- internlm/train/training_internlm.py | 13 +- 4 files changed, 205 insertions(+), 327 deletions(-) diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index 2faff8078..f158fcbdd 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Tuple +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat from torch import Tensor, nn from internlm.core.context import ParallelMode @@ -60,192 +60,220 @@ def forward(self, input_: Tensor) -> Tensor: return output -def apply_rotary_torch(x1, x2, cos, sin, conj): - assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device" - assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype" - assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" - assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) - if conj: - out1 = x1 * cos + x2 * sin - out2 = -x1 * sin + x2 * cos - else: - out1 = x1 * cos - x2 * sin - out2 = x1 * sin + x2 * cos - return out1, out2 +def apply_rotary_torch(x, cos, sin): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)") + sin = repeat(sin, "... d -> ... 1 (2 d)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin, x[..., ro_dim:]], + dim=-1, + ) -class ApplyRotaryEmbQKV_(torch.autograd.Function): +class ApplyRotaryEmb(torch.autograd.Function): """ - ApplyRotaryEmbQKV_ + ApplyRotaryEmb """ @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): - """ - qkv: (total, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - _, three, _, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): if gpc.config.model.use_flash_attn: - import rotary_emb - - rotary_emb.apply_rotary( - q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False + from flash_attn.ops.triton.rotary import apply_rotary + + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, ) else: - q1, q2 = apply_rotary_torch(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), False) - k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) - if gpc.config.model.use_flash_attn: - rotary_emb.apply_rotary( - k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False - ) + out = apply_rotary_torch(x, cos, sin) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets else: - k1, k2 = apply_rotary_torch( - k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), False - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - return qkv + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) - if gpc.config.model.use_flash_attn: - import rotary_emb - - rotary_emb.apply_rotary( - dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True - ) + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: - dq1, dq2 = apply_rotary_torch( - dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), True - ) - dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() if gpc.config.model.use_flash_attn: - rotary_emb.apply_rotary( - dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True + from flash_attn.ops.triton.rotary import apply_rotary + + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, ) else: - dk1, dk2 = apply_rotary_torch( - dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), True - ) - return dqkv, None, None, None, None + dx = apply_rotary_torch(do, cos, sin) + return dx, None, None, None, None, None, None, None -class TorchApplyRotaryEmb(torch.autograd.Function): - """ - TorchApplyRotaryEmb - """ +apply_rotary_emb = ApplyRotaryEmb.apply - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - _, seqlen, _, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) - x1, x2 = apply_rotary_torch( - x1, x2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False - ) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - return x - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, _ = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - do_ro = do[..., :rotary_dim] - do1, do2 = do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) - do1, do2 = apply_rotary_torch( - do1, do2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True - ) - return do, None, None, None, None - - -class TorchApplyRotaryEmbQKV_(torch.autograd.Function): +class ApplyRotaryEmbQKV_(torch.autograd.Function): """ - TorchApplyRotaryEmbQKV_ + ApplyRotaryEmbQKV_ """ @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): """ - qkv: (batch_size, seqlen, 3, nheads, headdim) + qkv: (total, 3, nheads, headdim) / (batch_size, seqlen, 3, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of q and k. """ - _, seqlen, three, _, headdim = qkv.shape + if len(qkv.shape == 4): + three = qkv.shape[1] + elif len(qkv.shape == 5): + three = qkv.shape[2] assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q_ro = qkv[:, :, 0, :, :rotary_dim] - q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) - q1, q2 = apply_rotary_torch( - q1, q2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False - ) - k_ro = qkv[:, :, 1, :, :rotary_dim] - k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) - k1, k2 = apply_rotary_torch( - k1, k2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), False - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved + + if gpc.config.model.use_flash_attn: + from flash_attn.ops.triton.rotary import apply_rotary + + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + if len(qkv.shape == 4): + qk = rearrange(qkv[:, :2], "t t h d -> t (t h) d") + elif len(qkv.shape == 5): + qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + if gpc.config.model.use_flash_attn: + apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=False, inplace=True) + else: + qk = apply_rotary_torch(qk, cos, sin) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + if len(qkv.shape == 4): + q, k = qkv[:, 0], qkv[:, 1] + elif len(qkv.shape == 5): + q, k = qkv[:, :, 0], qkv[:, :, 1] + if gpc.config.model.use_flash_attn: + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=False, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=False, inplace=True) + else: + q = apply_rotary_torch(q, cos, sin) + k = apply_rotary_torch(k, cos_k, sin_k) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None return qkv @staticmethod def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - _, seqlen, _, _, _ = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2]) - dq1, dq2 = apply_rotary_torch( - dq1, dq2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True - ) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) - dk1, dk2 = apply_rotary_torch( - dk1, dk2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), True - ) - return dqkv, None, None, None, None, None + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + + if gpc.config.model.use_flash_attn: + from flash_attn.ops.triton.rotary import apply_rotary + + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + if len(dqkv.shape == 4): + dqk = rearrange(dqkv[:, :2], "t t h d -> t (t h) d") + elif len(dqkv.shape == 5): + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + if gpc.config.model.use_flash_attn: + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=False, + inplace=True, + conjugate=True, + ) + else: + dqk = apply_rotary_torch(dqk, cos, sin) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + if len(dqkv.shape == 4): + dq, dk = dqkv[:, 0], dqkv[:, 1] + elif len(dqkv.shape == 5): + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + if gpc.config.model.use_flash_attn: + apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=False, inplace=True, conjugate=True) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=False, + inplace=True, + conjudate=True, + ) + else: + dq = apply_rotary_torch(dq, cos, sin) + dk = apply_rotary_torch(dk, cos_k, sin_k) + return dqkv, None, None, None, None, None, None apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply @@ -338,33 +366,16 @@ def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Te self._sin_k_cached[indexes], ) - def _get_legacy_apply_rotary_functions(self): - if gpc.config.model.use_flash_attn: - from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb - from flash_attn.layers.rotary import ( - ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_, - ) - - legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply - legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply - else: - legacy_apply_rotary_embed_qkv = TorchApplyRotaryEmbQKV_.apply - legacy_apply_rotary_embed = TorchApplyRotaryEmb.apply - return legacy_apply_rotary_embed_qkv, legacy_apply_rotary_embed - def _eval_forward(self, qkv, seqlen_offset=0): """ seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch. """ self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1]) - legacy_apply_rotary_embed_qkv, _ = self._get_legacy_apply_rotary_functions() if self.scale is None: - return legacy_apply_rotary_embed_qkv( - qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] - ) + return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) else: - return legacy_apply_rotary_embed_qkv( + return apply_rotary_emb_qkv_( qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], @@ -376,15 +387,13 @@ def _single_forward(self, x, indexes=0): assert self.scale is None self._update_cos_sin_cache(x, indexes) x = x[None, ...] - _, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions() - ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) + ret = apply_rotary_emb(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) return ret def _single_eval_forward(self, x, seqlen_offset=0): assert self.scale is None self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) - _, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions() - return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) + return apply_rotary_emb(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) class LinearRotaryEmbedding(RotaryEmbedding): diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 1d609c2ba..5b27bc82d 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -11,11 +11,10 @@ from internlm.core.context import global_context as gpc from internlm.model.utils import ( Silu, + all_reduce, fused_dense_func, isp_fused_dense_func, megatron_fused_dense_func, - all_reduce, - fused_dense_func_torch, reduce_scatter, ) @@ -206,7 +205,11 @@ def forward(self, x, gather_dim=0): ) -class MegatronColumnParallelLinearTorch(ColumnParallelLinear): +class MegatronColumnParallelLinearTorch(ColumnParallelLinearTorch): + """ + MegatronColumnParallelLinearTorch + """ + def forward(self, x, gather_dim=0): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. @@ -280,7 +283,11 @@ def forward(self, x): return reduce_fn(out, self.process_group) -class MegatronRowParallelLinearTorch(RowParallelLinear): +class MegatronRowParallelLinearTorch(RowParallelLinearTorch): + """ + MegatronRowParallelLinearTorch. + """ + def forward(self, x): """ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then @@ -442,7 +449,7 @@ def __init__( ) -class ISPLinear(ColumnParallelLinear): +class ISPLinear(ColumnParallelLinearTorch): """ Linear class for isp tensor parallel mode. """ diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 13980c0ab..68a919e83 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -16,27 +16,6 @@ logger = get_logger(__file__) -# Raw operation, does not support autograd, but does support async -def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - world_size = torch.distributed.get_world_size(process_group) - output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device) - handle = torch.distributed.all_gather_into_tensor( - output, input_.contiguous(), group=process_group, async_op=async_op - ) - return output, handle - - -# Raw operation, does not support autograd, but does support async -def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - world_size = torch.distributed.get_world_size(process_group) - assert input_.shape[0] % world_size == 0 - output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device) - handle = torch.distributed.reduce_scatter_tensor( - output, input_.contiguous(), group=process_group, async_op=async_op - ) - return output, handle - - # Raw operation, does not support autograd, but does support async def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): input_ = input_.contiguous() @@ -150,128 +129,6 @@ def gather_forward_split_backward(input_, parallel_mode, dim): return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) -def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): - assert my_input.dtype == grad_output.dtype - grad_weight = torch.matmul(grad_output.t(), my_input) - grad_bias = grad_output.sum(dim=0) if has_d_bias else None - return grad_weight, grad_bias - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py -class FusedDenseFuncTorch(torch.autograd.Function): - """A custom PyTorch module extending FusedDenseFunc.""" - - @staticmethod - @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) - else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - if process_group is not None and sequence_parallel: - handle_x.wait() - # we remove the cuda independence, which is different from flash_attn. - grad_weight, grad_bias = linear_bias_wgrad_torch( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None - - -def fused_dense_func_torch( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - if ( - gpc.config.model.use_flash_attn - and x.is_cuda - and weight.is_cuda - and (bias is None or bias.is_cuda) - and dtype_eligible - ): - from flash_attn.ops.fused_dense import FusedDenseFunc - - return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel) - else: - return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) - - class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 8fbd57580..42a022ebe 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -49,11 +49,11 @@ from internlm.model.embedding import Embedding1D from internlm.model.linear import ( BaseScaleColumnParallelLinear, - ColumnParallelLinear, + ColumnParallelLinearTorch, FeedForward, ISPLinear, RewardModelLinear, - RowParallelLinear, + RowParallelLinearTorch, ScaleColumnParallelLinear, ) from internlm.model.metrics import SchedulerMetricHook @@ -115,7 +115,12 @@ def _check_module(module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) # embedding and head - if isinstance(module, (Embedding1D, ParallelGPT2Embeddings, BaseScaleColumnParallelLinear)): + if gpc.config.model.use_flash_attn: + from flash_attn.modules.embedding import ParallelGPT2Embeddings + + if isinstance(module, (Embedding1D, BaseScaleColumnParallelLinear)) or ( + gpc.config.model.use_flash_attn and isinstance(module, ParallelGPT2Embeddings) + ): for param in module.parameters(): if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): setattr(param, IS_TENSOR_DATA_PARALLEL, True) @@ -123,7 +128,7 @@ def _check_module(module): setattr(param, IS_TENSOR_ZERO_PARALLEL, True) # for linear module - if isinstance(module, (ColumnParallelLinear, RowParallelLinear)): + if isinstance(module, (ColumnParallelLinearTorch, RowParallelLinearTorch)): for param in module.parameters(): if gpc.is_initialized(ParallelMode.EXPERT_DATA) and is_moe_param(param): # module should be MoE experts's linear