Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipeline] add stage manager #4093

Merged
merged 3 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple

import torch.distributed as dist
from torch.distributed import ProcessGroup

from colossalai.cluster import ProcessGroupMesh


class PipelineStageManager:
"""PipelineStageManager is a helper class to manage pipeline stages.

Args:
pg_mesh (ProcessGroupMesh): Process group mesh.
pipeline_axis (int): The axis along which the pipeline is constructed.

Attributes:
num_stages (int): Number of stages in the pipeline.
stage (int): The current stage.
num_virtual_stages (int): Number of virtual stages in the pipeline.
virtual_stage (int): The current virtual stage.
"""

def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None:
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.num_virtual_stages: Optional[int] = None
self.virtual_stage: Optional[int] = None
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord
coord = self.pg_mesh.coordinate()
if self.stage > 0:
prev_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape)
if self.stage < self.num_stages - 1:
next_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape)

# init p2p process groups
stages = list(range(self.num_stages))
for prev, cur in zip(stages[:-1], stages[1:]):
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
if self.stage in [prev, cur]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group

def is_first_stage(self, virtual: bool = False) -> bool:
"""Is the current stage the first stage.

Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.

Returns:
bool: Whether the current stage is the first stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == 0
return self.stage == 0

def is_last_stage(self, virtual: bool = False) -> bool:
"""Is the current stage the last stage.

Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.

Returns:
bool: Whether the current stage is the last stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == self.num_virtual_stages - 1
return self.stage == self.num_stages - 1

@property
def num_stages(self) -> int:
"""Number of stages in the pipeline.

Returns:
int: Number of stages in the pipeline.
"""
return self.pg_mesh.size(self.pipeline_axis)

@property
def stage(self) -> int:
"""Current stage.

Returns:
int: Current stage.
"""
return self.pg_mesh.coordinate(self.pipeline_axis)

def get_rank(self) -> int:
"""Get the rank of the current process.

Returns:
int: Rank of the current process.
"""
return dist.get_rank()

def get_prev_rank(self) -> int:
"""Get the rank of the previous stage.

Returns:
int: Rank of the previous stage.
"""
assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
return self.prev_rank

def get_next_rank(self) -> int:
"""Get the rank of the next stage.

Returns:
int: Rank of the next stage.
"""
assert not self.is_last_stage(), "Cannot get next rank in the last stage."
return self.next_rank

def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
"""Set the number of virtual stages.

Args:
num_virtual_stages (int): Number of virtual stages.
"""
self.num_virtual_stages = num_virtual_stages

def set_virtual_stage(self, virtual_stage: int) -> None:
"""Set the virtual stage.

Args:
virtual_stage (int): Virtual stage.
"""
self.virtual_stage = virtual_stage

@contextmanager
def switch_virtual_stage(self, virtual_stage: int) -> None:
"""A context manager to switch virtual stage.

Args:
virtual_stage (int): Target virtual stage.
"""
old_stage = self.virtual_stage
try:
self.set_virtual_stage(virtual_stage)
yield
finally:
self.set_virtual_stage(old_stage)

def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.

Args:
first_rank (int): The first rank.
second_rank (int): The second rank.

Returns:
ProcessGroup: P2P process group between the two ranks.
"""
if first_rank > second_rank:
first_rank, second_rank = second_rank, first_rank
return self.p2p_groups[(first_rank, second_rank)]

def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
"""Get the process group of the given stages.

Args:
stages (List[int]): List of stages.

Returns:
ProcessGroup: Process group of the given stages.
"""
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
86 changes: 86 additions & 0 deletions tests/test_pipeline/test_stage_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import torch.distributed as dist

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import spawn


def check_stage_manager():
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()

# check stage info
assert stage_manager.num_stages == PP_SIZE
assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM]

# check is_first_stage
ranks_in_group = PP_RANKS_IN_GROUP[rank]
is_first_stage = ranks_in_group.index(rank) == 0
assert stage_manager.is_first_stage() == is_first_stage

# check is_last_stage
is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1
assert stage_manager.is_last_stage() == is_last_stage

# check prev rank
if not is_first_stage:
prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1]
assert stage_manager.get_prev_rank() == prev_rank

# check next rank
if not is_last_stage:
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
assert stage_manager.get_next_rank() == next_rank

# check virtual stage
stage_manager.set_num_virtual_stages(PP_SIZE * 2)
assert stage_manager.num_virtual_stages == PP_SIZE * 2
stage_manager.set_virtual_stage(stage_manager.stage * 2)
assert stage_manager.virtual_stage == stage_manager.stage * 2
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
assert stage_manager.virtual_stage == stage_manager.stage * 2

# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
group = stage_manager.get_p2p_process_group(prev, cur)
dist.barrier(group=group)

# check stage groups
pg_mesh = ProcessGroupMesh(4)
stage_manager = PipelineStageManager(pg_mesh, 0)
group = stage_manager.init_process_group_by_stages([0, 2])
if rank in [0, 2]:
dist.barrier(group=group)


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_stage_manager()


@pytest.mark.dist
def test_process_group_mesh():
spawn(run_dist, 4)


if __name__ == '__main__':
test_process_group_mesh()
Loading