Skip to content

Commit

Permalink
add cu_seqlens and max_seqlen parameter in packed scenario
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Feb 26, 2024
1 parent 3817f00 commit 7ebc1c3
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
28 changes: 20 additions & 8 deletions internlm/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ def forward(
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
qkv: (total, 3, nheads, headdim) / (batch_size, seqlen, 3, nheads, headdim)
Expand Down Expand Up @@ -298,7 +298,7 @@ def backward(ctx, dqkv):
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
if len(dqkv.shape) == 4:
dqk = rearrange(dqkv[:, :2], "t t h d -> t (t h) d")
dqk = rearrange(dqkv[:, :2], "a t h d -> a (t h) d")
elif len(dqkv.shape) == 5:
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
if gpc.config.model.use_flash_attn:
Expand Down Expand Up @@ -429,23 +429,31 @@ def _update_cos_sin_cache(self, x, indexes):

def forward(self, qkv: torch.Tensor, **kwargs):
if kwargs.get("indexes", None) is not None:
return self._forward(qkv, kwargs.pop("indexes"))
cu_seqlens = kwargs.get("cu_seqlens", None)
max_seqlen = kwargs.get("max_seqlen", None)
return self._forward(qkv, kwargs.pop("indexes"), cu_seqlens, max_seqlen)
if kwargs.get("inference_params", None) is not None:
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
else:
return self._eval_forward(qkv)

def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
def _forward(
self, qkv: torch.Tensor, indexes=0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes)
if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[indexes], self._sin_cached[indexes], None, None, cu_seqlens, max_seqlen
)
else:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached[indexes],
self._sin_cached[indexes],
self._cos_k_cached[indexes],
self._sin_k_cached[indexes],
cu_seqlens,
max_seqlen,
)

def _eval_forward(self, qkv, seqlen_offset=0):
Expand All @@ -465,11 +473,15 @@ def _eval_forward(self, qkv, seqlen_offset=0):
self._sin_k_cached[seqlen_offset:],
)

def _single_forward(self, x, indexes=0):
def _single_forward(
self, x, indexes=0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None
):
assert self.scale is None
self._update_cos_sin_cache(x, indexes)
x = x[None, ...]
ret = apply_rotary_emb(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
ret = apply_rotary_emb(
x, self._cos_cached[indexes], self._sin_cached[indexes], False, False, 0, cu_seqlens, max_seqlen
).squeeze(0)
return ret

def _single_eval_forward(self, x, seqlen_offset=0):
Expand Down
12 changes: 10 additions & 2 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,21 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:
else:
q = q.squeeze(1)
k = k.squeeze(1)
cu_seqlens = kwargs.get("cu_seqlens", None)
max_seqlen = kwargs.get("max_seqlen", None)
q = self.rotary_emb._single_forward(
q,
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
k = self.rotary_emb._single_forward(
k,
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
else:
raise NotImplementedError(
Expand Down Expand Up @@ -423,8 +429,10 @@ def _packed_forward(self, x, inference_params=None, **kwargs):
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)

indexes = kwargs.pop("indexes")
q = self.rotary_emb._single_forward(q, indexes=indexes)
k = self.rotary_emb._single_forward(k, indexes=indexes)
cu_seqlens = kwargs.pop("cu_seqlens")
max_seqlen = kwargs.pop("max_seqlen")
q = self.rotary_emb._single_forward(q, indexes=indexes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
k = self.rotary_emb._single_forward(k, indexes=indexes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)

if inference_params is None:
kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1)
Expand Down
12 changes: 10 additions & 2 deletions internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,21 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:
else:
q = q.squeeze(1)
k = k.squeeze(1)
cu_seqlens = kwargs.get("cu_seqlens", None)
max_seqlen = kwargs.get("max_seqlen", None)
q = self.rotary_emb._single_forward(
q,
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
k = self.rotary_emb._single_forward(
k,
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
else:
raise NotImplementedError(
Expand Down Expand Up @@ -421,8 +427,10 @@ def _packed_forward(self, x, inference_params=None, **kwargs):
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)

indexes = kwargs.pop("indexes")
q = self.rotary_emb._single_forward(q, indexes=indexes)
k = self.rotary_emb._single_forward(k, indexes=indexes)
cu_seqlens = kwargs.pop("cu_seqlens")
max_seqlen = kwargs.pop("max_seqlen")
q = self.rotary_emb._single_forward(q, indexes=indexes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
k = self.rotary_emb._single_forward(k, indexes=indexes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)

if inference_params is None:
kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1)
Expand Down
16 changes: 14 additions & 2 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
kv = torch.stack([k, v], dim=2)
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
cu_seqlens = kwargs.get("cu_seqlens", None)
max_seqlen = kwargs.get("max_seqlen", None)

if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
Expand Down Expand Up @@ -490,12 +492,16 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:
inference_params.sequence_len_offset
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
k = self.rotary_emb._single_forward(
k,
inference_params.sequence_len_offset
* torch.ones(k.size(0), dtype=torch.int, device=k.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
else:
q = q.squeeze(1)
Expand All @@ -504,6 +510,8 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:
inference_params.sequence_len_offset
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
moved_k = k.clone()
for i in range(len(empties)):
Expand All @@ -516,8 +524,12 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:
else:
k[i] = moved_k[i]
else:
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
q = self.rotary_emb._single_forward(
q, inference_params.sequence_len_offset, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
k = self.rotary_emb._single_forward(
k, inference_params.sequence_len_offset, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)

kv = torch.stack([k, v], dim=2)
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
Expand Down

0 comments on commit 7ebc1c3

Please sign in to comment.