Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Zigzag Ring attention #5905

Merged
merged 37 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
519818f
halfway
Jun 28, 2024
04b14a2
fix cross-PP-stage position id length diff bug
Jun 28, 2024
45b9ac1
fix typo
Jun 29, 2024
3047c4e
fix typo
Jun 29, 2024
c0a5048
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2024
748b0a1
unified cross entropy func for all shardformer models
Jul 2, 2024
0262e6b
remove redundant lines
Jul 2, 2024
7dfdac1
add basic ring attn; debug cross entropy
Jul 8, 2024
a4d4e6a
fwd bwd logic complete
Jul 13, 2024
7a4e284
fwd bwd logic complete; add experimental triton rescale
Jul 14, 2024
f8be40d
precision tests passed
Jul 18, 2024
c3d7a86
precision tests passed
Jul 21, 2024
313bc48
fix typos and remove misc files
Jul 22, 2024
98627e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2024
a3bb451
add sp_mode to benchmark; fix varlen interface
Jul 22, 2024
b104530
update softmax_lse shape by new interface
Jul 22, 2024
6cbc5f6
change tester name
Jul 22, 2024
bf00238
remove buffer clone; support packed seq layout
Jul 23, 2024
bd2d642
add varlen tests
Jul 24, 2024
ea11927
fix typo
Jul 26, 2024
2f8e188
all tests passed
Aug 1, 2024
919eff5
add dkv_group; fix mask
Aug 1, 2024
04d2f88
remove debug statements
Aug 1, 2024
392bde6
add comments
Aug 2, 2024
e760507
q1 index only once
Aug 5, 2024
e90e984
remove events to simplify stream sync
Aug 6, 2024
e26c910
clarify kv_comm.wait()
Edenzzzz Aug 7, 2024
b6b2333
use torch.compile; add nsys
Aug 9, 2024
d3831b4
simplify forward/backward logic
Aug 9, 2024
0094bc0
2d ring forward passed
Aug 12, 2024
581ec0f
2d ring backward passed
Aug 13, 2024
1344849
fixes
Aug 14, 2024
e6bcde2
fix ring attn loss
Aug 14, 2024
b4c0809
2D ring backward + llama passed
Aug 14, 2024
26b008e
follow conventions
Aug 15, 2024
a68dd2f
fix dist logger
Aug 15, 2024
be5fed5
add a manual inner ring size option
Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repos:
hooks:
- id: isort
name: sort all imports (python)
args: ["--profile", "black"] # avoid comflict with black
Edenzzzz marked this conversation as resolved.
Show resolved Hide resolved

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
Expand Down
41 changes: 28 additions & 13 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor
Expand All @@ -42,7 +42,7 @@

from .pp_plugin_base import PipelinePluginBase

SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]

PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.dp_group = dp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.use_dpp = use_ddp
self.use_ddp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather

Expand Down Expand Up @@ -139,8 +139,8 @@ def no_sync(self):
# Disable automatic gradient synchronization.
self.require_grad_sync = False
try:
if self.use_dpp:
# If using data parallel processing (use_dpp), disable synchronization too.
if self.use_ddp:
# If using data parallel processing (use_ddp), disable synchronization too.
with self.module.no_sync():
yield
else:
Expand Down Expand Up @@ -188,7 +188,7 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
"""

if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode == "all_to_all":
if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
return

if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
Expand Down Expand Up @@ -1041,9 +1041,11 @@ def __init__(
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]:
elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
if self.sequence_parallelism_mode == "ring_attn":
enable_flash_attention = True
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
Expand Down Expand Up @@ -1116,6 +1118,14 @@ def __init__(
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
# According to https://github.com/InternLM/InternEvo/blob/a53a4ff4fc45761f80d7fe8e9188bc2e02d487fc/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L405
# and https://zhuanlan.zhihu.com/p/706805407
# using a different proc group may put p2p comm on a new
# NCCL stream :)
dkv_group = None
if sequence_parallelism_mode == "ring_attn":
sp_ranks = dist.get_process_group_ranks(self.sp_group)
dkv_group = dist.new_group(ranks=sp_ranks)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
Expand All @@ -1132,6 +1142,12 @@ def __init__(
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
sp_stream=(
torch.cuda.Stream()
if enable_sequence_parallelism and sequence_parallelism_mode == "ring_attn"
else None
),
dkv_group=dkv_group,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down Expand Up @@ -1216,14 +1232,13 @@ def configure(
zero_stage = 0

if not isinstance(model, ModelWrapper):
# Shouldn't use pp (frequent grad accumulation) with torch ddp
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":

# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
Expand Down
4 changes: 0 additions & 4 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def save_sharded_model(
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)

# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
Expand Down Expand Up @@ -643,14 +642,12 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()

if self.dp_rank != 0:
return

# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict()

if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
Expand All @@ -660,7 +657,6 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
state_dict_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)

# Only the master rank do the saving.
if self.coordinator.is_master():
complete_state_dict = dict()
Expand Down
4 changes: 0 additions & 4 deletions colossalai/lazy/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def new_from_pretrained(
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
Expand Down Expand Up @@ -116,7 +115,6 @@ def new_from_pretrained(
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
Expand Down Expand Up @@ -195,7 +193,6 @@ def new_from_pretrained(
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
ver217 marked this conversation as resolved.
Show resolved Hide resolved
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
Expand Down Expand Up @@ -312,7 +309,6 @@ def new_from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/moe/openmoe/model/openmoe_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def module_policy(self):
policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
# TODO: recursively assign ep group foe all modules
new_item = {
OpenMoeForCausalLM: ModulePolicyDescription(
Expand Down
3 changes: 3 additions & 0 deletions colossalai/legacy/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def backward(ctx, grad_output):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
# TODO: This seems to only work if you add torch.cuda.Event.wait()

# _ = torch.zeros(1, device=grad_output.device)

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def forward_step(
# for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used

with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
Expand Down
1 change: 1 addition & 0 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def forward_step(
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches

if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
Expand Down Expand Up @@ -31,5 +31,7 @@
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"RingAttention",
"get_pad_info",
"all_to_all_comm",
]
38 changes: 28 additions & 10 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.distributed as dist
import torch.nn.functional as F

from .utils import is_share_sp_tp

try:
import fused_mix_prec_layer_norm_cuda
except:
Expand Down Expand Up @@ -93,7 +95,7 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

grad_weight = total_input.t().matmul(grad_output)
Expand Down Expand Up @@ -143,7 +145,9 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
_ = torch.zeros(1, device=grad_input.device)

# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
Expand Down Expand Up @@ -331,7 +335,7 @@ def backward(ctx, grad_output):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
Expand Down Expand Up @@ -646,8 +650,8 @@ def backward(ctx, grad_output):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated

grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
Expand Down Expand Up @@ -721,16 +725,20 @@ class _ReduceForward(torch.autograd.Function):
Args:
input_: input matrix.
parallel_mode: parallel mode.
process_group: communication group.
"""

@staticmethod
def forward(ctx, input_, process_group):
def forward(ctx, input_, process_group, grad_scale=None):
ctx.grad_scale = grad_scale
return _reduce(input_, process_group)

@staticmethod
def backward(ctx, grad_output):
return grad_output, None
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return grad_output, None, None


class _ReduceBackward(torch.autograd.Function):
Expand Down Expand Up @@ -979,8 +987,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)


def reduce_forward(input_, process_group):
return _ReduceForward.apply(input_, process_group)
def reduce_forward(input_, process_group, grad_scale=None):
return _ReduceForward.apply(input_, process_group, grad_scale)


def reduce_backward(input_, process_group):
Expand All @@ -989,3 +997,13 @@ def reduce_backward(input_, process_group):

def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)


def gather_sp_output(hidden_states, sp_group, sp_mode):
"""
Gather the output of the last layer for cross entropy computation
"""
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
return hidden_states
Loading
Loading