Skip to content

Commit

Permalink
[Feature] Zigzag Ring attention (#5905)
Browse files Browse the repository at this point in the history
* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 16, 2024
1 parent 887d2d5 commit f5c84af
Show file tree
Hide file tree
Showing 50 changed files with 1,853 additions and 309 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repos:
hooks:
- id: isort
name: sort all imports (python)
args: ["--profile", "black"] # avoid conflict with black

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
Expand Down
50 changes: 35 additions & 15 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor
Expand All @@ -42,7 +42,7 @@

from .pp_plugin_base import PipelinePluginBase

SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]

PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.dp_group = dp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.use_dpp = use_ddp
self.use_ddp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather

Expand Down Expand Up @@ -139,8 +139,8 @@ def no_sync(self):
# Disable automatic gradient synchronization.
self.require_grad_sync = False
try:
if self.use_dpp:
# If using data parallel processing (use_dpp), disable synchronization too.
if self.use_ddp:
# If using data parallel processing (use_ddp), disable synchronization too.
with self.module.no_sync():
yield
else:
Expand Down Expand Up @@ -188,7 +188,7 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
"""

if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode == "all_to_all":
if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
return

if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
Expand Down Expand Up @@ -970,6 +970,9 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
"""

def __init__(
Expand Down Expand Up @@ -1017,6 +1020,7 @@ def __init__(
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
inner_ring_size: int = None,
) -> None:
super().__init__()

Expand All @@ -1041,9 +1045,11 @@ def __init__(
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]:
elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
if self.sequence_parallelism_mode == "ring_attn":
enable_flash_attention = True
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
Expand All @@ -1063,10 +1069,21 @@ def __init__(
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
Expand Down Expand Up @@ -1108,6 +1125,8 @@ def __init__(
)
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
assert parallel_output, "Ring Attention doesn't support gathering output yet."

self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
Expand All @@ -1132,6 +1151,7 @@ def __init__(
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
inner_ring_size=inner_ring_size,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down Expand Up @@ -1216,15 +1236,15 @@ def configure(
zero_stage = 0

if not isinstance(model, ModelWrapper):
# Shouldn't use pp (frequent grad accumulation) with torch ddp
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":

# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule(
Expand Down
4 changes: 0 additions & 4 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def save_sharded_model(
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)

# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
Expand Down Expand Up @@ -643,14 +642,12 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()

if self.dp_rank != 0:
return

# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict()

if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
Expand All @@ -660,7 +657,6 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
state_dict_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)

# Only the master rank do the saving.
if self.coordinator.is_master():
complete_state_dict = dict()
Expand Down
4 changes: 0 additions & 4 deletions colossalai/lazy/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def new_from_pretrained(
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
Expand Down Expand Up @@ -116,7 +115,6 @@ def new_from_pretrained(
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
Expand Down Expand Up @@ -195,7 +193,6 @@ def new_from_pretrained(
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
Expand Down Expand Up @@ -312,7 +309,6 @@ def new_from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/moe/openmoe/model/openmoe_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def module_policy(self):
policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
# TODO: recursively assign ep group foe all modules
new_item = {
OpenMoeForCausalLM: ModulePolicyDescription(
Expand Down
3 changes: 3 additions & 0 deletions colossalai/legacy/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def backward(ctx, grad_output):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
# TODO: This seems to only work if you add torch.cuda.Event.wait()

# _ = torch.zeros(1, device=grad_output.device)

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
Expand Down
5 changes: 4 additions & 1 deletion colossalai/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def __init__(self, name):
self._logger.propagate = False

DistributedLogger.__instances[name] = self
self.rank = dist.get_rank() if dist.is_initialized() else 0

@property
def rank(self):
return dist.get_rank() if dist.is_initialized() else 0

@staticmethod
def __get_call_info():
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def forward_step(
# for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used

with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
Expand Down
1 change: 1 addition & 0 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def forward_step(
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches

if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
Expand Down Expand Up @@ -31,5 +31,7 @@
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"RingAttention",
"get_pad_info",
"all_to_all_comm",
]
38 changes: 28 additions & 10 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.distributed as dist
import torch.nn.functional as F

from .utils import is_share_sp_tp

try:
import fused_mix_prec_layer_norm_cuda
except:
Expand Down Expand Up @@ -93,7 +95,7 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

grad_weight = total_input.t().matmul(grad_output)
Expand Down Expand Up @@ -143,7 +145,9 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
_ = torch.zeros(1, device=grad_input.device)

# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
Expand Down Expand Up @@ -331,7 +335,7 @@ def backward(ctx, grad_output):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
Expand Down Expand Up @@ -646,8 +650,8 @@ def backward(ctx, grad_output):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated

grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
Expand Down Expand Up @@ -721,16 +725,20 @@ class _ReduceForward(torch.autograd.Function):
Args:
input_: input matrix.
parallel_mode: parallel mode.
process_group: communication group.
"""

@staticmethod
def forward(ctx, input_, process_group):
def forward(ctx, input_, process_group, grad_scale=None):
ctx.grad_scale = grad_scale
return _reduce(input_, process_group)

@staticmethod
def backward(ctx, grad_output):
return grad_output, None
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return grad_output, None, None


class _ReduceBackward(torch.autograd.Function):
Expand Down Expand Up @@ -979,8 +987,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)


def reduce_forward(input_, process_group):
return _ReduceForward.apply(input_, process_group)
def reduce_forward(input_, process_group, grad_scale=None):
return _ReduceForward.apply(input_, process_group, grad_scale)


def reduce_backward(input_, process_group):
Expand All @@ -989,3 +997,13 @@ def reduce_backward(input_, process_group):

def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)


def gather_sp_output(hidden_states, sp_group, sp_mode):
"""
Gather the output of the last layer for cross entropy computation
"""
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
return hidden_states
Loading

0 comments on commit f5c84af

Please sign in to comment.