Skip to content

Commit

Permalink
MHA sequence parallel support dit
Browse files Browse the repository at this point in the history
  • Loading branch information
KimmiShi committed Mar 1, 2024
1 parent 062c2ed commit 6072593
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
varlen=True,
) -> None:
super().__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.varlen = varlen
self._scatter_gather_idx = {}

# scatter_gather_idx contains the scatter and gather index for different data packed mode
Expand All @@ -108,7 +110,11 @@ def forward(
eval_scatter_gather_idx = {key: [x + 1 for x in value] for key, value in self._scatter_gather_idx.items()}
return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=eval_scatter_gather_idx, **kwargs)
else:
return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=self._scatter_gather_idx, **kwargs)
if not self.varlen:
scatter_gather_idx = {key: [x + 1 for x in value] for key, value in self._scatter_gather_idx.items()}
else:
scatter_gather_idx = self._scatter_gather_idx
return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=scatter_gather_idx, **kwargs)

def _forward(
self,
Expand Down

0 comments on commit 6072593

Please sign in to comment.