-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cluster] add process group mesh (#4039)
* [cluster] add process group mesh * [test] add process group mesh test * force sync
- Loading branch information
Showing
3 changed files
with
356 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .device_mesh_manager import DeviceMeshManager | ||
from .dist_coordinator import DistCoordinator | ||
from .process_group_manager import ProcessGroupManager | ||
from .process_group_mesh import ProcessGroupMesh | ||
|
||
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager'] | ||
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import itertools | ||
from functools import reduce | ||
from operator import mul | ||
from typing import Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
|
||
|
||
def prod(nums: List[int]) -> int: | ||
"""Product of a list of numbers. | ||
Args: | ||
nums (List[int]): A list of numbers. | ||
Returns: | ||
int: The product of the numbers. | ||
""" | ||
return reduce(mul, nums) | ||
|
||
|
||
class ProcessGroupMesh: | ||
"""A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method. | ||
It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation. | ||
We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process. | ||
For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``. | ||
Args: | ||
*size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size. | ||
Attributes: | ||
shape (Tuple[int, ...]): The shape of the process group mesh. | ||
rank (int): The rank of the current process. | ||
""" | ||
|
||
def __init__(self, *size: int) -> None: | ||
assert dist.is_initialized(), "Please initialize torch.distributed first." | ||
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size." | ||
self._shape = size | ||
self._rank = dist.get_rank() | ||
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) | ||
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} | ||
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} | ||
|
||
@property | ||
def shape(self) -> Tuple[int, ...]: | ||
return self._shape | ||
|
||
@property | ||
def rank(self) -> int: | ||
return self._rank | ||
|
||
def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: | ||
"""Get the size of the process group mesh. | ||
Args: | ||
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. | ||
Returns: | ||
Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh. | ||
""" | ||
if dim is None: | ||
return self._shape | ||
else: | ||
return self._shape[dim] | ||
|
||
def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: | ||
"""Get the coordinate of the process group mesh. | ||
Args: | ||
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. | ||
Returns: | ||
Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh. | ||
""" | ||
if dim is None: | ||
return self._coord | ||
else: | ||
return self._coord[dim] | ||
|
||
@staticmethod | ||
def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: | ||
"""Convert a rank to a coordinate. | ||
Args: | ||
rank (int): Rank to be converted. | ||
shape (Tuple[int, ...]): Shape of the process group mesh. | ||
Returns: | ||
Tuple[int, ...]: Coordinate of the rank. | ||
""" | ||
return np.unravel_index(rank, shape) | ||
|
||
@staticmethod | ||
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: | ||
"""Convert a coordinate to a rank. | ||
Args: | ||
coords (Tuple[int, ...]): Coordinate to be converted. | ||
shape (Tuple[int, ...]): Shape of the process group mesh. | ||
Returns: | ||
int: Rank of the coordinate. | ||
""" | ||
return np.ravel_multi_index(coord, shape) | ||
|
||
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: | ||
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created. | ||
Args: | ||
ranks_in_group (List[int]): Ranks in the process group. | ||
backend (Optional[str], optional): Backend of the process group. Defaults to None. | ||
Returns: | ||
ProcessGroup: The process group with the given ranks. | ||
""" | ||
ranks_in_group = sorted(ranks_in_group) | ||
if tuple(ranks_in_group) not in self._group_to_ranks: | ||
group = dist.new_group(ranks_in_group, backend=backend) | ||
self._ranks_to_group[tuple(ranks_in_group)] = group | ||
self._group_to_ranks[group] = tuple(ranks_in_group) | ||
return self._ranks_to_group[tuple(ranks_in_group)] | ||
|
||
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: | ||
"""Get the ranks in the given process group. The process group must be created by this class. | ||
Args: | ||
group (ProcessGroup): The process group. | ||
Returns: | ||
List[int]: Ranks in the process group. | ||
""" | ||
return list(self._group_to_ranks[group]) | ||
|
||
@staticmethod | ||
def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, | ||
indices_at_axis: List[int]) -> List[Tuple[int, ...]]: | ||
"""Get coordinates along the given axis. | ||
Args: | ||
base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on. | ||
axis (int): Axis along which the coordinates are generated. | ||
indices_at_axis (List[int]): Indices at the axis. | ||
Returns: | ||
List[Tuple[int, ...]]: Coordinates along the axis. | ||
""" | ||
coords_in_group = [] | ||
for idx in indices_at_axis: | ||
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:]) | ||
return coords_in_group | ||
|
||
def create_group_along_axis(self, | ||
axis: int, | ||
indices_at_axis: Optional[List[int]] = None, | ||
backend: Optional[str] = None) -> ProcessGroup: | ||
"""Create all process groups along the given axis, and return the one which the current process belongs to. | ||
Args: | ||
axis (int): Axis along which the process groups are created. | ||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. | ||
backend (Optional[str], optional): Backend of the process group. Defaults to None. | ||
Returns: | ||
ProcessGroup: The process group along the given axis which the current process belongs to. | ||
""" | ||
indices_at_axis = indices_at_axis or list(range(self._shape[axis])) | ||
reduced_shape = list(self._shape) | ||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` | ||
reduced_shape[axis] = 1 | ||
target_group = None | ||
# use Cartesian product to generate all combinations of coordinates | ||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]): | ||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) | ||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) | ||
group = self.get_group(ranks_in_group, backend=backend) | ||
if self._rank in ranks_in_group: | ||
target_group = group | ||
return target_group | ||
|
||
def get_group_along_axis(self, | ||
axis: int, | ||
indices_at_axis: Optional[List[int]] = None, | ||
backend: Optional[str] = None) -> ProcessGroup: | ||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. | ||
Args: | ||
axis (int): Axis along which the process groups are created. | ||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. | ||
backend (Optional[str], optional): Backend of the process group. Defaults to None. | ||
Returns: | ||
ProcessGroup: The process group along the given axis which the current process belongs to. | ||
""" | ||
indices_at_axis = indices_at_axis or list(range(self._shape[axis])) | ||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) | ||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) | ||
if ranks_in_group not in self._ranks_to_group: | ||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis` | ||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend) | ||
return self._ranks_to_group[ranks_in_group] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import pytest | ||
import torch.distributed as dist | ||
|
||
import colossalai | ||
from colossalai.cluster import ProcessGroupMesh | ||
from colossalai.testing import spawn | ||
|
||
|
||
def check_process_group_mesh_with_gpc(): | ||
from colossalai.context import ParallelMode | ||
from colossalai.core import global_context as gpc | ||
|
||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 | ||
pg_mesh = ProcessGroupMesh(1, 2, 2) | ||
|
||
# check world size | ||
assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( | ||
TP_DIM), f'{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}' | ||
assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) | ||
assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) | ||
|
||
# check locak rank (coordinate) | ||
assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( | ||
TP_DIM), f'{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}' | ||
assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) | ||
assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) | ||
|
||
# check ranks in group | ||
tp_group = pg_mesh.get_group_along_axis(TP_DIM) | ||
assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group) | ||
pp_group = pg_mesh.get_group_along_axis(PP_DIM) | ||
assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group) | ||
dp_group = pg_mesh.get_group_along_axis(DP_DIM) | ||
assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group) | ||
|
||
# check prev rank | ||
coord = pg_mesh.coordinate() | ||
if not gpc.is_first_rank(ParallelMode.TENSOR): | ||
assert coord[TP_DIM] != 0 | ||
prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1:] | ||
assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) | ||
if not gpc.is_first_rank(ParallelMode.PIPELINE): | ||
assert coord[PP_DIM] != 0 | ||
prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1:] | ||
assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) | ||
|
||
# check next rank | ||
if not gpc.is_last_rank(ParallelMode.TENSOR): | ||
assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 | ||
next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1:] | ||
assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) | ||
if not gpc.is_last_rank(ParallelMode.PIPELINE): | ||
assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 | ||
next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1:] | ||
assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) | ||
|
||
|
||
def check_process_group_mesh_with_cases(): | ||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 | ||
DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2 | ||
RANK_TO_COORDINATE = { | ||
0: (0, 0, 0), | ||
1: (0, 0, 1), | ||
2: (0, 1, 0), | ||
3: (0, 1, 1), | ||
} | ||
TP_RANKS_IN_GROUP = { | ||
0: [0, 1], | ||
1: [0, 1], | ||
2: [2, 3], | ||
3: [2, 3], | ||
} | ||
PP_RANKS_IN_GROUP = { | ||
0: [0, 2], | ||
1: [1, 3], | ||
2: [0, 2], | ||
3: [1, 3], | ||
} | ||
DP_RANKS_IN_GROUP = { | ||
0: [0], | ||
1: [1], | ||
2: [2], | ||
3: [3], | ||
} | ||
|
||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) | ||
|
||
rank = dist.get_rank() | ||
assert rank == pg_mesh.rank | ||
|
||
# check world size | ||
assert pg_mesh.size(TP_DIM) == 2 | ||
assert pg_mesh.size(PP_DIM) == 2 | ||
assert pg_mesh.size(DP_DIM) == 1 | ||
|
||
# check coordinate | ||
assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM] | ||
assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM] | ||
assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM] | ||
|
||
# check ranks in group | ||
tp_group = pg_mesh.get_group_along_axis(TP_DIM) | ||
assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank] | ||
pp_group = pg_mesh.get_group_along_axis(PP_DIM) | ||
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] | ||
dp_group = pg_mesh.get_group_along_axis(DP_DIM) | ||
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] | ||
|
||
# check prev rank | ||
if RANK_TO_COORDINATE[rank][TP_DIM] != 0: | ||
prev_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + \ | ||
RANK_TO_COORDINATE[rank][TP_DIM + 1:] | ||
prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1] | ||
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank | ||
if RANK_TO_COORDINATE[rank][PP_DIM] != 0: | ||
prev_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + \ | ||
RANK_TO_COORDINATE[rank][PP_DIM + 1:] | ||
prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1] | ||
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank | ||
|
||
# check next rank | ||
if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1: | ||
next_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + \ | ||
RANK_TO_COORDINATE[rank][TP_DIM + 1:] | ||
next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1] | ||
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank | ||
if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1: | ||
next_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + \ | ||
RANK_TO_COORDINATE[rank][PP_DIM + 1:] | ||
next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1] | ||
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank | ||
|
||
|
||
def run_dist(rank, world_size, port): | ||
colossalai.launch(config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode='1d', size=2))), | ||
rank=rank, | ||
world_size=world_size, | ||
port=port, | ||
host='localhost') | ||
# TODO(ver217): this function should be removed when gpc is removed | ||
check_process_group_mesh_with_gpc() | ||
check_process_group_mesh_with_cases() | ||
|
||
|
||
@pytest.mark.dist | ||
def test_process_group_mesh(): | ||
spawn(run_dist, 4) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_process_group_mesh() |