Skip to content

Commit

Permalink
fMHA: Make dropout/bias configurable at build-time
Browse files Browse the repository at this point in the history
# Performance on A100 (FW)

```
------------------------------ attention (attn_bias=<class 'NoneType'>) -----------------------------]
                                     |  disable_bias_dropout  |    main    |   eager    |  pr587_d1b0fa
1 threads: --------------------------------------------------------------------------------------------
      f16 B=1, M=4096, H=160, K=128  |         14929.6        |   14853.7  |   21510.6  |     15245.8
      f32 B=1, M=4096, H=160, K=128  |         56159.3        |   56691.3  |   91407.5  |     57963.3
      f16 B=2, M=4096, H=160, K=128  |         29784.5        |   29633.5  |   43562.2  |     30427.9
      f32 B=2, M=4096, H=160, K=128  |        112125.7        |  113237.0  |            |    115784.8
      f16 B=1, M=8192, H=160, K=128  |         59600.3        |   59240.6  |            |     60820.3
      f32 B=1, M=8192, H=160, K=128  |        224734.5        |  226792.5  |            |    232806.2
      f16 B=2, M=8192, H=160, K=128  |        119145.7        |  118414.6  |            |    121599.1
      f32 B=2, M=8192, H=160, K=128  |        449501.3        |  453321.1  |            |    466094.3
      f16 B=1, M=4096, H=16, K=40    |           863.6        |     870.6  |    1958.3  |       857.0
      f32 B=1, M=4096, H=16, K=40    |          2771.4        |    2765.6  |    6914.2  |      2866.1
      f16 B=1, M=16384, H=16, K=40   |         12251.9        |   12346.3  |   30460.6  |     12159.2
      f32 B=1, M=16384, H=16, K=40   |         40643.3        |   40533.2  |  123282.2  |     41947.4
      f16 B=256, M=4096, H=16, K=64  |        182110.1        |  183488.9  |            |    181204.2
      f32 B=256, M=4096, H=16, K=64  |        644969.5        |  642960.3  |            |    665520.1
      f16 B=64, M=1024, H=16, K=16   |          2577.2        |    2593.0  |    7293.8  |      2611.3
      f32 B=64, M=1024, H=16, K=16   |          8440.7        |    8424.4  |   24237.0  |      8722.5
      f16 B=64, M=1024, H=16, K=32   |          2596.5        |    2612.0  |    7675.3  |      2634.5
      f32 B=64, M=1024, H=16, K=32   |          8513.5        |    8494.5  |   26228.7  |      8811.5
      f16 B=64, M=1024, H=16, K=64   |          2972.8        |    2991.8  |    8669.3  |      2993.2
      f32 B=64, M=1024, H=16, K=64   |         10304.0        |   10271.7  |   30612.5  |     10667.6
      f16 B=64, M=1024, H=16, K=128  |          5282.0        |    5266.6  |   10313.9  |      5398.9
      f32 B=64, M=1024, H=16, K=128  |         22730.2        |   22940.7  |   39341.6  |     23517.9
      f16 B=64, M=1024, H=16, K=256  |         12337.4        |   12354.8  |   13558.9  |     12272.2
      f32 B=64, M=1024, H=16, K=256  |         45885.8        |   46488.2  |   71460.8  |     48531.7
```

ghstack-source-id: dd55058ad3e1715515475029c8cd360be43d374d
Pull Request resolved: https://github.com/fairinternal/xformers/pull/441

__original_commit__ = fairinternal/xformers@a264f71c0de5c5474e6ec63b7a3b5d9e4a2ec98c
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jan 20, 2023
1 parent 83a567f commit bc08bbc
Showing 1 changed file with 34 additions and 23 deletions.
57 changes: 34 additions & 23 deletions xformers/csrc/attention/cuda/fmha/kernel_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ template <
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
>
bool kSingleValueIteration, // = `value.shape[-1] <= kKeysPerBlock`
// This is quite slower on V100 for some reason
// Set to false if you know at compile-time you will never need dropout
bool kSupportsDropout = true,
bool kSupportsBias = true>
struct AttentionKernel {
using scalar_t = scalar_t_;
using accum_t = float;
Expand Down Expand Up @@ -199,7 +202,7 @@ struct AttentionKernel {
output_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * head_dim_value;

if (attn_bias_ptr != nullptr) {
if (kSupportsBias && attn_bias_ptr != nullptr) {
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
}
if (output_accum_ptr != nullptr) {
Expand All @@ -215,9 +218,11 @@ struct AttentionKernel {
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
}

dropout_batch_head_rng_offset =
batch_id * num_heads * num_queries * num_keys +
head_id * num_queries * num_keys;
if (kSupportsDropout) {
dropout_batch_head_rng_offset =
batch_id * num_heads * num_queries * num_keys +
head_id * num_queries * num_keys;
}

num_queries -= query_start;
if (causal) {
Expand All @@ -231,7 +236,9 @@ struct AttentionKernel {
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
attn_bias_ptr = warp_uniform(attn_bias_ptr);
if (kSupportsBias) {
attn_bias_ptr = warp_uniform(attn_bias_ptr);
}
output_ptr = warp_uniform(output_ptr);
output_accum_ptr = warp_uniform(output_accum_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
Expand Down Expand Up @@ -469,7 +476,18 @@ struct AttentionKernel {
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
if (kSupportsBias) {
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
XFORMERS_CHECK(
p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
}
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
Expand All @@ -482,15 +500,6 @@ struct AttentionKernel {
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
XFORMERS_CHECK(
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
return true;
}

Expand Down Expand Up @@ -542,7 +551,7 @@ struct AttentionKernel {

#ifdef HAS_PYTORCH
curandStatePhilox4_32_10_t curand_state_init;
if (p.use_dropout) {
if (kSupportsDropout && p.use_dropout) {
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);

// each element of the attention matrix P with shape
Expand Down Expand Up @@ -654,11 +663,13 @@ struct AttentionKernel {
(my_warp_id / MM0::Mma::WarpCount::kM)};

// multiply by scaling factor
accum =
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
if (kSupportsBias) {
accum =
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
}

// apply attention bias if applicable
if (p.attn_bias_ptr != nullptr) {
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
// load bias tile Bij into shared memory
typename MM0::BiasLoader::GmemTileIterator bias_iter(
{cutlass::layout::RowMajor(p.bias_strideM)},
Expand Down Expand Up @@ -729,7 +740,7 @@ struct AttentionKernel {
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
1.0f);
kSupportsBias ? 1.0f : p.scale);
}));
}));

Expand Down Expand Up @@ -758,7 +769,7 @@ struct AttentionKernel {
// it ends up being very slow because each thread having noncontiguous
// strips of the Pij tile means we have to skip around a lot, and also
// have to generate a single random number at a time
if (p.use_dropout) {
if (kSupportsDropout && p.use_dropout) {
auto si = shared_storage.after_mm0.si.accum_ref();
// each thread handles a contiguous sequence of elements from Sij, all
// coming from the same row. the reason they have to come from the same
Expand Down

0 comments on commit bc08bbc

Please sign in to comment.