Skip to content

Commit

Permalink
add comments for Initializer_Expert_Weight_Data
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde committed Feb 4, 2024
1 parent 3ac714d commit 6d3855a
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,19 @@ def __init__(self, *args, **kwargs):
assert self.world_size % (self.pipeline_parallel_size * self.expert_data_parallel_size) == 0

def init_dist_group(self, use_cpu: bool = False):
"""Initialize expert parallel groups for isp, and assign local_ranks and groups to each gpu.
Returns:
list: [(local_rank, group_world_size, process_group, ranks_in_group, mode), ...]:
A length 3 list consists of expert parallelism's, expert weight parallelism's
and expert data parallelism's information tuple.
Example: n=16 ewp=2 ep=4 edp=2 with nopp
expert weight groups: [0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]
expert groups: [0, 2, 4, 6], [1, 3, 5, 7], [8, 10, 12, 14], [9, 11, 13, 15]
expert (weight) data groups:[0, 8], [1, 9], [2, 10], [3, 11], [4, 12], [5, 13], [6, 14], [7, 15]
"""

expert_parallel_groups = []
expert_weight_parallel_groups = []
expert_data_parallel_groups = []
Expand Down

0 comments on commit 6d3855a

Please sign in to comment.