You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I use memory_efficient_attention with a bias of Torch.Tensor, I got an error No operator found for this attention. Then I found a remark in xformers.ops.fmha , class _fMHA as "Only gradients to Q/K/V is implemented. For instance, it's not possible to backpropagate through the attention mask". I thought it means I could not get gradients for attention bias. I wonder if you could add a feature to support gradient for attention bias.
Motivation
In models like AlphaFold2, we used biased attention a lot, and the pairwise representation should require gradients.
Pitch
Backpropagate on attention bias for function memory_efficient_attention.
Additional context
I didn't see any args or kwargs for key_padding_mask. I wonder if it is proper to masked_fill the attention bias with float('-inf') to achieve key_padding_mask?
The text was updated successfully, but these errors were encountered:
Hi @EBGU
Thanks for your report.
There is a PR right now adding support for exactly that: #587
We plan to merge it after we release 0.0.16 - hopefully next week
🚀 Feature
When I use memory_efficient_attention with a bias of Torch.Tensor, I got an error No operator found for this attention. Then I found a remark in xformers.ops.fmha , class _fMHA as "Only gradients to Q/K/V is implemented. For instance, it's not possible to backpropagate through the attention mask". I thought it means I could not get gradients for attention bias. I wonder if you could add a feature to support gradient for attention bias.
Motivation
In models like AlphaFold2, we used biased attention a lot, and the pairwise representation should require gradients.
Pitch
Backpropagate on attention bias for function memory_efficient_attention.
Additional context
I didn't see any args or kwargs for key_padding_mask. I wonder if it is proper to masked_fill the attention bias with float('-inf') to achieve key_padding_mask?
The text was updated successfully, but these errors were encountered: