From 1a13152dc5e60fd93bda8aaf36137e874502ea10 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 23 Feb 2024 13:55:34 +0800 Subject: [PATCH] feat(context/process_group_initializer.py): add gqa process group --- internlm/core/context/parallel_context.py | 1 + .../core/context/process_group_initializer.py | 84 +++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index b1c7034d..fe4af16e 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -564,6 +564,7 @@ def init_parallel_groups(self): # run initialization of different process groups initializers = [] + initializers.append(pgroup_initializer.Initializer_GQA(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Weight(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 0ca2a14e..51a722ab 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -60,6 +60,9 @@ class ParallelMode(Enum): # sequence parallel SEQUENCE = "sequence" + # grouped query attention + GQA = "gqa" + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -849,3 +852,84 @@ def init_dist_group(self, use_cpu: bool = False): ranks_in_group = ranks return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_GQA(ProcessGroupInitializer): + """A ProcessGroupInitializer for allreduce kv gradients with common attention head. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + weight_parallel_size (int): Size of model weight parallel. + weight_data_parallel_size (int): Size of data parallel for common weight. + sequence_parallel_size (int): Size of data sequence parallel. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + zero1_parallel_size (int): Size of zero1 parallel. + nettest_parallel_size (int): Size of net testing parallel. + expert_parallel_size (int): Size of expert parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: should adapt to general case + self.num_kv_attention_heads = 8 + self.NUM_ATTENTION_HEAD = 32 + self.kv_head_repeats_num = self.NUM_ATTENTION_HEAD // self.num_kv_attention_heads + self.num_kv_group_per_tp = self.num_kv_attention_heads + self.num_kv_groups = self.num_kv_group_per_tp * self.data_parallel_size + + assert self.world_size % self.tensor_parallel_size == 0 + assert self.world_size % (self.pipeline_parallel_size * self.tensor_parallel_size) == 0 + assert self.pipeline_parallel_size == 1 + + def init_dist_group(self, use_cpu: bool = False): + """Initialize weight's data parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A WEIGHT_DATA parallelism's information tuple. + + n=128 sp=32 wp=64 zo1=1 with nopp + sp groups: [0-31] [32-63] [64-95] [96-127] + wp groups: [0-63] [64-127] + kv_head groups: [0,8,16,24] [1,9,17,25] [2,10,18,26] [3,11,19,27] + [4,12,20,28] [5,13,21,29] [6,14,22,30] [7,15,23,31] + [32,40,48,56] [33,41,49,57] [34,42,50,58] [35,43,51,59] + [36,44,52,60] [37,45,53,61] [38,46,54,62] [39,47,55,63] + ... + ... + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.GQA + + # TODO: consider PP + for i in range(self.data_parallel_size): + for j in range(self.num_kv_group_per_tp): + ranks = [ + i * self.tensor_parallel_size + j + k * self.num_kv_attention_heads + for k in range(self.kv_head_repeats_num) + ] + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) + if use_cpu: + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) + else: + group_cpu = None + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode