Skip to content

Commit

Permalink
Misc performance improvements for generic mem-efficient attention (#361)
Browse files Browse the repository at this point in the history
* 3% speedup by calculating mi from registers

* Also compute m_prime/s_prime and exponentiate from registers

* Support for Simt tiles

* Fix TensorOp for V100

* Fix for A100

* Fix Simt alignment calculation

* clang-format

* WarpReduction before atomic call for Simt

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
  • Loading branch information
2 people authored and fmassa committed Aug 10, 2022
1 parent 573ed14 commit 36cf435
Show file tree
Hide file tree
Showing 2 changed files with 581 additions and 88 deletions.
138 changes: 50 additions & 88 deletions xformers/components/attention/csrc/cuda/attention_forward_generic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"

#include "attention_scaling_coefs_updater.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
Expand Down Expand Up @@ -189,6 +190,11 @@ struct AttentionKernel {
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
typename Mma::Operator::IteratorC,
accum_t,
kQueriesPerBlock,
kWarpSize>::Updater;
};

struct MM1 {
Expand Down Expand Up @@ -412,10 +418,9 @@ struct AttentionKernel {
auto& mi = shared_storage.after_mm0.mi;

if (warp_id == 0) {
for (int32_t q = 0; q + lane_id < kQueriesPerBlock; q += kWarpSize) {
s_prime[q + lane_id] = accum_t(0);
m_prime[q + lane_id] = -std::numeric_limits<accum_t>::infinity();
}
static_assert(kQueriesPerBlock == kWarpSize);
s_prime[lane_id] = accum_t(0);
m_prime[lane_id] = -std::numeric_limits<accum_t>::infinity();
}

// Iterate through keys
Expand All @@ -425,39 +430,11 @@ struct AttentionKernel {
// updated from end of prev iter

// 1. Compute dot-product into shared memory for each query
// also calculates `mi`, and updates `m_prime` / `s_prime`
compute_dot_product_qk(
iter_key_start, query, key, m_prime, shared_storage);

__syncthreads(); // `mi` calculation done based on block data. `mi[a][i]
// == mi[a][j]` for all (a, i, j)

// WARNING: This modifies `si` and `m_prime` to store the precalculated
// exp version so we can reuse it later in `compute_dot_product_att_value`
static_assert(
kQueriesPerBlock % kNumWarpsPerBlock == 0,
".. or add a condition to loop below");
for (int32_t q = warp_id; q < kQueriesPerBlock;
q += kNumWarpsPerBlock) { // parallel warps
// 3. Update s_prime
accum_t sp = accum_t(0);
accum_t my_mi = mi[q];
static_assert(
kNumWarpsPerBlock * kWarpSize % kWarpSize == 0,
".. or add a condition to loop below");
for (int32_t key_id = lane_id; key_id < kNumWarpsPerBlock * kWarpSize;
key_id += kWarpSize) { // parallel lanes
accum_t si_exp = math<accum_t>::exp(si[q][key_id] - my_mi);
si_exp *= accum_t(key_id + iter_key_start < num_keys);
sp += si_exp;
si[q][key_id] = si_exp;
}
accum_t m_prime_exp = math<accum_t>::exp(m_prime[q] - my_mi);
sp = warpSum(sp) + s_prime[q] * m_prime_exp;
iter_key_start, query, key, m_prime, s_prime, shared_storage);

m_prime[q] = m_prime_exp;
s_prime[q] = sp;
}
__syncthreads(); // `s_prime` done
__syncthreads();

// 4. Partial matmull with the values we have and V
// `v* <- v* . exp(m* - mi) + v_i . exp(si - mi)`
Expand All @@ -471,10 +448,11 @@ struct AttentionKernel {
__syncthreads(); // we modify `m_prime` after

// 5. `m_prime` <- `mi`
for (int64_t q = thread_id(); q < kQueriesPerBlock;
q += kWarpSize * kNumWarpsPerBlock) { // parallel lanes
m_prime[q] = mi[q];
if (warp_id == 0) {
static_assert(kQueriesPerBlock == kWarpSize);
m_prime[lane_id] = mi[lane_id];
}
__syncthreads();
}

// 6. Divide by s_prime all of the values
Expand All @@ -500,13 +478,11 @@ struct AttentionKernel {
}

// 7. Calculate logsumexp
if (logsumexp.size(0)) {
iter_query_last = std::min<int32_t>(
(int32_t)kQueriesPerBlock, int32_t(num_queries - query_start()));
for (int64_t q = thread_id(); q < iter_query_last;
q += kNumWarpsPerBlock * kWarpSize) {
*(logsumexp.data() + query_start() + q) =
accum_t(m_prime[q]) + std::log(accum_t(s_prime[q]));
if (logsumexp.size(0) && warp_id == 0) {
static_assert(kQueriesPerBlock == kWarpSize);
if (query_start() + lane_id < num_queries) {
logsumexp[query_start() + lane_id] =
accum_t(m_prime[lane_id]) + std::log(accum_t(s_prime[lane_id]));
}
}
}
Expand Down Expand Up @@ -554,6 +530,7 @@ struct AttentionKernel {
at::TensorAccessor<scalar_t, 2, at::DefaultPtrTraits, int32_t>& query,
at::TensorAccessor<scalar_t, 2, at::DefaultPtrTraits, int32_t>& key,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
SharedStorage& shared_storage) {
/*
Computes the block-matrix product of:
Expand Down Expand Up @@ -630,57 +607,42 @@ struct AttentionKernel {
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);

__syncthreads(); // because we need to use shared memory as `si` now

__syncthreads();
auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.after_mm0.mi;
if (my_warp_id == 0) {
static_assert(kQueriesPerBlock == kWarpSize);
mi[my_lane_id] = m_prime[my_lane_id];
}
__syncthreads();

// Scale
accum_t scale = accum_t(1.0 / std::sqrt(float(K)));
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);

typename Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = {
(tb_tile_offset.m() * Mma::WarpCount::kM) +
(my_warp_id % Mma::WarpCount::kM),
(tb_tile_offset.n() * Mma::WarpCount::kN) +
(my_warp_id / Mma::WarpCount::kM)};
// Update `mi` from accum stored in registers
typename MM0::ScalingCoefsUpdater updater;
updater.update(
accum,
mi,
m_prime,
s_prime,
my_lane_id,
my_warp_id,
key.size(0) - iter_key_start,
iteratorC_tile_offset);

// Output results
typename Mma::Operator::IteratorC iterator_C(
{&si[0][0], kNumWarpsPerBlock * kWarpSize}, my_lane_id);

iterator_C.add_tile_offset(
{(tb_tile_offset.m() * Mma::WarpCount::kM) +
(my_warp_id % Mma::WarpCount::kM),
(tb_tile_offset.n() * Mma::WarpCount::kN) +
(my_warp_id / Mma::WarpCount::kM)});

iterator_C.add_tile_offset(iteratorC_tile_offset);
iterator_C.store(accum);

for (int32_t q = 0; q + my_lane_id < kQueriesPerBlock; q += kWarpSize) {
mi[q + my_lane_id] = -std::numeric_limits<accum_t>::infinity();
}
__syncthreads();

// 2. Update `mi`
int64_t num_keys = key.size(0);
accum_t scale = accum_t(1.0 / std::sqrt(float(K)));
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0);
for (int16_t q = 0; q < kQueriesPerBlock;
q += kNumWarpsPerBlock) { // parallel warps
if (query_start() + q + warp_id() >= num_queries) {
continue;
}
accum_t currentMax = m_prime[q + warp_id()];
CUTLASS_PRAGMA_UNROLL
for (int64_t key_id = 0; key_id < kNumWarpsPerBlock * kWarpSize;
key_id += kWarpSize) { // parallel lanes
bool within_bounds = iter_key_start + key_id + lane_id() < num_keys;

// TODO: Scaling could be done as part of an epilogue
// in the cutlass calculation above
accum_t dot_product = si[q + warp_id()][key_id + lane_id()];
dot_product *= scale;
si[q + warp_id()][key_id + lane_id()] =
within_bounds ? dot_product : accum_t(0.0);
if (within_bounds) {
currentMax = std::max(currentMax, dot_product);
}
}

currentMax = warpMax(currentMax);
mi[q + warp_id()] = currentMax;
}
}

static __device__ __forceinline__ accum_t warpMax(accum_t val) {
Expand Down
Loading

0 comments on commit 36cf435

Please sign in to comment.