Skip to content

Commit

Permalink
[shardformer] hotfix attn mask (#5947)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Jul 29, 2024
1 parent 9664b1b commit 7b38964
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,11 @@ def prepare_attn_kwargs(
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask = attention_mask.tril(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
assert q_padding_mask.shape == (
b,
s_q,
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
# self attention
Expand All @@ -156,7 +155,7 @@ def prepare_attn_kwargs(
b,
s_kv,
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
Expand All @@ -169,7 +168,8 @@ def prepare_attn_kwargs(
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
if s_q != 1:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1)
Expand Down

0 comments on commit 7b38964

Please sign in to comment.