From 4d590ca2c1450cdf94a36a2088160c45cb5ff3ae Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Thu, 1 Feb 2024 15:15:45 +0800 Subject: [PATCH] reconstruct embedding rotary related code --- internlm/model/embedding.py | 360 ++++++++++++++++++------------------ 1 file changed, 185 insertions(+), 175 deletions(-) diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index 2faff8078..8b2f9f39c 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Tuple +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat from torch import Tensor, nn from internlm.core.context import ParallelMode @@ -60,192 +60,221 @@ def forward(self, input_: Tensor) -> Tensor: return output -def apply_rotary_torch(x1, x2, cos, sin, conj): - assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device" - assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype" - assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" - assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) - if conj: - out1 = x1 * cos + x2 * sin - out2 = -x1 * sin + x2 * cos - else: - out1 = x1 * cos - x2 * sin - out2 = x1 * sin + x2 * cos - return out1, out2 +def apply_rotary_torch(x, cos, sin): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)") + sin = repeat(sin, "... d -> ... 1 (2 d)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin, x[..., ro_dim:]], + dim=-1, + ) -class ApplyRotaryEmbQKV_(torch.autograd.Function): +class ApplyRotaryEmb(torch.autograd.Function): """ - ApplyRotaryEmbQKV_ + ApplyRotaryEmb """ @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): - """ - qkv: (total, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - _, three, _, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): if gpc.config.model.use_flash_attn: - import rotary_emb - - rotary_emb.apply_rotary( - q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False + from flash_attn.ops.triton.rotary import apply_rotary + + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, ) else: - q1, q2 = apply_rotary_torch(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), False) - k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) - if gpc.config.model.use_flash_attn: - rotary_emb.apply_rotary( - k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False - ) + out = apply_rotary_torch(x, cos, sin) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets else: - k1, k2 = apply_rotary_torch( - k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), False - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - return qkv + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) - if gpc.config.model.use_flash_attn: - import rotary_emb - - rotary_emb.apply_rotary( - dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True - ) + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: - dq1, dq2 = apply_rotary_torch( - dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), True - ) - dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() if gpc.config.model.use_flash_attn: - rotary_emb.apply_rotary( - dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True + from flash_attn.ops.triton.rotary import apply_rotary + + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, ) else: - dk1, dk2 = apply_rotary_torch( - dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), True - ) - return dqkv, None, None, None, None + dx = apply_rotary_torch(do, cos, sin) + return dx, None, None, None, None, None, None, None -class TorchApplyRotaryEmb(torch.autograd.Function): - """ - TorchApplyRotaryEmb - """ +apply_rotary_emb = ApplyRotaryEmb.apply - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - _, seqlen, _, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) - x1, x2 = apply_rotary_torch( - x1, x2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False - ) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - return x - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, _ = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - do_ro = do[..., :rotary_dim] - do1, do2 = do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) - do1, do2 = apply_rotary_torch( - do1, do2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True - ) - return do, None, None, None, None - - -class TorchApplyRotaryEmbQKV_(torch.autograd.Function): +class ApplyRotaryEmbQKV_(torch.autograd.Function): """ - TorchApplyRotaryEmbQKV_ + ApplyRotaryEmbQKV_ """ @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): """ - qkv: (batch_size, seqlen, 3, nheads, headdim) + qkv: (total, 3, nheads, headdim) / (batch_size, seqlen, 3, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of q and k. """ - _, seqlen, three, _, headdim = qkv.shape + if len(qkv.shape == 4): + three = qkv.shape[1] + elif len(qkv.shape == 5): + three = qkv.shape[2] assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q_ro = qkv[:, :, 0, :, :rotary_dim] - q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) - q1, q2 = apply_rotary_torch( - q1, q2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False - ) - k_ro = qkv[:, :, 1, :, :rotary_dim] - k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) - k1, k2 = apply_rotary_torch( - k1, k2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), False - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved + + if gpc.config.model.use_flash_attn: + from flash_attn.ops.triton.rotary import apply_rotary + + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + if len(qkv.shape == 4): + qk = rearrange(qkv[:, :2], "t t h d -> t (t h) d") + elif len(qkv.shape == 5): + qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + if gpc.config.model.use_flash_attn: + apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=False, inplace=True) + else: + qk = apply_rotary_torch(qk, cos, sin) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + if len(qkv.shape == 4): + q, k = qkv[:, 0], qkv[:, 1] + elif len(qkv.shape == 5): + q, k = qkv[:, :, 0], qkv[:, :, 1] + if gpc.config.model.use_flash_attn: + # TODO: ε…₯ε‚εŒΊεˆ†4 5 + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=False, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=False, inplace=True) + else: + q = apply_rotary_torch(q, cos, sin) + k = apply_rotary_torch(k, cos_k, sin_k) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None return qkv @staticmethod def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - _, seqlen, _, _, _ = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2]) - dq1, dq2 = apply_rotary_torch( - dq1, dq2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True - ) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) - dk1, dk2 = apply_rotary_torch( - dk1, dk2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), True - ) - return dqkv, None, None, None, None, None + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + + if gpc.config.model.use_flash_attn: + from flash_attn.ops.triton.rotary import apply_rotary + + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + if len(dqkv.shape == 4): + dqk = rearrange(dqkv[:, :2], "t t h d -> t (t h) d") + elif len(dqkv.shape == 5): + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + if gpc.config.model.use_flash_attn: + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=False, + inplace=True, + conjugate=True, + ) + else: + dqk = apply_rotary_torch(dqk, cos, sin) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + if len(dqkv.shape == 4): + dq, dk = dqkv[:, 0], dqkv[:, 1] + elif len(dqkv.shape == 5): + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + if gpc.config.model.use_flash_attn: + apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=False, inplace=True, conjugate=True) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=False, + inplace=True, + conjudate=True, + ) + else: + dq = apply_rotary_torch(dq, cos, sin) + dk = apply_rotary_torch(dk, cos_k, sin_k) + return dqkv, None, None, None, None, None, None apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply @@ -338,33 +367,16 @@ def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Te self._sin_k_cached[indexes], ) - def _get_legacy_apply_rotary_functions(self): - if gpc.config.model.use_flash_attn: - from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb - from flash_attn.layers.rotary import ( - ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_, - ) - - legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply - legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply - else: - legacy_apply_rotary_embed_qkv = TorchApplyRotaryEmbQKV_.apply - legacy_apply_rotary_embed = TorchApplyRotaryEmb.apply - return legacy_apply_rotary_embed_qkv, legacy_apply_rotary_embed - def _eval_forward(self, qkv, seqlen_offset=0): """ seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch. """ self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1]) - legacy_apply_rotary_embed_qkv, _ = self._get_legacy_apply_rotary_functions() if self.scale is None: - return legacy_apply_rotary_embed_qkv( - qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] - ) + return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) else: - return legacy_apply_rotary_embed_qkv( + return apply_rotary_emb_qkv_( qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], @@ -376,15 +388,13 @@ def _single_forward(self, x, indexes=0): assert self.scale is None self._update_cos_sin_cache(x, indexes) x = x[None, ...] - _, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions() - ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) + ret = apply_rotary_emb(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) return ret def _single_eval_forward(self, x, seqlen_offset=0): assert self.scale is None self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) - _, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions() - return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) + return apply_rotary_emb(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) class LinearRotaryEmbedding(RotaryEmbedding):