Skip to content

Commit

Permalink
2d ring forward passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 12, 2024
1 parent d3831b4 commit caecd90
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 106 deletions.
21 changes: 14 additions & 7 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,10 +1065,21 @@ def __init__(
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
Expand Down Expand Up @@ -1134,11 +1145,6 @@ def __init__(
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
sp_stream=(
torch.cuda.Stream()
if enable_sequence_parallelism and sequence_parallelism_mode == "ring_attn"
else None
),
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down Expand Up @@ -1231,6 +1237,7 @@ def configure(
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule(
Expand Down
317 changes: 241 additions & 76 deletions colossalai/shardformer/layer/attn.py

Large diffs are not rendered by default.

30 changes: 20 additions & 10 deletions colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,26 +433,36 @@ def __init__(self, process_group: dist.ProcessGroup):
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size

if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)

def send_recv(self, send_tensor: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)

def send_recv(
self,
send_tensor: torch.Tensor,
recv_tensor: Optional[torch.Tensor] = None,
commit: bool = True,
) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(send_tensor)
else:
res = recv_tensor

# NOTE: looks like batch_isend_irecv doesn't deadlock even
# when we never swap send recv ops across ranks
# looks like batch_isend_irecv doesn't deadlock even
# when we don't swap send recv ops based on rank
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
self._reqs = dist.batch_isend_irecv(self._ops)
self._ops.extend([send_op, recv_op])

if commit:
self._reqs = dist.batch_isend_irecv(self._ops)
return res

def commit(self):
assert len(self._ops) > 0, "No ops to commit"
self._reqs = dist.batch_isend_irecv(self._ops)

def wait(self):
assert len(self._reqs) > 0, "No requests to wait for"
for req in self._reqs:
req.wait()
self._reqs = []
Expand Down
4 changes: 1 addition & 3 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,7 @@ def forward(
value_states = repeat_kv(value_states, self.num_key_value_groups)

if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states, key_states, value_states, sp_group, shard_config.sp_stream, **attention_mask
)
attn_output = RingAttention.attention(query_states, key_states, value_states, sp_group, **attention_mask)

elif shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
Expand Down
3 changes: 0 additions & 3 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

Expand Down Expand Up @@ -32,7 +31,6 @@ class ShardConfig:
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
For SP: set to True to NOT gather the output along the seq dim.
sp_stream (Optional[torch.cuda.Stream]): : The stream for ring attention output correction. Defaults to None.
"""

tensor_parallel_process_group: Optional[ProcessGroup] = None
Expand All @@ -54,7 +52,6 @@ class ShardConfig:
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
sp_stream: Optional[torch.cuda.Stream] = None
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down
2 changes: 2 additions & 0 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def empty_init():
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
del outputs # free memory

if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer)
Expand Down
25 changes: 19 additions & 6 deletions tests/test_shardformer/test_layer/test_ring_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
sp_stream = torch.cuda.Stream()

# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
Expand All @@ -37,7 +36,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
q.requires_grad = k.requires_grad = v.requires_grad = True

# Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True)
ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True
Expand All @@ -60,6 +59,10 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
if dist.get_rank() == 0:
print(
f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.LOCAL_RING_GROUP)} passed."
)


@parameterize("seqlen", [4096])
Expand All @@ -71,7 +74,6 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
sp_stream = torch.cuda.Stream()
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
Expand Down Expand Up @@ -113,7 +115,6 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
k_ring,
v_ring,
sp_group,
sp_stream,
**mask_info,
pad_output=False,
return_softmax=True,
Expand Down Expand Up @@ -148,17 +149,29 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
assert_close(dv, dv_ring, atol=atol, rtol=rtol)


def launch(rank, world_size, port):
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
check_ring_attn()


def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn()


@rerun_if_address_is_in_use()
@parameterize("world_size", [2])
def test_ring_attn(world_size):
spawn(launch, nprocs=world_size)
spawn(launch_single_ring, nprocs=world_size)


@rerun_if_address_is_in_use()
@parameterize("world_size", [4])
def test_double_ring(world_size):
spawn(launch_double_ring, nprocs=world_size)


if __name__ == "__main__":
test_ring_attn()
test_double_ring()
15 changes: 14 additions & 1 deletion tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
# Zigzag Ring Attention + PP
# # Double Ring Attention
# {
# "tp_size": 1,
# "pp_size": 1,
# "sp_size": 4,
# "num_microbatches": 1,
# "enable_sequence_parallelism": True,
# "sequence_parallelism_mode": "ring_attn",
# "use_lazy_init": True,
# "zero_stage": 2,
# "precision": "bf16",
# "initial_scale": 1,
# },
# Ring Attention + PP
{
"tp_size": 1,
"pp_size": 2,
Expand Down

0 comments on commit caecd90

Please sign in to comment.