Skip to content

Commit

Permalink
[cluster] add process group mesh (#4039)
Browse files Browse the repository at this point in the history
* [cluster] add process group mesh

* [test] add process group mesh test

* force sync
  • Loading branch information
ver217 authored Jun 20, 2023
1 parent 4a81faa commit 1015f04
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 1 deletion.
3 changes: 2 additions & 1 deletion colossalai/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .device_mesh_manager import DeviceMeshManager
from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager
from .process_group_mesh import ProcessGroupMesh

__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh']
203 changes: 203 additions & 0 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import itertools
from functools import reduce
from operator import mul
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup


def prod(nums: List[int]) -> int:
"""Product of a list of numbers.
Args:
nums (List[int]): A list of numbers.
Returns:
int: The product of the numbers.
"""
return reduce(mul, nums)


class ProcessGroupMesh:
"""A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
Args:
*size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
Attributes:
shape (Tuple[int, ...]): The shape of the process group mesh.
rank (int): The rank of the current process.
"""

def __init__(self, *size: int) -> None:
assert dist.is_initialized(), "Please initialize torch.distributed first."
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}

@property
def shape(self) -> Tuple[int, ...]:
return self._shape

@property
def rank(self) -> int:
return self._rank

def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the size of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._shape
else:
return self._shape[dim]

def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the coordinate of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._coord
else:
return self._coord[dim]

@staticmethod
def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""Convert a rank to a coordinate.
Args:
rank (int): Rank to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
Returns:
Tuple[int, ...]: Coordinate of the rank.
"""
return np.unravel_index(rank, shape)

@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int:
"""Convert a coordinate to a rank.
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
Returns:
int: Rank of the coordinate.
"""
return np.ravel_multi_index(coord, shape)

def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Args:
ranks_in_group (List[int]): Ranks in the process group.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)]

def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
"""Get the ranks in the given process group. The process group must be created by this class.
Args:
group (ProcessGroup): The process group.
Returns:
List[int]: Ranks in the process group.
"""
return list(self._group_to_ranks[group])

@staticmethod
def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
indices_at_axis: List[int]) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
Args:
base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on.
axis (int): Axis along which the coordinates are generated.
indices_at_axis (List[int]): Indices at the axis.
Returns:
List[Tuple[int, ...]]: Coordinates along the axis.
"""
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:])
return coords_in_group

def create_group_along_axis(self,
axis: int,
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
reduced_shape[axis] = 1
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self.get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group

def get_group_along_axis(self,
axis: int,
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]
151 changes: 151 additions & 0 deletions tests/test_cluster/test_process_group_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import pytest
import torch.distributed as dist

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import spawn


def check_process_group_mesh_with_gpc():
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc

DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, 2, 2)

# check world size
assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size(
TP_DIM), f'{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}'
assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM)
assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM)

# check locak rank (coordinate)
assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate(
TP_DIM), f'{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}'
assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM)
assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM)

# check ranks in group
tp_group = pg_mesh.get_group_along_axis(TP_DIM)
assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group)
pp_group = pg_mesh.get_group_along_axis(PP_DIM)
assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group)
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group)

# check prev rank
coord = pg_mesh.coordinate()
if not gpc.is_first_rank(ParallelMode.TENSOR):
assert coord[TP_DIM] != 0
prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1:]
assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
if not gpc.is_first_rank(ParallelMode.PIPELINE):
assert coord[PP_DIM] != 0
prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1:]
assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape)

# check next rank
if not gpc.is_last_rank(ParallelMode.TENSOR):
assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1
next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1:]
assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1
next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1:]
assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape)


def check_process_group_mesh_with_cases():
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2
RANK_TO_COORDINATE = {
0: (0, 0, 0),
1: (0, 0, 1),
2: (0, 1, 0),
3: (0, 1, 1),
}
TP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
PP_RANKS_IN_GROUP = {
0: [0, 2],
1: [1, 3],
2: [0, 2],
3: [1, 3],
}
DP_RANKS_IN_GROUP = {
0: [0],
1: [1],
2: [2],
3: [3],
}

pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE)

rank = dist.get_rank()
assert rank == pg_mesh.rank

# check world size
assert pg_mesh.size(TP_DIM) == 2
assert pg_mesh.size(PP_DIM) == 2
assert pg_mesh.size(DP_DIM) == 1

# check coordinate
assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM]
assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM]
assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM]

# check ranks in group
tp_group = pg_mesh.get_group_along_axis(TP_DIM)
assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank]
pp_group = pg_mesh.get_group_along_axis(PP_DIM)
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank]
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank]

# check prev rank
if RANK_TO_COORDINATE[rank][TP_DIM] != 0:
prev_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + \
RANK_TO_COORDINATE[rank][TP_DIM + 1:]
prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1]
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank
if RANK_TO_COORDINATE[rank][PP_DIM] != 0:
prev_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + \
RANK_TO_COORDINATE[rank][PP_DIM + 1:]
prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1]
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank

# check next rank
if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1:
next_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + \
RANK_TO_COORDINATE[rank][TP_DIM + 1:]
next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1]
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank
if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1:
next_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + \
RANK_TO_COORDINATE[rank][PP_DIM + 1:]
next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1]
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank


def run_dist(rank, world_size, port):
colossalai.launch(config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode='1d', size=2))),
rank=rank,
world_size=world_size,
port=port,
host='localhost')
# TODO(ver217): this function should be removed when gpc is removed
check_process_group_mesh_with_gpc()
check_process_group_mesh_with_cases()


@pytest.mark.dist
def test_process_group_mesh():
spawn(run_dist, 4)


if __name__ == '__main__':
test_process_group_mesh()

0 comments on commit 1015f04

Please sign in to comment.