Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance benchmarks? #9

Open
imoneoi opened this issue Dec 7, 2022 · 21 comments
Open

Performance benchmarks? #9

imoneoi opened this issue Dec 7, 2022 · 21 comments

Comments

@imoneoi
Copy link

imoneoi commented Dec 7, 2022

Are there any benchmark results now? Looking forward to performance comparisons with original attention, and official torch+CUDA implementation.

@jakubMitura14
Copy link

I am also curious, additionally maybe it is possible to use cuda code with jax ?

https://github.com/dfm/extending-jax

@jakubMitura14
Copy link

jakubMitura14 commented Feb 27, 2023

Fantastic! have you done experiment with the same data on original flash attention ?

@OhadRubin
Copy link

Not yet

@jon-chuang
Copy link

Hello, could I ask if this works with TPUs?

@evanatyourservice
Copy link

evanatyourservice commented Oct 21, 2023

Here's an updated notebook that precompiles jit and blocks results until ready for anyone interested:

https://colab.research.google.com/drive/11QKRdgMtcivrJNmjTrf2bXTE5yXkXl_Z?usp=sharing

Looks like JAX compiles vanilla attention in a way to be faster than jax flash attention, so no need to change to flash attention if you use JAX.

@SamuelGabriel
Copy link

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

@niemiaszek
Copy link

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

Would be definitely nice to see such benchmark, but I can imagine how hard is comparing JAX vs PyTorch (GPU/TPU), with many optimized implementations for each device. For PyTorch with GPU we have Triton/CUDA, but JAX recently has also added Triton-like mechanism for writing custom Kernels with GPU/TPU - Pallas. You can even find implementation of attention in it here.

@evanatyourservice
Copy link

@niemiaszek I just recently saw they named and added docs for pallas, looks very interesting. JAX is also improving our ability to customize how networks are sharded across accelerators and are publishing papers on their results wrt efficiency, pretty cool I think. Unfortunately I don't have time to do a fair comparison between torch and jax with attention but it seems that whoever takes the time to delve into it, especially jax's recent improvements, would certainly benefit if they have a need.

Even if we don't take the time, it looks like the jax team continually adds their efficiency findings into jax as defaults so we don't have to implement ourselves.

@lucidrains
Copy link
Owner

lucidrains commented Nov 29, 2023

from what i've heard, flash attention doesn't work well on TPUs, but i haven't kept up with the latest iteration of their chip design.

Pallas is just a wrapper around Triton, developed at OpenAI for GPUs. you will basically be always limited by what the Triton compiler can do

@lucidrains
Copy link
Owner

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

@evanatyourservice
Copy link

@lucidrains I'd agree as far as single-device optimizations go. I solely use jax because my work deals mainly with RL and I've already built everything out, but for things like language and vision models, resources like xformers are hard to beat. I do like jax's work toward multi-device customization especially from an RL perspective.

@jon-chuang
Copy link

jon-chuang commented Nov 29, 2023

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

Well, I would argue that in this day, that's no longer such a hard pill given the wide adoption of tiled programming paradigm like Triton (e.g. PyTorch - both codegen + incoming custom kernels, JAX - e.g. Pallas, hardware vendors including NVIDIA, AMD, Intel) which greatly reduces the effort and complexity of getting SOTA perf on GPUs.

@lucidrains
Copy link
Owner

@jon-chuang hmm, still a bit early to declare that imho

we'll see, i hope so!

@jon-chuang
Copy link

Yes, Triton is still not 100% (some matmul kernel size and certain kernels like flash attention backwards are still not SOTA). But it's certainly the direction that industry is investing in, and IMO it's good news for developers and tinkerers who want hackability of each layer of the stack.

I've already heard of some success stories with customizing flash attention kernels via Triton.

@jon-chuang
Copy link

I think these newish attention replacements will take time to be adopted particularly because the dust has not settled on them and it takes a while for wide-scale experimentation and large-scale training with them to truly prove them out.

IMO all it takes is a leap for a highly-funded industrial lab to go out on a limb and train an LLM with one of these...

For instance, Mistral AI essentially has a linear cost attention mechanism based on SWA - sliding window attention - one could argue of course how effective it is at truly capturing information across long context.

all these frameworks cannot do.

I think this is an overstatement? I think it simply has not been tried out in Triton yet. But it should not be that hard. But whether the performance matches is an open question.

I just hope that more devs become aware of how powerful triton is so that there's more experimentation with implementing these kernels.

@lucidrains
Copy link
Owner

lucidrains commented Nov 29, 2023

@jon-chuang yea, let us just agree that we both wish for Triton and the like to succeed so us non-CUDA experts can have control over the entire stack

i just know it isn't there yet.

@jon-chuang
Copy link

Interestingly, a basic building block for Mamba (associative scan) already has support in Triton: pytorch/pytorch#95408 (comment)

@lucidrains
Copy link
Owner

lucidrains commented Nov 30, 2023

it doesn't support multiple inputs. also i heard it is still buggy in its current state

@lucidrains
Copy link
Owner

@jon-chuang anyways, let us take the discussion elsewhere, as this is about flash attention

@MasterSkepticista
Copy link

Flash attention is now available in jax-nightly with a cudnn implementation: jax.nn.dot_product_attention. It only supports Ampere architecture and later.

Note that the default is xla.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants