From aa388b54d01edc9046b1e18f8b20b9debdecb72b Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Thu, 25 Jan 2024 14:32:48 +0800 Subject: [PATCH] modify the distributedAttention for different data pack mode --- internlm/core/context/parallel_context.py | 1 + internlm/model/multi_head_attention.py | 65 +++++++++++++---------- internlm/utils/evaluation.py | 7 ++- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index fd53c4be..e1bdb601 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -157,6 +157,7 @@ def __init__(self): self.virtual_pipeline_parallel_size = None self.virtual_pipeline_parallel_rank = None self._expert_parallel_group_names = [] + self.evaluation = False @property def config(self): diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 825e3f21..01a88034 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -80,48 +80,57 @@ def __init__( self, local_attention: Module, sequence_process_group: dist.ProcessGroup, - first_scatter_idx: int = 2, - first_gather_idx: int = 0, - second_scatter_idx: int = 0, - second_gather_idx: int = 1, ) -> None: super().__init__() self.local_attn = local_attention self.spg = sequence_process_group - self.first_scatter_idx = first_scatter_idx - self.first_gather_idx = first_gather_idx - self.second_scatter_idx = second_scatter_idx - self.second_gather_idx = second_gather_idx + self._scatter_gather_idx = {} + + # scatter_gather_idx contains the scatter and gather index for different data packed mode + # key is the data packed mode, which should be in ['qkv', 'kv', 'q', 'output'] + # value is the scatter and gather index in all2all + self._scatter_gather_idx['qkv'] = [2, 0] # qkv shape:[sequence, 3, head, head_dim] + self._scatter_gather_idx['kv'] = [2, 0] # kv shape: [sequence, 2, head, head_dim] + self._scatter_gather_idx['q'] = [1, 0] # q/k/v shape: [sequence, head, head_dim] + self._scatter_gather_idx['output'] = [0, 1] # output shape: [sequence, head, head_dim] + + + def forward(self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, **kwargs: Any) -> Tensor: + if gpc.evaluation is True: + # when conducting evaluation, the scatter and gather index should add 1. + eval_scatter_gather_idx = {key: [x + 1 for x in value] for key, value in self._scatter_gather_idx.items()} + self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=eval_scatter_gather_idx, **kwargs) + else: + self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=self._scatter_gather_idx, **kwargs) - def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor: + def _forward(self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, scatter_gather: dict = None, **kwargs: Any) -> Tensor: """forward Arguments: - query (Tensor): query input to the layer - key (Tensor): key input to the layer - value (Tensor): value input to the layer + qkv (Tensor): packed qkv input to the layer + kv (Tensor): packed kv input to the layer + q (Tensor): q input to the layer + k (Tensor): k input to the layer + v (Tensor): v input to the layer args: other args Returns: * output (Tensor): context output """ - # Evaluation - if qkv.ndim == 5: - # in shape: [batch, seq/tp_size, 3, head, head_dim] - qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1) - # out shape : [batch, seq, head/tp_size, head_dim] - context_layer = self.local_attn(qkv, **kwargs) - # in shape: [batch, seq, head/tp_size, head_dim] - output = _SeqAllToAll.apply( - self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1 - ) - else: # training - # in shape: [seq/tp_size, 3, head, head_dim] - qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx) - # out shape : [seq, head/tp_size, head_dim] + + if qkv is not None: + qkv = _SeqAllToAll.apply(self.spg, qkv, scatter_gather['qkv'][0], scatter_gather['qkv'][1]) context_layer = self.local_attn(qkv, **kwargs) - # in shape: [seq, head/tp_size, head_dim] - output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx) + elif kv is not None: + q = _SeqAllToAll.apply(self.spg, q, scatter_gather['q'][0], scatter_gather['q'][1]) + kv = _SeqAllToAll.apply(self.spg, kv, scatter_gather['kv'][0], scatter_gather['kv'][1]) + context_layer = self.local_attn(q, kv, **kwargs) + else: + q = _SeqAllToAll.apply(self.spg, q, scatter_gather['q'][0], scatter_gather['q'][1]) + k = _SeqAllToAll.apply(self.spg, k, scatter_gather['q'][0], scatter_gather['q'][1]) + v = _SeqAllToAll.apply(self.spg, v, scatter_gather['q'][0], scatter_gather['q'][1]) + context_layer = self.local_attn(q, k, v, **kwargs) + output = _SeqAllToAll.apply(self.spg, context_layer, scatter_gather['output'][0], scatter_gather['output'][1]) # out e.g., [s/p::h] return output diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 1c1515b4..1d840ac4 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -47,8 +47,9 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape @contextmanager -def switch_sequence_parallel_mode(): +def switch_evaluation_mode(): prev_mode = gpc.config.parallel.sequence_parallel + prev_evaluation = gpc.evaluation try: # when training x.shape is torch.Size([1024, 4096]), linear all gather in dim=0(sequence dim) # but evaluation x.shape is torch.Size([1, 1024, 4096]), gather in dim=0 is error. @@ -56,9 +57,11 @@ def switch_sequence_parallel_mode(): gpc.config.parallel.sequence_parallel = True else: gpc.config.parallel.sequence_parallel = False + gpc.evaluation = True yield finally: gpc.config.parallel.sequence_parallel = prev_mode + gpc.evaluation = prev_evaluation def evaluate_on_val_dls( @@ -70,7 +73,7 @@ def evaluate_on_val_dls( update_panel: bool = False, streaming: bool = False, ): - with switch_sequence_parallel_mode(): + with switch_evaluation_mode(): torch.cuda.empty_cache() trainer.eval() verbose = gpc.is_rank_for_log()