Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mask-tying to sequence length #660

Merged
merged 3 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,64 @@ def test_kqv_ordering(
assert torch.allclose(res_false[0, :, :], res_false[1, :, :])


@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ["scaled_dot_product"])
@pytest.mark.parametrize("device", DEVICES)
def test_different_seqlen(
attention_name: str,
heads: int,
device: torch.device,
):
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)

# Check kqv are not flipped
# this will not catch all issues, but would catch a V being misplaced
# make k and q complimentary, so that QKt is all zero and attention is uniform

q = torch.cat(
(
torch.rand((1, MODEL // 2), device=device),
torch.zeros((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))

k = torch.cat(
(
torch.zeros((1, MODEL // 2), device=device),
torch.rand((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL, device=device)

# Normal call
res = multi_head(query=q, key=k, value=v)

# Changing sequence length by dividing by two to simulate differing sequence length
q2 = torch.cat(
(
torch.rand((1, MODEL // 2), device=device),
torch.zeros((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ // 2, MODEL))

k2 = torch.cat(
(
torch.zeros((1, MODEL // 2), device=device),
torch.rand((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ // 2, MODEL))

v2 = torch.rand(BATCH, SEQ // 2, MODEL, device=device)

res2 = multi_head(query=q2, key=k2, value=v2)
erip marked this conversation as resolved.
Show resolved Hide resolved

assert res.shape != res2.shape


@pytest.mark.parametrize("proj_bias", [False, True])
@pytest.mark.parametrize("same_sizes", [False, True])
@pytest.mark.parametrize("same_settings", [False, True])
Expand Down
9 changes: 5 additions & 4 deletions xformers/components/attention/scaled_dot_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,20 @@ def forward(
)

# Handle a possibly deferred causal mask handling
mask = self.mask
erip marked this conversation as resolved.
Show resolved Hide resolved
if self.causal and self.mask is None:
self.mask = AttentionMask.make_causal(
mask = AttentionMask.make_causal(
seq_len=q.shape[-2],
to_seq_len=q.shape[-2],
device=q.device,
dtype=q.dtype,
)

# Merge the optional causal mask and the user-provided mask
if self.mask is not None:
self.mask = self.mask.to(dtype=q.dtype, device=q.device)
if mask is not None:
mask = mask.to(dtype=q.dtype, device=q.device)

att_mask = att_mask + self.mask if att_mask is not None else self.mask
att_mask = att_mask + mask if att_mask is not None else mask

# Try to handle a case where the sequence is smaller than the mask
if (
Expand Down