-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] cutlass FlashAttention bias+dropout support #587
Conversation
Hi @jfc4050! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Hi @jfc4050 First, that's a pretty impressive PR and significant changes, which must have taken a lot of effort :o The main things that will matter (in order of importance) for me would be: Once again, thanks a lot for putting all of this work there! EDIT: We seem to have some build errors on Windows that would need to be addressed (see CI) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a first pass on the code - a few questions and a few comments :)
Overall looks really clean and you did the effort to do things properly
if (bias->dim() == 2) { // (n_queries, n_keys) | ||
TORCH_INTERNAL_ASSERT(bias->size(0) == M); | ||
TORCH_INTERNAL_ASSERT(bias->size(1) == N); | ||
|
||
ASSIGN_CHECK_OVERFLOW(p.bias_strideB, 0); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, 0); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, grad_bias.stride(0)); | ||
|
||
if (bias_requires_grad) { | ||
ASSIGN_CHECK_OVERFLOW(p.gB_strideB, 0); | ||
ASSIGN_CHECK_OVERFLOW(p.gB_strideH, 0); | ||
ASSIGN_CHECK_OVERFLOW(p.gB_strideM, bias->stride(0)); | ||
} | ||
} else if (bias->dim() == 3) { // (batch_sz * n_heads, n_queries, n_keys) | ||
TORCH_INTERNAL_ASSERT(bias->size(0) == B * nH); | ||
TORCH_INTERNAL_ASSERT(bias->size(1) == M); | ||
TORCH_INTERNAL_ASSERT(bias->size(2) == N); | ||
|
||
ASSIGN_CHECK_OVERFLOW(p.bias_strideB, nH * bias->stride(0)); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(0)); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(1)); | ||
|
||
if (bias_requires_grad) { | ||
ASSIGN_CHECK_OVERFLOW(p.gB_strideB, nH * grad_bias.stride(0)); | ||
ASSIGN_CHECK_OVERFLOW(p.gB_strideH, grad_bias.stride(0)); | ||
ASSIGN_CHECK_OVERFLOW(p.gB_strideM, grad_bias.stride(1)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we could assume the bias always has dimension 4 for simplicity, the user could torch.expand it if necessary (which will set strides to 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome, will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: the tests use 3 dims. would you prefer i change the tests (not sure if other implementations can handle 4 dims), or handle cases for 3 and 4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's only support dim=4 in the kernel, and expand properly in python before we call the C++ code.
Later we can simplify that to accept masks with or without head dimension (because we accept q/k/v in both BMHK and BMK shapes, so it makes sense to be coherent for the mask as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ended up being simpler imo to do in c++ layer. doing it in the python layer means we don't have easy access to the shape variables and the bias gradient ends up having wrong shape unless that gets reshaped too. also it messed with autograd a bit since the view from unsqueezing/expanding/reshaping doesn't require grad so we'd have to either add it to the autograd graph or add another flag to the function signature to indicate that it should compute grad
you can see what it looks like in 296e6fa, lmk if you'd still prefer to do it in python layer and we can revert/redo
if (bias->dim() == 2) { // (n_queries, n_keys) | ||
TORCH_CHECK(bias->size(0) == M); | ||
TORCH_CHECK(bias->size(1) == N); | ||
|
||
ASSIGN_CHECK_OVERFLOW(p.bias_strideB, 0); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, 0); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(0)); | ||
} else if (bias->dim() == 3) { // (batch_sz * n_heads, n_queries, n_keys) | ||
TORCH_CHECK(bias->size(0) == B * num_heads); | ||
TORCH_CHECK(bias->size(1) == M); | ||
TORCH_CHECK(bias->size(2) == N); | ||
|
||
ASSIGN_CHECK_OVERFLOW(p.bias_strideB, num_heads * bias->stride(0)); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(0)); | ||
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, let's assume dim=4
xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h
Outdated
Show resolved
Hide resolved
typename DefaultGemm::Mma, | ||
typename MatmulQK::AccumulatorSharedStorage>; | ||
|
||
using DefaultMmaFromSmem = typename std::conditional< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should avoid using std::*
in this code, as we want to upstream it to CUTLASS, where it needs to build with nvrtc. I believe there should be the same functionality in platform::conditional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh got it. i'll check for other uses of std elsewhere
xformers/components/attention/csrc/cuda/mem_eff_attention/mma_from_smem_with_operand_scaling.h
Outdated
Show resolved
Hide resolved
xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h
Outdated
Show resolved
Hide resolved
thats a relief to hear 😅 |
ref = ref_attention(query, key, value, attn_bias, mask, p) | ||
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}" | ||
|
||
num_trials = 1000 | ||
p_val_tol = 0.0001 | ||
p_val_tol = 1e-6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot to mention this. dropped p_val_tol from 1e-4 -> 1e-6. otherwise the cutlass dropout mask fails the second binomial test.
took a look at the results it was outputting and it doesn't seem outlandish or anything. for example this test
test_dropout[cutlassF-33-32-2-32-0.7-42]
fails with one of the 2048 elements of masks
ending up with 248/1000 keeps (p=0.7), resulting in a p value of 3.9172e-05
.
looks like the other implementation uses a new subsequence every 4 elements which might have better independence guarantees but its unlikely to be as performant as the way the CUTLASS dropout implementation is done now
whats your take on this? maybe we can soften the constraint by using a percentage of elements have to pass the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @fmassa who implemented this test and the smallK
kernel which also supports dropout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EDIT: The errors were unrelated to your code - now I can run it on f32 properly
@@ -546,7 +546,7 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv): | |||
) | |||
|
|||
_out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( | |||
query, key, value | |||
query, key, value, op=op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
@@ -88,6 +102,7 @@ struct AttentionKernel { | |||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim] | |||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim] | |||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] | |||
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(You might not know the answer, but) I'm wondering how much we could gain by having:
(1) The bias already in the right format (eg if you know you want to use this bias only with MHA, you could store it in a format easy to load from gmem directly, without having to go through shared-memory - this format would be different depending on the kernel running tho)
(2) A boolean mask or some datatype with even lower precision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
i looked briefly into trying to load the bias directly from gmem but by using different threadmaps in the predicated tile iterator, rather than using different input format for the bias. Wasn't able to get the loaded fragment to match the accumulator fragment. I'm not too sure about having a different bias format but it could be tricky + require the user to understand internals because the elements of the accumulator tile each thread ends up and they way they are ordered in the fragment depends on architecture and MMA configuration. @hwu36 might know more
-
i'd imagine this would be an improvement for the use cases that don't need floating point bias, less memory traffic and each thread loads 128 bits at a time and fewer bits -> fewer loads. would require more template specializations though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you please let me know more about this bias? Is it after the first gemm or the second gemm? Is it a 2D matrix or a 1D vector? If it is a vector, does every row have different values or every columns?
The accumulator layout of one 1688 tensor core is like
t0 t0 t1 t1 t2 t2 t3 t3
t4 t4 t5 t5 t6 t6 t7 t7
t8 t8 t9 t9 t10 t10 t11 t11
...
...
t29 t29 t30 t30 t31 t31
t0 t0 t1 t1 t2 t2 t3 t3
t4 t4 t5 t5 t6 t6 t7 t7
t8 t8 t9 t9 t10 t10 t11 t11
...
...
t29 t29 t30 t30 t31 t31
Then you need to add different offsets such as threadblock offset and warp offset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied after 1st GEMM (Q @ K.T) and its 2d matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it is a 2D matrix, using shared memory may not be bad. if you load directly, you can only load 2 elements a time, it is not the most efficient way for the memory BW. If you transform through the shared memory, you can load 128bit data a time to fully use the memory BW
}, | ||
[&](int accum_m) {}); | ||
} | ||
|
||
// Mask out last if causal | ||
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting - we can still have causal masking + custom mask. We would need to find a way to expose that properly in a follow-up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep. i think that for users that need both additive bias and unidirectionality, using causal masking in addition to attention bias might be faster than just using attention bias with some of the values replaced with -inf since the kernel can use the knowledge that its causal to avoid some unnecessary compute. if thats true it might be unintuitive to some users though
Benchmarks as of 0062b22: A100 fw - maybe 5% slowdown on f32, similar perf on f16[---------------- attention (attn_bias=<class 'NoneType'>) ----------------]
| pr587_0062 | main | eager
1 threads: -----------------------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 125.0 | 121.3 | 851.8
f32 B=384, M=197, H=1, K=88 | 463.3 | 447.4 | 719.5
f16 B=384, M=197, H=1, K=80 | 116.2 | 112.2 | 742.2
f32 B=384, M=197, H=1, K=80 | 459.6 | 443.5 | 692.4
f16 B=384, M=197, H=1, K=64 | 87.8 | 86.7 | 679.8
f32 B=384, M=197, H=1, K=64 | 281.1 | 268.7 | 640.5
f16 B=1024, M=197, H=1, K=88 | 314.1 | 306.6 | 2221.3
f32 B=1024, M=197, H=1, K=88 | 1220.4 | 1168.2 | 1765.2
f16 B=1024, M=197, H=1, K=80 | 294.5 | 285.5 | 1927.2
f32 B=1024, M=197, H=1, K=80 | 1212.9 | 1160.4 | 1700.6
f16 B=1024, M=197, H=1, K=64 | 212.5 | 208.7 | 1770.5
f32 B=1024, M=197, H=1, K=64 | 689.2 | 661.7 | 1568.6
f16 B=512, M=197, H=1, K=80 | 153.0 | 148.2 | 979.2
f32 B=512, M=197, H=1, K=80 | 614.9 | 592.4 | 881.7
f16 B=32, M=197, H=16, K=80 | 153.9 | 149.2 | 1064.0
f32 B=32, M=197, H=16, K=80 | 618.0 | 594.5 | 1051.0
f16 B=32, M=197, H=16, K=64 | 113.4 | 111.7 | 979.9
f32 B=32, M=197, H=16, K=64 | 355.1 | 341.4 | 949.7
f16 B=32, M=197, H=16, K=128 | 168.2 | 162.5 | 1717.6
f32 B=32, M=197, H=16, K=128 | 683.9 | 661.8 | 1324.7
f16 B=256, M=197, H=1, K=88 | 87.3 | 85.0 | 577.0
f32 B=256, M=197, H=1, K=88 | 318.7 | 306.9 | 477.8
f16 B=16, M=197, H=16, K=88 | 89.0 | 86.2 | 628.7
f32 B=16, M=197, H=16, K=88 | 320.7 | 308.4 | 572.9
f16 B=16, M=197, H=16, K=64 | 65.0 | 61.1 | 506.9
f32 B=16, M=197, H=16, K=64 | 195.9 | 186.9 | 494.9
f16 B=16, M=197, H=16, K=128 | 90.2 | 87.0 | 877.3
f32 B=16, M=197, H=16, K=128 | 352.8 | 339.6 | 692.8
f16 B=1, M=4096, H=160, K=128 | 15244.1 | 14853.7 | 21510.6
f32 B=1, M=4096, H=160, K=128 | 57973.4 | 56691.3 | 91407.5
f16 B=2, M=4096, H=160, K=128 | 30414.8 | 29633.5 | 43562.2
f32 B=2, M=4096, H=160, K=128 | 115750.4 | 113237.0 |
f16 B=1, M=8192, H=160, K=128 | 60840.4 | 59240.6 |
f32 B=1, M=8192, H=160, K=128 | 232306.8 | 226792.5 |
f16 B=2, M=8192, H=160, K=128 | 121618.1 | 118414.6 |
f32 B=2, M=8192, H=160, K=128 | 465145.6 | 453321.1 |
f16 B=1024, M=82, H=8, K=64 | 476.0 | 446.3 | 1785.2
f32 B=1024, M=82, H=8, K=64 | 1430.9 | 1314.3 | 3753.8
f16 B=150, M=256, H=16, K=64 | 512.3 | 503.1 | 1964.8
f32 B=150, M=256, H=16, K=64 | 1691.3 | 1602.7 | 5294.3
f16 B=64, M=256, H=12, K=64 | 172.2 | 169.8 | 664.2
f32 B=64, M=256, H=12, K=64 | 559.7 | 533.0 | 1760.4
f16 B=1, M=4096, H=16, K=40 | 857.3 | 870.6 | 1958.3
f32 B=1, M=4096, H=16, K=40 | 2865.9 | 2765.6 | 6914.2
f16 B=1, M=16384, H=16, K=40 | 12168.3 | 12346.3 | 30460.6
f32 B=1, M=16384, H=16, K=40 | 41930.8 | 40533.2 | 123282.2
f16 B=256, M=4096, H=16, K=64 | 181232.1 | 183488.9 |
f32 B=256, M=4096, H=16, K=64 | 665659.8 | 642960.3 |
f16 B=16, M=128, H=16, K=16 | 57.9 | 60.4 | 155.3
f32 B=16, M=128, H=16, K=16 | 58.5 | 60.7 | 151.5
f16 B=16, M=128, H=16, K=32 | 57.8 | 60.7 | 154.5
f32 B=16, M=128, H=16, K=32 | 58.7 | 60.8 | 175.9
f16 B=16, M=128, H=16, K=64 | 57.9 | 60.6 | 155.4
f32 B=16, M=128, H=16, K=64 | 66.0 | 62.2 | 217.8
f16 B=16, M=128, H=16, K=128 | 58.1 | 60.2 | 155.4
f32 B=16, M=128, H=16, K=128 | 124.1 | 117.2 | 311.9
f16 B=16, M=128, H=16, K=256 | 78.6 | 76.5 | 253.0
f32 B=16, M=128, H=16, K=256 | 224.9 | 210.4 | 534.7
f16 B=16, M=512, H=16, K=16 | 174.8 | 173.5 | 516.1
f32 B=16, M=512, H=16, K=16 | 576.9 | 553.6 | 1629.1
f16 B=16, M=512, H=16, K=32 | 182.5 | 180.5 | 564.7
f32 B=16, M=512, H=16, K=32 | 586.2 | 562.4 | 1781.3
f16 B=16, M=512, H=16, K=64 | 209.7 | 207.9 | 673.7
f32 B=16, M=512, H=16, K=64 | 709.2 | 678.4 | 2107.5
f16 B=16, M=512, H=16, K=128 | 363.0 | 353.8 | 856.0
f32 B=16, M=512, H=16, K=128 | 1539.8 | 1493.4 | 2760.0
f16 B=16, M=512, H=16, K=256 | 819.3 | 812.9 | 1230.8
f32 B=16, M=512, H=16, K=256 | 3121.5 | 2984.7 | 4922.2
f16 B=16, M=1024, H=16, K=16 | 667.3 | 662.9 | 1857.2
f32 B=16, M=1024, H=16, K=16 | 2209.8 | 2130.1 | 6088.4
f16 B=16, M=1024, H=16, K=32 | 673.6 | 669.0 | 1951.1
f32 B=16, M=1024, H=16, K=32 | 2233.2 | 2150.0 | 6591.0
f16 B=16, M=1024, H=16, K=64 | 762.9 | 766.7 | 2192.7
f32 B=16, M=1024, H=16, K=64 | 2701.4 | 2597.3 | 7687.2
f16 B=16, M=1024, H=16, K=128 | 1346.0 | 1311.5 | 2618.4
f32 B=16, M=1024, H=16, K=128 | 5926.2 | 5772.7 | 9881.6
f16 B=16, M=1024, H=16, K=256 | 3120.2 | 3096.7 | 3436.4
f32 B=16, M=1024, H=16, K=256 | 12178.1 | 11659.9 | 17896.5
f16 B=64, M=128, H=16, K=16 | 58.2 | 61.0 | 203.5
f32 B=64, M=128, H=16, K=16 | 166.8 | 157.7 | 489.0
f16 B=64, M=128, H=16, K=32 | 59.3 | 61.2 | 250.3
f32 B=64, M=128, H=16, K=32 | 174.7 | 165.0 | 574.8
f16 B=64, M=128, H=16, K=64 | 74.9 | 73.2 | 349.0
f32 B=64, M=128, H=16, K=64 | 215.3 | 203.6 | 737.7
f16 B=64, M=128, H=16, K=128 | 131.5 | 129.1 | 534.2
f32 B=64, M=128, H=16, K=128 | 449.1 | 427.7 | 1070.0
f16 B=64, M=128, H=16, K=256 | 257.1 | 253.4 | 914.4
f32 B=64, M=128, H=16, K=256 | 833.6 | 791.8 | 1921.3
f16 B=64, M=512, H=16, K=16 | 683.6 | 676.5 | 1893.0
f32 B=64, M=512, H=16, K=16 | 2246.6 | 2161.4 | 6273.1
f16 B=64, M=512, H=16, K=32 | 693.6 | 684.8 | 2096.1
f32 B=64, M=512, H=16, K=32 | 2279.6 | 2190.8 | 6822.3
f16 B=64, M=512, H=16, K=64 | 787.5 | 787.1 | 2495.9
f32 B=64, M=512, H=16, K=64 | 2748.7 | 2635.4 | 8108.3
f16 B=64, M=512, H=16, K=128 | 1416.0 | 1380.4 | 3264.1
f32 B=64, M=512, H=16, K=128 | 6080.3 | 5895.7 | 10731.7
f16 B=64, M=512, H=16, K=256 | 3214.8 | 3181.7 | 4760.3
f32 B=64, M=512, H=16, K=256 | 12433.4 | 11825.6 | 19614.2
f16 B=64, M=1024, H=16, K=16 | 2611.2 | 2593.0 | 7293.8
f32 B=64, M=1024, H=16, K=16 | 8726.2 | 8424.4 | 24237.0
f16 B=64, M=1024, H=16, K=32 | 2632.2 | 2612.0 | 7675.3
f32 B=64, M=1024, H=16, K=32 | 8812.5 | 8494.5 | 26228.7
f16 B=64, M=1024, H=16, K=64 | 2980.8 | 2991.8 | 8669.3
f32 B=64, M=1024, H=16, K=64 | 10665.4 | 10271.7 | 30612.5
f16 B=64, M=1024, H=16, K=128 | 5393.5 | 5266.6 | 10313.9
f32 B=64, M=1024, H=16, K=128 | 23526.3 | 22940.7 | 39341.6
f16 B=64, M=1024, H=16, K=256 | 12266.5 | 12354.8 | 13558.9
f32 B=64, M=1024, H=16, K=256 | 48552.5 | 46488.2 | 71460.8
Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
| pr587_0062 | main | eager
1 threads: -----------------------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 95.1 | 92.0 | 918.3
f32 B=384, M=197, H=1, K=88 | 336.6 | 321.6 | 779.4
f16 B=384, M=197, H=1, K=80 | 90.4 | 86.8 | 812.3
f32 B=384, M=197, H=1, K=80 | 333.1 | 318.7 | 752.1
f16 B=384, M=197, H=1, K=64 | 67.2 | 65.3 | 749.9
f32 B=384, M=197, H=1, K=64 | 203.3 | 191.7 | 705.0
f16 B=1024, M=197, H=1, K=88 | 231.4 | 223.7 | 2389.0
f32 B=1024, M=197, H=1, K=88 | 855.0 | 818.3 | 1907.8
f16 B=1024, M=197, H=1, K=80 | 219.2 | 211.0 | 2105.1
f32 B=1024, M=197, H=1, K=80 | 846.8 | 810.7 | 1843.8
f16 B=1024, M=197, H=1, K=64 | 153.0 | 149.8 | 1945.4
f32 B=1024, M=197, H=1, K=64 | 484.7 | 457.8 | 1719.9
f16 B=512, M=197, H=1, K=80 | 116.2 | 111.6 | 1071.6
f32 B=512, M=197, H=1, K=80 | 435.9 | 417.0 | 961.4
f16 B=32, M=197, H=16, K=80 | 116.6 | 112.3 | 1154.2
f32 B=32, M=197, H=16, K=80 | 437.2 | 418.4 | 1131.0
f16 B=32, M=197, H=16, K=64 | 85.3 | 83.3 | 1068.3
f32 B=32, M=197, H=16, K=64 | 260.1 | 245.5 | 1027.6
f16 B=32, M=197, H=16, K=128 | 127.3 | 123.0 | 1802.6
f32 B=32, M=197, H=16, K=128 | 490.9 | 467.4 | 1403.9
f16 B=256, M=197, H=1, K=88 | 67.6 | 65.3 | 622.5
f32 B=256, M=197, H=1, K=88 | 232.2 | 222.5 | 523.1
f16 B=16, M=197, H=16, K=88 | 67.8 | 65.6 | 674.0
f32 B=16, M=197, H=16, K=88 | 233.0 | 223.4 | 618.4
f16 B=16, M=197, H=16, K=64 | 64.5 | 60.7 | 552.6
f32 B=16, M=197, H=16, K=64 | 147.1 | 139.1 | 542.7
f16 B=16, M=197, H=16, K=128 | 70.7 | 68.2 | 923.9
f32 B=16, M=197, H=16, K=128 | 258.8 | 246.7 | 737.2
f16 B=1, M=4096, H=160, K=128 | 7808.8 | 7595.2 | 38531.2
f32 B=1, M=4096, H=160, K=128 | 30015.9 | 29331.8 | 109017.4
f16 B=2, M=4096, H=160, K=128 | 15492.1 | 15094.0 | 78631.6
f32 B=2, M=4096, H=160, K=128 | 59694.2 | 58316.8 |
f16 B=1, M=8192, H=160, K=128 | 30803.3 | 29983.0 |
f32 B=1, M=8192, H=160, K=128 | 117655.3 | 115038.2 |
f16 B=2, M=8192, H=160, K=128 | 61373.3 | 59734.6 |
f32 B=2, M=8192, H=160, K=128 | 234708.5 | 229411.1 |
f16 B=1024, M=82, H=8, K=64 | 381.5 | 370.1 | 1997.2
f32 B=1024, M=82, H=8, K=64 | 1145.7 | 1063.2 | 3984.2
f16 B=150, M=256, H=16, K=64 | 358.8 | 352.6 | 2604.2
f32 B=150, M=256, H=16, K=64 | 1139.8 | 1071.0 | 5906.0
f16 B=64, M=256, H=12, K=64 | 124.8 | 122.8 | 873.6
f32 B=64, M=256, H=12, K=64 | 388.6 | 365.7 | 1963.7
f16 B=1, M=4096, H=16, K=40 | 529.2 | 538.7 | 3768.9
f32 B=1, M=4096, H=16, K=40 | 1699.7 | 1652.4 | 8888.6
f16 B=1, M=16384, H=16, K=40 | 6541.3 | 6637.7 | 59019.0
f32 B=1, M=16384, H=16, K=40 | 22413.7 | 21682.4 |
f16 B=256, M=4096, H=16, K=64 | 93055.1 | 94157.2 |
f32 B=256, M=4096, H=16, K=64 | 340536.6 | 328323.8 |
f16 B=16, M=128, H=16, K=16 | 57.8 | 60.9 | 161.1
f32 B=16, M=128, H=16, K=16 | 58.0 | 61.0 | 171.2
f16 B=16, M=128, H=16, K=32 | 57.9 | 60.7 | 160.9
f32 B=16, M=128, H=16, K=32 | 58.7 | 60.6 | 192.6
f16 B=16, M=128, H=16, K=64 | 57.6 | 60.8 | 161.7
f32 B=16, M=128, H=16, K=64 | 63.6 | 61.0 | 244.1
f16 B=16, M=128, H=16, K=128 | 57.8 | 60.4 | 166.6
f32 B=16, M=128, H=16, K=128 | 114.1 | 108.0 | 341.2
f16 B=16, M=128, H=16, K=256 | 70.1 | 68.1 | 283.8
f32 B=16, M=128, H=16, K=256 | 201.6 | 190.6 | 562.6
f16 B=16, M=512, H=16, K=16 | 117.4 | 115.7 | 824.6
f32 B=16, M=512, H=16, K=16 | 368.9 | 352.6 | 2052.2
f16 B=16, M=512, H=16, K=32 | 122.3 | 120.0 | 869.4
f32 B=16, M=512, H=16, K=32 | 377.7 | 360.1 | 2129.5
f16 B=16, M=512, H=16, K=64 | 140.4 | 139.4 | 961.3
f32 B=16, M=512, H=16, K=64 | 454.7 | 433.3 | 2362.8
f16 B=16, M=512, H=16, K=128 | 238.1 | 230.7 | 1131.0
f32 B=16, M=512, H=16, K=128 | 971.1 | 936.1 | 2994.9
f16 B=16, M=512, H=16, K=256 | 514.0 | 509.1 | 1503.4
f32 B=16, M=512, H=16, K=256 | 1971.1 | 1875.0 | 5159.1
f16 B=16, M=1024, H=16, K=16 | 389.8 | 385.9 | 3008.8
f32 B=16, M=1024, H=16, K=16 | 1259.5 | 1211.4 | 8011.4
f16 B=16, M=1024, H=16, K=32 | 394.5 | 390.2 | 3109.4
f32 B=16, M=1024, H=16, K=32 | 1275.1 | 1224.9 | 8143.3
f16 B=16, M=1024, H=16, K=64 | 447.3 | 448.2 | 3304.5
f32 B=16, M=1024, H=16, K=64 | 1540.9 | 1477.3 | 8691.8
f16 B=16, M=1024, H=16, K=128 | 788.6 | 765.3 | 3685.4
f32 B=16, M=1024, H=16, K=128 | 3354.7 | 3257.0 | 10811.9
f16 B=16, M=1024, H=16, K=256 | 1761.0 | 1749.4 | 4464.5
f32 B=16, M=1024, H=16, K=256 | 6908.9 | 6593.8 | 18779.5
f16 B=64, M=128, H=16, K=16 | 57.0 | 60.8 | 277.1
f32 B=64, M=128, H=16, K=16 | 138.8 | 130.5 | 600.0
f16 B=64, M=128, H=16, K=32 | 57.7 | 60.9 | 332.3
f32 B=64, M=128, H=16, K=32 | 146.7 | 138.2 | 672.4
f16 B=64, M=128, H=16, K=64 | 67.7 | 66.1 | 426.4
f32 B=64, M=128, H=16, K=64 | 179.3 | 169.3 | 819.9
f16 B=64, M=128, H=16, K=128 | 117.1 | 114.5 | 607.5
f32 B=64, M=128, H=16, K=128 | 404.7 | 383.1 | 1139.5
f16 B=64, M=128, H=16, K=256 | 232.0 | 228.8 | 988.8
f32 B=64, M=128, H=16, K=256 | 734.4 | 695.7 | 1991.6
f16 B=64, M=512, H=16, K=16 | 419.4 | 412.5 | 3068.8
f32 B=64, M=512, H=16, K=16 | 1332.0 | 1274.7 | 7904.7
f16 B=64, M=512, H=16, K=32 | 426.5 | 419.1 | 3256.0
f32 B=64, M=512, H=16, K=32 | 1360.2 | 1297.8 | 8190.3
f16 B=64, M=512, H=16, K=64 | 490.8 | 485.2 | 3628.3
f32 B=64, M=512, H=16, K=64 | 1638.6 | 1560.6 | 9073.3
f16 B=64, M=512, H=16, K=128 | 898.9 | 873.2 | 4331.2
f32 B=64, M=512, H=16, K=128 | 3732.3 | 3597.2 | 11582.6
f16 B=64, M=512, H=16, K=256 | 1968.1 | 1955.4 | 5832.9
f32 B=64, M=512, H=16, K=256 | 7635.7 | 7274.0 | 20462.0
f16 B=64, M=1024, H=16, K=16 | 1455.2 | 1441.0 | 11836.9
f32 B=64, M=1024, H=16, K=16 | 4773.1 | 4594.4 | 31900.4
f16 B=64, M=1024, H=16, K=32 | 1471.2 | 1454.8 | 12237.2
f32 B=64, M=1024, H=16, K=32 | 4834.2 | 4644.8 | 32409.2
f16 B=64, M=1024, H=16, K=64 | 1666.1 | 1668.8 | 13076.3
f32 B=64, M=1024, H=16, K=64 | 5841.9 | 5601.8 | 34561.1
f16 B=64, M=1024, H=16, K=128 | 3077.4 | 2981.2 | 14547.6
f32 B=64, M=1024, H=16, K=128 | 13149.3 | 12773.4 | 42996.6
f16 B=64, M=1024, H=16, K=256 | 6988.3 | 6942.8 | 17618.8
f32 B=64, M=1024, H=16, K=256 | 27189.1 | 25945.4 | 74993.3
Times are in microseconds (us). A100 bw - roughly equivalent, somehow faster now with `causal=True` on f16 (?)[------------ attention backward (attn_bias=<class 'NoneType'>) ------------]
| pr587_0062 | main | vanilla
1 threads: ------------------------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 716.7 | 651.2 | 2260.9
f32 B=384, M=197, H=1, K=88 | 2371.8 | 2330.6 | 1841.9
f16 B=384, M=197, H=1, K=80 | 688.7 | 621.9 | 1916.9
f32 B=384, M=197, H=1, K=80 | 2263.2 | 2229.2 | 1785.9
f16 B=384, M=197, H=1, K=64 | 423.7 | 459.5 | 1808.0
f32 B=384, M=197, H=1, K=64 | 1282.9 | 1262.5 | 1673.1
f16 B=1024, M=197, H=1, K=88 | 1819.1 | 1609.8 | 5941.3
f32 B=1024, M=197, H=1, K=88 | 6129.9 | 6051.2 | 4553.4
f16 B=1024, M=197, H=1, K=80 | 1731.1 | 1536.3 | 5022.0
f32 B=1024, M=197, H=1, K=80 | 5850.0 | 5778.6 | 4405.0
f16 B=1024, M=197, H=1, K=64 | 965.3 | 1037.3 | 4732.1
f32 B=1024, M=197, H=1, K=64 | 3345.0 | 3295.5 | 4112.7
f16 B=512, M=197, H=1, K=80 | 876.9 | 785.1 | 2533.9
f32 B=512, M=197, H=1, K=80 | 2899.7 | 2857.5 | 2283.9
f16 B=32, M=197, H=16, K=80 | 875.9 | 787.2 | 2568.0
f32 B=32, M=197, H=16, K=80 | 2895.1 | 2834.4 | 2351.5
f16 B=32, M=197, H=16, K=64 | 496.2 | 538.7 | 2430.3
f32 B=32, M=197, H=16, K=64 | 1821.3 | 1777.6 | 2195.1
f16 B=32, M=197, H=16, K=128 | 1035.3 | 928.2 | 4486.7
f32 B=32, M=197, H=16, K=128 | 3596.3 | 3544.5 | 2803.4
f16 B=256, M=197, H=1, K=88 | 515.3 | 477.7 | 1521.7
f32 B=256, M=197, H=1, K=88 | 1700.3 | 1675.8 | 1206.7
f16 B=16, M=197, H=16, K=88 | 513.2 | 473.4 | 1539.3
f32 B=16, M=197, H=16, K=88 | 1691.6 | 1664.6 | 1249.0
f16 B=16, M=197, H=16, K=64 | 253.0 | 276.0 | 1242.9
f32 B=16, M=197, H=16, K=64 | 1075.4 | 1060.9 | 1124.1
f16 B=16, M=197, H=16, K=128 | 575.2 | 526.5 | 2266.7
f32 B=16, M=197, H=16, K=128 | 1961.6 | 1935.8 | 1444.7
f16 B=1, M=4096, H=160, K=128 | 62894.9 | 67019.0 | 46384.8
f32 B=1, M=4096, H=160, K=128 | 237534.2 | 222376.4 |
f16 B=2, M=4096, H=160, K=128 | 106164.4 | 110240.7 |
f32 B=2, M=4096, H=160, K=128 | 374929.4 | 351572.8 |
f16 B=1, M=8192, H=160, K=128 | 245856.4 | 267465.8 |
f32 B=1, M=8192, H=160, K=128 | 942689.4 | 881885.6 |
f16 B=2, M=8192, H=160, K=128 | 419550.7 | 433848.1 |
f32 B=2, M=8192, H=160, K=128 | 1490911.6 | 1398395.3 |
f16 B=1024, M=82, H=8, K=64 | 2039.9 | 2111.2 | 3823.5
f32 B=1024, M=82, H=8, K=64 | 8516.3 | 8376.5 | 8720.2
f16 B=150, M=256, H=16, K=64 | 2341.4 | 2537.9 | 4560.5
f32 B=150, M=256, H=16, K=64 | 6266.0 | 6269.7 | 12921.7
f16 B=64, M=256, H=12, K=64 | 794.6 | 875.9 | 1499.4
f32 B=64, M=256, H=12, K=64 | 2149.5 | 2153.6 | 4260.9
f16 B=1, M=4096, H=16, K=40 | 23841.0 | 25712.6 | 4235.3
f32 B=1, M=4096, H=16, K=40 | 73752.6 | 73180.5 | 17706.1
f16 B=1, M=16384, H=16, K=40 | 397392.7 | 430370.9 |
f32 B=1, M=16384, H=16, K=40 | 1197343.6 | 1187422.5 |
f16 B=256, M=4096, H=16, K=64 | 742700.6 | 801632.2 |
f16 B=16, M=128, H=16, K=16 | 207.9 | 189.3 | 306.7
f32 B=16, M=128, H=16, K=16 | 248.0 | 231.0 | 373.0
f16 B=16, M=128, H=16, K=32 | 203.0 | 182.4 | 302.1
f32 B=16, M=128, H=16, K=32 | 246.3 | 226.6 | 413.2
f16 B=16, M=128, H=16, K=64 | 202.8 | 182.9 | 301.7
f32 B=16, M=128, H=16, K=64 | 277.7 | 273.2 | 499.2
f16 B=16, M=128, H=16, K=128 | 200.5 | 209.6 | 304.6
f32 B=16, M=128, H=16, K=128 | 510.1 | 488.3 | 672.2
f16 B=16, M=128, H=16, K=256 | 786.2 | 777.0 | 544.9
f32 B=16, M=128, H=16, K=256 | 974.6 | 937.4 | 1162.6
f16 B=16, M=512, H=16, K=16 | 640.5 | 713.5 | 1203.7
f32 B=16, M=512, H=16, K=16 | 2173.5 | 2150.9 | 4409.0
f16 B=16, M=512, H=16, K=32 | 723.6 | 805.8 | 1306.9
f32 B=16, M=512, H=16, K=32 | 2354.7 | 2343.7 | 4633.5
f16 B=16, M=512, H=16, K=64 | 927.5 | 1019.2 | 1544.1
f32 B=16, M=512, H=16, K=64 | 2990.9 | 2981.1 | 5115.9
f16 B=16, M=512, H=16, K=128 | 1842.4 | 1958.5 | 1984.9
f32 B=16, M=512, H=16, K=128 | 6131.2 | 5800.4 | 6086.4
f16 B=16, M=512, H=16, K=256 | 8430.5 | 8490.1 | 2902.9
f32 B=16, M=512, H=16, K=256 | 11834.2 | 11313.2 | 10617.2
f16 B=16, M=1024, H=16, K=16 | 2477.5 | 2809.0 | 4262.6
f32 B=16, M=1024, H=16, K=16 | 8526.8 | 8520.4 | 16608.1
f16 B=16, M=1024, H=16, K=32 | 2736.0 | 3086.4 | 4485.7
f32 B=16, M=1024, H=16, K=32 | 9032.9 | 9040.3 | 17262.9
f16 B=16, M=1024, H=16, K=64 | 3361.6 | 3721.9 | 4991.7
f32 B=16, M=1024, H=16, K=64 | 11625.7 | 11677.5 | 18670.5
f16 B=16, M=1024, H=16, K=128 | 6566.2 | 7003.0 | 5949.2
f32 B=16, M=1024, H=16, K=128 | 23315.6 | 21954.0 | 21480.0
f16 B=16, M=1024, H=16, K=256 | 31674.1 | 32062.5 | 7897.9
f32 B=16, M=1024, H=16, K=256 | 45039.1 | 42840.9 | 37951.9
f16 B=64, M=128, H=16, K=16 | 200.6 | 184.8 | 439.3
f32 B=64, M=128, H=16, K=16 | 497.2 | 495.2 | 1268.7
f16 B=64, M=128, H=16, K=32 | 262.4 | 241.0 | 545.3
f32 B=64, M=128, H=16, K=32 | 604.9 | 603.1 | 1425.5
f16 B=64, M=128, H=16, K=64 | 334.6 | 369.2 | 767.2
f32 B=64, M=128, H=16, K=64 | 873.3 | 871.9 | 1743.4
f16 B=64, M=128, H=16, K=128 | 698.0 | 723.9 | 1228.2
f32 B=64, M=128, H=16, K=128 | 1771.2 | 1699.5 | 2383.6
f16 B=64, M=128, H=16, K=256 | 2850.2 | 2888.7 | 2129.9
f32 B=64, M=128, H=16, K=256 | 3415.2 | 3289.3 | 4314.9
f16 B=64, M=512, H=16, K=16 | 2385.3 | 2629.4 | 4486.1
f32 B=64, M=512, H=16, K=16 | 6698.9 | 6719.5 | 16963.1
f16 B=64, M=512, H=16, K=32 | 2751.9 | 3005.5 | 4975.9
f32 B=64, M=512, H=16, K=32 | 7491.6 | 7497.9 | 17823.2
f16 B=64, M=512, H=16, K=64 | 3533.0 | 3876.7 | 5893.6
f32 B=64, M=512, H=16, K=64 | 9617.1 | 9634.2 | 19731.2
f16 B=64, M=512, H=16, K=128 | 6635.8 | 6871.8 | 7707.9
f32 B=64, M=512, H=16, K=128 | 21317.5 | 20087.6 | 23584.0
f16 B=64, M=512, H=16, K=256 | 31162.9 | 30844.5 | 11501.6
f32 B=64, M=512, H=16, K=256 | 40918.1 | 38994.0 | 42386.4
f16 B=64, M=1024, H=16, K=16 | 9388.2 | 10399.8 | 16846.5
f32 B=64, M=1024, H=16, K=16 | 26568.8 | 26744.9 | 66205.9
f16 B=64, M=1024, H=16, K=32 | 10683.5 | 11750.7 | 17866.1
f32 B=64, M=1024, H=16, K=32 | 28430.2 | 28477.5 | 68832.9
f16 B=64, M=1024, H=16, K=64 | 13117.1 | 14436.5 | 19915.5
f32 B=64, M=1024, H=16, K=64 | 35834.8 | 35988.5 | 74463.8
f16 B=64, M=1024, H=16, K=128 | 23610.9 | 24519.0 | 23742.3
f32 B=64, M=1024, H=16, K=128 | 80716.8 | 75406.6 | 85733.5
f16 B=64, M=1024, H=16, K=256 | 114888.4 | 115626.1 | 32765.2
f32 B=64, M=1024, H=16, K=256 | 155081.1 | 147906.3 | 152428.4
Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
| pr587_0062 | main | vanilla
1 threads: -----------------------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 565.6 | 527.5 | 2261.1
f32 B=384, M=197, H=1, K=88 | 1853.5 | 1791.2 | 1841.0
f16 B=384, M=197, H=1, K=80 | 538.6 | 501.6 | 1915.5
f32 B=384, M=197, H=1, K=80 | 1787.3 | 1722.7 | 1786.5
f16 B=384, M=197, H=1, K=64 | 284.1 | 325.0 | 1810.7
f32 B=384, M=197, H=1, K=64 | 979.8 | 974.5 | 1674.6
f16 B=1024, M=197, H=1, K=88 | 1425.0 | 1302.1 | 5939.6
f32 B=1024, M=197, H=1, K=88 | 4696.6 | 4595.7 | 4552.5
f16 B=1024, M=197, H=1, K=80 | 1360.8 | 1237.0 | 5019.9
f32 B=1024, M=197, H=1, K=80 | 4521.6 | 4435.3 | 4406.9
f16 B=1024, M=197, H=1, K=64 | 645.9 | 725.8 | 4732.0
f32 B=1024, M=197, H=1, K=64 | 2605.8 | 2548.6 | 4112.1
f16 B=512, M=197, H=1, K=80 | 688.9 | 634.7 | 2535.1
f32 B=512, M=197, H=1, K=80 | 2234.7 | 2198.3 | 2283.7
f16 B=32, M=197, H=16, K=80 | 695.4 | 640.3 | 2570.0
f32 B=32, M=197, H=16, K=80 | 2227.4 | 2196.0 | 2351.8
f16 B=32, M=197, H=16, K=64 | 339.2 | 378.3 | 2428.0
f32 B=32, M=197, H=16, K=64 | 1330.9 | 1371.3 | 2193.7
f16 B=32, M=197, H=16, K=128 | 832.6 | 765.2 | 4489.1
f32 B=32, M=197, H=16, K=128 | 2724.8 | 2684.6 | 2802.6
f16 B=256, M=197, H=1, K=88 | 406.9 | 386.0 | 1526.5
f32 B=256, M=197, H=1, K=88 | 1315.9 | 1291.3 | 1210.1
f16 B=16, M=197, H=16, K=88 | 407.3 | 385.3 | 1539.7
f32 B=16, M=197, H=16, K=88 | 1307.9 | 1282.6 | 1251.5
f16 B=16, M=197, H=16, K=64 | 200.6 | 192.4 | 1243.9
f32 B=16, M=197, H=16, K=64 | 812.8 | 809.9 | 1126.4
f16 B=16, M=197, H=16, K=128 | 460.6 | 431.8 | 2268.8
f32 B=16, M=197, H=16, K=128 | 1547.5 | 1484.0 | 1445.7
f16 B=1, M=4096, H=160, K=128 | 33562.5 | 36002.6 | 46369.0
f32 B=1, M=4096, H=160, K=128 | 123681.2 | 117500.2 |
f16 B=2, M=4096, H=160, K=128 | 56569.5 | 58882.2 |
f32 B=2, M=4096, H=160, K=128 | 196332.7 | 185992.3 |
f16 B=1, M=8192, H=160, K=128 | 128776.5 | 138568.8 |
f32 B=1, M=8192, H=160, K=128 | 482544.1 | 455362.6 |
f16 B=2, M=8192, H=160, K=128 | 217573.6 | 225203.4 |
f32 B=2, M=8192, H=160, K=128 | 763604.3 | 722918.1 |
f16 B=1024, M=82, H=8, K=64 | 1653.0 | 1726.6 | 3822.7
f32 B=1024, M=82, H=8, K=64 | 7709.0 | 7623.2 | 8710.4
f16 B=150, M=256, H=16, K=64 | 1651.4 | 1826.8 | 4561.7
f32 B=150, M=256, H=16, K=64 | 4488.8 | 4477.1 | 12926.9
f16 B=64, M=256, H=12, K=64 | 568.4 | 632.6 | 1500.9
f32 B=64, M=256, H=12, K=64 | 1539.8 | 1534.2 | 4260.9
f16 B=1, M=4096, H=16, K=40 | 11162.6 | 12100.5 | 4237.0
f32 B=1, M=4096, H=16, K=40 | 35687.6 | 35281.9 | 17692.4
f16 B=1, M=16384, H=16, K=40 | 198363.5 | 221542.6 |
f32 B=1, M=16384, H=16, K=40 | 597947.5 | 592061.1 |
f16 B=256, M=4096, H=16, K=64 | 389118.2 | 424073.1 |
f16 B=16, M=128, H=16, K=16 | 202.8 | 183.9 | 289.0
f32 B=16, M=128, H=16, K=16 | 245.7 | 227.3 | 373.5
f16 B=16, M=128, H=16, K=32 | 204.2 | 182.8 | 286.4
f32 B=16, M=128, H=16, K=32 | 243.5 | 227.3 | 415.3
f16 B=16, M=128, H=16, K=64 | 202.3 | 184.3 | 287.4
f32 B=16, M=128, H=16, K=64 | 241.6 | 231.1 | 502.4
f16 B=16, M=128, H=16, K=128 | 200.4 | 210.2 | 301.0
f32 B=16, M=128, H=16, K=128 | 509.8 | 489.2 | 679.4
f16 B=16, M=128, H=16, K=256 | 790.9 | 777.2 | 555.4
f32 B=16, M=128, H=16, K=256 | 975.4 | 939.9 | 1163.0
f16 B=16, M=512, H=16, K=16 | 360.4 | 413.9 | 1199.9
f32 B=16, M=512, H=16, K=16 | 1261.1 | 1242.3 | 4408.4
f16 B=16, M=512, H=16, K=32 | 424.1 | 484.5 | 1305.6
f32 B=16, M=512, H=16, K=32 | 1412.7 | 1400.3 | 4633.7
f16 B=16, M=512, H=16, K=64 | 577.0 | 641.4 | 1544.3
f32 B=16, M=512, H=16, K=64 | 1850.4 | 1833.5 | 5117.6
f16 B=16, M=512, H=16, K=128 | 1286.1 | 1375.8 | 1986.1
f32 B=16, M=512, H=16, K=128 | 4045.2 | 3852.5 | 6086.9
f16 B=16, M=512, H=16, K=256 | 5719.9 | 5757.7 | 2903.4
f32 B=16, M=512, H=16, K=256 | 7844.9 | 7501.2 | 10619.7
f16 B=16, M=1024, H=16, K=16 | 1317.8 | 1522.2 | 4256.2
f32 B=16, M=1024, H=16, K=16 | 4591.7 | 4583.4 | 16612.4
f16 B=16, M=1024, H=16, K=32 | 1486.2 | 1702.7 | 4478.4
f32 B=16, M=1024, H=16, K=32 | 4971.2 | 4968.6 | 17261.8
f16 B=16, M=1024, H=16, K=64 | 1914.7 | 2123.9 | 4987.3
f32 B=16, M=1024, H=16, K=64 | 6380.9 | 6376.7 | 18674.5
f16 B=16, M=1024, H=16, K=128 | 4028.0 | 4296.1 | 5947.0
f32 B=16, M=1024, H=16, K=128 | 13653.7 | 12937.8 | 21481.2
f16 B=16, M=1024, H=16, K=256 | 18675.2 | 19016.0 | 7896.4
f32 B=16, M=1024, H=16, K=256 | 26325.6 | 25151.9 | 37929.3
f16 B=64, M=128, H=16, K=16 | 200.7 | 184.6 | 440.2
f32 B=64, M=128, H=16, K=16 | 405.3 | 402.4 | 1270.6
f16 B=64, M=128, H=16, K=32 | 228.4 | 204.3 | 545.0
f32 B=64, M=128, H=16, K=32 | 512.3 | 508.2 | 1427.3
f16 B=64, M=128, H=16, K=64 | 288.0 | 312.9 | 773.7
f32 B=64, M=128, H=16, K=64 | 741.6 | 737.1 | 1743.0
f16 B=64, M=128, H=16, K=128 | 703.0 | 723.6 | 1226.1
f32 B=64, M=128, H=16, K=128 | 1774.8 | 1703.0 | 2383.4
f16 B=64, M=128, H=16, K=256 | 2854.0 | 2888.6 | 2129.0
f32 B=64, M=128, H=16, K=256 | 3410.0 | 3294.6 | 4314.2
f16 B=64, M=512, H=16, K=16 | 1315.1 | 1522.1 | 4483.1
f32 B=64, M=512, H=16, K=16 | 3871.6 | 3864.4 | 16965.0
f16 B=64, M=512, H=16, K=32 | 1609.5 | 1810.4 | 4972.8
f32 B=64, M=512, H=16, K=32 | 4508.8 | 4501.9 | 17822.1
f16 B=64, M=512, H=16, K=64 | 2225.9 | 2484.9 | 5891.3
f32 B=64, M=512, H=16, K=64 | 5978.6 | 5975.5 | 19736.5
f16 B=64, M=512, H=16, K=128 | 4688.2 | 4853.1 | 7704.5
f32 B=64, M=512, H=16, K=128 | 14127.6 | 13458.0 | 23594.8
f16 B=64, M=512, H=16, K=256 | 21160.4 | 21087.3 | 11491.5
f32 B=64, M=512, H=16, K=256 | 27188.5 | 25985.1 | 42300.0
f16 B=64, M=1024, H=16, K=16 | 4880.6 | 5585.3 | 16841.1
f32 B=64, M=1024, H=16, K=16 | 14349.5 | 14389.0 | 66224.9
f16 B=64, M=1024, H=16, K=32 | 5786.9 | 6465.5 | 17853.1
f32 B=64, M=1024, H=16, K=32 | 15835.9 | 15835.2 | 68841.8
f16 B=64, M=1024, H=16, K=64 | 7456.0 | 8286.4 | 19909.4
f32 B=64, M=1024, H=16, K=64 | 20260.5 | 20341.4 | 74454.6
f16 B=64, M=1024, H=16, K=128 | 14640.3 | 15119.5 | 23731.7
f32 B=64, M=1024, H=16, K=128 | 47175.5 | 44702.4 | 85699.6
f16 B=64, M=1024, H=16, K=256 | 68925.8 | 69018.2 | 32542.6
f32 B=64, M=1024, H=16, K=256 | 91071.8 | 87053.4 | 152328.7
Times are in microseconds (us). P100/V100 fw[------------------- attention (attn_bias=<class 'NoneType'>) -------------------]
| main | eager
1 threads: -----------------------------------------------------------------------
(Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 1510.0 | 1397.8
f32 B=384, M=197, H=1, K=88 | 1434.3 | 1509.2
f16 B=384, M=197, H=1, K=80 | 1457.3 | 1347.1
f32 B=384, M=197, H=1, K=80 | 1388.8 | 1459.8
f16 B=384, M=197, H=1, K=64 | 1079.7 | 1245.4
f32 B=384, M=197, H=1, K=64 | 988.8 | 1344.6
f16 B=1024, M=197, H=1, K=88 | 4065.9 | 3705.0
f32 B=1024, M=197, H=1, K=88 | 3873.3 | 3997.2
f16 B=1024, M=197, H=1, K=80 | 3904.2 | 3603.3
f32 B=1024, M=197, H=1, K=80 | 3784.7 | 3880.4
f16 B=1024, M=197, H=1, K=64 | 2880.6 | 3349.5
f32 B=1024, M=197, H=1, K=64 | 2634.3 | 3593.6
f16 B=512, M=197, H=1, K=80 | 1984.0 | 1846.5
f32 B=512, M=197, H=1, K=80 | 1924.1 | 1975.0
f16 B=32, M=197, H=16, K=80 | 2028.9 | 2209.5
f32 B=32, M=197, H=16, K=80 | 1958.0 | 2388.4
f16 B=32, M=197, H=16, K=64 | 1477.0 | 1990.0
f32 B=32, M=197, H=16, K=64 | 1350.7 | 2163.4
f16 B=32, M=197, H=16, K=128 | 2530.0 | 2878.7
f32 B=32, M=197, H=16, K=128 | 2429.7 | 3099.2
f16 B=256, M=197, H=1, K=88 | 1070.5 | 990.1
f32 B=256, M=197, H=1, K=88 | 1024.0 | 1053.3
f16 B=16, M=197, H=16, K=88 | 1072.3 | 1194.3
f32 B=16, M=197, H=16, K=88 | 1029.1 | 1282.4
f16 B=16, M=197, H=16, K=64 | 764.4 | 1028.8
f32 B=16, M=197, H=16, K=64 | 696.7 | 1108.2
f16 B=16, M=197, H=16, K=128 | 1293.0 | 1470.2
f32 B=16, M=197, H=16, K=128 | 1230.2 | 1592.4
f16 B=1, M=4096, H=160, K=128 | 252185.1 | 201705.3
f32 B=1, M=4096, H=160, K=128 | 241164.8 |
f16 B=2, M=4096, H=160, K=128 | 500526.4 |
f32 B=2, M=4096, H=160, K=128 | 485637.0 |
f16 B=1, M=8192, H=160, K=128 | 1015719.4 |
f32 B=1, M=8192, H=160, K=128 | 996319.1 |
f16 B=2, M=8192, H=160, K=128 | 2037663.8 |
f32 B=2, M=8192, H=160, K=128 | 1997234.2 |
f16 B=1024, M=82, H=8, K=64 | 5752.5 | 8562.6
f32 B=1024, M=82, H=8, K=64 | 5329.1 | 9109.9
f16 B=150, M=256, H=16, K=64 | 7729.5 | 11332.5
f32 B=150, M=256, H=16, K=64 | 7054.9 | 12679.5
f16 B=64, M=256, H=12, K=64 | 2511.9 | 3674.3
f32 B=64, M=256, H=12, K=64 | 2308.0 | 4095.5
f16 B=1, M=4096, H=16, K=40 | 11231.9 | 14588.4
f32 B=1, M=4096, H=16, K=40 | 9929.1 | 17735.1
f16 B=1, M=16384, H=16, K=40 | 170210.6 |
f32 B=1, M=16384, H=16, K=40 | 155516.2 |
f16 B=256, M=4096, H=16, K=64 | 3252023.7 |
f16 B=16, M=128, H=16, K=16 | 150.9 | 265.4
f32 B=16, M=128, H=16, K=16 | 141.5 | 300.4
f16 B=16, M=128, H=16, K=32 | 179.2 | 311.5
f32 B=16, M=128, H=16, K=32 | 166.3 | 357.9
f16 B=16, M=128, H=16, K=64 | 231.8 | 414.1
f32 B=16, M=128, H=16, K=64 | 227.0 | 462.4
f16 B=16, M=128, H=16, K=128 | 437.0 | 599.3
f32 B=16, M=128, H=16, K=128 | 452.5 | 685.6
f16 B=16, M=128, H=16, K=256 | 835.5 | 1116.1
f32 B=16, M=128, H=16, K=256 | 899.6 | 1358.8
f16 B=16, M=512, H=16, K=16 | 2150.5 | 3183.3
f32 B=16, M=512, H=16, K=16 | 1960.3 | 3646.8
f16 B=16, M=512, H=16, K=32 | 2548.4 | 3576.9
f32 B=16, M=512, H=16, K=32 | 2256.6 | 4017.2
f16 B=16, M=512, H=16, K=64 | 3323.5 | 4343.4
f32 B=16, M=512, H=16, K=64 | 2990.8 | 4862.4
f16 B=16, M=512, H=16, K=128 | 6408.5 | 5901.5
f32 B=16, M=512, H=16, K=128 | 6033.5 | 6559.6
f16 B=16, M=512, H=16, K=256 | 12975.4 | 10746.3
f32 B=16, M=512, H=16, K=256 | 12683.6 | 11748.1
f16 B=16, M=1024, H=16, K=16 | 8366.1 | 12316.8
f32 B=16, M=1024, H=16, K=16 | 7490.7 | 13839.4
f16 B=16, M=1024, H=16, K=32 | 9894.1 | 13713.1
f32 B=16, M=1024, H=16, K=32 | 8854.9 | 15169.7
f16 B=16, M=1024, H=16, K=64 | 12859.7 | 16233.0
f32 B=16, M=1024, H=16, K=64 | 11548.4 | 18025.9
f16 B=16, M=1024, H=16, K=128 | 25314.9 | 21504.5
f32 B=16, M=1024, H=16, K=128 | 23507.9 | 23364.1
f16 B=16, M=1024, H=16, K=256 | 51510.7 | 38749.2
f32 B=16, M=1024, H=16, K=256 | 50812.9 | 41896.2
f16 B=64, M=128, H=16, K=16 | 565.0 | 961.0
f32 B=64, M=128, H=16, K=16 | 524.3 | 1091.2
f16 B=64, M=128, H=16, K=32 | 665.8 | 1138.3
f32 B=64, M=128, H=16, K=32 | 613.0 | 1300.2
f16 B=64, M=128, H=16, K=64 | 866.5 | 1524.1
f32 B=64, M=128, H=16, K=64 | 825.7 | 1723.0
f16 B=64, M=128, H=16, K=128 | 1682.9 | 2270.2
f32 B=64, M=128, H=16, K=128 | 1707.2 | 2591.8
f16 B=64, M=128, H=16, K=256 | 3262.1 | 4236.3
f32 B=64, M=128, H=16, K=256 | 3443.2 | 5390.7
f16 B=64, M=512, H=16, K=16 | 8419.1 | 12452.5
f32 B=64, M=512, H=16, K=16 | 7496.2 | 14249.4
f16 B=64, M=512, H=16, K=32 | 9984.7 | 14043.7
f32 B=64, M=512, H=16, K=32 | 8902.3 | 15898.0
f16 B=64, M=512, H=16, K=64 | 13064.5 | 17211.0
f32 B=64, M=512, H=16, K=64 | 11695.4 | 19154.0
f16 B=64, M=512, H=16, K=128 | 25396.8 | 23241.0
f32 B=64, M=512, H=16, K=128 | 24079.8 | 25745.1
f16 B=64, M=512, H=16, K=256 | 51471.0 | 43083.7
f32 B=64, M=512, H=16, K=256 | 50596.1 | 47045.4
f16 B=64, M=1024, H=16, K=16 | 32927.4 | 49295.4
f32 B=64, M=1024, H=16, K=16 | 29545.4 | 55489.4
f16 B=64, M=1024, H=16, K=32 | 39128.1 | 54820.3
f32 B=64, M=1024, H=16, K=32 | 34791.5 | 59901.5
f16 B=64, M=1024, H=16, K=64 | 51089.9 | 65794.8
f32 B=64, M=1024, H=16, K=64 | 45600.0 | 72379.6
f16 B=64, M=1024, H=16, K=128 | 100577.7 | 85978.0
f32 B=64, M=1024, H=16, K=128 | 94634.7 | 94077.1
f16 B=64, M=1024, H=16, K=256 | 205416.2 | 156833.3
f32 B=64, M=1024, H=16, K=256 | 204171.7 |
(Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 280.5 | 550.9
f32 B=384, M=197, H=1, K=88 | 780.4 | 888.6
f16 B=384, M=197, H=1, K=80 | 270.8 | 526.0
f32 B=384, M=197, H=1, K=80 | 755.2 | 857.3
f16 B=384, M=197, H=1, K=64 | 191.4 | 418.9
f32 B=384, M=197, H=1, K=64 | 567.8 | 714.1
f16 B=1024, M=197, H=1, K=88 | 742.6 | 1425.3
f32 B=1024, M=197, H=1, K=88 | 2075.7 | 2352.3
f16 B=1024, M=197, H=1, K=80 | 711.3 | 1360.5
f32 B=1024, M=197, H=1, K=80 | 1994.1 | 2264.5
f16 B=1024, M=197, H=1, K=64 | 498.9 | 1075.2
f32 B=1024, M=197, H=1, K=64 | 1493.6 | 1866.1
f16 B=512, M=197, H=1, K=80 | 359.6 | 690.8
f32 B=512, M=197, H=1, K=80 | 1011.7 | 1147.9
f16 B=32, M=197, H=16, K=80 | 374.6 | 841.7
f32 B=32, M=197, H=16, K=80 | 1031.9 | 1383.7
f16 B=32, M=197, H=16, K=64 | 255.3 | 675.1
f32 B=32, M=197, H=16, K=64 | 763.7 | 1133.3
f16 B=32, M=197, H=16, K=128 | 425.2 | 1030.2
f32 B=32, M=197, H=16, K=128 | 1314.9 | 1902.7
f16 B=256, M=197, H=1, K=88 | 193.7 | 378.2
f32 B=256, M=197, H=1, K=88 | 540.4 | 612.0
f16 B=16, M=197, H=16, K=88 | 196.0 | 461.6
f32 B=16, M=197, H=16, K=88 | 542.0 | 746.9
f16 B=16, M=197, H=16, K=64 | 171.2 | 429.8
f32 B=16, M=197, H=16, K=64 | 397.3 | 609.0
f16 B=16, M=197, H=16, K=128 | 218.1 | 536.7
f32 B=16, M=197, H=16, K=128 | 673.3 | 966.1
f16 B=1, M=4096, H=160, K=128 | 35200.8 | 44509.1
f32 B=1, M=4096, H=160, K=128 | 139848.7 |
f16 B=2, M=4096, H=160, K=128 | 70836.1 |
f32 B=2, M=4096, H=160, K=128 | 279696.9 |
f16 B=1, M=8192, H=160, K=128 | 144262.0 |
f32 B=1, M=8192, H=160, K=128 | 578509.8 |
f16 B=2, M=8192, H=160, K=128 | 289830.3 |
f32 B=2, M=8192, H=160, K=128 | 1163174.7 |
f16 B=1024, M=82, H=8, K=64 | 1095.4 | 2496.5
f32 B=1024, M=82, H=8, K=64 | 2915.9 | 4424.4
f16 B=150, M=256, H=16, K=64 | 1295.8 | 3131.0
f32 B=150, M=256, H=16, K=64 | 4037.3 | 6739.0
f16 B=64, M=256, H=12, K=64 | 416.9 | 1025.5
f32 B=64, M=256, H=12, K=64 | 1304.1 | 2254.9
f16 B=1, M=4096, H=16, K=40 | 1996.3 | 4067.3
f32 B=1, M=4096, H=16, K=40 | 5851.3 | 8249.3
f16 B=1, M=16384, H=16, K=40 | 29646.1 |
f32 B=1, M=16384, H=16, K=40 | 86661.6 |
f16 B=256, M=4096, H=16, K=64 | 462010.1 |
f16 B=16, M=128, H=16, K=16 | 148.0 | 356.6
f32 B=16, M=128, H=16, K=16 | 144.8 | 357.5
f16 B=16, M=128, H=16, K=32 | 147.6 | 345.4
f32 B=16, M=128, H=16, K=32 | 149.3 | 335.8
f16 B=16, M=128, H=16, K=64 | 143.7 | 346.2
f32 B=16, M=128, H=16, K=64 | 151.9 | 350.6
f16 B=16, M=128, H=16, K=128 | 144.9 | 349.5
f32 B=16, M=128, H=16, K=128 | 241.8 | 462.1
f16 B=16, M=128, H=16, K=256 | 167.3 | 367.1
f32 B=16, M=128, H=16, K=256 | 477.2 | 862.7
f16 B=16, M=512, H=16, K=16 | 399.8 | 845.5
f32 B=16, M=512, H=16, K=16 | 1088.7 | 1830.8
f16 B=16, M=512, H=16, K=32 | 418.4 | 937.4
f32 B=16, M=512, H=16, K=32 | 1264.9 | 2095.9
f16 B=16, M=512, H=16, K=64 | 511.8 | 1143.3
f32 B=16, M=512, H=16, K=64 | 1681.4 | 2501.4
f16 B=16, M=512, H=16, K=128 | 965.5 | 1458.1
f32 B=16, M=512, H=16, K=128 | 3330.1 | 4066.9
f16 B=16, M=512, H=16, K=256 | 2761.3 | 2316.0
f32 B=16, M=512, H=16, K=256 | 7204.6 | 7218.3
f16 B=16, M=1024, H=16, K=16 | 1537.0 | 3433.3
f32 B=16, M=1024, H=16, K=16 | 4195.9 | 7165.5
f16 B=16, M=1024, H=16, K=32 | 1589.4 | 3673.3
f32 B=16, M=1024, H=16, K=32 | 4958.9 | 7966.2
f16 B=16, M=1024, H=16, K=64 | 1945.7 | 4154.1
f32 B=16, M=1024, H=16, K=64 | 6528.0 | 9394.8
f16 B=16, M=1024, H=16, K=128 | 3621.1 | 4855.3
f32 B=16, M=1024, H=16, K=128 | 13031.0 | 15398.9
f16 B=16, M=1024, H=16, K=256 | 11071.2 | 7464.6
f32 B=16, M=1024, H=16, K=256 | 28171.8 | 26984.5
f16 B=64, M=128, H=16, K=16 | 139.2 | 364.1
f32 B=64, M=128, H=16, K=16 | 288.3 | 517.2
f16 B=64, M=128, H=16, K=32 | 172.0 | 355.8
f32 B=64, M=128, H=16, K=32 | 339.6 | 657.9
f16 B=64, M=128, H=16, K=64 | 170.1 | 492.8
f32 B=64, M=128, H=16, K=64 | 465.8 | 935.6
f16 B=64, M=128, H=16, K=128 | 323.9 | 765.4
f32 B=64, M=128, H=16, K=128 | 915.7 | 1600.0
f16 B=64, M=128, H=16, K=256 | 624.2 | 1326.2
f32 B=64, M=128, H=16, K=256 | 1825.8 | 2868.1
f16 B=64, M=512, H=16, K=16 | 1576.0 | 3237.9
f32 B=64, M=512, H=16, K=16 | 4193.3 | 7437.9
f16 B=64, M=512, H=16, K=32 | 1649.4 | 3603.9
f32 B=64, M=512, H=16, K=32 | 4969.9 | 8514.8
f16 B=64, M=512, H=16, K=64 | 2015.4 | 4469.6
f32 B=64, M=512, H=16, K=64 | 6611.6 | 10198.4
f16 B=64, M=512, H=16, K=128 | 3794.5 | 5712.6
f32 B=64, M=512, H=16, K=128 | 13156.5 | 16556.8
f16 B=64, M=512, H=16, K=256 | 10970.2 | 9144.8
f32 B=64, M=512, H=16, K=256 | 28069.9 | 30192.7
f16 B=64, M=1024, H=16, K=16 | 6039.1 | 14015.6
f32 B=64, M=1024, H=16, K=16 | 16379.4 | 28076.7
f16 B=64, M=1024, H=16, K=32 | 6222.1 | 14556.7
f32 B=64, M=1024, H=16, K=32 | 19341.7 | 30747.7
f16 B=64, M=1024, H=16, K=64 | 7476.1 | 17005.3
f32 B=64, M=1024, H=16, K=64 | 25863.7 | 38337.4
f16 B=64, M=1024, H=16, K=128 | 14205.3 | 19211.8
f32 B=64, M=1024, H=16, K=128 | 51596.0 | 60226.9
f16 B=64, M=1024, H=16, K=256 | 43848.4 | 29946.7
f32 B=64, M=1024, H=16, K=256 | 111474.6 |
Times are in microseconds (us).
[- attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) -]
| main | eager
1 threads: -----------------------------------------------------------------------
(Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 1063.0 | 1641.4
f32 B=384, M=197, H=1, K=88 | 1040.8 | 1722.3
f16 B=384, M=197, H=1, K=80 | 1020.5 | 1588.0
f32 B=384, M=197, H=1, K=80 | 1005.1 | 1675.2
f16 B=384, M=197, H=1, K=64 | 758.3 | 1493.9
f32 B=384, M=197, H=1, K=64 | 683.5 | 1570.4
f16 B=1024, M=197, H=1, K=88 | 2817.5 | 4363.1
f32 B=1024, M=197, H=1, K=88 | 2779.7 | 4576.2
f16 B=1024, M=197, H=1, K=80 | 2722.3 | 4233.4
f32 B=1024, M=197, H=1, K=80 | 2694.0 | 4471.1
f16 B=1024, M=197, H=1, K=64 | 1997.4 | 4000.6
f32 B=1024, M=197, H=1, K=64 | 1808.7 | 4174.4
f16 B=512, M=197, H=1, K=80 | 1386.9 | 2166.3
f32 B=512, M=197, H=1, K=80 | 1371.1 | 2283.7
f16 B=32, M=197, H=16, K=80 | 1395.5 | 2543.1
f32 B=32, M=197, H=16, K=80 | 1383.4 | 2674.6
f16 B=32, M=197, H=16, K=64 | 1030.3 | 2327.5
f32 B=32, M=197, H=16, K=64 | 933.1 | 2465.0
f16 B=32, M=197, H=16, K=128 | 1760.5 | 3212.4
f32 B=32, M=197, H=16, K=128 | 1745.0 | 3387.1
f16 B=256, M=197, H=1, K=88 | 753.8 | 1151.7
f32 B=256, M=197, H=1, K=88 | 729.1 | 1206.3
f16 B=16, M=197, H=16, K=88 | 751.0 | 1361.1
f32 B=16, M=197, H=16, K=88 | 738.2 | 1446.1
f16 B=16, M=197, H=16, K=64 | 538.5 | 1198.4
f32 B=16, M=197, H=16, K=64 | 490.5 | 1269.8
f16 B=16, M=197, H=16, K=128 | 915.2 | 1650.7
f32 B=16, M=197, H=16, K=128 | 890.3 | 1730.6
f16 B=1, M=4096, H=160, K=128 | 128716.3 |
f32 B=1, M=4096, H=160, K=128 | 123542.7 |
f16 B=2, M=4096, H=160, K=128 | 256642.2 |
f32 B=2, M=4096, H=160, K=128 | 245869.1 |
f16 B=1, M=8192, H=160, K=128 | 507149.2 |
f32 B=1, M=8192, H=160, K=128 | 501763.7 |
f16 B=2, M=8192, H=160, K=128 | 1012807.4 |
f32 B=2, M=8192, H=160, K=128 | 1008124.8 |
f16 B=1024, M=82, H=8, K=64 | 4756.3 | 9552.2
f32 B=1024, M=82, H=8, K=64 | 4346.6 | 10218.6
f16 B=150, M=256, H=16, K=64 | 5001.0 | 13734.1
f32 B=150, M=256, H=16, K=64 | 4644.7 | 14799.1
f16 B=64, M=256, H=12, K=64 | 1640.0 | 4390.8
f32 B=64, M=256, H=12, K=64 | 1532.0 | 4743.6
f16 B=1, M=4096, H=16, K=40 | 5921.3 | 18371.5
f32 B=1, M=4096, H=16, K=40 | 5263.0 | 22564.1
f16 B=1, M=16384, H=16, K=40 | 88312.9 |
f32 B=1, M=16384, H=16, K=40 | 81014.7 |
f16 B=256, M=4096, H=16, K=64 | 1657734.4 |
f16 B=16, M=128, H=16, K=16 | 126.8 | 339.5
f32 B=16, M=128, H=16, K=16 | 120.6 | 373.7
f16 B=16, M=128, H=16, K=32 | 147.7 | 381.4
f32 B=16, M=128, H=16, K=32 | 136.6 | 426.2
f16 B=16, M=128, H=16, K=64 | 189.5 | 477.6
f32 B=16, M=128, H=16, K=64 | 184.7 | 530.4
f16 B=16, M=128, H=16, K=128 | 367.7 | 667.8
f32 B=16, M=128, H=16, K=128 | 374.2 | 749.5
f16 B=16, M=128, H=16, K=256 | 701.9 | 1173.9
f32 B=16, M=128, H=16, K=256 | 714.6 | 1419.3
f16 B=16, M=512, H=16, K=16 | 1272.5 | 4180.7
f32 B=16, M=512, H=16, K=16 | 1148.0 | 4759.7
f16 B=16, M=512, H=16, K=32 | 1502.7 | 4573.0
f32 B=16, M=512, H=16, K=32 | 1350.9 | 5028.5
f16 B=16, M=512, H=16, K=64 | 1925.4 | 5352.1
f32 B=16, M=512, H=16, K=64 | 1770.0 | 5705.0
f16 B=16, M=512, H=16, K=128 | 3844.9 | 6844.9
f32 B=16, M=512, H=16, K=128 | 3692.4 | 7410.2
f16 B=16, M=512, H=16, K=256 | 7677.2 | 11763.5
f32 B=16, M=512, H=16, K=256 | 7520.1 | 12469.9
f16 B=16, M=1024, H=16, K=16 | 4571.1 | 16594.1
f32 B=16, M=1024, H=16, K=16 | 4116.6 | 19868.4
f16 B=16, M=1024, H=16, K=32 | 5390.4 | 17848.7
f32 B=16, M=1024, H=16, K=32 | 4856.3 | 20576.0
f16 B=16, M=1024, H=16, K=64 | 7004.4 | 20169.9
f32 B=16, M=1024, H=16, K=64 | 6332.3 | 22698.9
f16 B=16, M=1024, H=16, K=128 | 13939.1 | 25708.1
f32 B=16, M=1024, H=16, K=128 | 13245.0 | 28058.6
f16 B=16, M=1024, H=16, K=256 | 28277.4 | 43267.2
f32 B=16, M=1024, H=16, K=256 | 27680.0 | 46053.4
f16 B=64, M=128, H=16, K=16 | 454.9 | 1231.9
f32 B=64, M=128, H=16, K=16 | 418.2 | 1369.6
f16 B=64, M=128, H=16, K=32 | 531.9 | 1408.3
f32 B=64, M=128, H=16, K=32 | 494.8 | 1570.4
f16 B=64, M=128, H=16, K=64 | 689.3 | 1777.4
f32 B=64, M=128, H=16, K=64 | 649.5 | 1968.9
f16 B=64, M=128, H=16, K=128 | 1417.5 | 2529.1
f32 B=64, M=128, H=16, K=128 | 1437.2 | 2821.2
f16 B=64, M=128, H=16, K=256 | 2680.2 | 4523.1
f32 B=64, M=128, H=16, K=256 | 2752.2 | 5524.1
f16 B=64, M=512, H=16, K=16 | 4880.6 | 16602.6
f32 B=64, M=512, H=16, K=16 | 4416.0 | 18778.5
f16 B=64, M=512, H=16, K=32 | 5769.7 | 18066.7
f32 B=64, M=512, H=16, K=32 | 5192.4 | 19832.4
f16 B=64, M=512, H=16, K=64 | 7508.1 | 21208.7
f32 B=64, M=512, H=16, K=64 | 6824.1 | 22507.2
f16 B=64, M=512, H=16, K=128 | 15071.8 | 27390.6
f32 B=64, M=512, H=16, K=128 | 14547.0 | 29146.1
f16 B=64, M=512, H=16, K=256 | 30054.6 | 46986.3
f32 B=64, M=512, H=16, K=256 | 29807.5 | 50467.8
f16 B=64, M=1024, H=16, K=16 | 17881.9 | 67019.4
f32 B=64, M=1024, H=16, K=16 | 16107.3 | 79155.4
f16 B=64, M=1024, H=16, K=32 | 21207.7 | 72235.5
f32 B=64, M=1024, H=16, K=32 | 19099.1 | 82273.1
f16 B=64, M=1024, H=16, K=64 | 27819.1 | 83763.2
f32 B=64, M=1024, H=16, K=64 | 24771.3 | 91611.2
f16 B=64, M=1024, H=16, K=128 | 55023.8 | 104599.3
f32 B=64, M=1024, H=16, K=128 | 52263.7 |
f16 B=64, M=1024, H=16, K=256 | 111388.9 | 175503.2
f32 B=64, M=1024, H=16, K=256 | 109982.0 |
(Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 210.6 | 643.6
f32 B=384, M=197, H=1, K=88 | 545.5 | 1029.5
f16 B=384, M=197, H=1, K=80 | 202.4 | 623.9
f32 B=384, M=197, H=1, K=80 | 526.7 | 999.5
f16 B=384, M=197, H=1, K=64 | 144.6 | 515.1
f32 B=384, M=197, H=1, K=64 | 398.2 | 857.8
f16 B=1024, M=197, H=1, K=88 | 552.3 | 1661.8
f32 B=1024, M=197, H=1, K=88 | 1438.1 | 2720.2
f16 B=1024, M=197, H=1, K=80 | 530.0 | 1605.3
f32 B=1024, M=197, H=1, K=80 | 1378.9 | 2626.9
f16 B=1024, M=197, H=1, K=64 | 368.1 | 1316.1
f32 B=1024, M=197, H=1, K=64 | 1024.9 | 2240.6
f16 B=512, M=197, H=1, K=80 | 270.4 | 823.3
f32 B=512, M=197, H=1, K=80 | 704.4 | 1332.7
f16 B=32, M=197, H=16, K=80 | 278.4 | 972.7
f32 B=32, M=197, H=16, K=80 | 720.9 | 1568.6
f16 B=32, M=197, H=16, K=64 | 192.9 | 801.1
f32 B=32, M=197, H=16, K=64 | 529.1 | 1325.7
f16 B=32, M=197, H=16, K=128 | 318.0 | 1146.5
f32 B=32, M=197, H=16, K=128 | 917.5 | 2045.7
f16 B=256, M=197, H=1, K=88 | 146.0 | 442.7
f32 B=256, M=197, H=1, K=88 | 381.0 | 706.5
f16 B=16, M=197, H=16, K=88 | 146.8 | 526.1
f32 B=16, M=197, H=16, K=88 | 383.3 | 839.9
f16 B=16, M=197, H=16, K=64 | 151.9 | 439.0
f32 B=16, M=197, H=16, K=64 | 280.0 | 707.9
f16 B=16, M=197, H=16, K=128 | 164.5 | 598.4
f32 B=16, M=197, H=16, K=128 | 472.0 | 1044.9
f16 B=1, M=4096, H=160, K=128 | 18344.5 |
f32 B=1, M=4096, H=160, K=128 | 71344.4 |
f16 B=2, M=4096, H=160, K=128 | 36563.5 |
f32 B=2, M=4096, H=160, K=128 | 141996.6 |
f16 B=1, M=8192, H=160, K=128 | 72900.7 |
f32 B=1, M=8192, H=160, K=128 | 284758.3 |
f16 B=2, M=8192, H=160, K=128 | 145710.6 |
f32 B=2, M=8192, H=160, K=128 | 568170.0 |
f16 B=1024, M=82, H=8, K=64 | 961.6 | 2798.5
f32 B=1024, M=82, H=8, K=64 | 2373.1 | 4951.5
f16 B=150, M=256, H=16, K=64 | 926.2 | 3821.3
f32 B=150, M=256, H=16, K=64 | 2624.9 | 7918.7
f16 B=64, M=256, H=12, K=64 | 303.3 | 1254.6
f32 B=64, M=256, H=12, K=64 | 858.8 | 2645.0
f16 B=1, M=4096, H=16, K=40 | 1072.2 | 5996.5
f32 B=1, M=4096, H=16, K=40 | 3053.7 | 11513.7
f16 B=1, M=16384, H=16, K=40 | 15403.6 |
f32 B=1, M=16384, H=16, K=40 | 44929.7 |
f16 B=256, M=4096, H=16, K=64 | 240596.0 |
f16 B=16, M=128, H=16, K=16 | 151.9 | 351.0
f32 B=16, M=128, H=16, K=16 | 148.8 | 344.7
f16 B=16, M=128, H=16, K=32 | 177.0 | 394.4
f32 B=16, M=128, H=16, K=32 | 184.5 | 340.0
f16 B=16, M=128, H=16, K=64 | 145.4 | 351.3
f32 B=16, M=128, H=16, K=64 | 171.4 | 352.3
f16 B=16, M=128, H=16, K=128 | 171.6 | 345.0
f32 B=16, M=128, H=16, K=128 | 196.7 | 503.1
f16 B=16, M=128, H=16, K=256 | 156.8 | 388.2
f32 B=16, M=128, H=16, K=256 | 388.2 | 900.9
f16 B=16, M=512, H=16, K=16 | 261.1 | 1191.4
f32 B=16, M=512, H=16, K=16 | 642.7 | 2520.3
f16 B=16, M=512, H=16, K=32 | 268.5 | 1287.8
f32 B=16, M=512, H=16, K=32 | 748.9 | 2752.7
f16 B=16, M=512, H=16, K=64 | 327.2 | 1437.7
f32 B=16, M=512, H=16, K=64 | 1004.8 | 3000.1
f16 B=16, M=512, H=16, K=128 | 614.1 | 1757.2
f32 B=16, M=512, H=16, K=128 | 2001.5 | 4532.4
f16 B=16, M=512, H=16, K=256 | 1525.0 | 2612.2
f32 B=16, M=512, H=16, K=256 | 4174.8 | 7621.3
f16 B=16, M=1024, H=16, K=16 | 876.6 | 5043.8
f32 B=16, M=1024, H=16, K=16 | 2286.3 | 10962.5
f16 B=16, M=1024, H=16, K=32 | 908.2 | 5179.5
f32 B=16, M=1024, H=16, K=32 | 2674.6 | 11440.7
f16 B=16, M=1024, H=16, K=64 | 1100.6 | 5472.6
f32 B=16, M=1024, H=16, K=64 | 3600.6 | 12196.1
f16 B=16, M=1024, H=16, K=128 | 2062.6 | 6217.9
f32 B=16, M=1024, H=16, K=128 | 7193.1 | 17906.8
f16 B=16, M=1024, H=16, K=256 | 5683.6 | 8804.5
f32 B=16, M=1024, H=16, K=256 | 15334.5 | 29569.0
f16 B=64, M=128, H=16, K=16 | 148.6 | 381.3
f32 B=64, M=128, H=16, K=16 | 230.5 | 689.7
f16 B=64, M=128, H=16, K=32 | 144.0 | 451.7
f32 B=64, M=128, H=16, K=32 | 270.6 | 813.7
f16 B=64, M=128, H=16, K=64 | 153.4 | 587.2
f32 B=64, M=128, H=16, K=64 | 370.7 | 1085.4
f16 B=64, M=128, H=16, K=128 | 284.4 | 862.4
f32 B=64, M=128, H=16, K=128 | 740.5 | 1720.8
f16 B=64, M=128, H=16, K=256 | 529.9 | 1408.8
f32 B=64, M=128, H=16, K=256 | 1460.1 | 2982.0
f16 B=64, M=512, H=16, K=16 | 987.3 | 4603.3
f32 B=64, M=512, H=16, K=16 | 2408.6 | 10136.8
f16 B=64, M=512, H=16, K=32 | 1027.3 | 4933.2
f32 B=64, M=512, H=16, K=32 | 2862.2 | 11037.5
f16 B=64, M=512, H=16, K=64 | 1255.5 | 5622.9
f32 B=64, M=512, H=16, K=64 | 3861.0 | 12125.9
f16 B=64, M=512, H=16, K=128 | 2405.2 | 6901.4
f32 B=64, M=512, H=16, K=128 | 7843.0 | 18276.2
f16 B=64, M=512, H=16, K=256 | 5983.3 | 10359.0
f32 B=64, M=512, H=16, K=256 | 16371.9 | 31659.4
f16 B=64, M=1024, H=16, K=16 | 3408.7 | 20126.8
f32 B=64, M=1024, H=16, K=16 | 8819.8 | 43940.0
f16 B=64, M=1024, H=16, K=32 | 3531.5 | 20639.3
f32 B=64, M=1024, H=16, K=32 | 10498.0 | 46144.9
f16 B=64, M=1024, H=16, K=64 | 4277.0 | 22159.0
f32 B=64, M=1024, H=16, K=64 | 14147.1 | 50392.1
f16 B=64, M=1024, H=16, K=128 | 8174.7 | 24830.6
f32 B=64, M=1024, H=16, K=128 | 28411.4 |
f16 B=64, M=1024, H=16, K=256 | 22481.4 | 35353.6
f32 B=64, M=1024, H=16, K=256 | 60469.3 |
Times are in microseconds (us). P100/V100 bw[-------------- attention backward (attn_bias=<class 'NoneType'>) ---------------]
| main | vanilla
1 threads: -----------------------------------------------------------------------
(Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6662.4 | 3591.8
f32 B=384, M=197, H=1, K=88 | 9584.0 | 4337.0
f16 B=384, M=197, H=1, K=80 | 6192.6 | 3437.6
f32 B=384, M=197, H=1, K=80 | 9158.2 | 4107.5
f16 B=384, M=197, H=1, K=64 | 3518.0 | 2927.8
f32 B=384, M=197, H=1, K=64 | 6136.2 | 3451.8
f16 B=1024, M=197, H=1, K=88 | 16310.1 | 9852.8
f32 B=1024, M=197, H=1, K=88 | 25756.1 | 12151.9
f16 B=1024, M=197, H=1, K=80 | 15522.3 | 9330.4
f32 B=1024, M=197, H=1, K=80 | 24578.1 | 11356.0
f16 B=1024, M=197, H=1, K=64 | 8935.0 | 7719.5
f32 B=1024, M=197, H=1, K=64 | 16599.7 | 9475.8
f16 B=512, M=197, H=1, K=80 | 7927.7 | 4632.1
f32 B=512, M=197, H=1, K=80 | 12806.3 | 5525.4
f16 B=32, M=197, H=16, K=80 | 8129.2 | 4891.3
f32 B=32, M=197, H=16, K=80 | 12774.1 | 5811.9
f16 B=32, M=197, H=16, K=64 | 4506.7 | 4068.1
f32 B=32, M=197, H=16, K=64 | 8681.2 | 4838.7
f16 B=32, M=197, H=16, K=128 | 9626.9 | 5991.5
f32 B=32, M=197, H=16, K=128 | 15544.2 | 7540.1
f16 B=256, M=197, H=1, K=88 | 4769.0 | 2451.1
f32 B=256, M=197, H=1, K=88 | 6682.9 | 2906.0
f16 B=16, M=197, H=16, K=88 | 4781.1 | 2549.4
f32 B=16, M=197, H=16, K=88 | 6629.7 | 3063.5
f16 B=16, M=197, H=16, K=64 | 2609.3 | 2042.1
f32 B=16, M=197, H=16, K=64 | 4322.7 | 2445.7
f16 B=16, M=197, H=16, K=128 | 5432.6 | 3014.7
f32 B=16, M=197, H=16, K=128 | 7794.2 | 3670.1
f16 B=1, M=4096, H=160, K=128 | 1033138.6 |
f32 B=1, M=4096, H=160, K=128 | 1264717.2 |
f16 B=2, M=4096, H=160, K=128 | 1689231.2 |
f32 B=2, M=4096, H=160, K=128 | 2511754.6 |
f16 B=1, M=8192, H=160, K=128 | 4110718.8 |
f32 B=1, M=8192, H=160, K=128 | 5051277.9 |
f16 B=2, M=8192, H=160, K=128 | 6751365.7 |
f16 B=1024, M=82, H=8, K=64 | 22967.1 | 18046.4
f32 B=1024, M=82, H=8, K=64 | 43698.8 | 22978.7
f16 B=150, M=256, H=16, K=64 | 23440.9 | 24551.6
f32 B=150, M=256, H=16, K=64 | 37480.4 | 32205.0
f16 B=64, M=256, H=12, K=64 | 7491.8 | 7716.8
f32 B=64, M=256, H=12, K=64 | 12214.8 | 9890.6
f16 B=1, M=4096, H=16, K=40 | 135707.0 | 29317.2
f32 B=1, M=4096, H=16, K=40 | 145042.0 | 37192.7
f16 B=1, M=16384, H=16, K=40 | 2150814.2 |
f32 B=1, M=16384, H=16, K=40 | 2295614.2 |
f16 B=16, M=128, H=16, K=16 | 517.6 | 572.7
f32 B=16, M=128, H=16, K=16 | 652.2 | 691.7
f16 B=16, M=128, H=16, K=32 | 601.6 | 677.2
f32 B=16, M=128, H=16, K=32 | 813.6 | 828.2
f16 B=16, M=128, H=16, K=64 | 778.9 | 891.9
f32 B=16, M=128, H=16, K=64 | 1163.4 | 1088.5
f16 B=16, M=128, H=16, K=128 | 1607.0 | 1337.7
f32 B=16, M=128, H=16, K=128 | 2259.3 | 1666.7
f16 B=16, M=128, H=16, K=256 | 4062.0 | 2507.3
f32 B=16, M=128, H=16, K=256 | 4647.4 | 3356.5
f16 B=16, M=512, H=16, K=16 | 7866.8 | 6958.1
f32 B=16, M=512, H=16, K=16 | 9792.9 | 8610.8
f16 B=16, M=512, H=16, K=32 | 9111.2 | 7500.4
f32 B=16, M=512, H=16, K=32 | 11388.9 | 9295.5
f16 B=16, M=512, H=16, K=64 | 11402.2 | 8911.2
f32 B=16, M=512, H=16, K=64 | 16094.5 | 11084.9
f16 B=16, M=512, H=16, K=128 | 24449.4 | 12629.6
f32 B=16, M=512, H=16, K=128 | 32234.3 | 15264.6
f16 B=16, M=512, H=16, K=256 | 52619.0 | 23373.4
f32 B=16, M=512, H=16, K=256 | 65241.9 | 27094.9
f16 B=16, M=1024, H=16, K=16 | 31510.4 | 26565.4
f32 B=16, M=1024, H=16, K=16 | 38369.3 | 32614.4
f16 B=16, M=1024, H=16, K=32 | 36294.3 | 28420.7
f32 B=16, M=1024, H=16, K=32 | 44377.6 | 35432.3
f16 B=16, M=1024, H=16, K=64 | 45366.8 | 32269.4
f32 B=16, M=1024, H=16, K=64 | 62745.1 | 39776.9
f16 B=16, M=1024, H=16, K=128 | 99353.4 | 43627.4
f32 B=16, M=1024, H=16, K=128 | 127366.1 | 51474.4
f16 B=16, M=1024, H=16, K=256 | 204810.4 | 81201.6
f32 B=16, M=1024, H=16, K=256 | 258126.0 | 92288.2
f16 B=64, M=128, H=16, K=16 | 1730.2 | 2117.6
f32 B=64, M=128, H=16, K=16 | 2428.4 | 2576.5
f16 B=64, M=128, H=16, K=32 | 2070.4 | 2487.9
f32 B=64, M=128, H=16, K=32 | 3084.5 | 3078.0
f16 B=64, M=128, H=16, K=64 | 2718.5 | 3317.9
f32 B=64, M=128, H=16, K=64 | 4421.9 | 4237.9
f16 B=64, M=128, H=16, K=128 | 5646.9 | 5284.4
f32 B=64, M=128, H=16, K=128 | 8635.8 | 6958.5
f16 B=64, M=128, H=16, K=256 | 13961.0 | 10316.2
f32 B=64, M=128, H=16, K=256 | 17417.3 | 13584.2
f16 B=64, M=512, H=16, K=16 | 26936.5 | 27427.8
f32 B=64, M=512, H=16, K=16 | 36403.9 | 33753.3
f16 B=64, M=512, H=16, K=32 | 31542.1 | 30266.4
f32 B=64, M=512, H=16, K=32 | 42935.1 | 37398.0
f16 B=64, M=512, H=16, K=64 | 39718.3 | 36109.8
f32 B=64, M=512, H=16, K=64 | 61577.7 | 43677.3
f16 B=64, M=512, H=16, K=128 | 86608.8 | 51294.6
f32 B=64, M=512, H=16, K=128 | 123085.0 | 61843.3
f16 B=64, M=512, H=16, K=256 | 179902.4 | 99364.5
f32 B=64, M=512, H=16, K=256 | 250051.9 | 111501.9
f16 B=64, M=1024, H=16, K=16 | 107724.9 | 106757.7
f32 B=64, M=1024, H=16, K=16 | 144482.4 |
f16 B=64, M=1024, H=16, K=32 | 124733.4 | 114732.4
f32 B=64, M=1024, H=16, K=32 | 168876.8 |
f16 B=64, M=1024, H=16, K=64 | 157059.1 | 131304.1
f32 B=64, M=1024, H=16, K=64 | 241476.9 |
f16 B=64, M=1024, H=16, K=128 | 334298.9 | 179659.1
f32 B=64, M=1024, H=16, K=128 | 483706.8 |
f16 B=64, M=1024, H=16, K=256 | 692904.0 |
f32 B=64, M=1024, H=16, K=256 | 982044.1 |
(Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1809.0 | 1374.5
f32 B=384, M=197, H=1, K=88 | 4340.1 | 2247.7
f16 B=384, M=197, H=1, K=80 | 1732.4 | 1282.2
f32 B=384, M=197, H=1, K=80 | 3974.2 | 2163.4
f16 B=384, M=197, H=1, K=64 | 1134.9 | 1044.1
f32 B=384, M=197, H=1, K=64 | 2689.9 | 1741.7
f16 B=1024, M=197, H=1, K=88 | 4707.0 | 3724.7
f32 B=1024, M=197, H=1, K=88 | 10546.7 | 6061.7
f16 B=1024, M=197, H=1, K=80 | 4523.2 | 3330.1
f32 B=1024, M=197, H=1, K=80 | 9609.2 | 5719.5
f16 B=1024, M=197, H=1, K=64 | 2799.2 | 2675.3
f32 B=1024, M=197, H=1, K=64 | 6586.8 | 4507.3
f16 B=512, M=197, H=1, K=80 | 2380.2 | 1684.0
f32 B=512, M=197, H=1, K=80 | 5267.1 | 2874.5
f16 B=32, M=197, H=16, K=80 | 2393.9 | 1800.3
f32 B=32, M=197, H=16, K=80 | 5392.4 | 3029.9
f16 B=32, M=197, H=16, K=64 | 1558.5 | 1450.9
f32 B=32, M=197, H=16, K=64 | 3636.3 | 2410.5
f16 B=32, M=197, H=16, K=128 | 2782.4 | 2211.6
f32 B=32, M=197, H=16, K=128 | 6643.3 | 4061.8
f16 B=256, M=197, H=1, K=88 | 1357.9 | 947.7
f32 B=256, M=197, H=1, K=88 | 2884.8 | 1533.7
f16 B=16, M=197, H=16, K=88 | 1346.8 | 970.9
f32 B=16, M=197, H=16, K=88 | 2801.2 | 1629.0
f16 B=16, M=197, H=16, K=64 | 766.2 | 931.7
f32 B=16, M=197, H=16, K=64 | 1838.3 | 1287.4
f16 B=16, M=197, H=16, K=128 | 1513.8 | 1135.0
f32 B=16, M=197, H=16, K=128 | 3407.8 | 2034.9
f16 B=1, M=4096, H=160, K=128 | 169073.2 |
f32 B=1, M=4096, H=160, K=128 | 550508.2 |
f16 B=2, M=4096, H=160, K=128 | 340149.9 |
f32 B=2, M=4096, H=160, K=128 | 1102674.8 |
f16 B=1, M=8192, H=160, K=128 | 681002.9 |
f32 B=1, M=8192, H=160, K=128 | 2200639.4 |
f16 B=2, M=8192, H=160, K=128 | 1364914.7 |
f16 B=1024, M=82, H=8, K=64 | 9059.6 | 5802.6
f32 B=1024, M=82, H=8, K=64 | 14694.7 | 11037.4
f16 B=150, M=256, H=16, K=64 | 5693.1 | 7563.8
f32 B=150, M=256, H=16, K=64 | 16696.3 | 16305.4
f16 B=64, M=256, H=12, K=64 | 1852.1 | 2386.4
f32 B=64, M=256, H=12, K=64 | 5462.4 | 4969.2
f16 B=1, M=4096, H=16, K=40 | 47164.3 | 8362.1
f32 B=1, M=4096, H=16, K=40 | 113058.1 | 19476.1
f16 B=1, M=16384, H=16, K=40 | 759023.3 |
f32 B=1, M=16384, H=16, K=40 | 1804493.4 |
f16 B=16, M=128, H=16, K=16 | 476.6 | 712.1
f32 B=16, M=128, H=16, K=16 | 619.0 | 651.7
f16 B=16, M=128, H=16, K=32 | 445.6 | 776.2
f32 B=16, M=128, H=16, K=32 | 555.9 | 662.4
f16 B=16, M=128, H=16, K=64 | 517.8 | 680.9
f32 B=16, M=128, H=16, K=64 | 601.7 | 736.3
f16 B=16, M=128, H=16, K=128 | 451.1 | 686.1
f32 B=16, M=128, H=16, K=128 | 1105.7 | 1007.0
f16 B=16, M=128, H=16, K=256 | 1049.9 | 888.0
f32 B=16, M=128, H=16, K=256 | 2192.3 | 1855.9
f16 B=16, M=512, H=16, K=16 | 1731.3 | 1896.6
f32 B=16, M=512, H=16, K=16 | 4476.5 | 4249.5
f16 B=16, M=512, H=16, K=32 | 1948.7 | 2095.1
f32 B=16, M=512, H=16, K=32 | 5679.8 | 4600.4
f16 B=16, M=512, H=16, K=64 | 2448.1 | 2577.1
f32 B=16, M=512, H=16, K=64 | 7617.6 | 5491.4
f16 B=16, M=512, H=16, K=128 | 4891.9 | 3380.6
f32 B=16, M=512, H=16, K=128 | 15084.0 | 8860.2
f16 B=16, M=512, H=16, K=256 | 12952.6 | 5381.5
f32 B=16, M=512, H=16, K=256 | 29870.0 | 16766.3
f16 B=16, M=1024, H=16, K=16 | 6817.0 | 6986.0
f32 B=16, M=1024, H=16, K=16 | 18132.2 | 16098.9
f16 B=16, M=1024, H=16, K=32 | 7568.5 | 7399.8
f32 B=16, M=1024, H=16, K=32 | 22038.8 | 17093.0
f16 B=16, M=1024, H=16, K=64 | 9320.6 | 8623.2
f32 B=16, M=1024, H=16, K=64 | 29998.6 | 20238.1
f16 B=16, M=1024, H=16, K=128 | 18972.4 | 10503.1
f32 B=16, M=1024, H=16, K=128 | 58953.5 | 33141.1
f16 B=16, M=1024, H=16, K=256 | 49804.3 | 17122.0
f32 B=16, M=1024, H=16, K=256 | 116887.9 | 60004.3
f16 B=64, M=128, H=16, K=16 | 509.3 | 673.2
f32 B=64, M=128, H=16, K=16 | 1029.9 | 1234.0
f16 B=64, M=128, H=16, K=32 | 546.7 | 813.7
f32 B=64, M=128, H=16, K=32 | 1408.3 | 1533.5
f16 B=64, M=128, H=16, K=64 | 745.2 | 1186.2
f32 B=64, M=128, H=16, K=64 | 2019.2 | 2154.9
f16 B=64, M=128, H=16, K=128 | 1417.3 | 1916.9
f32 B=64, M=128, H=16, K=128 | 3950.5 | 3779.3
f16 B=64, M=128, H=16, K=256 | 3808.4 | 3450.8
f32 B=64, M=128, H=16, K=256 | 7983.2 | 7252.1
f16 B=64, M=512, H=16, K=16 | 6187.6 | 7461.6
f32 B=64, M=512, H=16, K=16 | 16328.7 | 16558.3
f16 B=64, M=512, H=16, K=32 | 7026.3 | 8314.3
f32 B=64, M=512, H=16, K=32 | 20583.0 | 18328.0
f16 B=64, M=512, H=16, K=64 | 9087.4 | 10425.2
f32 B=64, M=512, H=16, K=64 | 27696.9 | 22791.7
f16 B=64, M=512, H=16, K=128 | 17574.7 | 14673.6
f32 B=64, M=512, H=16, K=128 | 54678.0 | 39872.5
f16 B=64, M=512, H=16, K=256 | 47507.7 | 26896.4
f32 B=64, M=512, H=16, K=256 | 109608.7 | 75908.0
f16 B=64, M=1024, H=16, K=16 | 24447.9 | 28512.3
f32 B=64, M=1024, H=16, K=16 | 65064.6 |
f16 B=64, M=1024, H=16, K=32 | 27254.5 | 30504.4
f32 B=64, M=1024, H=16, K=32 | 80142.4 |
f16 B=64, M=1024, H=16, K=64 | 34677.9 | 37021.6
f32 B=64, M=1024, H=16, K=64 | 108919.4 |
f16 B=64, M=1024, H=16, K=128 | 68389.8 | 49203.3
f32 B=64, M=1024, H=16, K=128 | 214535.3 |
f16 B=64, M=1024, H=16, K=256 | 183195.8 |
f32 B=64, M=1024, H=16, K=256 | 425804.3 |
Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
| main | vanilla
1 threads: -----------------------------------------------------------------------
(Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 4252.9 | 3568.0
f32 B=384, M=197, H=1, K=88 | 6516.4 | 4266.8
f16 B=384, M=197, H=1, K=80 | 4024.4 | 3422.3
f32 B=384, M=197, H=1, K=80 | 6216.6 | 4078.8
f16 B=384, M=197, H=1, K=64 | 2367.5 | 2914.9
f32 B=384, M=197, H=1, K=64 | 4350.7 | 3435.6
f16 B=1024, M=197, H=1, K=88 | 10541.5 | 9757.7
f32 B=1024, M=197, H=1, K=88 | 17610.3 | 12024.4
f16 B=1024, M=197, H=1, K=80 | 9913.7 | 9288.2
f32 B=1024, M=197, H=1, K=80 | 16804.2 | 11179.8
f16 B=1024, M=197, H=1, K=64 | 5802.8 | 7663.7
f32 B=1024, M=197, H=1, K=64 | 11778.8 | 9444.2
f16 B=512, M=197, H=1, K=80 | 5037.0 | 4611.7
f32 B=512, M=197, H=1, K=80 | 8749.3 | 5465.3
f16 B=32, M=197, H=16, K=80 | 5118.0 | 4819.7
f32 B=32, M=197, H=16, K=80 | 8713.1 | 5732.0
f16 B=32, M=197, H=16, K=64 | 2979.9 | 4031.1
f32 B=32, M=197, H=16, K=64 | 6085.4 | 4790.2
f16 B=32, M=197, H=16, K=128 | 6053.1 | 5955.1
f32 B=32, M=197, H=16, K=128 | 10682.1 | 7341.6
f16 B=256, M=197, H=1, K=88 | 3074.8 | 2440.7
f32 B=256, M=197, H=1, K=88 | 4561.4 | 2860.9
f16 B=16, M=197, H=16, K=88 | 3082.1 | 2523.8
f32 B=16, M=197, H=16, K=88 | 4517.3 | 3017.4
f16 B=16, M=197, H=16, K=64 | 1774.8 | 2029.5
f32 B=16, M=197, H=16, K=64 | 3061.0 | 2429.9
f16 B=16, M=197, H=16, K=128 | 3489.2 | 3004.0
f32 B=16, M=197, H=16, K=128 | 5322.8 | 3613.6
f16 B=1, M=4096, H=160, K=128 | 533256.0 |
f32 B=1, M=4096, H=160, K=128 | 644655.4 |
f16 B=2, M=4096, H=160, K=128 | 868892.7 |
f32 B=2, M=4096, H=160, K=128 | 1281898.4 |
f16 B=1, M=8192, H=160, K=128 | 2085529.2 |
f32 B=1, M=8192, H=160, K=128 | 2548207.5 |
f16 B=2, M=8192, H=160, K=128 | 3437383.9 |
f16 B=1024, M=82, H=8, K=64 | 20984.0 | 18067.6
f32 B=1024, M=82, H=8, K=64 | 37605.3 | 22806.3
f16 B=150, M=256, H=16, K=64 | 15327.7 | 24392.8
f32 B=150, M=256, H=16, K=64 | 24784.3 | 31835.0
f16 B=64, M=256, H=12, K=64 | 4922.7 | 7678.4
f32 B=64, M=256, H=12, K=64 | 8068.8 | 9808.9
f16 B=1, M=4096, H=16, K=40 | 69262.6 | 29178.9
f32 B=1, M=4096, H=16, K=40 | 73372.5 | 37290.0
f16 B=1, M=16384, H=16, K=40 | 1082724.4 |
f32 B=1, M=16384, H=16, K=40 | 1156356.8 |
f16 B=16, M=128, H=16, K=16 | 403.8 | 573.1
f32 B=16, M=128, H=16, K=16 | 514.4 | 693.1
f16 B=16, M=128, H=16, K=32 | 454.5 | 670.3
f32 B=16, M=128, H=16, K=32 | 642.2 | 821.0
f16 B=16, M=128, H=16, K=64 | 613.9 | 885.4
f32 B=16, M=128, H=16, K=64 | 922.2 | 1080.5
f16 B=16, M=128, H=16, K=128 | 1239.4 | 1329.3
f32 B=16, M=128, H=16, K=128 | 1777.7 | 1662.2
f16 B=16, M=128, H=16, K=256 | 3354.0 | 2500.6
f32 B=16, M=128, H=16, K=256 | 3651.8 | 3291.1
f16 B=16, M=512, H=16, K=16 | 4427.2 | 6857.0
f32 B=16, M=512, H=16, K=16 | 5531.2 | 8419.4
f16 B=16, M=512, H=16, K=32 | 5193.8 | 7481.4
f32 B=16, M=512, H=16, K=32 | 6608.2 | 9166.2
f16 B=16, M=512, H=16, K=64 | 6536.6 | 8855.1
f32 B=16, M=512, H=16, K=64 | 9423.3 | 10849.5
f16 B=16, M=512, H=16, K=128 | 13962.6 | 12345.1
f32 B=16, M=512, H=16, K=128 | 18679.2 | 15003.2
f16 B=16, M=512, H=16, K=256 | 31425.8 | 23147.7
f32 B=16, M=512, H=16, K=256 | 37686.5 | 26873.0
f16 B=16, M=1024, H=16, K=16 | 16928.6 | 26395.9
f32 B=16, M=1024, H=16, K=16 | 20647.1 | 32762.1
f16 B=16, M=1024, H=16, K=32 | 19584.6 | 28100.2
f32 B=16, M=1024, H=16, K=32 | 24153.9 | 35231.2
f16 B=16, M=1024, H=16, K=64 | 24358.8 | 31949.6
f32 B=16, M=1024, H=16, K=64 | 34135.4 | 39247.4
f16 B=16, M=1024, H=16, K=128 | 52553.6 | 42857.2
f32 B=16, M=1024, H=16, K=128 | 68490.7 | 50818.2
f16 B=16, M=1024, H=16, K=256 | 113179.1 | 79246.2
f32 B=16, M=1024, H=16, K=256 | 138958.7 | 90470.1
f16 B=64, M=128, H=16, K=16 | 1313.4 | 2093.5
f32 B=64, M=128, H=16, K=16 | 1912.5 | 2551.8
f16 B=64, M=128, H=16, K=32 | 1605.9 | 2479.4
f32 B=64, M=128, H=16, K=32 | 2424.6 | 3055.7
f16 B=64, M=128, H=16, K=64 | 2135.5 | 3306.4
f32 B=64, M=128, H=16, K=64 | 3512.2 | 4184.7
f16 B=64, M=128, H=16, K=128 | 4349.4 | 5250.9
f32 B=64, M=128, H=16, K=128 | 6734.9 | 6860.6
f16 B=64, M=128, H=16, K=256 | 11412.6 | 10225.4
f32 B=64, M=128, H=16, K=256 | 13715.3 | 13386.5
f16 B=64, M=512, H=16, K=16 | 15298.3 | 27164.6
f32 B=64, M=512, H=16, K=16 | 20818.0 | 33373.3
f16 B=64, M=512, H=16, K=32 | 18168.1 | 29831.6
f32 B=64, M=512, H=16, K=32 | 25124.1 | 37340.7
f16 B=64, M=512, H=16, K=64 | 22989.5 | 35792.1
f32 B=64, M=512, H=16, K=64 | 36008.3 | 43156.8
f16 B=64, M=512, H=16, K=128 | 48567.6 | 50699.7
f32 B=64, M=512, H=16, K=128 | 70832.3 | 60779.7
f16 B=64, M=512, H=16, K=256 | 107837.9 | 97016.9
f32 B=64, M=512, H=16, K=256 | 144816.6 | 109729.1
f16 B=64, M=1024, H=16, K=16 | 57449.7 | 105449.7
f32 B=64, M=1024, H=16, K=16 | 77196.8 |
f16 B=64, M=1024, H=16, K=32 | 67469.2 | 113569.9
f32 B=64, M=1024, H=16, K=32 | 92452.9 |
f16 B=64, M=1024, H=16, K=64 | 85985.5 | 129934.5
f32 B=64, M=1024, H=16, K=64 | 131842.4 |
f16 B=64, M=1024, H=16, K=128 | 180353.2 | 176085.7
f32 B=64, M=1024, H=16, K=128 | 261345.7 |
f16 B=64, M=1024, H=16, K=256 | 380736.0 |
f32 B=64, M=1024, H=16, K=256 | 530228.5 |
(Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1502.5 | 1373.5
f32 B=384, M=197, H=1, K=88 | 2885.5 | 2234.0
f16 B=384, M=197, H=1, K=80 | 1440.5 | 1282.0
f32 B=384, M=197, H=1, K=80 | 2606.4 | 2148.4
f16 B=384, M=197, H=1, K=64 | 824.4 | 1044.6
f32 B=384, M=197, H=1, K=64 | 1888.7 | 1737.6
f16 B=1024, M=197, H=1, K=88 | 3916.4 | 3731.4
f32 B=1024, M=197, H=1, K=88 | 7123.7 | 6025.3
f16 B=1024, M=197, H=1, K=80 | 3751.2 | 3329.3
f32 B=1024, M=197, H=1, K=80 | 6440.6 | 5673.2
f16 B=1024, M=197, H=1, K=64 | 2033.2 | 2674.8
f32 B=1024, M=197, H=1, K=64 | 4637.9 | 4491.2
f16 B=512, M=197, H=1, K=80 | 1980.8 | 1678.7
f32 B=512, M=197, H=1, K=80 | 3457.6 | 2856.9
f16 B=32, M=197, H=16, K=80 | 1972.6 | 1799.2
f32 B=32, M=197, H=16, K=80 | 3514.2 | 3015.5
f16 B=32, M=197, H=16, K=64 | 1119.0 | 1450.0
f32 B=32, M=197, H=16, K=64 | 2554.1 | 2415.0
f16 B=32, M=197, H=16, K=128 | 2290.4 | 2213.7
f32 B=32, M=197, H=16, K=128 | 4547.6 | 4023.2
f16 B=256, M=197, H=1, K=88 | 1138.3 | 941.5
f32 B=256, M=197, H=1, K=88 | 1926.5 | 1521.0
f16 B=16, M=197, H=16, K=88 | 1116.3 | 970.6
f32 B=16, M=197, H=16, K=88 | 1862.4 | 1606.0
f16 B=16, M=197, H=16, K=64 | 557.9 | 795.7
f32 B=16, M=197, H=16, K=64 | 1275.8 | 1285.0
f16 B=16, M=197, H=16, K=128 | 1245.8 | 1136.1
f32 B=16, M=197, H=16, K=128 | 2292.3 | 2016.9
f16 B=1, M=4096, H=160, K=128 | 87327.9 |
f32 B=1, M=4096, H=160, K=128 | 281487.9 |
f16 B=2, M=4096, H=160, K=128 | 176796.0 |
f32 B=2, M=4096, H=160, K=128 | 564008.1 |
f16 B=1, M=8192, H=160, K=128 | 347318.8 |
f32 B=1, M=8192, H=160, K=128 | 1111934.5 |
f16 B=2, M=8192, H=160, K=128 | 696927.1 |
f16 B=1024, M=82, H=8, K=64 | 7899.4 | 5814.6
f32 B=1024, M=82, H=8, K=64 | 12697.1 | 11005.3
f16 B=150, M=256, H=16, K=64 | 3963.8 | 7590.3
f32 B=150, M=256, H=16, K=64 | 11184.3 | 16357.7
f16 B=64, M=256, H=12, K=64 | 1293.5 | 2385.3
f32 B=64, M=256, H=12, K=64 | 3633.0 | 4970.4
f16 B=1, M=4096, H=16, K=40 | 24253.3 | 8364.6
f32 B=1, M=4096, H=16, K=40 | 57122.2 | 19517.0
f16 B=1, M=16384, H=16, K=40 | 386768.6 |
f32 B=1, M=16384, H=16, K=40 | 909207.0 |
f16 B=16, M=128, H=16, K=16 | 500.4 | 633.7
f32 B=16, M=128, H=16, K=16 | 546.9 | 610.3
f16 B=16, M=128, H=16, K=32 | 575.3 | 670.2
f32 B=16, M=128, H=16, K=32 | 519.1 | 618.9
f16 B=16, M=128, H=16, K=64 | 461.2 | 648.9
f32 B=16, M=128, H=16, K=64 | 575.0 | 615.2
f16 B=16, M=128, H=16, K=128 | 515.3 | 690.1
f32 B=16, M=128, H=16, K=128 | 875.9 | 1006.7
f16 B=16, M=128, H=16, K=256 | 1052.4 | 888.7
f32 B=16, M=128, H=16, K=256 | 1740.9 | 1854.9
f16 B=16, M=512, H=16, K=16 | 1015.1 | 1918.9
f32 B=16, M=512, H=16, K=16 | 2540.9 | 4288.5
f16 B=16, M=512, H=16, K=32 | 1158.3 | 2128.9
f32 B=16, M=512, H=16, K=32 | 3260.8 | 4634.9
f16 B=16, M=512, H=16, K=64 | 1490.7 | 2560.9
f32 B=16, M=512, H=16, K=64 | 4449.8 | 5479.5
f16 B=16, M=512, H=16, K=128 | 3212.9 | 3377.7
f32 B=16, M=512, H=16, K=128 | 8759.7 | 8724.9
f16 B=16, M=512, H=16, K=256 | 8505.5 | 5348.4
f32 B=16, M=512, H=16, K=256 | 17494.2 | 16621.6
f16 B=16, M=1024, H=16, K=16 | 3717.3 | 7286.8
f32 B=16, M=1024, H=16, K=16 | 9676.0 | 16131.3
f16 B=16, M=1024, H=16, K=32 | 4170.7 | 7662.0
f32 B=16, M=1024, H=16, K=32 | 12001.4 | 17151.3
f16 B=16, M=1024, H=16, K=64 | 5203.6 | 8637.0
f32 B=16, M=1024, H=16, K=64 | 16305.8 | 19853.1
f16 B=16, M=1024, H=16, K=128 | 11030.4 | 10478.4
f32 B=16, M=1024, H=16, K=128 | 32050.9 | 32589.2
f16 B=16, M=1024, H=16, K=256 | 28891.7 | 16874.2
f32 B=16, M=1024, H=16, K=256 | 63284.2 | 58763.0
f16 B=64, M=128, H=16, K=16 | 508.8 | 651.3
f32 B=64, M=128, H=16, K=16 | 795.8 | 1240.6
f16 B=64, M=128, H=16, K=32 | 469.8 | 814.0
f32 B=64, M=128, H=16, K=32 | 1115.5 | 1530.4
f16 B=64, M=128, H=16, K=64 | 623.9 | 1185.2
f32 B=64, M=128, H=16, K=64 | 1612.9 | 2154.2
f16 B=64, M=128, H=16, K=128 | 1419.7 | 1918.1
f32 B=64, M=128, H=16, K=128 | 3171.7 | 3761.3
f16 B=64, M=128, H=16, K=256 | 3810.6 | 3445.0
f32 B=64, M=128, H=16, K=256 | 6416.3 | 7248.7
f16 B=64, M=512, H=16, K=16 | 3596.6 | 7506.6
f32 B=64, M=512, H=16, K=16 | 9262.7 | 16742.1
f16 B=64, M=512, H=16, K=32 | 4189.1 | 8365.1
f32 B=64, M=512, H=16, K=32 | 11914.3 | 18455.8
f16 B=64, M=512, H=16, K=64 | 5510.7 | 10294.1
f32 B=64, M=512, H=16, K=64 | 16312.6 | 22721.7
f16 B=64, M=512, H=16, K=128 | 11544.8 | 14667.4
f32 B=64, M=512, H=16, K=128 | 31957.9 | 39608.1
f16 B=64, M=512, H=16, K=256 | 31179.8 | 26641.1
f32 B=64, M=512, H=16, K=256 | 63495.1 | 74597.9
f16 B=64, M=1024, H=16, K=16 | 13245.5 | 29173.0
f32 B=64, M=1024, H=16, K=16 | 34907.6 |
f16 B=64, M=1024, H=16, K=32 | 15063.4 | 31256.2
f32 B=64, M=1024, H=16, K=32 | 43456.3 |
f16 B=64, M=1024, H=16, K=64 | 19322.4 | 37280.5
f32 B=64, M=1024, H=16, K=64 | 58997.7 |
f16 B=64, M=1024, H=16, K=128 | 39771.6 | 49110.0
f32 B=64, M=1024, H=16, K=128 | 116715.2 |
f16 B=64, M=1024, H=16, K=256 | 106294.7 |
f32 B=64, M=1024, H=16, K=256 | 231929.0 |
Times are in microseconds (us). TESTS Looks like tests are failing on A100: A100 backward test$ CUDA_LAUNCH_BLOCKING=1 python -m pytest /scratch/XXXX/xformers/tests/test_mem_eff_attention.py -k "test_backward" -x -s -v --pdb
=========================================================================================================================================== test session starts ============================================================================================================================================
platform linux -- Python 3.10.8, pytest-7.2.0, pluggy-1.0.0 -- /scratch/XXXX/lxformers/bin/python
cachedir: .pytest_cache
rootdir: /scratch/XXXX/xformers
plugins: mpi-0.4, timeout-1.4.2, hydra-core-1.2.0, cov-2.10.0, typeguard-2.13.3
collected 34016 items / 16368 deselected / 17648 selected
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg0-BMK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg0-BMHK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg1-BMK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg1-BMHK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg2-BMK] FAILED
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> traceback >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
op_device_dtype_B_Mq_Mkv_H_K_Kv = (<class 'xformers.ops.fmha.cutlass.BwOp'>, 'cuda', torch.bfloat16, 1, 32, 32, ...), grad_out_contiguous = False, attn_bias_cfg = (<class 'torch.Tensor'>, True), fmt = 'BMK'
@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@pytest.mark.parametrize(
"attn_bias_cfg", # (type(bias), bias.requires_grad)
[
(None, False),
(xformers.ops.LowerTriangularMask, False),
(torch.Tensor, True),
(torch.Tensor, False),
],
)
@pytest.mark.parametrize("grad_out_contiguous", [False, True])
@pytest.mark.parametrize(
"op_device_dtype_B_Mq_Mkv_H_K_Kv",
_opBW_device_dtype_B_Mq_Mkv_H_K_Kv,
ids=_opBW_device_dtype_B_Mq_Mkv_H_K_Kv_ids,
)
def test_backward(
op_device_dtype_B_Mq_Mkv_H_K_Kv,
grad_out_contiguous,
attn_bias_cfg,
fmt,
):
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
(
op_bw,
device,
dtype,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = op_device_dtype_B_Mq_Mkv_H_K_Kv
query, key, value, attn_bias = create_tensors(
*op_device_dtype_B_Mq_Mkv_H_K_Kv,
attn_bias_type=attn_bias_type,
fmt=fmt,
)
op_fw = (
sample_random_supported_fw(
fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias),
seed=q_len * kv + kv_len * k,
)
if op_bw != fmha.cutlass.BwOp
else fmha.cutlass.FwOp
)
qkv = None
if (
fmt == "BMHK"
and query.shape[3] == value.shape[3]
and query.shape[1] == value.shape[1]
):
qkv = torch.stack([query, key, value], 2)
qkv.requires_grad_(True)
# bm3hk -> 3 x bmhk
query, key, value = xformers.ops.unbind(qkv, 2)
assert not query.is_contiguous()
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
if isinstance(attn_bias, torch.Tensor):
attn_bias.requires_grad_(attn_bias_requires_grad)
if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)):
pytest.skip("inputs not supported")
out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=(op_fw, op_bw)
)
grad_out = torch.ones_like(out)
if grad_out_contiguous is False:
grad_out = torch.tensor([1.0], dtype=query.dtype, device=device)[
None, None, :
].expand_as(out)
out.backward(grad_out)
del out
if qkv is None and op_bw == fmha.cutlass.BwOp:
assert query.stride() == query.grad.stride()
grads = []
if qkv is None:
grads = [query.grad, key.grad, value.grad]
query.grad = None
key.grad = None
value.grad = None
else:
grads = [qkv.grad]
qkv.grad = None
if attn_bias_requires_grad:
grads.append(attn_bias.grad)
attn_bias.grad = None
ref = ref_attention(query, key, value, attn_bias)
ref.backward(grad_out)
del grad_out
del ref
atol = op_bw.ERROR_ATOL[dtype]
rtol = op_bw.ERROR_RTOL[dtype]
grads_ref = []
grads_name = []
if qkv is None:
assert isinstance(query.grad, torch.Tensor)
assert isinstance(key.grad, torch.Tensor)
assert isinstance(value.grad, torch.Tensor)
grads_ref = [query.grad, key.grad, value.grad]
grads_name = ["query", "key", "value"]
else:
assert isinstance(qkv.grad, torch.Tensor)
grads_ref = [qkv.grad]
grads_name = ["qkv"]
if attn_bias_requires_grad:
assert isinstance(attn_bias.grad, torch.Tensor)
grads_ref.append(attn_bias.grad)
grads_name.append("bias")
del query
del key
del value
del qkv
for name, calc_grad, ref_grad in zip(grads_name, grads, grads_ref):
> assert_allclose(
calc_grad,
ref_grad,
msg=f"{op_fw.NAME}+{op_bw.NAME}:{name}",
atol=atol,
rtol=rtol,
)
tests/test_mem_eff_attention.py:686:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
out = tensor([[[ 2.8519e+12, -1.1141e-11, 1.0522e+29, ..., 1.6022e-03,
1.8014e+16, 1.2360e-03],
[-1....2461e-01, -2.8712e+24, ..., -1.0156e+00,
-4.9720e+18, -4.2969e-01]]], device='cuda:0', dtype=torch.bfloat16)
ref = tensor([[[ 4.5013e-04, 6.3324e-04, 3.0365e-03, ..., -7.0190e-03,
-2.4414e-03, 7.0953e-04],
[ 1....2734e-01, 8.0078e-01, ..., -1.9336e-01,
-8.5938e-01, 4.5508e-01]]], device='cuda:0', dtype=torch.bfloat16), msg = 'cutlassF+cutlassB:query', atol = 0.7
rtol = 0.1
def assert_allclose(
out: torch.Tensor,
ref: torch.Tensor,
msg: str = "failed",
atol: float = 1e-8,
rtol: float = 1e-5,
) -> None:
assert out.shape == ref.shape
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
max_diff = flatten_diff[max_pos]
num_different = torch.count_nonzero(flatten_diff > 0)
percentage = num_different / flatten_diff.numel()
del flatten_diff
> assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
f"/ atol={atol}, rtol={rtol}"
f"/ total failing elements: {num_different}, percentage={percentage}"
)
E AssertionError: cutlassF+cutlassB:query: out=-8.048060130728983e+35 and ref=0.67578125 (diff=8.048060130728983e+35 > 0)/ atol=0.7, rtol=0.1/ total failing elements: 684, percentage=0.66796875
E assert False
E + where False = <built-in method allclose of type object at 0x7f52fa29f200>(tensor([[[ 2.8519e+12, -1.1141e-11, 1.0522e+29, ..., 1.6022e-03,\n 1.8014e+16, 1.2360e-03],\n [-1.9241e+12, -2.0385e-05, -5.5535e+24, ..., -3.0884e-02,\n -4.3580e+20, 2.0996e-02],\n [-8.3317e+16, 6.5484e-11, -2.2849e+26, ..., 1.2398e-04,\n -1.9922e+22, 9.3460e-05],\n ...,\n [ 2.6828e+14, -1.3097e-09, -1.6712e+28, ..., -2.5558e-04,\n -3.6605e+19, -1.9646e-04],\n [ 4.5425e+20, -2.4214e-07, 2.8531e+26, ..., -6.9618e-05,\n -3.5654e+23, -2.8729e-05],\n [ 4.6043e+22, -2.2461e-01, -2.8712e+24, ..., -1.0156e+00,\n -4.9720e+18, -4.2969e-01]]], device='cuda:0', dtype=torch.bfloat16), tensor([[[ 4.5013e-04, 6.3324e-04, 3.0365e-03, ..., -7.0190e-03,\n -2.4414e-03, 7.0953e-04],\n [ 1.8234e-03, -3.3417e-03, 1.3855e-02, ..., -3.4668e-02,\n -2.4719e-03, 3.7231e-03],\n [-5.8984e-01, -2.3535e-01, -8.5547e-01, ..., -7.6172e-01,\n 4.0625e-01, -2.9883e-01],\n ...,\n [ 6.2012e-02, 1.9238e-01, 2.0508e-01, ..., -4.2725e-02,\n -5.3516e-01, 8.9062e-01],\n [ 5.0000e-01, 3.9219e+00, 4.0430e-01, ..., 2.3594e+00,\n 9.0625e-01, -2.8750e+00],\n [-3.8672e-01, -5.2734e-01, 8.0078e-01, ..., -1.9336e-01,\n -8.5938e-01, 4.5508e-01]]], device='cuda:0', dtype=torch.bfloat16), rtol=0.1, atol=0.7)
E + where <built-in method allclose of type object at 0x7f52fa29f200> = torch.allclose
tests/test_mem_eff_attention.py:182: AssertionError They seem to pass on V100/P100 :) |
cool! thats about all i could hope for on A100. i'll see if i can find an A100 today or tomorrow and start debugging |
Thanks a lot for addressing / answering the comments! As the code looks fairly clean, I'm happy to merge once the following conditions are met: Also heads-up as we will move all of the C++ files around (as in #579). Might create conflicts in git as you rebase, but we won't change the content of the files you are touching. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests seem to pass on my A100! Congratulations!
I just opened draft PR with these changes to get our CI to test them:
#606
int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise | ||
|
||
// dropout | ||
bool use_dropout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this additional variable? Can't we just compare "dropout_prob != 0" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'll see if i can get rid of it, was done this way because i was worried about dropout_prob != 0
because its a floating point comparison
EDIT: if you're only talking about backward here can definitely get rid of it since its only used once to dispatch and we can use std::fpclassify there
kPreloadMmas && kApplyDropout ? | ||
cutlass::const_min(2, DefaultConfig::kStages) : | ||
DefaultConfig::kStages, // Stages |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I last tested, it seemed that cutlass::gemm::threadblock::DefaultMma
would match the Pipelined implementation on A100 when using kStages=2
, instead of the Mma implem (cc @hwu36). This makes performance much worse on A100 when dropout is enabled - but let's keep it like this for now, we can address that in a later PR (I also have plans to reduce shmem usage, so this might no longer be needed in the future).
I also see that you are using cutlass::gemm::kernel::DefaultGemm
now, which might not have this issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from looking at it seems like it does mean it selects the pipelined implementation. here's some benchmarks to show the impact
[---------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------]
| optimized[flshattB] | vanilla
1 threads: ---------------------------------------------------------------------------------------------------
b16 B=8, M=512, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False | 1.1 | 2.9
b16 B=8, M=512, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False | 1.2 | 3.4
b16 B=8, M=512, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False | 2.6 | 3.9
b16 B=8, M=512, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False | 2.7 | 4.3
b16 B=8, M=1024, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False | 4.1 | 10.0
b16 B=8, M=1024, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False | 4.4 | 12.0
b16 B=8, M=1024, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False | 9.1 | 12.0
b16 B=8, M=1024, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False | 9.4 | 13.9
b16 B=16, M=512, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False | 2.2 | 5.7
b16 B=16, M=512, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False | 2.4 | 6.7
b16 B=16, M=512, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False | 5.1 | 7.7
b16 B=16, M=512, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False | 5.3 | 8.7
b16 B=16, M=1024, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False | 8.2 | 20.1
b16 B=16, M=1024, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False | 8.7 | 23.9
b16 B=16, M=1024, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False | 18.2 | 24.0
b16 B=16, M=1024, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False | 18.8 | 27.8
Times are in milliseconds (ms).
[------------------------- attention backward (attn_bias=<class 'torch.Tensor'>) --------------------------]
| optimized[cutlassB] | vanilla
1 threads: -------------------------------------------------------------------------------------------------
b16 B=8, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False | 2.1 | 2.9
b16 B=8, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True | 2.1 | 2.9
b16 B=8, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False | 3.1 | 3.4
b16 B=8, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True | 3.1 | 3.4
b16 B=8, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False | 3.6 | 3.9
b16 B=8, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True | 3.6 | 3.9
b16 B=8, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False | 4.5 | 4.3
b16 B=8, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True | 4.5 | 4.3
b16 B=8, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False | 7.6 | 10.0
b16 B=8, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True | 7.6 | 10.0
b16 B=8, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False | 11.7 | 11.9
b16 B=8, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True | 11.7 | 11.9
b16 B=8, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False | 13.1 | 12.0
b16 B=8, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True | 13.1 | 12.0
b16 B=8, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False | 16.4 | 13.9
b16 B=8, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True | 16.4 | 13.9
b16 B=16, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False | 3.9 | 5.7
b16 B=16, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True | 4.0 | 5.7
b16 B=16, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False | 6.0 | 6.7
b16 B=16, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True | 6.0 | 6.7
b16 B=16, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False | 7.3 | 7.7
b16 B=16, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True | 7.3 | 7.7
b16 B=16, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False | 8.9 | 8.7
b16 B=16, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True | 8.9 | 8.7
b16 B=16, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False | 14.8 | 20.1
b16 B=16, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True | 14.7 | 20.0
b16 B=16, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False | 23.1 | 23.9
b16 B=16, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True | 23.1 | 23.9
b16 B=16, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False | 25.9 | 24.0
b16 B=16, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True | 25.9 | 24.0
b16 B=16, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False | 32.4 | 27.8
b16 B=16, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True | 32.4 | 27.8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, what are the plans to cut shmem usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which GPU did you run these benchmarks on? It's ~20-50% slower depending on the cases - might be worth investigating later if you want.
just curious, what are the plans to cut shmem usage?
We are loading Q,K,dO from global memory twice. We could reuse the shared-memory to avoid that
Some illustration of the BW pass:
(1) Q@K, dO@V, gradV matmuls
(2) gradQ, gradK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A100 Shmem occupancy - current situation (max ~160kb)
A100 Shmem occupancy - after potential changes (max ~131kb)
(1) setting kNumStages=4
so the entire block of matrices fit in shared-memory
(2) loading dO/K from same shared-memory location (rather than re-loading them from global memory)
Also I expect this to be faster
I don't have immediate plans to work on this, but let me know if you would like to contribute :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just saw your question in NVIDIA/cutlass#744, I believe it's related to this :) Let's do that as part of a different PR if possible to make things easier to review.
Also as a heads-up for this PR, I won't be available next week but hopefully we can get this merged early january!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it was, thanks for the hints there! yes definitely we can put it in another PR, just wanted to start experimenting.
and sounds good, thanks for the quick reviews so far, happy holidays!
xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_forward.h
Outdated
Show resolved
Hide resolved
adds attn bias (including bias grad) and dropout support to CUTLASS flashattn implementation [-------------------------------------------- attn --------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 12.7 | 7.5 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 15.5 | 9.1 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 12.7 | 7.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 15.6 | 9.1 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 10.1 | 6.0 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 12.7 | 7.5 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 44.3 | 29.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 55.0 | 35.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 45.1 | 29.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 55.6 | 35.3 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 37.0 | 22.6 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 46.8 | 29.0 Times are in milliseconds (ms). [------------------------------------------ attn-bwd ------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 19.3 | 24.1 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 19.4 | 24.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 22.3 | 28.7 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 22.4 | 29.0 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 19.5 | 22.7 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 19.5 | 23.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 62.7 | 91.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 63.4 | 93.7 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 74.8 | 109.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 75.1 | 111.1 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 63.2 | 85.5 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 64.0 | 90.1
b9337bd
to
5f86d95
Compare
rebased and addressed the format/lint issues. hoping CI will pass now 🙏 (minus windows build, will look at that next) |
BEFORE [------------------------------------------ attn-bwd ------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0) | 2.8 | 2.4 (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5) | 2.8 | 3.3 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0) | 3.4 | 3.2 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5) | 3.4 | 4.2 (8, 512, 64, 64, torch.float16, None, False, 0.0) | 2.8 | 2.0 (8, 512, 64, 64, torch.float16, None, False, 0.5) | 2.8 | 2.9 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 3.6 | 3.9 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 3.6 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 4.2 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 4.2 | 5.6 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 3.6 | 3.4 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 3.6 | 4.4 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0) | 9.7 | 8.8 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5) | 9.7 | 12.6 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0) | 12.0 | 12.1 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5) | 12.1 | 16.1 (8, 1024, 64, 64, torch.float16, None, False, 0.0) | 9.7 | 7.4 (8, 1024, 64, 64, torch.float16, None, False, 0.5) | 9.7 | 10.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 11.3 | 14.0 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 11.3 | 17.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 13.6 | 17.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 13.6 | 20.9 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 11.3 | 12.1 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 11.3 | 15.8 AFTER [------------------------------------------ attn-bwd ------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0) | 2.8 | 2.4 (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5) | 2.8 | 3.0 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0) | 3.4 | 3.2 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5) | 3.4 | 3.8 (8, 512, 64, 64, torch.float16, None, False, 0.0) | 2.8 | 2.0 (8, 512, 64, 64, torch.float16, None, False, 0.5) | 2.8 | 2.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 3.6 | 3.9 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 3.6 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 4.2 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 4.2 | 5.6 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 3.6 | 3.4 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 3.6 | 4.4 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0) | 9.7 | 8.8 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5) | 9.7 | 11.4 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0) | 12.0 | 12.1 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5) | 12.1 | 14.6 (8, 1024, 64, 64, torch.float16, None, False, 0.0) | 9.7 | 7.4 (8, 1024, 64, 64, torch.float16, None, False, 0.5) | 9.7 | 9.6 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 11.3 | 14.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 11.3 | 17.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 13.6 | 17.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 13.6 | 20.9 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 11.3 | 12.1 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 11.3 | 15.8
fixed the windows build, should be good to rerun the tests when you get back. i'm prototyping some of the shared memory changes you suggested, but can't guarantee i'll end up finishing it and opening a PR. |
not sure what to do about the "Could not find a usable config.yml, you may have revoked the CircleCI OAuth app." error.
edit: seems like the #include order in the swiglu file was intentional, build fails after allowing formatter to reorder them |
The CI looks good to me! Thanks for fixing the windows build :) will merge early next week once I re-run the benchmarks on A100/v100 to ensure recent changes didn't bring any regression. |
awesome! and that sounds good to me. in general we are using this for distributed training of large language models using ZeRO-style data parallelism. in pytorch, running close to GPU memory limit really messes with performance of collectives because of the way the caching allocator works. we can definitely discuss more over slack if you are interested in the details. |
Perf measurements look good. I'm just a bit worried about the stable diffusion case on V100 ( A100A100 fw
A100 bw
V100/P100V100/P100 fw
V100/P100 bw
|
edit: are you talking about these forward measurements?
|
I was mentioning this for V100: 30ms before, and 37ms after
|
hm, not sure. i tried to repro on T4 but no major gap. I don't have a V100 to test with but i'll go through the code and see if i can make any guesses pr
main
|
Okay that's fine - let's not worry too much about it, we can create an issue and document the regression, that should be good enough. |
See also #64 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * 0 * * * * 0 0 * * * * * * * * * * * 0 * 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
See also #64 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
- FW: Drop support for dropout if pytorch is not installed - Use rng_state in Context to store seed/offset for dropout - Add test to ensure we can't combine flash+cutlass's dropouts ghstack-source-id: c5e05a1994b9c20fc27b071c3bfefbb4174987a2 Pull Request resolved: https://github.com/fairinternal/xformers/pull/434 __original_commit__ = fairinternal/xformers@408fefe5506c92b9c58444620c45bf5159b7fb39
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ``` ghstack-source-id: 44740f71132fa76226fd4c559cc3f09732ff139b Pull Request resolved: https://github.com/fairinternal/xformers/pull/435 __original_commit__ = fairinternal/xformers@be55fcd21c5dd621831245c5995e1c6fb49d9b77
What does this PR do?
Adds support for attention bias, bias gradient, and dropout to CUTLASS FlashAttention.
i'm mostly new to CUDA programming and totally new to CUTLASS so please let me know if i'm doing anything that doesn't make sense, is slow, or is otherwise weird :)
one note: i noticed the CPU and pure CUDA implementations also support attention bias, but they expect bias to be same across queries. Bias here is implemented to accept different values along rows
TODOs
there's a bug in the backward benchmarks that prevent me from using it to test performance (doesn't seem to be caused by my PR as same thing happens in main). will see if i can find a fix.
but here's some results from the benchmarks i wrote for myself. the labels take the form
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.