From bc08bbc631348913a3c37b4e09832973ff93a398 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Fri, 20 Jan 2023 12:33:35 +0000 Subject: [PATCH] fMHA: Make dropout/bias configurable at build-time # Performance on A100 (FW) ``` ------------------------------ attention (attn_bias=) -----------------------------] | disable_bias_dropout | main | eager | pr587_d1b0fa 1 threads: -------------------------------------------------------------------------------------------- f16 B=1, M=4096, H=160, K=128 | 14929.6 | 14853.7 | 21510.6 | 15245.8 f32 B=1, M=4096, H=160, K=128 | 56159.3 | 56691.3 | 91407.5 | 57963.3 f16 B=2, M=4096, H=160, K=128 | 29784.5 | 29633.5 | 43562.2 | 30427.9 f32 B=2, M=4096, H=160, K=128 | 112125.7 | 113237.0 | | 115784.8 f16 B=1, M=8192, H=160, K=128 | 59600.3 | 59240.6 | | 60820.3 f32 B=1, M=8192, H=160, K=128 | 224734.5 | 226792.5 | | 232806.2 f16 B=2, M=8192, H=160, K=128 | 119145.7 | 118414.6 | | 121599.1 f32 B=2, M=8192, H=160, K=128 | 449501.3 | 453321.1 | | 466094.3 f16 B=1, M=4096, H=16, K=40 | 863.6 | 870.6 | 1958.3 | 857.0 f32 B=1, M=4096, H=16, K=40 | 2771.4 | 2765.6 | 6914.2 | 2866.1 f16 B=1, M=16384, H=16, K=40 | 12251.9 | 12346.3 | 30460.6 | 12159.2 f32 B=1, M=16384, H=16, K=40 | 40643.3 | 40533.2 | 123282.2 | 41947.4 f16 B=256, M=4096, H=16, K=64 | 182110.1 | 183488.9 | | 181204.2 f32 B=256, M=4096, H=16, K=64 | 644969.5 | 642960.3 | | 665520.1 f16 B=64, M=1024, H=16, K=16 | 2577.2 | 2593.0 | 7293.8 | 2611.3 f32 B=64, M=1024, H=16, K=16 | 8440.7 | 8424.4 | 24237.0 | 8722.5 f16 B=64, M=1024, H=16, K=32 | 2596.5 | 2612.0 | 7675.3 | 2634.5 f32 B=64, M=1024, H=16, K=32 | 8513.5 | 8494.5 | 26228.7 | 8811.5 f16 B=64, M=1024, H=16, K=64 | 2972.8 | 2991.8 | 8669.3 | 2993.2 f32 B=64, M=1024, H=16, K=64 | 10304.0 | 10271.7 | 30612.5 | 10667.6 f16 B=64, M=1024, H=16, K=128 | 5282.0 | 5266.6 | 10313.9 | 5398.9 f32 B=64, M=1024, H=16, K=128 | 22730.2 | 22940.7 | 39341.6 | 23517.9 f16 B=64, M=1024, H=16, K=256 | 12337.4 | 12354.8 | 13558.9 | 12272.2 f32 B=64, M=1024, H=16, K=256 | 45885.8 | 46488.2 | 71460.8 | 48531.7 ``` ghstack-source-id: dd55058ad3e1715515475029c8cd360be43d374d Pull Request resolved: https://github.com/fairinternal/xformers/pull/441 __original_commit__ = fairinternal/xformers@a264f71c0de5c5474e6ec63b7a3b5d9e4a2ec98c --- .../csrc/attention/cuda/fmha/kernel_forward.h | 57 +++++++++++-------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/xformers/csrc/attention/cuda/fmha/kernel_forward.h b/xformers/csrc/attention/cuda/fmha/kernel_forward.h index c7f82ff266..d49254e653 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_forward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_forward.h @@ -68,8 +68,11 @@ template < bool isAligned_, int kQueriesPerBlock, int kKeysPerBlock, - bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock` - > + bool kSingleValueIteration, // = `value.shape[-1] <= kKeysPerBlock` + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout = true, + bool kSupportsBias = true> struct AttentionKernel { using scalar_t = scalar_t_; using accum_t = float; @@ -199,7 +202,7 @@ struct AttentionKernel { output_ptr += int64_t(q_start + query_start) * o_strideM() + head_id * head_dim_value; - if (attn_bias_ptr != nullptr) { + if (kSupportsBias && attn_bias_ptr != nullptr) { attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); } if (output_accum_ptr != nullptr) { @@ -215,9 +218,11 @@ struct AttentionKernel { batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; } - dropout_batch_head_rng_offset = - batch_id * num_heads * num_queries * num_keys + - head_id * num_queries * num_keys; + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } num_queries -= query_start; if (causal) { @@ -231,7 +236,9 @@ struct AttentionKernel { query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); - attn_bias_ptr = warp_uniform(attn_bias_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } output_ptr = warp_uniform(output_ptr); output_accum_ptr = warp_uniform(output_accum_ptr); logsumexp_ptr = warp_uniform(logsumexp_ptr); @@ -469,7 +476,18 @@ struct AttentionKernel { CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); - CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK( + p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + } XFORMERS_CHECK( p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); XFORMERS_CHECK( @@ -482,15 +500,6 @@ struct AttentionKernel { p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); XFORMERS_CHECK( p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); - XFORMERS_CHECK( - p.bias_strideB % kAlignmentQ == 0, - "attn_bias is not correctly aligned"); - XFORMERS_CHECK( - p.bias_strideH % kAlignmentQ == 0, - "attn_bias is not correctly aligned"); - XFORMERS_CHECK( - p.bias_strideM % kAlignmentQ == 0, - "attn_bias is not correctly aligned"); return true; } @@ -542,7 +551,7 @@ struct AttentionKernel { #ifdef HAS_PYTORCH curandStatePhilox4_32_10_t curand_state_init; - if (p.use_dropout) { + if (kSupportsDropout && p.use_dropout) { const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); // each element of the attention matrix P with shape @@ -654,11 +663,13 @@ struct AttentionKernel { (my_warp_id / MM0::Mma::WarpCount::kM)}; // multiply by scaling factor - accum = - cutlass::multiplies()(p.scale, accum); + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } // apply attention bias if applicable - if (p.attn_bias_ptr != nullptr) { + if (kSupportsBias && p.attn_bias_ptr != nullptr) { // load bias tile Bij into shared memory typename MM0::BiasLoader::GmemTileIterator bias_iter( {cutlass::layout::RowMajor(p.bias_strideM)}, @@ -729,7 +740,7 @@ struct AttentionKernel { warp_id(), p.num_keys - iter_key_start, iteratorC_tile_offset, - 1.0f); + kSupportsBias ? 1.0f : p.scale); })); })); @@ -758,7 +769,7 @@ struct AttentionKernel { // it ends up being very slow because each thread having noncontiguous // strips of the Pij tile means we have to skip around a lot, and also // have to generate a single random number at a time - if (p.use_dropout) { + if (kSupportsDropout && p.use_dropout) { auto si = shared_storage.after_mm0.si.accum_ref(); // each thread handles a contiguous sequence of elements from Sij, all // coming from the same row. the reason they have to come from the same