Skip to content

Commit

Permalink
simplify forward/backward logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 9, 2024
1 parent b6b2333 commit d3831b4
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 152 deletions.
9 changes: 0 additions & 9 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,14 +1118,6 @@ def __init__(
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
# According to https://github.com/InternLM/InternEvo/blob/a53a4ff4fc45761f80d7fe8e9188bc2e02d487fc/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L405
# and https://zhuanlan.zhihu.com/p/706805407
# using a different proc group may put p2p comm on a new
# NCCL stream :)
dkv_group = None
if sequence_parallelism_mode == "ring_attn":
sp_ranks = dist.get_process_group_ranks(self.sp_group)
dkv_group = dist.new_group(ranks=sp_ranks)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
Expand All @@ -1147,7 +1139,6 @@ def __init__(
if enable_sequence_parallelism and sequence_parallelism_mode == "ring_attn"
else None
),
dkv_group=dkv_group,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
203 changes: 76 additions & 127 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_pad_info(
"""Get padding information from padding mask.
Args:
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Sq]
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv]
invert (Optional[bool], optional): Whether to reverse the padding mask.
return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens.
Expand Down Expand Up @@ -342,7 +342,9 @@ def _load_flash_attn():
_load_varlen_helpers()


@torch.compile
# NOTE: This can cause spawned processes to hang on exit
# with python 3.9
@torch.compile()
def _rescale_out_lse(out, block_out, lse, block_lse):
"""
Compute the new attention denominator:
Expand Down Expand Up @@ -388,6 +390,7 @@ class RingAttention(torch.autograd.Function):
HALF_INDICES: Tuple = None
SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
ATTN_DONE: torch.cuda.Event = None
DKV_GROUP: dist.ProcessGroup = None

@staticmethod
def attention(
Expand Down Expand Up @@ -434,6 +437,7 @@ def attention(
softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp).
Shape should be [total_q_seqlen, nHeads]
"""
# Check input args
_load_flash_attn()
if RingAttention.ATTN_DONE is None:
RingAttention.ATTN_DONE = torch.cuda.Event()
Expand All @@ -444,6 +448,14 @@ def attention(
assert (
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet."
if dkv_group is None:
if RingAttention.DKV_GROUP is None or dist.get_process_group_ranks(
sp_group
) != dist.get_process_group_ranks(RingAttention.DKV_GROUP):
ranks = dist.get_process_group_ranks(sp_group)
RingAttention.DKV_GROUP = dkv_group = dist.new_group(ranks)
else:
dkv_group = RingAttention.DKV_GROUP

# (B, H, Sq, D) -> (B, Sq, H, D)
q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)]
Expand Down Expand Up @@ -529,10 +541,6 @@ def forward(
"return_softmax": False,
}

# For Flash Attn, indexing blocks of contiguous mem has the same perf
# as indexing one big contiguous block.
# Also the former avoids frequent mem copies, e.g. when indexing
# half of the seq dim and reshaping
if (
RingAttention.HALF_INDICES is not None
and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape
Expand All @@ -550,7 +558,8 @@ def forward(
else:
b, sq, h, d = q.shape
t = b * sq
q, k, v = [x.view(t, h, d) for x in (q, k, v)]
# Be careful about GQA/MQA in reshape
q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)]

kv_comms = [RingComm(sp_group) for _ in range(2)]
sp_size = kv_comms[0].world_size
Expand All @@ -573,6 +582,29 @@ def forward(
rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream]

def _forward(q, k, v, causal):
(
_,
_,
_,
_,
out,
softmax_lse,
_,
rng_state,
) = _flash_attn_forward(
q,
k,
v,
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
causal=causal,
**misc_kwargs,
)
return out, softmax_lse, rng_state

# Overlap output correction with next flash attn
for i in range(sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
Expand All @@ -592,25 +624,8 @@ def forward(
# Compute with local KV; no mask
kv_block = kv_buffers[0]
q_block = q
(
_,
_,
_,
_,
block_out[i % 2], # (B * Sq, H, D)
block_softmax_lse[i % 2], # (H, total_q_seqlen)
_,
rng_states[i],
) = _flash_attn_forward(
q_block,
kv_block[0],
kv_block[1],
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
causal=True,
**misc_kwargs,
(block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T)
q_block, kv_block[0], kv_block[1], causal=True
)
elif i <= sp_rank:
# Received the "surrounding" kv chunks
Expand All @@ -619,61 +634,28 @@ def forward(
kv_block = kv_buffers[i % 2][:, half_idx_front]
q_block = q
(
_,
_,
_,
_,
block_out[i % 2], # (B * Sq, H, D)
block_softmax_lse[i % 2], # (H, total_q_seqlen)
_,
block_out[i % 2], # (T, H, D)
block_softmax_lse[i % 2], # (H, T)
rng_states[i],
) = _flash_attn_forward(
q_block,
kv_block[0],
kv_block[1],
cu_seqlens_q,
cu_seqlens_half,
max_seqlen_q,
max_seqlen_half,
causal=False,
**misc_kwargs,
)

) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
else:
# Received the inner kv chunks
# Drop the first half of q
kv_block = kv_buffers[i % 2]
q_block = q1

(
_,
_,
_,
_,
block_out[i % 2], # (B * Sq // 2, H, D)
block_softmax_lse[i % 2], # (H, total_q_seqlen)
_,
block_out[i % 2], # (T, H, D)
block_softmax_lse[i % 2], # (H, T)
rng_states[i],
) = _flash_attn_forward(
q_block,
kv_block[0],
kv_block[1],
cu_seqlens_half,
cu_seqlens_kv,
max_seqlen_half,
max_seqlen_kv,
causal=False,
**misc_kwargs,
)
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()

block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction
# Overlap output correction with next flash attn kernel
# In reality this always finishes before next flash attn
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
# In reality this always finishes before next flash attn; no need for extra sync.
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
Expand All @@ -683,12 +665,12 @@ def forward(
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
)
torch.cuda.current_stream().wait_stream(sp_stream)

# torch.cuda.current_stream().wait_stream(sp_stream)
out = out.to(q.dtype)
if not is_packed:
out = out.view(b, sq, h, d)
q, k, v = [x.view(b, sq, h, d) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D)
q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D)
softmax_lse = softmax_lse.squeeze(-1)

ctx.sp_group = sp_group
Expand Down Expand Up @@ -743,7 +725,7 @@ def backward(ctx, dout, _):
else:
b, sq, h, d = q.shape
t = b * sq
q, k, v, out, dout = [x.view(t, h, d) for x in (q, k, v, out, dout)]
q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)]

# Sequence parallel args
sp_group = ctx.sp_group
Expand Down Expand Up @@ -773,6 +755,26 @@ def backward(ctx, dout, _):
dkv_send = dkv_recv = None
del k, v

def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half,
cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half,
max_seqlen_q if dq.shape[0] == t else max_seqlen_half,
max_seqlen_kv if dk.shape[0] == t else max_seqlen_half,
causal=causal,
rng_state=rng_state,
**misc_kwargs,
)

# NOTE: We avoid using two streams since it requires doubling dkv and kv buffers,
# and backward is more communication intensive than forward
for i in range(sp_size):
Expand All @@ -788,76 +790,23 @@ def backward(ctx, dout, _):
k_, v_ = kv_buffers[i % 2]
q_, dout_, out_ = q, dout, out
dq_, dk_, dv_ = dq_block, dk_block, dv_block
_flash_attn_backward(
dout_,
q_,
k_,
v_,
out_,
softmax_lse,
dq_,
dk_,
dv_,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
causal=True,
rng_state=rng_states[i],
**misc_kwargs,
)
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True)

elif i <= sp_rank:
# Drop the second half of kv
# (T, H, D) -> (T // 2, H, D)
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
dq_, q_, out_, dout_ = (dq_block, q, out, dout)

_flash_attn_backward(
dout_,
q_,
k_,
v_,
out_,
softmax_lse,
dq_,
dk_,
dv_,
cu_seqlens_q,
cu_seqlens_half,
max_seqlen_q,
max_seqlen_half,
causal=False,
rng_state=rng_states[i],
**misc_kwargs,
)
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False)

else:
# Drop the first half of q
k_, v_ = kv_buffers[i % 2]
dk_, dv_ = dk_block, dv_block
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
dq_ = dq_block[: t // 2]

_flash_attn_backward(
dout_,
q_,
k_,
v_,
out_,
softmax_lse1,
dq_,
dk_,
dv_,
cu_seqlens_half,
cu_seqlens_kv,
max_seqlen_half,
max_seqlen_kv,
causal=False,
rng_state=rng_states[i],
**misc_kwargs,
)
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False)

# Accumulate grads
dkv_send = dkv_buffers[i % 2]
Expand Down Expand Up @@ -889,7 +838,7 @@ def backward(ctx, dout, _):
dkv_recv = dkv_send
dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)]
if not is_packed:
dq, dk, dv = [x.view(b, sq, h, d) for x in (dq, dk, dv)]
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]

return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def forward(
position_ids = cache_position.unsqueeze(0)

if shard_config.enable_flash_attention:
mask_shape = (batch_size, 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
Expand Down
2 changes: 0 additions & 2 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class ShardConfig:
parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
For SP: set to True to NOT gather the output along the seq dim.
sp_stream (Optional[torch.cuda.Stream]): : The stream for ring attention output correction. Defaults to None.
dkv_group (Optional[ProcessGroup]): The process group for using a new NCCL stream in ring attention backward.
"""

tensor_parallel_process_group: Optional[ProcessGroup] = None
Expand All @@ -56,7 +55,6 @@ class ShardConfig:
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
sp_stream: Optional[torch.cuda.Stream] = None
dkv_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down
1 change: 0 additions & 1 deletion examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def empty_init():
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
Expand Down
Loading

0 comments on commit d3831b4

Please sign in to comment.