Skip to content

Commit

Permalink
2D ring backward + llama passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 14, 2024
1 parent e6bcde2 commit c663265
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"

logger.info(
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Pray for the speed-up!",
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for the speed-up!",
ranks=[0],
)
num_rings = sp_size // inner_ring_size
Expand Down Expand Up @@ -1090,7 +1090,7 @@ def _other_ring_backward(ring_num_idx, dq):
if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]

return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None)
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)

@staticmethod
def prepare_varlen_batch(
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ def dist_cross_entropy(
split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward

if sp_mode == "ring_attn":
# For Ring Attention, labels should be split and shifted by RingAttention.prepare_varlen_batch()
# and parallel_output must be True
if sp_rank == sp_size - 1:
# For Zigzag Ring Attention, labels should've been split and
# shifted by RingAttention.prepare_varlen_batch()
if sp_rank == 0:
logits = logits[..., :-1, :]
logits = torch.cat([logits, torch.zeros_like(logits[:, :1, :])], dim=seq_dim)
logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim)
elif is_sp:
# Shift only once: either before splitting or in the last rank without splitting
if split_labels_here or (sp_rank == sp_size - 1):
Expand Down
24 changes: 12 additions & 12 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"test_config",
[
# Double Ring Attention
# {
# "tp_size": 1,
# "pp_size": 1,
# "sp_size": 4,
# "num_microbatches": 1,
# "enable_sequence_parallelism": True,
# "sequence_parallelism_mode": "ring_attn",
# "use_lazy_init": True,
# "zero_stage": 0,
# "precision": "fp16",
# "initial_scale": 1,
# },
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
# Ring Attention + PP
{
"tp_size": 1,
Expand Down

0 comments on commit c663265

Please sign in to comment.