Skip to content

Commit

Permalink
feat(context/process_group_initializer.py): add gqa process group
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Feb 23, 2024
1 parent 663ef2e commit 1a13152
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
1 change: 1 addition & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def init_parallel_groups(self):

# run initialization of different process groups
initializers = []
initializers.append(pgroup_initializer.Initializer_GQA(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Weight(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
Expand Down
84 changes: 84 additions & 0 deletions internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class ParallelMode(Enum):
# sequence parallel
SEQUENCE = "sequence"

# grouped query attention
GQA = "gqa"


class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
Expand Down Expand Up @@ -849,3 +852,84 @@ def init_dist_group(self, use_cpu: bool = False):
ranks_in_group = ranks

return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode


class Initializer_GQA(ProcessGroupInitializer):
"""A ProcessGroupInitializer for allreduce kv gradients with common attention head.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
weight_parallel_size (int): Size of model weight parallel.
weight_data_parallel_size (int): Size of data parallel for common weight.
sequence_parallel_size (int): Size of data sequence parallel.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
nettest_parallel_size (int): Size of net testing parallel.
expert_parallel_size (int): Size of expert parallel.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: should adapt to general case
self.num_kv_attention_heads = 8
self.NUM_ATTENTION_HEAD = 32
self.kv_head_repeats_num = self.NUM_ATTENTION_HEAD // self.num_kv_attention_heads
self.num_kv_group_per_tp = self.num_kv_attention_heads
self.num_kv_groups = self.num_kv_group_per_tp * self.data_parallel_size

assert self.world_size % self.tensor_parallel_size == 0
assert self.world_size % (self.pipeline_parallel_size * self.tensor_parallel_size) == 0
assert self.pipeline_parallel_size == 1

def init_dist_group(self, use_cpu: bool = False):
"""Initialize weight's data parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A WEIGHT_DATA parallelism's information tuple.
n=128 sp=32 wp=64 zo1=1 with nopp
sp groups: [0-31] [32-63] [64-95] [96-127]
wp groups: [0-63] [64-127]
kv_head groups: [0,8,16,24] [1,9,17,25] [2,10,18,26] [3,11,19,27]
[4,12,20,28] [5,13,21,29] [6,14,22,30] [7,15,23,31]
[32,40,48,56] [33,41,49,57] [34,42,50,58] [35,43,51,59]
[36,44,52,60] [37,45,53,61] [38,46,54,62] [39,47,55,63]
...
...
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.GQA

# TODO: consider PP
for i in range(self.data_parallel_size):
for j in range(self.num_kv_group_per_tp):
ranks = [
i * self.tensor_parallel_size + j + k * self.num_kv_attention_heads
for k in range(self.kv_head_repeats_num)
]
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks

return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

0 comments on commit 1a13152

Please sign in to comment.