-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(triton): InplaceNorm + InstanceNorm #50
Comments
that's a great question ! definitely open for contributions, and this looks very reasonable, there's a good chance that Triton gives something a lot faster than pytorch there. In terms of scope I think that it's very much ok, as xformers to me is also an optimized parts zoo (with some automatic builders for them, but that's optional). Just a couple of caveats to begin with:
|
Disable PagedAttn bias types and hdim-512 for test_logsumexp
I'd love to run LayerNorm in place and ideally also add InstanceNorm (by extracting the core normalization from LayerNorm) as HomebrewNLP is currently using a slow PyTorch-level implementation with a correct backward pass.
While we're at it, optionally fusing GLU and GLUv2 (
gelu(f(x)) * g(x) + gelu(h(x))
) with various activation functions and normalization might give another speed boost.To add this myself, I'd need to fully understand triton's pointers and how to access the output instead of input in your LayerNorm implementation. Could you help me with that? or would you instead implement this yourself? Is this even in the scope of xformers?
The text was updated successfully, but these errors were encountered: