Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No operator found for this attention when attn_bias is a torch.Tensor #576

Open
pfeatherstone opened this issue Dec 10, 2022 · 12 comments
Open

Comments

@pfeatherstone
Copy link

🐛 Bug

I get the error No operator found for this attention

Command

Run the code below

To Reproduce

from einops import rearrange
import xformers.ops as xops

def attention(q, k, v, attn_bias, D, attn_drop=0, is_training=False):
    scale   = D ** -0.5
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', d = D), (q, k, v))
    dots    = (torch.einsum('b h l d, b h s d -> b h l s', q, k) + attn_bias) * scale
    attn    = dots.softmax(dim = -1)
    attn    = F.dropout(attn, attn_drop, is_training, True)
    out     = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
    out     = rearrange(out, 'b h l d -> b l (h d)')
    return out

def attention_efficient(q, k, v, attn_bias, D, attn_drop=0, is_training=False):
    scale   = D ** -0.5
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', d = D), (q, k, v))
    drop    = attn_drop if is_training else 0
    out     = xops.memory_efficient_attention(q, k, v, attn_bias, p=drop, scale=scale, op=None)
    out     = rearrange(out, 'b l h d -> b l (h d)')
    return out

q            = torch.randn(2,512,128).cuda()
k            = torch.randn(2,256,128).cuda()
v            = torch.randn(2,256,128).cuda()
attn_bias    = torch.randn(2,4,512,256).cuda()

out1 = attention(q, k, v, attn_bias, 32)
out2 = attention_efficient(q, k, v, attn_bias, 32)
torch.testing.assert_close(out1, out2)

Expected behavior

I expect this to work.

Environment

Collecting environment information...
PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: 14.0.0-1ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.35

Python version: 3.10.6 (main, Nov  2 2022, 18:53:38) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1660 Ti
Nvidia driver version: 520.61.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] torch==1.13.0
[conda] Could not collect
@pfeatherstone
Copy link
Author

It works if attn_bias is None. But even then, if the inputs are on device('cpu'), it never works

@danthe3rd
Copy link
Contributor

Hi @pfeatherstone

That's correct.

  • Memory efficient attention works mostly on GPU (except for some very special cases: f32 & K <= 32)
  • We don't support arbitrary attention masks. However, you can use memory_efficient_attention with a causal mask (lower triangular mask) by setting attn_bias=xformers.ops.LowerTriangularMask()

@pfeatherstone
Copy link
Author

hmm. I'm using attn_bias for both relative positional encoding and causal masking

@pfeatherstone
Copy link
Author

So this isn't possible with xformers then ?

@pfeatherstone
Copy link
Author

If not, then maybe updating the docs saying that attn_bias has some tight restrictions. It should maybe be replaced with a boolean called causal, a bit like how the lucidrains repos formulate it, which will either apply a lower triangular matrix, or nothing at all.

@pfeatherstone
Copy link
Author

I can see that FlashAttention has attention bias on their roadmap. Is that the case for xformers too ?

@danthe3rd
Copy link
Contributor

So this isn't possible with xformers then ?

Not at the moment. We have a few customers asking for it as well, it's on our radar, but likely won't happen until some time.
We don't use a "causal" bool because we want to support different types of bias in the future without breaking the API.
We actually support a torch.Tensor mask but only for the forward pass, if triton is installed correctly.

XFormers used flash attention, depending on the input setting, so it will be supported in XFormers if/when flash implements it.

@pfeatherstone
Copy link
Author

Cool. Shall we keep this open then for tracking purposes until a PR is merged which fixes this?

@danthe3rd
Copy link
Contributor

Sure - let me just rename it though

@danthe3rd danthe3rd changed the title No operator found for this attention No operator found for this attention when attn_bias is a torch.Tensor Dec 10, 2022
@dingjingzhen
Copy link

@danthe3rd hi,Is there a solution to this problem please?
image
When I use it this way, I get the error "No operator found for this attention"

@danthe3rd
Copy link
Contributor

I expect this to work once this PR is merged:
#587

@jfc4050
Copy link
Contributor

jfc4050 commented Jan 12, 2023

hmm. I'm using attn_bias for both relative positional encoding and causal masking

So this isn't possible with xformers then ?

my team has a similar use case and i've proposed a change, you might be interested in following this issue #640

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants