From 36cf435dceefb1a8220fd59c23b7974a27433353 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Fri, 8 Jul 2022 22:58:42 +0200 Subject: [PATCH] Misc performance improvements for generic mem-efficient attention (#361) * 3% speedup by calculating mi from registers * Also compute m_prime/s_prime and exponentiate from registers * Support for Simt tiles * Fix TensorOp for V100 * Fix for A100 * Fix Simt alignment calculation * clang-format * WarpReduction before atomic call for Simt Co-authored-by: danthe3rd Co-authored-by: danthe3rd --- .../csrc/cuda/attention_forward_generic.cu | 138 ++--- .../cuda/attention_scaling_coefs_updater.h | 531 ++++++++++++++++++ 2 files changed, 581 insertions(+), 88 deletions(-) create mode 100644 xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h diff --git a/xformers/components/attention/csrc/cuda/attention_forward_generic.cu b/xformers/components/attention/csrc/cuda/attention_forward_generic.cu index a42cda4f0b..6b84c2a4d7 100644 --- a/xformers/components/attention/csrc/cuda/attention_forward_generic.cu +++ b/xformers/components/attention/csrc/cuda/attention_forward_generic.cu @@ -13,6 +13,7 @@ #include "cutlass/layout/vector.h" #include "cutlass/numeric_types.h" +#include "attention_scaling_coefs_updater.h" #include "cutlass/gemm/device/default_gemm_configuration.h" #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass/gemm/threadblock/default_mma_core_simt.h" @@ -189,6 +190,11 @@ struct AttentionKernel { using IteratorA = typename DefaultMma::IteratorA; using IteratorB = typename DefaultMma::IteratorB; using Mma = typename DefaultMma::ThreadblockMma; + using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + typename Mma::Operator::IteratorC, + accum_t, + kQueriesPerBlock, + kWarpSize>::Updater; }; struct MM1 { @@ -412,10 +418,9 @@ struct AttentionKernel { auto& mi = shared_storage.after_mm0.mi; if (warp_id == 0) { - for (int32_t q = 0; q + lane_id < kQueriesPerBlock; q += kWarpSize) { - s_prime[q + lane_id] = accum_t(0); - m_prime[q + lane_id] = -std::numeric_limits::infinity(); - } + static_assert(kQueriesPerBlock == kWarpSize); + s_prime[lane_id] = accum_t(0); + m_prime[lane_id] = -std::numeric_limits::infinity(); } // Iterate through keys @@ -425,39 +430,11 @@ struct AttentionKernel { // updated from end of prev iter // 1. Compute dot-product into shared memory for each query + // also calculates `mi`, and updates `m_prime` / `s_prime` compute_dot_product_qk( - iter_key_start, query, key, m_prime, shared_storage); - - __syncthreads(); // `mi` calculation done based on block data. `mi[a][i] - // == mi[a][j]` for all (a, i, j) - - // WARNING: This modifies `si` and `m_prime` to store the precalculated - // exp version so we can reuse it later in `compute_dot_product_att_value` - static_assert( - kQueriesPerBlock % kNumWarpsPerBlock == 0, - ".. or add a condition to loop below"); - for (int32_t q = warp_id; q < kQueriesPerBlock; - q += kNumWarpsPerBlock) { // parallel warps - // 3. Update s_prime - accum_t sp = accum_t(0); - accum_t my_mi = mi[q]; - static_assert( - kNumWarpsPerBlock * kWarpSize % kWarpSize == 0, - ".. or add a condition to loop below"); - for (int32_t key_id = lane_id; key_id < kNumWarpsPerBlock * kWarpSize; - key_id += kWarpSize) { // parallel lanes - accum_t si_exp = math::exp(si[q][key_id] - my_mi); - si_exp *= accum_t(key_id + iter_key_start < num_keys); - sp += si_exp; - si[q][key_id] = si_exp; - } - accum_t m_prime_exp = math::exp(m_prime[q] - my_mi); - sp = warpSum(sp) + s_prime[q] * m_prime_exp; + iter_key_start, query, key, m_prime, s_prime, shared_storage); - m_prime[q] = m_prime_exp; - s_prime[q] = sp; - } - __syncthreads(); // `s_prime` done + __syncthreads(); // 4. Partial matmull with the values we have and V // `v* <- v* . exp(m* - mi) + v_i . exp(si - mi)` @@ -471,10 +448,11 @@ struct AttentionKernel { __syncthreads(); // we modify `m_prime` after // 5. `m_prime` <- `mi` - for (int64_t q = thread_id(); q < kQueriesPerBlock; - q += kWarpSize * kNumWarpsPerBlock) { // parallel lanes - m_prime[q] = mi[q]; + if (warp_id == 0) { + static_assert(kQueriesPerBlock == kWarpSize); + m_prime[lane_id] = mi[lane_id]; } + __syncthreads(); } // 6. Divide by s_prime all of the values @@ -500,13 +478,11 @@ struct AttentionKernel { } // 7. Calculate logsumexp - if (logsumexp.size(0)) { - iter_query_last = std::min( - (int32_t)kQueriesPerBlock, int32_t(num_queries - query_start())); - for (int64_t q = thread_id(); q < iter_query_last; - q += kNumWarpsPerBlock * kWarpSize) { - *(logsumexp.data() + query_start() + q) = - accum_t(m_prime[q]) + std::log(accum_t(s_prime[q])); + if (logsumexp.size(0) && warp_id == 0) { + static_assert(kQueriesPerBlock == kWarpSize); + if (query_start() + lane_id < num_queries) { + logsumexp[query_start() + lane_id] = + accum_t(m_prime[lane_id]) + std::log(accum_t(s_prime[lane_id])); } } } @@ -554,6 +530,7 @@ struct AttentionKernel { at::TensorAccessor& query, at::TensorAccessor& key, cutlass::Array& m_prime, + cutlass::Array& s_prime, SharedStorage& shared_storage) { /* Computes the block-matrix product of: @@ -630,57 +607,42 @@ struct AttentionKernel { // Compute threadblock-scoped matrix multiply-add mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - __syncthreads(); // because we need to use shared memory as `si` now - + __syncthreads(); auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.after_mm0.mi; + if (my_warp_id == 0) { + static_assert(kQueriesPerBlock == kWarpSize); + mi[my_lane_id] = m_prime[my_lane_id]; + } + __syncthreads(); + + // Scale + accum_t scale = accum_t(1.0 / std::sqrt(float(K))); + accum = cutlass::multiplies()(scale, accum); + + typename Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = { + (tb_tile_offset.m() * Mma::WarpCount::kM) + + (my_warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (my_warp_id / Mma::WarpCount::kM)}; + // Update `mi` from accum stored in registers + typename MM0::ScalingCoefsUpdater updater; + updater.update( + accum, + mi, + m_prime, + s_prime, + my_lane_id, + my_warp_id, + key.size(0) - iter_key_start, + iteratorC_tile_offset); // Output results typename Mma::Operator::IteratorC iterator_C( {&si[0][0], kNumWarpsPerBlock * kWarpSize}, my_lane_id); - iterator_C.add_tile_offset( - {(tb_tile_offset.m() * Mma::WarpCount::kM) + - (my_warp_id % Mma::WarpCount::kM), - (tb_tile_offset.n() * Mma::WarpCount::kN) + - (my_warp_id / Mma::WarpCount::kM)}); - + iterator_C.add_tile_offset(iteratorC_tile_offset); iterator_C.store(accum); - - for (int32_t q = 0; q + my_lane_id < kQueriesPerBlock; q += kWarpSize) { - mi[q + my_lane_id] = -std::numeric_limits::infinity(); - } - __syncthreads(); - - // 2. Update `mi` - int64_t num_keys = key.size(0); - accum_t scale = accum_t(1.0 / std::sqrt(float(K))); - static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0); - for (int16_t q = 0; q < kQueriesPerBlock; - q += kNumWarpsPerBlock) { // parallel warps - if (query_start() + q + warp_id() >= num_queries) { - continue; - } - accum_t currentMax = m_prime[q + warp_id()]; - CUTLASS_PRAGMA_UNROLL - for (int64_t key_id = 0; key_id < kNumWarpsPerBlock * kWarpSize; - key_id += kWarpSize) { // parallel lanes - bool within_bounds = iter_key_start + key_id + lane_id() < num_keys; - - // TODO: Scaling could be done as part of an epilogue - // in the cutlass calculation above - accum_t dot_product = si[q + warp_id()][key_id + lane_id()]; - dot_product *= scale; - si[q + warp_id()][key_id + lane_id()] = - within_bounds ? dot_product : accum_t(0.0); - if (within_bounds) { - currentMax = std::max(currentMax, dot_product); - } - } - - currentMax = warpMax(currentMax); - mi[q + warp_id()] = currentMax; - } } static __device__ __forceinline__ accum_t warpMax(accum_t val) { diff --git a/xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h b/xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h new file mode 100644 index 0000000000..35d95e432d --- /dev/null +++ b/xformers/components/attention/csrc/cuda/attention_scaling_coefs_updater.h @@ -0,0 +1,531 @@ +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +namespace { + +static __device__ __forceinline__ float atomicMaxFloat( + float* addr, + float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +/* Iterates on the accumulator and corresponding position on result matrix + +(1) Update `mi[r]` to the max value of the row `r` +(2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + +All of this is done on registers, before we store all of this +on shared memory for the next matmul with Value. +*/ + +template +struct AttentionScalingCoefsUpdaterSm80 { + static_assert( + std::is_same::value); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord __device__ + get_lane_offset(int8_t lane_id, int8_t warp_id) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord(quad, lane_in_quad * kElementsPerAccess); + } + + static __device__ __forceinline__ void update( + typename T::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t warp_id, + int16_t max_col, + typename T::TensorCoord const& tile_offset) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + + cutlass::MatrixCoord lane_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + lane_offset += get_lane_offset(lane_id, warp_id); + int lane_in_quad = (lane_id & 3); + + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto reduceSameRow = [&](auto& myValue, auto fn) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + return lane_in_quad == 0; + }; + + auto iterateRowFirst = [&](auto beginRow, auto op, auto endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + }; + + // First update `mi` to the max per-row + { + accum_t max; + iterateRowFirst( + [&](int accum_m) { max = -std::numeric_limits::infinity(); }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = std::max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (warp_id == 0) { + static_assert(kQueriesPerBlock == kWarpSize); + auto m_prime_exp = expf(m_prime[lane_id] - mi[lane_id]); + m_prime[lane_id] = m_prime_exp; + s_prime[lane_id] *= m_prime_exp; + } + __syncthreads(); + + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + iterateRowFirst( + [&](int accum_m) { + mi_row = mi[accum_m]; + total_row = 0.0; + }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? expf(frag[idx] - mi_row) : accum_t(0.0); + total_row += frag[idx]; + }, + [&](int accum_m) { + if (reduceSameRow( + total_row, [](accum_t a, accum_t b) { return a + b; })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } +}; + +// cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator, float, cutlass::layout::RowMajor, cutlass::gemm::GemmShape<16, 16, 4>, +// cutlass::MatrixShape<1, 1>> See +// cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +template +struct AttentionScalingCoefsUpdaterVolta { + static_assert( + std::is_same::value); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord __device__ + get_lane_offset(int8_t lane_id, int8_t warp_id) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord(accum_m, accum_n); + } + + static __device__ __forceinline__ void update( + typename T::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t warp_id, + int16_t max_col, + typename T::TensorCoord const& tile_offset) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h + + cutlass::MatrixCoord lane_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + lane_offset += get_lane_offset(lane_id, warp_id); + + auto iterateRowFirst = [&](auto beginRow, auto op, auto endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + }; + + // First update `mi` to the max per-row + { + accum_t max; + iterateRowFirst( + [&](int accum_m) { max = -std::numeric_limits::infinity(); }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = std::max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (warp_id == 0) { + static_assert(kQueriesPerBlock == kWarpSize); + auto m_prime_exp = expf(m_prime[lane_id] - mi[lane_id]); + m_prime[lane_id] = m_prime_exp; + s_prime[lane_id] *= m_prime_exp; + } + __syncthreads(); + + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + iterateRowFirst( + [&](int accum_m) { + mi_row = mi[accum_m]; + total_row = 0.0; + }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? expf(frag[idx] - mi_row) : accum_t(0.0); + total_row += frag[idx]; + }, + [&](int accum_m) { atomicAdd(&s_prime[accum_m], total_row); }); + } + } +}; + +template +struct AttentionScalingCoefsUpdaterSimt { + static_assert( + std::is_same::value); + static_assert( + std::is_same::value); + + static __device__ __forceinline__ void update( + typename T::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t warp_id, + int16_t max_col, + typename T::TensorCoord const& tile_offset) { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + + // compute offset based on thread ID and lane layout + /* Policy= + cutlass::gemm::warp::MmaSimtPolicy< + WarpShape_=cutlass::MatrixShape<4, 8>, + LaneLayout_=cutlass::layout::RowMajorInterleaved<1>, + LaneMmaShape_=cutlass::gemm::GemmShape<8, 4, 1>> + */ + static_assert(std::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + lane_offset += + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + + auto reduceSameRow = [&](auto& myValue, auto fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + }; + + auto iterateRowFirst = [&](auto beginRow, auto op, auto endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + }; + + // First update `mi` to the max per-row + { + accum_t max; + iterateRowFirst( + [&](int accum_m) { max = -std::numeric_limits::infinity(); }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = std::max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (warp_id == 0) { + static_assert(kQueriesPerBlock == kWarpSize); + auto m_prime_exp = expf(m_prime[lane_id] - mi[lane_id]); + m_prime[lane_id] = m_prime_exp; + s_prime[lane_id] *= m_prime_exp; + } + __syncthreads(); + + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + iterateRowFirst( + [&](int accum_m) { + mi_row = mi[accum_m]; + total_row = 0.0; + }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? expf(frag[idx] - mi_row) : accum_t(0.0); + total_row += frag[idx]; + }, + [&](int accum_m) { + if (reduceSameRow( + total_row, [](accum_t a, accum_t b) { return a + b; })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } +}; + +template +struct DefaultAttentionScalingCoefsUpdater; + +// Simt +template < + typename S, + typename P, + typename accum_t, + int kQueriesPerBlock, + int kWarpSize> +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>, + accum_t, + kQueriesPerBlock, + kWarpSize> { + using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Updater = AttentionScalingCoefsUpdaterSimt< + Iterator, + accum_t, + kQueriesPerBlock, + kWarpSize>; +}; + +// Volta +template < + typename S1, + typename S2, + typename accum_t, + int kQueriesPerBlock, + int kWarpSize> +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kQueriesPerBlock, + kWarpSize> { + using Iterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Updater = AttentionScalingCoefsUpdaterVolta< + Iterator, + accum_t, + kQueriesPerBlock, + kWarpSize>; +}; + +// TensorOp - Sm80 +template < + typename S1, + typename S2, + typename S3, + typename accum_t, + int kQueriesPerBlock, + int kWarpSize> +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kQueriesPerBlock, + kWarpSize> { + using Iterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Updater = AttentionScalingCoefsUpdaterSm80< + Iterator, + accum_t, + kQueriesPerBlock, + kWarpSize>; +};