Skip to content

Commit

Permalink
Cleanup fMHA code
Browse files Browse the repository at this point in the history
ghstack-source-id: 48000b65896f42a7876b05a123f1ad73bf9b7e5a
Pull Request resolved: https://github.com/fairinternal/xformers/pull/440

__original_commit__ = fairinternal/xformers@af62d3696d4fa3afb64ff8a34fb418a0db2dd9e4
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jan 20, 2023
1 parent bac8718 commit 83a567f
Show file tree
Hide file tree
Showing 58 changed files with 406 additions and 494 deletions.
4 changes: 2 additions & 2 deletions xformers/csrc/attention/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ PyMODINIT_FUNC PyInit__C(void) {

TORCH_LIBRARY_FRAGMENT(xformers, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)"));
"xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, bool compute_logsumexp, bool causal, float? scale) -> (Tensor, Tensor, int, int)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)"));
"xformers::efficient_attention_backward_small_k(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward_cutlass(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int rng_offset, bool causal, float? scale) -> (Tensor, Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> attention_backward(

TORCH_LIBRARY_IMPL(xformers, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention"),
TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_small_k"),
TORCH_FN(attention));
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"),
TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_small_k"),
TORCH_FN(attention_backward));
}
4 changes: 2 additions & 2 deletions xformers/csrc/attention/cuda/fmha/debug_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ constexpr __string_view __get_type_name() {
int(ps.n()), \
int(ps.k()))

template <typename Iterator, typename LaneOffsetT, typename AccumT>
template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
CUTLASS_DEVICE void print_warp_accum(
AccumT accum,
LaneOffsetT lane_offset,
Expand All @@ -179,7 +179,7 @@ CUTLASS_DEVICE void print_warp_accum(
}
__syncthreads();
}
Iterator::iterateRows(
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,137 +5,15 @@
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
#include "gemm_kernel_utils.h"

namespace {

static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
} // namespace

/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
We have multiple implementations, because each configuration has a different way
of iterating in the accumulators.
/*
TensorCores have different accumulator layouts.
This file provides a class to easily map the accumulator
i-th element with the corresponding matrix row/col.
*/

template <typename BASE, typename T, typename accum_t, int kWarpSize>
struct RegisterOps {
template <
int kQueriesPerBlock,
bool kFullColumns,
bool kIsFirst,
bool kKeepOutputInRF>
CUTLASS_DEVICE static void update(
typename T::Fragment& frag_o, // output so far
typename T::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename T::TensorCoord const& tile_offset,
float scaling) {
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}

auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset);

// First update `mi` to the max per-row
{
accum_t max;
BASE::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<typename T::Fragment>()(scaling * kLog2e, frag);

// Make sure we all share the update values for `mi`
__syncthreads();

if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
BASE::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (BASE::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
};

template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSm80
: RegisterOps<
AttentionScalingCoefsUpdaterSm80<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
struct AccumLambdaIteratorSm80 {
static_assert(
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
Expand Down Expand Up @@ -208,12 +86,7 @@ struct AttentionScalingCoefsUpdaterSm80
};

template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterVolta
: RegisterOps<
AttentionScalingCoefsUpdaterVolta<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
struct AccumLambdaIteratorSm70 {
static_assert(
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
Expand Down Expand Up @@ -326,12 +199,7 @@ struct AttentionScalingCoefsUpdaterVolta
};

template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSimt
: RegisterOps<
AttentionScalingCoefsUpdaterSimt<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
struct AccumLambdaIteratorSimt {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
Expand Down Expand Up @@ -405,11 +273,11 @@ struct AttentionScalingCoefsUpdaterSimt
};

template <typename T, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater;
struct DefaultMmaAccumLambdaIterator;

// Simt
template <typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
Expand All @@ -420,21 +288,20 @@ struct DefaultAttentionScalingCoefsUpdater<
1>,
accum_t,
kWarpSize> {
using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>;
using Updater =
AttentionScalingCoefsUpdaterSimt<Iterator, accum_t, kWarpSize>;
using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
};

// TensorOp - Volta
template <typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
Expand All @@ -443,15 +310,14 @@ struct DefaultAttentionScalingCoefsUpdater<
cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using Iterator =
using WarpIterator =
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>;
using Updater =
AttentionScalingCoefsUpdaterVolta<Iterator, accum_t, kWarpSize>;
using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
};

// TensorOp - Sm75+
Expand All @@ -461,7 +327,7 @@ template <
typename S3,
typename accum_t,
int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
Expand All @@ -470,13 +336,12 @@ struct DefaultAttentionScalingCoefsUpdater<
S3>,
accum_t,
kWarpSize> {
using Iterator =
using WarpIterator =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>;
using Updater =
AttentionScalingCoefsUpdaterSm80<Iterator, accum_t, kWarpSize>;
using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
};
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/vector_iterator.h"

#include "attention_scaling_coefs_updater.h"
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
#include "../gemm/mma_accum_lambda_iterator.h"
#include "../gemm_kernel_utils.h"
#include "../iterators/make_residual_last.h"
#include "../iterators/transpose_warp_iterator.h"
#include "../iterators/warp_iterator_from_smem.h"
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
#include "epilogue_thread_apply_logsumexp.h"
#include "gemm_kernel_utils.h"
#include "iterators/make_residual_last.h"
#include "iterators/transpose_warp_iterator.h"
#include "iterators/warp_iterator_from_smem.h"

namespace cutlass {
namespace gemm {
Expand Down Expand Up @@ -1879,18 +1879,17 @@ struct B2bGemm<
// NOTE: accum is attn.T
// TODO: Optimize for each architecture
static constexpr int WarpSize = 32;
using RegistersIter = typename DefaultAttentionScalingCoefsUpdater<
IteratorC,
accum_t,
WarpSize>::Updater;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
Iterator;
auto lane_offset =
RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords);
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);

cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
lse_prefetched.clear();
int rowIdx = 0;
int colIdx = 0;
RegistersIter::iterateRows(
AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
++rowIdx;
Expand Down Expand Up @@ -2019,18 +2018,17 @@ struct B2bGemm<
// NOTE: accum is attn.T
// TODO: Optimize for each architecture
static constexpr int WarpSize = 32;
using RegistersIter = typename DefaultAttentionScalingCoefsUpdater<
IteratorC,
accum_t,
WarpSize>::Updater;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
Iterator;
auto lane_offset =
RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords);
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);

cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
lse_prefetched.clear();
int rowIdx = 0;
int colIdx = 0;
RegistersIter::iterateRows(
AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
++rowIdx;
Expand Down
Loading

0 comments on commit 83a567f

Please sign in to comment.