-
Notifications
You must be signed in to change notification settings - Fork 695
[bugfix] Causal mask being skipped if no extra mask #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b4d6adb
to
1de96c7
Compare
The test I came up with is a bit strict, so it's effective with scaled_dot_product but it does not pass with all the approximations, they're mostly there but not exactly. I think it would be nice to come up with a proper test for all of them, but in the meantime this needs to be fixed on main |
@@ -24,6 +24,7 @@ | |||
|
|||
@dataclass | |||
class FavorAttentionConfig(AttentionConfig): | |||
causal: Optional[bool] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
causal was not being passed on with the factory, because of that..
@@ -59,8 +59,9 @@ def _get_causal_mask(self, seq_len: int, to_seq_len: int) -> torch.Tensor: | |||
# Cache a mask so that multiple instances would reuse the same | |||
causal_mask = self._causal_mask | |||
if not causal_mask: | |||
causal_mask = torch.tril(torch.ones(seq_len, to_seq_len), diagonal=0) | |||
causal_mask[self._causal_mask == 1] = -float("inf") | |||
causal_mask = torch.triu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was wrong with additive masks, pretty bad
back to draft, looks like favor causal needs improvements, long overdue (spotted by @fmassa long ago) |
3ff3d90
to
2abf788
Compare
# Algorithm 1 in the paper | ||
ref_v = torch.ones_like(v[:, 0, :].unsqueeze(1)) | ||
ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a memory hog, that's the issue with this approach @fmassa, but the older state was not actually trainable (this was not being tested.. fixed now), so I guess that it's better than before and improving on that could be a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see #105 for tracking
@dianaml0 breaking now around Nystrom + key padding mask, I'm seeing a few lines that I don´t get (like mask and key_padding_mask being used interchangeably ?) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #104 +/- ##
==========================================
+ Coverage 86.68% 86.86% +0.18%
==========================================
Files 49 49
Lines 2493 2497 +4
==========================================
+ Hits 2161 2169 +8
+ Misses 332 328 -4
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Taking a look now! For Nystrom only |
|
||
# FIXME: @lefaudeux This is probably broken, we don´t test this I suppose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be tested in test_masking
in test_nystrom_attention.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's this line which I thought could not possibly work, given the difference in shapes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should work since we assert on key_padding_mask
dimension being (batch, 1, seq)? And we also convert it to an additive mask. But perhaps using masked_fill
with a boolean key padding mask makes more sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes you're right, sorry I forgot again, the 1 will be enough to broadcast ! all good, 'll remove the FIXME, I've changed the unit test to make sure that we do go there and it works indeed
tests/test_nystrom_attention.py
Outdated
@@ -85,8 +85,10 @@ def test_att_mask_ignored(): | |||
assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2) | |||
|
|||
def test_masking(): | |||
nystrom_config["causal"] = True | |||
sdp_config["causal"] = True | |||
# FIXME: Masking seems fine as such, but whole lines are randomly masked when causal+random mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this what you meant by key padding + causal being broken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, although it feels like it could be only an issue with the unit test, but scaled_dot_product and nystrom are not behaving in the same way (although there are plenty of NaNs around in both cases). I think that it could be that Nystrom breaks earlier because of the pooling for anchors, if one anchor is full of inf (easy with causal) it cannot compute the attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed now with a random masking a little less aggressive (not masking half of the words :)), I think that it's still a good test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :) So can causal be used for the test now or that's still an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not an issue anymore, I messed up the push when rebasing but it works fine locally, let me update..
tests/test_attentions.py
Outdated
@@ -81,7 +96,9 @@ def test_order_invariance( | |||
) | |||
|
|||
# Check that a shuffled input produces the same results | |||
seqs = [SEQ, SEQ - 16] if attention_name != "blocksparse" else [SEQ] | |||
seqs = ( | |||
[SEQ, SEQ - 16] if (attention_name != "blocksparse" and not causal) else [SEQ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come we don't test SEQ-16 for causal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good question ! Basically we adjust the mask on the fly when the sequence is smaller than initially thought, but we don't have the codepath in place to also adjust the causal mask so that broke (it "worked" before because we were not actually applying the causal mask..). I can fix that in this PR and then refactor all this part in a sequel ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay thanks for the explanation :) That sounds good
|
||
# FIXME: @lefaudeux This is probably broken, we don´t test this I suppose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should work since we assert on key_padding_mask
dimension being (batch, 1, seq)? And we also convert it to an additive mask. But perhaps using masked_fill
with a boolean key padding mask makes more sense?
@@ -74,24 +74,40 @@ def forward( | |||
- If the mask has the float type, then an additive mask is expected (masked values are -inf) | |||
|
|||
""" | |||
|
|||
# Handle a possibly deferred causal mask handling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catching of edge cases!
…d to abstract that
…research#104) * Adding some helpers on SparseCS + small unit testing * unit test fix * adding a device test, checking for facebookresearch#93 * catching the padding bug, fixing
What does this PR do?
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.