From 34b94790b08ca8e1260a398366cf44bfbb891318 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Thu, 25 Jan 2024 14:48:54 +0800 Subject: [PATCH] feat(model/multi_head_attention.py): fix return output --- internlm/model/multi_head_attention.py | 46 ++++++++++++++++---------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 01a88034..200d4a9f 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -85,25 +85,35 @@ def __init__( self.local_attn = local_attention self.spg = sequence_process_group self._scatter_gather_idx = {} - + # scatter_gather_idx contains the scatter and gather index for different data packed mode # key is the data packed mode, which should be in ['qkv', 'kv', 'q', 'output'] # value is the scatter and gather index in all2all - self._scatter_gather_idx['qkv'] = [2, 0] # qkv shape:[sequence, 3, head, head_dim] - self._scatter_gather_idx['kv'] = [2, 0] # kv shape: [sequence, 2, head, head_dim] - self._scatter_gather_idx['q'] = [1, 0] # q/k/v shape: [sequence, head, head_dim] - self._scatter_gather_idx['output'] = [0, 1] # output shape: [sequence, head, head_dim] - - - def forward(self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, **kwargs: Any) -> Tensor: + self._scatter_gather_idx["qkv"] = [2, 0] # qkv shape:[sequence, 3, head, head_dim] + self._scatter_gather_idx["kv"] = [2, 0] # kv shape: [sequence, 2, head, head_dim] + self._scatter_gather_idx["q"] = [1, 0] # q/k/v shape: [sequence, head, head_dim] + self._scatter_gather_idx["output"] = [0, 1] # output shape: [sequence, head, head_dim] + + def forward( + self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, **kwargs: Any + ) -> Tensor: if gpc.evaluation is True: # when conducting evaluation, the scatter and gather index should add 1. eval_scatter_gather_idx = {key: [x + 1 for x in value] for key, value in self._scatter_gather_idx.items()} - self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=eval_scatter_gather_idx, **kwargs) + return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=eval_scatter_gather_idx, **kwargs) else: - self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=self._scatter_gather_idx, **kwargs) + return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=self._scatter_gather_idx, **kwargs) - def _forward(self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, scatter_gather: dict = None, **kwargs: Any) -> Tensor: + def _forward( + self, + qkv: Tensor = None, + kv: Tensor = None, + q: Tensor = None, + k: Tensor = None, + v: Tensor = None, + scatter_gather: dict = None, + **kwargs: Any, + ) -> Tensor: """forward Arguments: @@ -119,18 +129,18 @@ def _forward(self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: T """ if qkv is not None: - qkv = _SeqAllToAll.apply(self.spg, qkv, scatter_gather['qkv'][0], scatter_gather['qkv'][1]) + qkv = _SeqAllToAll.apply(self.spg, qkv, scatter_gather["qkv"][0], scatter_gather["qkv"][1]) context_layer = self.local_attn(qkv, **kwargs) elif kv is not None: - q = _SeqAllToAll.apply(self.spg, q, scatter_gather['q'][0], scatter_gather['q'][1]) - kv = _SeqAllToAll.apply(self.spg, kv, scatter_gather['kv'][0], scatter_gather['kv'][1]) + q = _SeqAllToAll.apply(self.spg, q, scatter_gather["q"][0], scatter_gather["q"][1]) + kv = _SeqAllToAll.apply(self.spg, kv, scatter_gather["kv"][0], scatter_gather["kv"][1]) context_layer = self.local_attn(q, kv, **kwargs) else: - q = _SeqAllToAll.apply(self.spg, q, scatter_gather['q'][0], scatter_gather['q'][1]) - k = _SeqAllToAll.apply(self.spg, k, scatter_gather['q'][0], scatter_gather['q'][1]) - v = _SeqAllToAll.apply(self.spg, v, scatter_gather['q'][0], scatter_gather['q'][1]) + q = _SeqAllToAll.apply(self.spg, q, scatter_gather["q"][0], scatter_gather["q"][1]) + k = _SeqAllToAll.apply(self.spg, k, scatter_gather["q"][0], scatter_gather["q"][1]) + v = _SeqAllToAll.apply(self.spg, v, scatter_gather["q"][0], scatter_gather["q"][1]) context_layer = self.local_attn(q, k, v, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, scatter_gather['output'][0], scatter_gather['output'][1]) + output = _SeqAllToAll.apply(self.spg, context_layer, scatter_gather["output"][0], scatter_gather["output"][1]) # out e.g., [s/p::h] return output