forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pipeline] add stage manager (hpcaitech#4093)
* [pipeline] add stage manager * [test] add pipeline stage manager test * [pipeline] add docstring for stage manager
- Loading branch information
Showing
2 changed files
with
262 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from contextlib import contextmanager | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
|
||
from colossalai.cluster import ProcessGroupMesh | ||
|
||
|
||
class PipelineStageManager: | ||
"""PipelineStageManager is a helper class to manage pipeline stages. | ||
Args: | ||
pg_mesh (ProcessGroupMesh): Process group mesh. | ||
pipeline_axis (int): The axis along which the pipeline is constructed. | ||
Attributes: | ||
num_stages (int): Number of stages in the pipeline. | ||
stage (int): The current stage. | ||
num_virtual_stages (int): Number of virtual stages in the pipeline. | ||
virtual_stage (int): The current virtual stage. | ||
""" | ||
|
||
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: | ||
self.pg_mesh = pg_mesh | ||
self.pipeline_axis = pipeline_axis | ||
self.num_virtual_stages: Optional[int] = None | ||
self.virtual_stage: Optional[int] = None | ||
self.prev_rank: Optional[Tuple[int, ...]] = None | ||
self.next_rank: Optional[Tuple[int, ...]] = None | ||
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} | ||
# init prev and next coord | ||
coord = self.pg_mesh.coordinate() | ||
if self.stage > 0: | ||
prev_coord = coord[: self.pipeline_axis] + \ | ||
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] | ||
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) | ||
if self.stage < self.num_stages - 1: | ||
next_coord = coord[: self.pipeline_axis] + \ | ||
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] | ||
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) | ||
|
||
# init p2p process groups | ||
stages = list(range(self.num_stages)) | ||
for prev, cur in zip(stages[:-1], stages[1:]): | ||
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur]) | ||
if self.stage in [prev, cur]: | ||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) | ||
self.p2p_groups[tuple(ranks_in_group)] = group | ||
|
||
def is_first_stage(self, virtual: bool = False) -> bool: | ||
"""Is the current stage the first stage. | ||
Args: | ||
virtual (bool, optional): Whether to consider virtual stages. Defaults to False. | ||
Returns: | ||
bool: Whether the current stage is the first stage. | ||
""" | ||
if virtual: | ||
assert self.num_virtual_stages is not None | ||
return self.virtual_stage == 0 | ||
return self.stage == 0 | ||
|
||
def is_last_stage(self, virtual: bool = False) -> bool: | ||
"""Is the current stage the last stage. | ||
Args: | ||
virtual (bool, optional): Whether to consider virtual stages. Defaults to False. | ||
Returns: | ||
bool: Whether the current stage is the last stage. | ||
""" | ||
if virtual: | ||
assert self.num_virtual_stages is not None | ||
return self.virtual_stage == self.num_virtual_stages - 1 | ||
return self.stage == self.num_stages - 1 | ||
|
||
@property | ||
def num_stages(self) -> int: | ||
"""Number of stages in the pipeline. | ||
Returns: | ||
int: Number of stages in the pipeline. | ||
""" | ||
return self.pg_mesh.size(self.pipeline_axis) | ||
|
||
@property | ||
def stage(self) -> int: | ||
"""Current stage. | ||
Returns: | ||
int: Current stage. | ||
""" | ||
return self.pg_mesh.coordinate(self.pipeline_axis) | ||
|
||
def get_rank(self) -> int: | ||
"""Get the rank of the current process. | ||
Returns: | ||
int: Rank of the current process. | ||
""" | ||
return dist.get_rank() | ||
|
||
def get_prev_rank(self) -> int: | ||
"""Get the rank of the previous stage. | ||
Returns: | ||
int: Rank of the previous stage. | ||
""" | ||
assert not self.is_first_stage(), "Cannot get previous rank in the first stage." | ||
return self.prev_rank | ||
|
||
def get_next_rank(self) -> int: | ||
"""Get the rank of the next stage. | ||
Returns: | ||
int: Rank of the next stage. | ||
""" | ||
assert not self.is_last_stage(), "Cannot get next rank in the last stage." | ||
return self.next_rank | ||
|
||
def set_num_virtual_stages(self, num_virtual_stages: int) -> None: | ||
"""Set the number of virtual stages. | ||
Args: | ||
num_virtual_stages (int): Number of virtual stages. | ||
""" | ||
self.num_virtual_stages = num_virtual_stages | ||
|
||
def set_virtual_stage(self, virtual_stage: int) -> None: | ||
"""Set the virtual stage. | ||
Args: | ||
virtual_stage (int): Virtual stage. | ||
""" | ||
self.virtual_stage = virtual_stage | ||
|
||
@contextmanager | ||
def switch_virtual_stage(self, virtual_stage: int) -> None: | ||
"""A context manager to switch virtual stage. | ||
Args: | ||
virtual_stage (int): Target virtual stage. | ||
""" | ||
old_stage = self.virtual_stage | ||
try: | ||
self.set_virtual_stage(virtual_stage) | ||
yield | ||
finally: | ||
self.set_virtual_stage(old_stage) | ||
|
||
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: | ||
"""Get the p2p process group between two ranks. The order of the two ranks does not matter. | ||
Args: | ||
first_rank (int): The first rank. | ||
second_rank (int): The second rank. | ||
Returns: | ||
ProcessGroup: P2P process group between the two ranks. | ||
""" | ||
if first_rank > second_rank: | ||
first_rank, second_rank = second_rank, first_rank | ||
return self.p2p_groups[(first_rank, second_rank)] | ||
|
||
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: | ||
"""Get the process group of the given stages. | ||
Args: | ||
stages (List[int]): List of stages. | ||
Returns: | ||
ProcessGroup: Process group of the given stages. | ||
""" | ||
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import pytest | ||
import torch.distributed as dist | ||
|
||
import colossalai | ||
from colossalai.cluster import ProcessGroupMesh | ||
from colossalai.pipeline.stage_manager import PipelineStageManager | ||
from colossalai.testing import spawn | ||
|
||
|
||
def check_stage_manager(): | ||
DP_DIM, PP_DIM = 0, 1 | ||
DP_SIZE, PP_SIZE = 2, 2 | ||
RANK_TO_COORDINATE = { | ||
0: (0, 0), | ||
1: (0, 1), | ||
2: (1, 0), | ||
3: (1, 1), | ||
} | ||
PP_RANKS_IN_GROUP = { | ||
0: [0, 1], | ||
1: [0, 1], | ||
2: [2, 3], | ||
3: [2, 3], | ||
} | ||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) | ||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) | ||
rank = dist.get_rank() | ||
|
||
# check stage info | ||
assert stage_manager.num_stages == PP_SIZE | ||
assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM] | ||
|
||
# check is_first_stage | ||
ranks_in_group = PP_RANKS_IN_GROUP[rank] | ||
is_first_stage = ranks_in_group.index(rank) == 0 | ||
assert stage_manager.is_first_stage() == is_first_stage | ||
|
||
# check is_last_stage | ||
is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1 | ||
assert stage_manager.is_last_stage() == is_last_stage | ||
|
||
# check prev rank | ||
if not is_first_stage: | ||
prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1] | ||
assert stage_manager.get_prev_rank() == prev_rank | ||
|
||
# check next rank | ||
if not is_last_stage: | ||
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] | ||
assert stage_manager.get_next_rank() == next_rank | ||
|
||
# check virtual stage | ||
stage_manager.set_num_virtual_stages(PP_SIZE * 2) | ||
assert stage_manager.num_virtual_stages == PP_SIZE * 2 | ||
stage_manager.set_virtual_stage(stage_manager.stage * 2) | ||
assert stage_manager.virtual_stage == stage_manager.stage * 2 | ||
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1): | ||
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1 | ||
assert stage_manager.virtual_stage == stage_manager.stage * 2 | ||
|
||
# check p2p groups | ||
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): | ||
if rank in [prev, cur]: | ||
group = stage_manager.get_p2p_process_group(prev, cur) | ||
dist.barrier(group=group) | ||
|
||
# check stage groups | ||
pg_mesh = ProcessGroupMesh(4) | ||
stage_manager = PipelineStageManager(pg_mesh, 0) | ||
group = stage_manager.init_process_group_by_stages([0, 2]) | ||
if rank in [0, 2]: | ||
dist.barrier(group=group) | ||
|
||
|
||
def run_dist(rank, world_size, port): | ||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') | ||
check_stage_manager() | ||
|
||
|
||
@pytest.mark.dist | ||
def test_process_group_mesh(): | ||
spawn(run_dist, 4) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_process_group_mesh() |