diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 141baf3d3770..5872c64856b9 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -139,12 +139,11 @@ def prepare_attn_kwargs( # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) + if s_q != 1: + attention_mask = attention_mask.tril(diagonal=0) + attention_mask = attention_mask.expand(b, s_q, s_kv) else: - assert q_padding_mask.shape == ( - b, - s_q, - ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})" max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention @@ -156,7 +155,7 @@ def prepare_attn_kwargs( b, s_kv, ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" - attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device) + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -169,7 +168,8 @@ def prepare_attn_kwargs( ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + if s_q != 1: + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: outputs["attention_mask_type"] = AttnMaskType.PADDED attention_mask = invert_mask(attention_mask).unsqueeze(1)