Skip to content

Commit

Permalink
fix mask-tying to sequence length (#660)
Browse files Browse the repository at this point in the history
* fix mask-tying to sequence length. WIP

* fix unbound variable.

* add test to ensure different seq-length path works.
  • Loading branch information
erip authored Feb 3, 2023
1 parent 48a77cc commit 7f4fdce
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
58 changes: 58 additions & 0 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,64 @@ def test_kqv_ordering(
assert torch.allclose(res_false[0, :, :], res_false[1, :, :])


@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ["scaled_dot_product"])
@pytest.mark.parametrize("device", DEVICES)
def test_different_seqlen(
attention_name: str,
heads: int,
device: torch.device,
):
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)

# Check kqv are not flipped
# this will not catch all issues, but would catch a V being misplaced
# make k and q complimentary, so that QKt is all zero and attention is uniform

q = torch.cat(
(
torch.rand((1, MODEL // 2), device=device),
torch.zeros((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))

k = torch.cat(
(
torch.zeros((1, MODEL // 2), device=device),
torch.rand((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL, device=device)

# Normal call
res = multi_head(query=q, key=k, value=v)

# Changing sequence length by dividing by two to simulate differing sequence length
q2 = torch.cat(
(
torch.rand((1, MODEL // 2), device=device),
torch.zeros((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ // 2, MODEL))

k2 = torch.cat(
(
torch.zeros((1, MODEL // 2), device=device),
torch.rand((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ // 2, MODEL))

v2 = torch.rand(BATCH, SEQ // 2, MODEL, device=device)

res2 = multi_head(query=q2, key=k2, value=v2)

assert res.shape != res2.shape


@pytest.mark.parametrize("proj_bias", [False, True])
@pytest.mark.parametrize("same_sizes", [False, True])
@pytest.mark.parametrize("same_settings", [False, True])
Expand Down
9 changes: 5 additions & 4 deletions xformers/components/attention/scaled_dot_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,20 @@ def forward(
)

# Handle a possibly deferred causal mask handling
mask = self.mask
if self.causal and self.mask is None:
self.mask = AttentionMask.make_causal(
mask = AttentionMask.make_causal(
seq_len=q.shape[-2],
to_seq_len=q.shape[-2],
device=q.device,
dtype=q.dtype,
)

# Merge the optional causal mask and the user-provided mask
if self.mask is not None:
self.mask = self.mask.to(dtype=q.dtype, device=q.device)
if mask is not None:
mask = mask.to(dtype=q.dtype, device=q.device)

att_mask = att_mask + self.mask if att_mask is not None else self.mask
att_mask = att_mask + mask if att_mask is not None else mask

# Try to handle a case where the sequence is smaller than the mask
if (
Expand Down

0 comments on commit 7f4fdce

Please sign in to comment.