From caecd90d2807630effe798950239a915015d8450 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 12 Aug 2024 11:09:15 +0000 Subject: [PATCH] 2d ring forward passed --- .../booster/plugin/hybrid_parallel_plugin.py | 21 +- colossalai/shardformer/layer/attn.py | 317 +++++++++++++----- colossalai/shardformer/layer/utils.py | 30 +- colossalai/shardformer/modeling/llama.py | 4 +- colossalai/shardformer/shard/shard_config.py | 3 - examples/language/llama/benchmark.py | 2 + .../test_layer/test_ring_attn.py | 25 +- .../test_model/test_shard_llama.py | 15 +- 8 files changed, 311 insertions(+), 106 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 69b7e6c0ea40..66e7ca7d2ad3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1065,10 +1065,21 @@ def __init__( self.enable_sequence_parallelism = enable_sequence_parallelism if dp_outside: self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + if sequence_parallelism_mode == "ring_attn": + # Swap tp and sp since 2D Ring has better inter-node latency + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) else: self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + if sequence_parallelism_mode == "ring_attn": + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None @@ -1134,11 +1145,6 @@ def __init__( parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, - sp_stream=( - torch.cuda.Stream() - if enable_sequence_parallelism and sequence_parallelism_mode == "ring_attn" - else None - ), ) self.amp_config = dict( initial_scale=initial_scale, @@ -1231,6 +1237,7 @@ def configure( # Apply Hybrid ZeRO across DP * SP ranks if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(dp_group) else: dp_group = self.dp_group model = HybridParallelModule( diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 919c66285318..c9f80322406a 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -166,8 +166,8 @@ def prepare_attn_kwargs( else: assert q_padding_mask.shape == ( b, - s_q, - ), f"q_padding_mask shape {q_padding_mask.shape} should be {b, s_q}." + s_kv, + ), f"q_padding_mask shape {q_padding_mask.shape} should be {b, s_kv}." max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention @@ -382,6 +382,8 @@ class RingAttention(torch.autograd.Function): For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper, which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370; implemented in Jax and not optimized). + We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to minimize inter-node latency + by utilizing more NICs and fully utilize intra-node bandwidth. """ # Globle cache to avoid recomputation for same-lengthed sequences @@ -390,7 +392,65 @@ class RingAttention(torch.autograd.Function): HALF_INDICES: Tuple = None SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) ATTN_DONE: torch.cuda.Event = None + SP_STREAM: torch.cuda.Stream = None + SP_GROUP: dist.ProcessGroup = None + # duplicate process group for concurrent NCCL streams + # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) + # against this, in practice it seems to work fine. + INNER_RING_GROUP_COPY: dist.ProcessGroup = None DKV_GROUP: dist.ProcessGroup = None + LOCAL_RING_GROUP: dist.ProcessGroup = None + INTER_RING_GROUP: dist.ProcessGroup = None + + @staticmethod + def get_double_ring_groups(sp_group, inner_ring_size=None): + """ + Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size + shouldn't be larger than the number of NICs on each node. + Args: + sp_group (dist.ProcessGroup): Process group for sequence parallelism + inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None. + Returns: + Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + + if inner_ring_size is None: + if sp_size <= 4: + inner_ring_size = min(2, sp_size) + else: + inner_ring_size = min(4, sp_size) + else: + assert ( + inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 + ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + + if inner_ring_size == sp_size: + return sp_group, sp_group + assert ( + sp_size % inner_ring_size == 0 + ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + + num_rings = sp_size // inner_ring_size + inner_ring_group = None + inter_ring_group = None + + # Create inner ring groups + for i in range(inner_ring_size): + ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inner_ring_group = group + + # Create inter ring groups + for i in range(num_rings): + ranks = list(range(i, sp_size, num_rings)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inter_ring_group = group + + return inner_ring_group, inter_ring_group @staticmethod def attention( @@ -398,7 +458,6 @@ def attention( k, v, sp_group, - sp_stream, attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -408,6 +467,7 @@ def attention( deterministic=False, return_softmax=False, dkv_group=None, + inner_ring_size=None, **kwargs, ): """ @@ -430,7 +490,7 @@ def attention( softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax. deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349 return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp). - dkv_group (Optional[dist.ProcessGroup]): Process group for using a new NCCL stream in ring attention backward. + dkv_group (Optional[dist.ProcessGroup]): Process group for using a concurrent NCCL stream in ring attention backward. Returns: out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False. @@ -441,6 +501,9 @@ def attention( _load_flash_attn() if RingAttention.ATTN_DONE is None: RingAttention.ATTN_DONE = torch.cuda.Event() + if RingAttention.SP_STREAM is None: + RingAttention.SP_STREAM = torch.cuda.Stream() + assert ( q.shape[2] == k.shape[2] ), "Q, K and V having different sequence lengths (inference or cross-attn)\ @@ -448,15 +511,27 @@ def attention( assert ( attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES ), f"Mask type {attention_mask_type} is not supported yet." + + # Register process groups locally to make it simple and self-contained to use if dkv_group is None: - if RingAttention.DKV_GROUP is None or dist.get_process_group_ranks( - sp_group - ) != dist.get_process_group_ranks(RingAttention.DKV_GROUP): + if RingAttention.DKV_GROUP is None or RingAttention.SP_GROUP is not sp_group: ranks = dist.get_process_group_ranks(sp_group) RingAttention.DKV_GROUP = dkv_group = dist.new_group(ranks) else: dkv_group = RingAttention.DKV_GROUP + if RingAttention.SP_GROUP is not sp_group: + RingAttention.SP_GROUP = sp_group + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + ranks = dist.get_process_group_ranks(inner_ring_group) + inner_ring_group_copy = RingAttention.INNER_RING_GROUP_COPY = dist.new_group(ranks) + RingAttention.LOCAL_RING_GROUP = inner_ring_group + RingAttention.INTER_RING_GROUP = inter_ring_group + else: + inner_ring_group = RingAttention.LOCAL_RING_GROUP + inter_ring_group = RingAttention.INTER_RING_GROUP + inner_ring_group_copy = RingAttention.INNER_RING_GROUP_COPY + # (B, H, Sq, D) -> (B, Sq, H, D) q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)] pad_output = q.dim() == 4 @@ -487,7 +562,7 @@ def attention( k, v, sp_group, - sp_stream, + RingAttention.SP_STREAM, cu_seqlens, max_seqlen, dropout_p, @@ -496,6 +571,9 @@ def attention( return_softmax, attention_mask_type == AttnMaskType.PADDED_CAUSAL, dkv_group, + inner_ring_group, + inner_ring_group_copy, + inter_ring_group, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: @@ -525,7 +603,11 @@ def forward( return_softmax: Optional[bool] = False, is_packed: Optional[bool] = False, dkv_group: Optional[dist.ProcessGroup] = None, + inner_ring_group: Optional[dist.ProcessGroup] = None, + inner_ring_group_copy: Optional[dist.ProcessGroup] = None, + inter_ring_group: Optional[dist.ProcessGroup] = None, ): + cu_seqlens_q = cu_seqlens_kv = cu_seqlens max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 @@ -561,11 +643,21 @@ def forward( # Be careful about GQA/MQA in reshape q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)] - kv_comms = [RingComm(sp_group) for _ in range(2)] - sp_size = kv_comms[0].world_size - sp_rank = kv_comms[0].rank + if inner_ring_group is None or inter_ring_group is None: + # Use one ring if not specified + inner_ring_group = inter_ring_group = sp_group - # Non-contiguous indexing creates a new contiguous tensor, + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + # Attempt to achieve concurrent comm in the two-stream forward + local_kv_comms = [RingComm(inner_ring_group), RingComm(inner_ring_group_copy)] + inter_ring_comm = RingComm(inter_ring_group) + local_sp_size = dist.get_world_size(inner_ring_group) + local_sp_rank = dist.get_rank(inner_ring_group) + inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 + num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 + + # Non-contiguous indexing copies to a new contiguous tensor, # so only do it once if sp_rank != sp_size - 1: q1 = q[half_idx_back] @@ -605,68 +697,138 @@ def _forward(q, k, v, causal): ) return out, softmax_lse, rng_state - # Overlap output correction with next flash attn - for i in range(sp_size): - with torch.cuda.stream(sp_streams[i % 2]): - # Wait for current kv from prev rank - # NOTE: waiting outside the current stream will NOT correctly synchronize. - if i > 0: - kv_comms[(i + 1) % 2].wait() - - # Avoid overwriting attn input when it shares mem with buffer - if not RingAttention.ATTN_DONE.query(): - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) - - if i < sp_size - 1: - kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) - - if i == 0: - # Compute with local KV; no mask - kv_block = kv_buffers[0] - q_block = q - (block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T) - q_block, kv_block[0], kv_block[1], causal=True - ) - elif i <= sp_rank: - # Received the "surrounding" kv chunks - # Drop the second half of received kv - # (2, t // 2, H, D) - kv_block = kv_buffers[i % 2][:, half_idx_front] - q_block = q - ( - block_out[i % 2], # (T, H, D) - block_softmax_lse[i % 2], # (H, T) - rng_states[i], - ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) - else: - # Received the inner kv chunks - # Drop the first half of q - kv_block = kv_buffers[i % 2] - q_block = q1 - ( - block_out[i % 2], # (T, H, D) - block_softmax_lse[i % 2], # (H, T) - rng_states[i], - ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) - RingAttention.ATTN_DONE.record() - - block_softmax_lse[i % 2] = ( - block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() - ) # (H, T) -> (T, H, 1) - assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] - # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. - # In reality this always finishes before next flash attn; no need for extra sync. - if i == 0: - out = block_out[0] - softmax_lse = block_softmax_lse[0] - elif i <= sp_rank: - out, softmax_lse = _rescale_out_lse(out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]) + def _local_ring_forward(): + # (Hopefully) overlap output correction with next flash attn + for i in range(local_sp_size): + with torch.cuda.stream(sp_streams[i % 2]): + # Avoid overwriting attn input when it shares mem with buffer + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + # Wait for current kv from prev rank + # NOTE: waiting outside the current stream will NOT correctly synchronize. + if i > 0: + local_kv_comms[(i + 1) % 2].wait() + + if i == 0: + # Compute with local KV; no mask + kv_block = kv_buffers[0] + q_block = q + (block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T) + q_block, kv_block[0], kv_block[1], causal=True + ) + elif i <= local_sp_rank: + # Received the "surrounding" kv chunks + # Drop the second half of received kv + # (2, t // 2, H, D) + kv_block = kv_buffers[i % 2][:, half_idx_front] + q_block = q + ( + block_out[i % 2], # (T, H, D) + block_softmax_lse[i % 2], # (H, T) + rng_states[i], + ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) + else: + # Received the inner kv chunks + # Drop the first half of q + kv_block = kv_buffers[i % 2] + q_block = q1 + ( + block_out[i % 2], # (T, H, D) + block_softmax_lse[i % 2], # (H, T) + rng_states[i], + ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) # (H, T) -> (T, H, 1) + assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] + # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. + # In reality this always finishes before next flash attn; no need for extra sync. + if i == 0: + out = block_out[0] + softmax_lse = block_softmax_lse[0] + elif i <= local_sp_rank: + out, softmax_lse = _rescale_out_lse( + out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] + ) + else: + out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( + out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] + ) + + torch.cuda.current_stream().wait_stream(sp_stream) + return out, softmax_lse + + def _other_ring_forward(ring_num_idx, out, softmax_lse): + # Loop through the inner ring after receiving + # all new KVs from the previous inner ring + for i in range(local_sp_size): + with torch.cuda.stream(sp_streams[i % 2]): + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + # Send & recv KV + if i > 0: + local_kv_comms[(i + 1) % 2].wait() + + if ring_num_idx > inter_ring_rank: + kv_block = kv_buffers[i % 2] + ( + block_out[i % 2], + block_softmax_lse[i % 2], + rng_states[i + local_sp_size * ring_num_idx], + ) = _forward(q1, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) + out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( + out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] + ) + else: + kv_block = kv_buffers[i % 2][:, half_idx_front] + ( + block_out[i % 2], + block_softmax_lse[i % 2], + rng_states[i + local_sp_size * ring_num_idx], + ) = _forward(q, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) + out, softmax_lse = _rescale_out_lse( + out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] + ) + + torch.cuda.current_stream().wait_stream(sp_stream) + return out, softmax_lse + + # Send and recv KV between rings at once to maximize NIC util. + for ring_num_idx in range(num_rings): + if ring_num_idx > 0: + inter_ring_comm.wait() + # Reset indices + kv_buffers[0] = inter_ring_kv + + if ring_num_idx < num_rings - 1: + if ring_num_idx == 0: + to_send = kv_buffers[0] else: - out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( - out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] - ) + # The last received KV + to_send = kv_buffers[(local_sp_size - 1) % 2] + inter_ring_comm.send_recv(to_send) + + if ring_num_idx == 0: + out, softmax_lse = _local_ring_forward() + else: + out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse) - # torch.cuda.current_stream().wait_stream(sp_stream) out = out.to(q.dtype) if not is_packed: out = out.view(b, sq, h, d) @@ -775,8 +937,9 @@ def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): **misc_kwargs, ) - # NOTE: We avoid using two streams since it requires doubling dkv and kv buffers, - # and backward is more communication intensive than forward + # NOTE: We avoid using two streams due to doubled buffers + # and that backward is more communication intensive. + # def _local_ring_backward(): for i in range(sp_size): if i > 0: kv_comm.wait() @@ -832,15 +995,17 @@ def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): # q blocks "surrounding" kv blocks dkv_recv[0] += dk_ dkv_recv[1] += dv_ - dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send) + dkv_comm.wait() dkv_recv = dkv_send + # return dq, dkv_recv + dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)] if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] - return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None) + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch( diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index a525eff05a2c..f880a760c558 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -433,26 +433,36 @@ def __init__(self, process_group: dist.ProcessGroup): self.send_rank = (self.rank + 1) % self.world_size self.recv_rank = (self.rank - 1) % self.world_size - if process_group is not None: - self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) - self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) - - def send_recv(self, send_tensor: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, + send_tensor: torch.Tensor, + recv_tensor: Optional[torch.Tensor] = None, + commit: bool = True, + ) -> torch.Tensor: if recv_tensor is None: res = torch.empty_like(send_tensor) else: res = recv_tensor - # NOTE: looks like batch_isend_irecv doesn't deadlock even - # when we never swap send recv ops across ranks + # looks like batch_isend_irecv doesn't deadlock even + # when we don't swap send recv ops based on rank send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group) recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) - self._reqs = dist.batch_isend_irecv(self._ops) + self._ops.extend([send_op, recv_op]) + + if commit: + self._reqs = dist.batch_isend_irecv(self._ops) return res + def commit(self): + assert len(self._ops) > 0, "No ops to commit" + self._reqs = dist.batch_isend_irecv(self._ops) + def wait(self): + assert len(self._reqs) > 0, "No requests to wait for" for req in self._reqs: req.wait() self._reqs = [] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 3ddb2ac89193..59583a273022 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -563,9 +563,7 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) if sp_mode == "ring_attn": - attn_output = RingAttention.attention( - query_states, key_states, value_states, sp_group, shard_config.sp_stream, **attention_mask - ) + attn_output = RingAttention.attention(query_states, key_states, value_states, sp_group, **attention_mask) elif shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index cb20bea5af82..505443b14012 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from typing import Any, Dict, Optional -import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -32,7 +31,6 @@ class ShardConfig: enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim. For SP: set to True to NOT gather the output along the seq dim. - sp_stream (Optional[torch.cuda.Stream]): : The stream for ring attention output correction. Defaults to None. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -54,7 +52,6 @@ class ShardConfig: # for moe related moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None - sp_stream: Optional[torch.cuda.Stream] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 2b5b4f279367..093377e7a034 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -332,6 +332,8 @@ def empty_init(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] + del outputs # free memory + if dist.get_rank() == dist.get_world_size() - 1: print(f"Step {step} loss: {loss}") booster.backward(loss, optimizer) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 3463000fadd6..df18aefa2301 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -21,7 +21,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD - sp_stream = torch.cuda.Stream() # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's @@ -37,7 +36,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU - ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) + ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True @@ -60,6 +59,10 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) + if dist.get_rank() == 0: + print( + f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.LOCAL_RING_GROUP)} passed." + ) @parameterize("seqlen", [4096]) @@ -71,7 +74,6 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() - sp_stream = torch.cuda.Stream() atol = rtol = 7e-3 torch.cuda.manual_seed(2) # Prepare varlen attention mask @@ -113,7 +115,6 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): k_ring, v_ring, sp_group, - sp_stream, **mask_info, pad_output=False, return_softmax=True, @@ -148,17 +149,29 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): assert_close(dv, dv_ring, atol=atol, rtol=rtol) -def launch(rank, world_size, port): +def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() check_ring_attn() +def launch_double_ring(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + check_ring_attn() + + @rerun_if_address_is_in_use() @parameterize("world_size", [2]) def test_ring_attn(world_size): - spawn(launch, nprocs=world_size) + spawn(launch_single_ring, nprocs=world_size) + + +@rerun_if_address_is_in_use() +@parameterize("world_size", [4]) +def test_double_ring(world_size): + spawn(launch_double_ring, nprocs=world_size) if __name__ == "__main__": test_ring_attn() + test_double_ring() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index beeada8f2e5b..581c578f5bef 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,7 +153,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # Zigzag Ring Attention + PP + # # Double Ring Attention + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 4, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "bf16", + # "initial_scale": 1, + # }, + # Ring Attention + PP { "tp_size": 1, "pp_size": 2,