Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to blocksparse for causal attention #334

Merged
merged 4 commits into from
Jun 27, 2022
Merged

Conversation

yuanandonly
Copy link
Contributor

What does this PR do?

Addresses issue here. Automatically switches to blocksparse when attention is causal, and mask is not sparse.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@yuanandonly yuanandonly requested review from fmassa and dianaml0 June 14, 2022 00:22
@yuanandonly yuanandonly self-assigned this Jun 14, 2022
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 14, 2022
@blefaudeux
Copy link
Contributor

blefaudeux commented Jun 14, 2022

hey @yuanandonly, thanks for the PR ! I cannot review myself but it seems that the unit test caught something valid, the Favor mechanism can expose a causal attention but blocksparse (normal attention) should not be used here, it's just not the same mechanism. An easy fix is to limit your background optimization to the normal attention, aka scaled_dot_product

edit1: Wait a sec, I wrote too fast, you're already doing that. I forgot that favor was using this codepath..

edit2 : Ahh ok, so Favor does not use this codepath indeed, but this test does, since it compares favor with the normal attention. It looks like a good candidate for something which should work out of the box, looks like it's just a small dimension problem

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @yuanandonly !

I've left a few comments, let me know what you think.

I believe test failures are due to the fact that you return a 4d tensor instead of a 3d one.

Also, could you add a benchmark script that checks the speed of using blocksparse vs the default case, and share the results here?

xformers/components/attention/core.py Outdated Show resolved Hide resolved
xformers/components/attention/core.py Outdated Show resolved Hide resolved
xformers/components/attention/utils.py Outdated Show resolved Hide resolved
xformers/components/attention/core.py Outdated Show resolved Hide resolved
xformers/components/attention/core.py Outdated Show resolved Hide resolved
xformers/components/attention/core.py Outdated Show resolved Hide resolved
tests/test_core_attention.py Outdated Show resolved Hide resolved
assert r_sparse.dtype == expected_device

if r_custom.dtype == r_att_mask.dtype:
assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-3)
Copy link
Contributor Author

@yuanandonly yuanandonly Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux @fmassa @dianaml0 In this test I had to increase the tolerances of torch.allclose() from the default atol=1e-8 and rtol=1e-5 pretty significantly for the assert to pass. Something similar, assert_almost_equal(), from is used here to test parity between standard SDP attention and blocksparse attention.

Is this difference between SDP attention and blocksparse acceptable in this situation? I.e., we're silently switching to blocksparse as of now, but should we inform the user? Or are there any other steps I should take in the code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference might be due to the fact that block-sparse is using TF32 while PyTorch is not. I would say this is fine as long as it's only a matter of numerical differences.

Also, were you able to have the benchmarks for this case ready?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can force pytorch to use tf32 in that case also, else I think that it's fine if the tolerance relaxation is limited to this case ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa @dianaml0 Just pushed what I currently have as my benchmark file. It's a WIP so little messy, but its there in case you guys want to take a look.

@yuanandonly
Copy link
Contributor Author

@dianaml0 @fmassa @blefaudeux
Summary of changes since last review:

  • Added benchmark file
  • Added LRU cache to save blocksparse objects
  • Added more tests
  • Added plots to benchmark.md

Do i need to update the changelog or add any other documentation? Thanks!

@@ -210,19 +215,94 @@ def scaled_query_key_softmax(
return att


# 128 is default maxsize
@lru_cache(maxsize=128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, I did not think of that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

layout_heads = 1

# TODO perhaps add functionality to pad qkv if sequence length is not divisible by block size?
assert seq_len % block_size == 0, "Sequence length must be divisible by block size"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both TODO and assert are good here I think, in practice I doubt that it's a really significant limitation but good to write it down. Padding would trigger a memory copy and possibly allocation, not ideal either

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the future we should just have a custom set of operators (sddmm / softmax / spmm`) which work on causal structure. This would enable for fast execution without being dependent on block sizes, and should be faster than blocksparse as we wouldn't need to index on a set of (unused) indices

blocksparse_attention = _retrieve_blocksparse(layout_heads, seq_len, block_size)
# Dropout is a no-op in evaluation mode
if isinstance(dropout, torch.nn.Dropout):
blocksparse_attention.attn_drop = dropout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it would have been nice to fuse that (fuse the dropout op with the softmax for instance, not make it another call), I'm actually a little surprised that it's not the case already.. for another day

and not seq_len % block_size
and q.shape[-2] == k.shape[-2]
):
# print("switching to blocksparse...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, dead code ?

seq_len = q.shape[-2]
if (
switch_to_blocksparse
and not seq_len % block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but could these two conditions be part of the tests just above, to decide whether to switch to blocksparse or not ? I think that it makes the flow a little easier to follow, there's a conditional branch and you can see all the factors in one place

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks for pointing it out!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a good cleanup to integrate those changes in the same switch_to_blocksparse (at least in the future)

@blefaudeux
Copy link
Contributor

looks great to me, thanks a lot @yuanandonly ! Small nits here (feel free to dispute), and I would definitely update the changelog with this as I think it can be significant perf wise for all GPT like workloads. Also letting @dianaml0 and @fmassa give the green light, but thanks already for the very thorough PR

@blefaudeux
Copy link
Contributor

Oh, and the mypy issue in the CI just require a rebase onto current main I think

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuanandonly really great PR, appreciate all the thorough testing! I added a few comments. It would be a good idea to update the changelog also

expected_device = torch.float32
assert r_sparse.dtype == expected_device

if r_custom.dtype == r_att_mask.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How often is this not the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the inputs are fp16, the datatypes for normal sdp attention and blocksparse are both the same (fp16). But when the inputs are fp32, they are different where blocksparse attention returns fp32 and sdp attention always returns an fp16 tensor. I looked into this a while ago and found that it had to do with pytorch matmuls, not completely sure though (this might be relevant)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you cast the results to fp16 or fp32 and compare?

Copy link
Contributor Author

@yuanandonly yuanandonly Jun 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, added that!

# Checks if blocksparse object exists in cache

blocks = seq_len // block_size
print("Made uncached blocksparse")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use logging.info here instead of print

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general printing in library code can be quite annoying, specially if it happens often during training.

It's ok to leave it like this for now (as the overhead of creating the BlockSparseAttention object is high), but in the long run it would be good to remove this print

xformers/components/attention/core.py Show resolved Hide resolved
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, this is looking great!

For a follow-up PR (so that we don't take too long to get the PR merged), it would be good to address some of the comments that were left by @blefaudeux @dianaml0 and myself, as I believe it could make the code a bit simpler.


# Reshape attention (B, nh, S, hs) back to (N, S, hs)
if orig_dim == 3:
return reshape_heads(att, *att.size())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this could be simplified with attn.flatten(0, 1).

Comment on lines +112 to +118
def split_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int):
return t.view(B, nH, S, Hs)


# (B, nh, S, hs) back to (N, S, hs)
def reshape_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int):
return t.view(B * nH, S, Hs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks like split_heads is not used anymore, and given that reshape_heads can be simplified as t.flatten(0, 1) in the main code, I think we could remove those two functions

seq_len = q.shape[-2]
if (
switch_to_blocksparse
and not seq_len % block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a good cleanup to integrate those changes in the same switch_to_blocksparse (at least in the future)

@yuanandonly yuanandonly merged commit e3aa730 into main Jun 27, 2022
fmassa added a commit that referenced this pull request Aug 10, 2022
* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function

It wasn't needed to break it in two functions to begin with

* Add CUDA implementation for dropout

* clang-format

* Make p be drop probability

* Only CUDA supports dropout

* Add benchmarks

* Remove unused variables

* Fix test

* Cleanups and comments
fmassa added a commit that referenced this pull request Aug 25, 2022
* Enable masking in memory-efficient attention (#333)

* Add attention bias in memory-efficient attention

* Add gradient for attn_mask support

* Add CPU implementation

* clang-format

* Add benchmark scripts

* Add extra loop in benchmarks

* Move zeros array out of helper function

* clang-format

* Enable dropout in memory-efficient attention (#334)

* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function

It wasn't needed to break it in two functions to begin with

* Add CUDA implementation for dropout

* clang-format

* Make p be drop probability

* Only CUDA supports dropout

* Add benchmarks

* Remove unused variables

* Fix test

* Cleanups and comments

* Fix masking corner case when full block is masked (#339)

* Add cutlass 2.9 - 858c735856a7f17bd33fe438ec76d3c9f0234e7f

* Option to load from shared memory for PredicatedTileIterator

* Add cutlass include dir

* Ignore files in third-party for flake8/coverage

* third-party -> third_party

* Address comments

* Revert some un-needed mods

* Add attention_forward_generic.cu

* Add tests

* Fix duplicate calculations on baseline for mem efficient transformers

* Always run all linters in CI

* clang-format attention_forward_generic.cu

* Benchmark: Add possibility to compare benchmarks

* [isort] Ignore third_party

* black autoformat

* Black again + ignore third_party properly

* black

* Fix memory leak between the 2 benchmarks in backward

* Exclude third_party/ without using pyproject.toml as it imposes isolated build which is a pain

* Remove progress bar when finished

* mypy

* flake8

* Save results to shared folder in home location

* run black

* clang-format with 'run-clang-format.py'

* Fix cutlass build for arch>=75

* Set tests precision for gradient more accurately

* Fix precision margin

* Revert changes to black

* [feat] Fix importing xformers when not built (#351)

authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Update black to 22.3.0

* Tweak precision for mem_eff_attention test

* mem-efficient impl for f16 (#352)

Co-authored-by: danthe3rd <danthe3rd>

* Add support for f16 with tensorcores [sm70/sm75/sm80] (#354)

* Add support for f16 with tensorcores

* sm75 minimum for tensorcores

* Run tests with CUDA_LAUNCH_BLOCKING=1

* Support sm70 properly

* Disable tensorcore when not correctly aligned - and use 32bit accessors

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Optimize backward of memory-efficient attention by ~20% (#355)

* Optimize backward by 15% by using equivalent formulation

* Unify everything into single kernel

* Remove unused implementation

* clang-format

* Remove unused tensor

* Display results as we progress during benchmark (#357)

Co-authored-by: danthe3rd <danthe3rd>

* RFC: Ops dispatch (#356)

* Ops dispatch

* CI: Fix doc build

* memory_efficient_attention raises when no implementation is available

* type: ignore

* Fix torch.device/str comparison

* Make mypy happy

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>

* [A100/f32] Use TensorCores for Q.K_t matmul with FastF32 (#358)

* Use TensorCores for MM0 on Float as well

* Use MultiStage MMA when available - change to FastF32 rather than FastF16

* Better alignment calculation

* Just use regular f32, no fastf32

* Hackfix to handle alignment

* HeuristicsMM0 -> GemmTypeQK

* No longer use f16 for matmul

* Add some doc

* Typo

* Fix build <sm80

* Alignment check based on current device compute capability

* Use TORCH_INTERNAL_ASSERT

Co-authored-by: danthe3rd <danthe3rd>

* FlashAttention implem and dispatch (#360)

* FlashAttention implem WIP

* Fix flashattention forward+backward

* Fix forward/backward for FlashAttention

* Enable tests (more permissive) for f16 backward

* Fix CI

* flashattn only supports Sm75 and above

* Fix CI2

* Disable K=128 when below sm80 for flashattn

Co-authored-by: danthe3rd <danthe3rd>

* Misc performance improvements for generic mem-efficient attention (#361)

* 3% speedup by calculating mi from registers

* Also compute m_prime/s_prime and exponentiate from registers

* Support for Simt tiles

* Fix TensorOp for V100

* Fix for A100

* Fix Simt alignment calculation

* clang-format

* WarpReduction before atomic call for Simt

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Update flashattention to support bf16 (#363)

* Update flashattention to support bf16

* bfloat16 only on sm80 and above

Co-authored-by: danthe3rd <danthe3rd>

* Flashattn causal (#364)

* Implement causal memory-efficient attention with FlashAttention

* Update benchmarks

* Fix mypy

Co-authored-by: danthe3rd <danthe3rd>

* Option to disable flashattention (long to build) (#362)

* Option to disable flashattention (long to build)

* Update setup.py

Co-authored-by: danthe3rd <danthe3rd>

* Remove code duplicate in attention_scaling_coefs_updater.h (#367)

Co-authored-by: danthe3rd <danthe3rd>

* Update .gitmodules (#366)

* MemoryEff attention forward: Properly fuse matmul and enable TensorCores on the second matmul (#368)

* Generic backwards

* Guard backward to sm75 only

* bounds checking for gradV

* clang-format

* Fused gemm working for Sm80/Sm75 f16/f32

* WIP

* Volta TensorOp for f16

* Working on A100 again

* SIMT working

* Code cleanup 1

* Code cleanup2

* BUGFIX for shared memory limit

* Remove code

* clang-format

* Remove code again

* Remove draft of backward

* Enforce alignment for fp16

* Fix tests

* Fix constraint on seq length when not using tensorcores

* Fix alignment requirements for V100/tensorcores

* Clang-format

* Update xformers/components/attention/csrc/cuda/attention_forward_generic.cu

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>

* Address comments from fmassa

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>

* Update install instructions with submodule (#365)

* Generic backward implem with cutlass (#371)

* Old bw code

* P100: gradV working

* gk/gq working (at least for small values of M, and on P100/f16)

* Further restrict supported values for bw

* Fix storage into smem for Simt

* More tooling for pruint/debug

* Remove tests we dont need for now

* Tests pass on P100 :D

* 4 warps per block

* Restraint on q length

* Use tensorcores on V100 for f16

* Support dynamic smem for bw

* Handle alignment and different dtype/arch

* Fix NaNS by initializing shared memory

* bw.py

* Fix launch bounds

* Faster 'computeDi'

* minus_lse can operate on arrays

* Output number of regs used etc...

* Code cleanup

* Hackfix for alignment check during forward

* zFill to avoid nans in Sm80 + fix launch bounds

* COde cleanup1

* clang-format

* Fix tests

* Add benchmark for K=64

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>

* Cutlass as submodule (#375)

* Make cutlass be back at 858c735856a7f17bd33fe438ec76d3c9f0234e7f

* Remove cutlass

* Update submodules

* Add submodule (properly)

* spaces / tab

* Make submodule init be recursive

* Fix bad rebase

* Bump tolerance for backward (#377)

* Add verbose flag to CI builds (#376)

* Add verbose flag to CI builds

* Spurious change to rebuild cache

* Add ninja

* Ninja wasn't visible before, install through conda

* Debugging

* Source env

* One more try

* Forgot to uncomment a line

* Another try

* Cleanup

* Fix for FlashAttention dispatch

It requires device capability >= 7.5

* Remove generated file

* Address some reviewer feedback

Remove unused function and typo fix

* Perf improvement on backward (#378)

* Fast again on V100

* Fix correctness - missing syncthreads

* Get rid of AttentionInfo

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com>
bertmaher pushed a commit to bertmaher/xformers that referenced this pull request Dec 20, 2024
* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function

It wasn't needed to break it in two functions to begin with

* Add CUDA implementation for dropout

* clang-format

* Make p be drop probability

* Only CUDA supports dropout

* Add benchmarks

* Remove unused variables

* Fix test

* Cleanups and comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants