Skip to content

Commit

Permalink
[legacy] move communication to legacy (#4640)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Sep 6, 2023
1 parent efba0f4 commit 0336d39
Show file tree
Hide file tree
Showing 22 changed files with 52 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
recv_forward, recv_backward)
from .collective import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from .p2p import (
recv_backward,
recv_forward,
send_backward,
send_backward_recv_backward,
send_backward_recv_forward,
send_forward,
send_forward_backward_recv_forward_backward,
send_forward_recv_backward,
send_forward_recv_forward,
)
from .ring import ring_forward
from .utils import send_obj_meta, recv_obj_meta
from .utils import recv_obj_meta, send_obj_meta

__all__ = [
'all_gather',
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion colossalai/legacy/engine/schedule/_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch.cuda

import colossalai.communication as comm
import colossalai.legacy.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
Expand Down
6 changes: 3 additions & 3 deletions colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import torch.cuda

import colossalai.communication.p2p_v2 as comm
from colossalai import engine
import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device

from ._pipeline_schedule import PipelineSchedule
Expand Down Expand Up @@ -60,7 +60,7 @@ def data_process_func(stage_output, dataloader_output):
"""

def forward_backward_step(self,
engine: engine.Engine,
engine: Engine,
data_iter: Iterable,
forward_only=False,
return_loss=True,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/trainer/hooks/_metric_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch
import torch.distributed as dist

from colossalai.communication import all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.communication import all_reduce
from colossalai.legacy.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage

Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/layer/parallel_1d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from torch import Tensor
from torch.nn.parameter import Parameter

from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.utils.checkpointing import (
Expand Down
21 changes: 11 additions & 10 deletions colossalai/nn/layer/parallel_2d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import torch
import torch.distributed as dist
from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce, reduce_scatter
from colossalai.utils import get_current_device


def matmul_2d(
Expand Down Expand Up @@ -226,9 +227,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)

src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size

opa = [None] * 2
opb = [None] * 2
Expand Down Expand Up @@ -351,9 +352,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)

src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size

opb = [None] * 2
opr = [None] * 2
Expand Down Expand Up @@ -484,9 +485,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)

src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size

opa = [None] * 2
opr = [None] * 2
Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/layer/parallel_2d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torch import Tensor
from torch.nn import Parameter

from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication import broadcast
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
Expand Down
7 changes: 4 additions & 3 deletions colossalai/nn/layer/parallel_2p5d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import torch
import torch.distributed as dist
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.utils import get_current_device
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd


def get_parallel_group(parallel_mode: ParallelMode):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/layer/parallel_2p5d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torch import Tensor
from torch.nn import Parameter

from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication import broadcast
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.utils.checkpointing import (
Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/layer/parallel_3d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd

from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter

from ._utils import get_parallel_mode_from_env, push_async_grad

Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/layer/parallel_3d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from torch import Tensor
from torch.nn import Parameter

from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication import all_reduce, broadcast
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
Expand Down
6 changes: 3 additions & 3 deletions colossalai/nn/layer/parallel_sequence/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import torch
from torch import distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd

from colossalai.communication import ring_forward
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
from colossalai.legacy.communication import ring_forward
from colossalai.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd


class RingQK(torch.autograd.Function):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import torch

from colossalai.communication.p2p_v2 import _recv_object, _send_object
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.legacy.communication.p2p_v2 import _recv_object, _send_object
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch
import torch.distributed as dist

from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import pytest
import torch

from colossalai.communication.p2p import (
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.legacy.communication.p2p import (
recv_backward,
recv_forward,
send_backward,
send_backward_recv_forward,
send_forward,
send_forward_recv_backward,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(parallel=dict(pipeline=2))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import torch

from colossalai.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.legacy.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn

Expand Down
8 changes: 4 additions & 4 deletions tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import torch.distributed as dist

from colossalai.communication import (
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.legacy.communication import (
recv_backward,
recv_forward,
recv_obj_meta,
Expand All @@ -15,9 +18,6 @@
send_forward_recv_backward,
send_obj_meta,
)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
Expand Down

0 comments on commit 0336d39

Please sign in to comment.