forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pipeline] implement p2p communication (hpcaitech#4100)
* [pipeline] add p2p communication * [test] add p2p communication test * [test] add rerun decorator * [test] rename to avoid conflict
- Loading branch information
Showing
3 changed files
with
287 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
#!/usr/bin/env python | ||
# -*- encoding: utf-8 -*- | ||
|
||
import io | ||
import pickle | ||
from typing import Any, List, Optional, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
from torch.distributed import distributed_c10d as c10d | ||
|
||
from .stage_manager import PipelineStageManager | ||
|
||
_unpickler = pickle.Unpickler | ||
|
||
|
||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: | ||
"""transform tensor to object with unpickle. | ||
Info of the device in bytes stream will be modified into current device before unpickling | ||
Args: | ||
tensor (:class:`torch.tensor`): tensor to be unpickled | ||
tensor_size (:class:`torch.Size`): Size of the real info in bytes | ||
Returns: | ||
Any: object after unpickled | ||
""" | ||
buf = tensor.numpy().tobytes()[:tensor_size] | ||
if b'cuda' in buf: | ||
buf_array = bytearray(buf) | ||
device_index = torch.cuda.current_device() | ||
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index | ||
buf = bytes(buf_array) | ||
|
||
io_bytes = io.BytesIO(buf) | ||
byte_pickler = _unpickler(io_bytes) | ||
unpickle = byte_pickler.load() | ||
|
||
return unpickle | ||
|
||
|
||
def _broadcast_object_list(object_list: List[Any], | ||
src: int, | ||
group: ProcessGroup, | ||
device: Optional[Union[torch.device, str, int]] = None): | ||
"""This is a modified version of the broadcast_object_list in torch.distribution | ||
The only difference is that object will be move to correct device after unpickled. | ||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will | ||
be updated with data sent from rank src. | ||
Args: | ||
object_list (List[Any]): list of object to broadcast | ||
src (int): source rank to broadcast | ||
dst (int): dst rank to broadcast | ||
device (:class:`torch.device`): device to do broadcast. current device in default | ||
""" | ||
|
||
if c10d._rank_not_in_group(group): | ||
c10d._warn_not_in_group("broadcast_object_list") | ||
return | ||
|
||
my_rank = dist.get_rank() | ||
# Serialize object_list elements to tensors on src rank. | ||
if my_rank == src: | ||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) | ||
object_sizes_tensor = torch.cat(size_list) | ||
else: | ||
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) | ||
|
||
is_nccl_backend = c10d._check_for_nccl_backend(group) | ||
current_device = None | ||
|
||
if device is not None: | ||
if is_nccl_backend and device.type != "cuda": | ||
raise ValueError("device type must be cuda for nccl backend") | ||
current_device = device | ||
else: | ||
current_device = torch.device("cpu") | ||
if is_nccl_backend: | ||
current_device = torch.device("cuda", torch.cuda.current_device()) | ||
if is_nccl_backend: | ||
object_sizes_tensor = object_sizes_tensor.to(current_device) | ||
|
||
# Broadcast object sizes | ||
c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) | ||
|
||
# Concatenate and broadcast serialized object tensors | ||
if my_rank == src: | ||
object_tensor = torch.cat(tensor_list) | ||
else: | ||
object_tensor = torch.empty( # type: ignore[call-overload] | ||
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] | ||
dtype=torch.uint8, | ||
) | ||
|
||
if is_nccl_backend: | ||
object_tensor = object_tensor.to(current_device) | ||
|
||
c10d.broadcast(object_tensor, src=src, group=group, async_op=False) | ||
|
||
# Deserialize objects using their stored sizes. | ||
offset = 0 | ||
|
||
if my_rank != src: | ||
for i, obj_size in enumerate(object_sizes_tensor): | ||
obj_view = object_tensor[offset:offset + obj_size] | ||
obj_view = obj_view.type(torch.uint8) | ||
if obj_view.device != torch.device("cpu"): | ||
obj_view = obj_view.cpu() | ||
offset += obj_size | ||
# unpickle | ||
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) | ||
|
||
# unconsistence in device | ||
if isinstance(unpickle_object, | ||
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): | ||
unpickle_object = unpickle_object.cuda() | ||
|
||
object_list[i] = unpickle_object | ||
|
||
|
||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: | ||
"""send anything to dst rank | ||
Args: | ||
object (Any): object needed to be sent | ||
dst (int): rank of the destination | ||
Returns: | ||
None | ||
""" | ||
# then broadcast safely | ||
_broadcast_object_list([object], src, group) | ||
|
||
|
||
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: | ||
"""recv anything from src | ||
Args: | ||
src (int): source rank of data. local rank will receive data from src rank. | ||
Returns: | ||
Any: Object received from src. | ||
""" | ||
object_list = [None] | ||
_broadcast_object_list(object_list, src, group) | ||
|
||
return object_list[0] | ||
|
||
|
||
class PipelineP2PCommunication: | ||
|
||
def __init__(self, stage_manager: PipelineStageManager) -> None: | ||
self.stage_manager = stage_manager | ||
|
||
def recv_forward(self, prev_rank: int = None) -> Any: | ||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. | ||
Args: | ||
prev_rank (int, optional): The rank of the source of the tensor. | ||
Returns: | ||
Any: The input tensor or input tensor list. | ||
""" | ||
if self.stage_manager.is_first_stage(): | ||
input_tensor = None | ||
else: | ||
if prev_rank is None: | ||
prev_rank = self.stage_manager.get_prev_rank() | ||
cur_rank = self.stage_manager.get_rank() | ||
input_tensor = _recv_object(prev_rank, cur_rank, | ||
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) | ||
|
||
return input_tensor | ||
|
||
def recv_backward(self, next_rank: int = None) -> Any: | ||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. | ||
Args: | ||
next_rank (int, optional): The rank of the source of the tensor. | ||
Returns: | ||
Any: The input gradient tensor or gradient tensor list. | ||
""" | ||
if self.stage_manager.is_last_stage(): | ||
output_tensor_grad = None | ||
else: | ||
if next_rank is None: | ||
next_rank = self.stage_manager.get_next_rank() | ||
cur_rank = self.stage_manager.get_rank() | ||
output_tensor_grad = _recv_object(next_rank, cur_rank, | ||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) | ||
|
||
return output_tensor_grad | ||
|
||
def send_forward(self, output_object: Any, next_rank: int = None) -> None: | ||
"""Sends the input tensor to the next stage in pipeline. | ||
Args: | ||
output_object (Any): Object to be sent. | ||
next_rank (int, optional): The rank of the recipient of the tensor. | ||
""" | ||
if not self.stage_manager.is_last_stage(): | ||
if next_rank is None: | ||
next_rank = self.stage_manager.get_next_rank() | ||
cur_rank = self.stage_manager.get_rank() | ||
_send_object(output_object, cur_rank, next_rank, | ||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) | ||
|
||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: | ||
"""Sends the gradient tensor to the previous stage in pipeline. | ||
Args: | ||
input_object (Any): Object to be sent. | ||
prev_rank (int, optional): The rank of the recipient of the tensor | ||
""" | ||
if not self.stage_manager.is_first_stage(): | ||
if prev_rank is None: | ||
prev_rank = self.stage_manager.get_prev_rank() | ||
cur_rank = self.stage_manager.get_rank() | ||
_send_object(input_object, cur_rank, prev_rank, | ||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
|
||
import colossalai | ||
from colossalai.cluster import ProcessGroupMesh | ||
from colossalai.pipeline.p2p import PipelineP2PCommunication | ||
from colossalai.pipeline.stage_manager import PipelineStageManager | ||
from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
from colossalai.utils import get_current_device | ||
|
||
|
||
def check_p2p_communication(): | ||
pg_mesh = ProcessGroupMesh(2) | ||
stage_manager = PipelineStageManager(pg_mesh, 0) | ||
p2p = PipelineP2PCommunication(stage_manager) | ||
|
||
rank = dist.get_rank() | ||
|
||
tensor = torch.ones(1, device=get_current_device()) | ||
|
||
if rank == 0: | ||
p2p.send_forward(tensor) | ||
p2p.send_forward([tensor]) | ||
p2p.send_forward({'tensor': tensor}) | ||
else: | ||
obj = p2p.recv_forward() | ||
assert torch.equal(obj, tensor) | ||
obj = p2p.recv_forward() | ||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) | ||
obj = p2p.recv_forward() | ||
assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) | ||
|
||
if rank == 1: | ||
p2p.send_backward(tensor) | ||
p2p.send_backward([tensor]) | ||
p2p.send_backward({'tensor': tensor}) | ||
else: | ||
obj = p2p.recv_backward() | ||
assert torch.equal(obj, tensor) | ||
obj = p2p.recv_backward() | ||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) | ||
obj = p2p.recv_backward() | ||
assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) | ||
|
||
|
||
def run_dist(rank, world_size, port): | ||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') | ||
check_p2p_communication() | ||
|
||
|
||
@pytest.mark.dist | ||
@rerun_if_address_is_in_use() | ||
def test_pipeline_p2p(): | ||
spawn(run_dist, 2) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_pipeline_p2p() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters