diff --git a/xformers/csrc/attention/cuda/fmha/kernel_forward.h b/xformers/csrc/attention/cuda/fmha/kernel_forward.h index c7f82ff266..d49254e653 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_forward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_forward.h @@ -68,8 +68,11 @@ template < bool isAligned_, int kQueriesPerBlock, int kKeysPerBlock, - bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock` - > + bool kSingleValueIteration, // = `value.shape[-1] <= kKeysPerBlock` + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout = true, + bool kSupportsBias = true> struct AttentionKernel { using scalar_t = scalar_t_; using accum_t = float; @@ -199,7 +202,7 @@ struct AttentionKernel { output_ptr += int64_t(q_start + query_start) * o_strideM() + head_id * head_dim_value; - if (attn_bias_ptr != nullptr) { + if (kSupportsBias && attn_bias_ptr != nullptr) { attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); } if (output_accum_ptr != nullptr) { @@ -215,9 +218,11 @@ struct AttentionKernel { batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; } - dropout_batch_head_rng_offset = - batch_id * num_heads * num_queries * num_keys + - head_id * num_queries * num_keys; + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } num_queries -= query_start; if (causal) { @@ -231,7 +236,9 @@ struct AttentionKernel { query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); - attn_bias_ptr = warp_uniform(attn_bias_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } output_ptr = warp_uniform(output_ptr); output_accum_ptr = warp_uniform(output_accum_ptr); logsumexp_ptr = warp_uniform(logsumexp_ptr); @@ -469,7 +476,18 @@ struct AttentionKernel { CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); - CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK( + p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + } XFORMERS_CHECK( p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); XFORMERS_CHECK( @@ -482,15 +500,6 @@ struct AttentionKernel { p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); XFORMERS_CHECK( p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); - XFORMERS_CHECK( - p.bias_strideB % kAlignmentQ == 0, - "attn_bias is not correctly aligned"); - XFORMERS_CHECK( - p.bias_strideH % kAlignmentQ == 0, - "attn_bias is not correctly aligned"); - XFORMERS_CHECK( - p.bias_strideM % kAlignmentQ == 0, - "attn_bias is not correctly aligned"); return true; } @@ -542,7 +551,7 @@ struct AttentionKernel { #ifdef HAS_PYTORCH curandStatePhilox4_32_10_t curand_state_init; - if (p.use_dropout) { + if (kSupportsDropout && p.use_dropout) { const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); // each element of the attention matrix P with shape @@ -654,11 +663,13 @@ struct AttentionKernel { (my_warp_id / MM0::Mma::WarpCount::kM)}; // multiply by scaling factor - accum = - cutlass::multiplies()(p.scale, accum); + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } // apply attention bias if applicable - if (p.attn_bias_ptr != nullptr) { + if (kSupportsBias && p.attn_bias_ptr != nullptr) { // load bias tile Bij into shared memory typename MM0::BiasLoader::GmemTileIterator bias_iter( {cutlass::layout::RowMajor(p.bias_strideM)}, @@ -729,7 +740,7 @@ struct AttentionKernel { warp_id(), p.num_keys - iter_key_start, iteratorC_tile_offset, - 1.0f); + kSupportsBias ? 1.0f : p.scale); })); })); @@ -758,7 +769,7 @@ struct AttentionKernel { // it ends up being very slow because each thread having noncontiguous // strips of the Pij tile means we have to skip around a lot, and also // have to generate a single random number at a time - if (p.use_dropout) { + if (kSupportsDropout && p.use_dropout) { auto si = shared_storage.after_mm0.si.accum_ref(); // each thread handles a contiguous sequence of elements from Sij, all // coming from the same row. the reason they have to come from the same