Skip to content

[bugfix] Causal mask being skipped if no extra mask #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 17, 2021
Merged

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Nov 15, 2021

What does this PR do?

  • Fixes [bug] Causal mask is skipped if no att_mask is passed #103
  • Adding a unit test checking that causality is correct for scaled dot product and FAVOR. Nystrom does not seem to pass that
  • Fixes Favor not picking up the causal flag
  • Fixes Favor causal not being trainable (but introduces a big memory use)

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 15, 2021
@blefaudeux blefaudeux force-pushed the fix_103 branch 3 times, most recently from b4d6adb to 1de96c7 Compare November 15, 2021 22:43
@blefaudeux
Copy link
Contributor Author

The test I came up with is a bit strict, so it's effective with scaled_dot_product but it does not pass with all the approximations, they're mostly there but not exactly. I think it would be nice to come up with a proper test for all of them, but in the meantime this needs to be fixed on main

@@ -24,6 +24,7 @@

@dataclass
class FavorAttentionConfig(AttentionConfig):
causal: Optional[bool]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

causal was not being passed on with the factory, because of that..

@blefaudeux blefaudeux changed the title [DRAFT][bugfix] Causal mask being skipped if no extra mask [bugfix] Causal mask being skipped if no extra mask Nov 15, 2021
@@ -59,8 +59,9 @@ def _get_causal_mask(self, seq_len: int, to_seq_len: int) -> torch.Tensor:
# Cache a mask so that multiple instances would reuse the same
causal_mask = self._causal_mask
if not causal_mask:
causal_mask = torch.tril(torch.ones(seq_len, to_seq_len), diagonal=0)
causal_mask[self._causal_mask == 1] = -float("inf")
causal_mask = torch.triu(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was wrong with additive masks, pretty bad

@blefaudeux blefaudeux marked this pull request as draft November 15, 2021 23:38
@blefaudeux
Copy link
Contributor Author

back to draft, looks like favor causal needs improvements, long overdue (spotted by @fmassa long ago)

@blefaudeux blefaudeux force-pushed the fix_103 branch 4 times, most recently from 3ff3d90 to 2abf788 Compare November 16, 2021 00:45
# Algorithm 1 in the paper
ref_v = torch.ones_like(v[:, 0, :].unsqueeze(1))
ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a memory hog, that's the issue with this approach @fmassa, but the older state was not actually trainable (this was not being tested.. fixed now), so I guess that it's better than before and improving on that could be a follow up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #105 for tracking

@blefaudeux blefaudeux marked this pull request as ready for review November 16, 2021 01:00
@blefaudeux blefaudeux marked this pull request as draft November 16, 2021 02:26
@blefaudeux
Copy link
Contributor Author

@dianaml0 breaking now around Nystrom + key padding mask, I'm seeing a few lines that I don´t get (like mask and key_padding_mask being used interchangeably ?)

@codecov-commenter
Copy link

codecov-commenter commented Nov 16, 2021

Codecov Report

Attention: Patch coverage is 92.59259% with 2 lines in your changes missing coverage. Please review.

Project coverage is 86.86%. Comparing base (c3021ed) to head (70c70dc).
Report is 996 commits behind head on main.

Files with missing lines Patch % Lines
xformers/components/attention/nystrom.py 80.00% 1 Missing ⚠️
...formers/components/attention/scaled_dot_product.py 91.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #104      +/-   ##
==========================================
+ Coverage   86.68%   86.86%   +0.18%     
==========================================
  Files          49       49              
  Lines        2493     2497       +4     
==========================================
+ Hits         2161     2169       +8     
+ Misses        332      328       -4     
Flag Coverage Δ
Python 86.86% <92.59%> (+0.18%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@dianaml0
Copy link
Contributor

@dianaml0 breaking now around Nystrom + key padding mask, I'm seeing a few lines that I don´t get (like mask and key_padding_mask being used interchangeably ?)

Taking a look now! For Nystrom only key_padding_mask is accepted. And if the causal flag is specified, key_padding_mask is combined with a causal mask into mask. I'll check the errors


# FIXME: @lefaudeux This is probably broken, we don´t test this I suppose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be tested in test_masking in test_nystrom_attention.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's this line which I thought could not possibly work, given the difference in shapes ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should work since we assert on key_padding_mask dimension being (batch, 1, seq)? And we also convert it to an additive mask. But perhaps using masked_fill with a boolean key padding mask makes more sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you're right, sorry I forgot again, the 1 will be enough to broadcast ! all good, 'll remove the FIXME, I've changed the unit test to make sure that we do go there and it works indeed

@@ -85,8 +85,10 @@ def test_att_mask_ignored():
assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2)

def test_masking():
nystrom_config["causal"] = True
sdp_config["causal"] = True
# FIXME: Masking seems fine as such, but whole lines are randomly masked when causal+random mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this what you meant by key padding + causal being broken?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, although it feels like it could be only an issue with the unit test, but scaled_dot_product and nystrom are not behaving in the same way (although there are plenty of NaNs around in both cases). I think that it could be that Nystrom breaks earlier because of the pooling for anchors, if one anchor is full of inf (easy with causal) it cannot compute the attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now with a random masking a little less aggressive (not masking half of the words :)), I think that it's still a good test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :) So can causal be used for the test now or that's still an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not an issue anymore, I messed up the push when rebasing but it works fine locally, let me update..

@@ -81,7 +96,9 @@ def test_order_invariance(
)

# Check that a shuffled input produces the same results
seqs = [SEQ, SEQ - 16] if attention_name != "blocksparse" else [SEQ]
seqs = (
[SEQ, SEQ - 16] if (attention_name != "blocksparse" and not causal) else [SEQ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come we don't test SEQ-16 for causal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good question ! Basically we adjust the mask on the fly when the sequence is smaller than initially thought, but we don't have the codepath in place to also adjust the causal mask so that broke (it "worked" before because we were not actually applying the causal mask..). I can fix that in this PR and then refactor all this part in a sequel ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay thanks for the explanation :) That sounds good


# FIXME: @lefaudeux This is probably broken, we don´t test this I suppose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should work since we assert on key_padding_mask dimension being (batch, 1, seq)? And we also convert it to an additive mask. But perhaps using masked_fill with a boolean key padding mask makes more sense?

@@ -74,24 +74,40 @@ def forward(
- If the mask has the float type, then an additive mask is expected (masked values are -inf)

"""

# Handle a possibly deferred causal mask handling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catching of edge cases!

@blefaudeux blefaudeux merged commit 1328ba7 into main Nov 17, 2021
@blefaudeux blefaudeux deleted the fix_103 branch November 17, 2021 04:26
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
…research#104)

* Adding some helpers on SparseCS + small unit testing
* unit test fix
* adding a device test, checking for facebookresearch#93
* catching the padding bug, fixing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug] Causal mask is skipped if no att_mask is passed
5 participants