Skip to content

Commit

Permalink
[pipeline] add stage manager (hpcaitech#4093)
Browse files Browse the repository at this point in the history
* [pipeline] add stage manager

* [test] add pipeline stage manager test

* [pipeline] add docstring for stage manager
  • Loading branch information
ver217 committed Jul 4, 2023
1 parent 3be0c35 commit 18c7539
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 0 deletions.
176 changes: 176 additions & 0 deletions colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple

import torch.distributed as dist
from torch.distributed import ProcessGroup

from colossalai.cluster import ProcessGroupMesh


class PipelineStageManager:
"""PipelineStageManager is a helper class to manage pipeline stages.
Args:
pg_mesh (ProcessGroupMesh): Process group mesh.
pipeline_axis (int): The axis along which the pipeline is constructed.
Attributes:
num_stages (int): Number of stages in the pipeline.
stage (int): The current stage.
num_virtual_stages (int): Number of virtual stages in the pipeline.
virtual_stage (int): The current virtual stage.
"""

def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None:
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.num_virtual_stages: Optional[int] = None
self.virtual_stage: Optional[int] = None
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord
coord = self.pg_mesh.coordinate()
if self.stage > 0:
prev_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape)
if self.stage < self.num_stages - 1:
next_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape)

# init p2p process groups
stages = list(range(self.num_stages))
for prev, cur in zip(stages[:-1], stages[1:]):
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
if self.stage in [prev, cur]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group

def is_first_stage(self, virtual: bool = False) -> bool:
"""Is the current stage the first stage.
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
Returns:
bool: Whether the current stage is the first stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == 0
return self.stage == 0

def is_last_stage(self, virtual: bool = False) -> bool:
"""Is the current stage the last stage.
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
Returns:
bool: Whether the current stage is the last stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == self.num_virtual_stages - 1
return self.stage == self.num_stages - 1

@property
def num_stages(self) -> int:
"""Number of stages in the pipeline.
Returns:
int: Number of stages in the pipeline.
"""
return self.pg_mesh.size(self.pipeline_axis)

@property
def stage(self) -> int:
"""Current stage.
Returns:
int: Current stage.
"""
return self.pg_mesh.coordinate(self.pipeline_axis)

def get_rank(self) -> int:
"""Get the rank of the current process.
Returns:
int: Rank of the current process.
"""
return dist.get_rank()

def get_prev_rank(self) -> int:
"""Get the rank of the previous stage.
Returns:
int: Rank of the previous stage.
"""
assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
return self.prev_rank

def get_next_rank(self) -> int:
"""Get the rank of the next stage.
Returns:
int: Rank of the next stage.
"""
assert not self.is_last_stage(), "Cannot get next rank in the last stage."
return self.next_rank

def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
"""Set the number of virtual stages.
Args:
num_virtual_stages (int): Number of virtual stages.
"""
self.num_virtual_stages = num_virtual_stages

def set_virtual_stage(self, virtual_stage: int) -> None:
"""Set the virtual stage.
Args:
virtual_stage (int): Virtual stage.
"""
self.virtual_stage = virtual_stage

@contextmanager
def switch_virtual_stage(self, virtual_stage: int) -> None:
"""A context manager to switch virtual stage.
Args:
virtual_stage (int): Target virtual stage.
"""
old_stage = self.virtual_stage
try:
self.set_virtual_stage(virtual_stage)
yield
finally:
self.set_virtual_stage(old_stage)

def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
Args:
first_rank (int): The first rank.
second_rank (int): The second rank.
Returns:
ProcessGroup: P2P process group between the two ranks.
"""
if first_rank > second_rank:
first_rank, second_rank = second_rank, first_rank
return self.p2p_groups[(first_rank, second_rank)]

def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
"""Get the process group of the given stages.
Args:
stages (List[int]): List of stages.
Returns:
ProcessGroup: Process group of the given stages.
"""
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
86 changes: 86 additions & 0 deletions tests/test_pipeline/test_stage_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import torch.distributed as dist

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import spawn


def check_stage_manager():
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()

# check stage info
assert stage_manager.num_stages == PP_SIZE
assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM]

# check is_first_stage
ranks_in_group = PP_RANKS_IN_GROUP[rank]
is_first_stage = ranks_in_group.index(rank) == 0
assert stage_manager.is_first_stage() == is_first_stage

# check is_last_stage
is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1
assert stage_manager.is_last_stage() == is_last_stage

# check prev rank
if not is_first_stage:
prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1]
assert stage_manager.get_prev_rank() == prev_rank

# check next rank
if not is_last_stage:
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
assert stage_manager.get_next_rank() == next_rank

# check virtual stage
stage_manager.set_num_virtual_stages(PP_SIZE * 2)
assert stage_manager.num_virtual_stages == PP_SIZE * 2
stage_manager.set_virtual_stage(stage_manager.stage * 2)
assert stage_manager.virtual_stage == stage_manager.stage * 2
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
assert stage_manager.virtual_stage == stage_manager.stage * 2

# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
group = stage_manager.get_p2p_process_group(prev, cur)
dist.barrier(group=group)

# check stage groups
pg_mesh = ProcessGroupMesh(4)
stage_manager = PipelineStageManager(pg_mesh, 0)
group = stage_manager.init_process_group_by_stages([0, 2])
if rank in [0, 2]:
dist.barrier(group=group)


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_stage_manager()


@pytest.mark.dist
def test_process_group_mesh():
spawn(run_dist, 4)


if __name__ == '__main__':
test_process_group_mesh()

0 comments on commit 18c7539

Please sign in to comment.