From 83a567f11f9e10da0ca1fead2e9812a7ffe0c9a4 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Fri, 20 Jan 2023 12:33:35 +0000 Subject: [PATCH] Cleanup fMHA code ghstack-source-id: 48000b65896f42a7876b05a123f1ad73bf9b7e5a Pull Request resolved: https://github.com/fairinternal/xformers/pull/440 __original_commit__ = fairinternal/xformers@af62d3696d4fa3afb64ff8a34fb418a0db2dd9e4 --- xformers/csrc/attention/attention.cpp | 4 +- .../cpu/{attention.cpp => small_k.cpp} | 4 +- .../csrc/attention/cuda/fmha/debug_utils.h | 4 +- .../fmha/{ => epilogue}/epilogue_pipelined.h | 0 .../{ => epilogue}/epilogue_rescale_output.h | 0 .../epilogue_thread_apply_logsumexp.h | 0 .../cuda/fmha/{ => gemm}/find_default_mma.h | 0 .../mma_accum_lambda_iterator.h} | 169 ++----------- .../cuda/fmha/{ => gemm}/mma_from_smem.h | 34 ++- .../attention/cuda/fmha/kernel_backward.h | 107 +++----- .../csrc/attention/cuda/fmha/kernel_forward.h | 228 ++++++++++-------- .../attention/cuda/fmha/kernels/backward.h | 68 ++++++ .../cuda/fmha/kernels/backward_bf16.cu | 4 +- .../fmha/kernels/backward_bf16_aligned.cu | 4 +- .../kernels/backward_bf16_aligned_dropout.cu | 4 +- .../backward_bf16_aligned_dropout_k128.cu | 4 +- .../backward_bf16_aligned_dropout_k64.cu | 4 +- .../kernels/backward_bf16_aligned_k128.cu | 4 +- .../fmha/kernels/backward_bf16_aligned_k64.cu | 4 +- .../fmha/kernels/backward_bf16_dropout.cu | 4 +- .../kernels/backward_bf16_dropout_k128.cu | 4 +- .../fmha/kernels/backward_bf16_dropout_k64.cu | 4 +- .../cuda/fmha/kernels/backward_bf16_k128.cu | 4 +- .../cuda/fmha/kernels/backward_bf16_k64.cu | 4 +- .../cuda/fmha/kernels/backward_f16.cu | 4 +- .../cuda/fmha/kernels/backward_f16_aligned.cu | 4 +- .../kernels/backward_f16_aligned_dropout.cu | 4 +- .../backward_f16_aligned_dropout_k128.cu | 4 +- .../backward_f16_aligned_dropout_k64.cu | 4 +- .../fmha/kernels/backward_f16_aligned_k128.cu | 4 +- .../fmha/kernels/backward_f16_aligned_k64.cu | 4 +- .../cuda/fmha/kernels/backward_f16_dropout.cu | 4 +- .../fmha/kernels/backward_f16_dropout_k128.cu | 4 +- .../fmha/kernels/backward_f16_dropout_k64.cu | 4 +- .../cuda/fmha/kernels/backward_f16_k128.cu | 4 +- .../cuda/fmha/kernels/backward_f16_k64.cu | 4 +- .../cuda/fmha/kernels/backward_f32.cu | 4 +- .../cuda/fmha/kernels/backward_f32_aligned.cu | 4 +- .../kernels/backward_f32_aligned_dropout.cu | 4 +- .../backward_f32_aligned_dropout_k128.cu | 4 +- .../backward_f32_aligned_dropout_k64.cu | 4 +- .../fmha/kernels/backward_f32_aligned_k128.cu | 4 +- .../fmha/kernels/backward_f32_aligned_k64.cu | 4 +- .../cuda/fmha/kernels/backward_f32_dropout.cu | 4 +- .../fmha/kernels/backward_f32_dropout_k128.cu | 4 +- .../fmha/kernels/backward_f32_dropout_k64.cu | 4 +- .../cuda/fmha/kernels/backward_f32_k128.cu | 4 +- .../cuda/fmha/kernels/backward_f32_k64.cu | 4 +- .../attention/cuda/fmha/kernels/forward.h | 92 +++++++ .../cuda/fmha/kernels/forward_bf16.cu | 4 +- .../cuda/fmha/kernels/forward_bf16_aligned.cu | 4 +- .../cuda/fmha/kernels/forward_f16.cu | 4 +- .../cuda/fmha/kernels/forward_f16_aligned.cu | 4 +- .../cuda/fmha/kernels/forward_f32.cu | 4 +- .../cuda/fmha/kernels/forward_f32_aligned.cu | 4 +- .../cuda/fmha/kernels/generate_kernels.sh | 14 +- .../cuda/fmha/{attention.cu => small_k.cu} | 4 +- xformers/ops/fmha/small_k.py | 4 +- 58 files changed, 406 insertions(+), 494 deletions(-) rename xformers/csrc/attention/cpu/{attention.cpp => small_k.cpp} (99%) rename xformers/csrc/attention/cuda/fmha/{ => epilogue}/epilogue_pipelined.h (100%) rename xformers/csrc/attention/cuda/fmha/{ => epilogue}/epilogue_rescale_output.h (100%) rename xformers/csrc/attention/cuda/fmha/{ => epilogue}/epilogue_thread_apply_logsumexp.h (100%) rename xformers/csrc/attention/cuda/fmha/{ => gemm}/find_default_mma.h (100%) rename xformers/csrc/attention/cuda/fmha/{attention_scaling_coefs_updater.h => gemm/mma_accum_lambda_iterator.h} (68%) rename xformers/csrc/attention/cuda/fmha/{ => gemm}/mma_from_smem.h (98%) create mode 100644 xformers/csrc/attention/cuda/fmha/kernels/backward.h create mode 100644 xformers/csrc/attention/cuda/fmha/kernels/forward.h rename xformers/csrc/attention/cuda/fmha/{attention.cu => small_k.cu} (99%) diff --git a/xformers/csrc/attention/attention.cpp b/xformers/csrc/attention/attention.cpp index 8b01485e73..af2dc4501a 100644 --- a/xformers/csrc/attention/attention.cpp +++ b/xformers/csrc/attention/attention.cpp @@ -15,11 +15,11 @@ PyMODINIT_FUNC PyInit__C(void) { TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); + "xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, bool compute_logsumexp, bool causal, float? scale) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)")); + "xformers::efficient_attention_backward_small_k(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_backward_cutlass(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int rng_offset, bool causal, float? scale) -> (Tensor, Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( diff --git a/xformers/csrc/attention/cpu/attention.cpp b/xformers/csrc/attention/cpu/small_k.cpp similarity index 99% rename from xformers/csrc/attention/cpu/attention.cpp rename to xformers/csrc/attention/cpu/small_k.cpp index 7052eacca4..e494f619d5 100644 --- a/xformers/csrc/attention/cpu/attention.cpp +++ b/xformers/csrc/attention/cpu/small_k.cpp @@ -355,9 +355,9 @@ std::tuple attention_backward( TORCH_LIBRARY_IMPL(xformers, CPU, m) { m.impl( - TORCH_SELECTIVE_NAME("xformers::efficient_attention"), + TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_small_k"), TORCH_FN(attention)); m.impl( - TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"), + TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_small_k"), TORCH_FN(attention_backward)); } diff --git a/xformers/csrc/attention/cuda/fmha/debug_utils.h b/xformers/csrc/attention/cuda/fmha/debug_utils.h index add5b5a064..98038544cb 100644 --- a/xformers/csrc/attention/cuda/fmha/debug_utils.h +++ b/xformers/csrc/attention/cuda/fmha/debug_utils.h @@ -163,7 +163,7 @@ constexpr __string_view __get_type_name() { int(ps.n()), \ int(ps.k())) -template +template CUTLASS_DEVICE void print_warp_accum( AccumT accum, LaneOffsetT lane_offset, @@ -179,7 +179,7 @@ CUTLASS_DEVICE void print_warp_accum( } __syncthreads(); } - Iterator::iterateRows( + LambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, [&](int accum_m, int accum_n, int idx) { diff --git a/xformers/csrc/attention/cuda/fmha/epilogue_pipelined.h b/xformers/csrc/attention/cuda/fmha/epilogue/epilogue_pipelined.h similarity index 100% rename from xformers/csrc/attention/cuda/fmha/epilogue_pipelined.h rename to xformers/csrc/attention/cuda/fmha/epilogue/epilogue_pipelined.h diff --git a/xformers/csrc/attention/cuda/fmha/epilogue_rescale_output.h b/xformers/csrc/attention/cuda/fmha/epilogue/epilogue_rescale_output.h similarity index 100% rename from xformers/csrc/attention/cuda/fmha/epilogue_rescale_output.h rename to xformers/csrc/attention/cuda/fmha/epilogue/epilogue_rescale_output.h diff --git a/xformers/csrc/attention/cuda/fmha/epilogue_thread_apply_logsumexp.h b/xformers/csrc/attention/cuda/fmha/epilogue/epilogue_thread_apply_logsumexp.h similarity index 100% rename from xformers/csrc/attention/cuda/fmha/epilogue_thread_apply_logsumexp.h rename to xformers/csrc/attention/cuda/fmha/epilogue/epilogue_thread_apply_logsumexp.h diff --git a/xformers/csrc/attention/cuda/fmha/find_default_mma.h b/xformers/csrc/attention/cuda/fmha/gemm/find_default_mma.h similarity index 100% rename from xformers/csrc/attention/cuda/fmha/find_default_mma.h rename to xformers/csrc/attention/cuda/fmha/gemm/find_default_mma.h diff --git a/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h b/xformers/csrc/attention/cuda/fmha/gemm/mma_accum_lambda_iterator.h similarity index 68% rename from xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h rename to xformers/csrc/attention/cuda/fmha/gemm/mma_accum_lambda_iterator.h index 9265b52b3c..8c33c1afc0 100644 --- a/xformers/csrc/attention/cuda/fmha/attention_scaling_coefs_updater.h +++ b/xformers/csrc/attention/cuda/fmha/gemm/mma_accum_lambda_iterator.h @@ -5,137 +5,15 @@ #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" #include "cutlass/matrix_shape.h" -#include "gemm_kernel_utils.h" -namespace { - -static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { - // source: https://stackoverflow.com/a/51549250 - return (value >= 0) - ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) - : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); -} -} // namespace - -/* Iterates on the accumulator and corresponding position on result matrix - -(1) Update `mi[r]` to the max value of the row `r` -(2) In a second iteration do the following: - (a) accum <- exp(accum - mi) - (b) m_prime <- exp(m_prime - mi) - (c) s_prime <- s_prime * m_prime + sum(accum) - -All of this is done on registers, before we store all of this -on shared memory for the next matmul with Value. - -We have multiple implementations, because each configuration has a different way -of iterating in the accumulators. +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. */ -template -struct RegisterOps { - template < - int kQueriesPerBlock, - bool kFullColumns, - bool kIsFirst, - bool kKeepOutputInRF> - CUTLASS_DEVICE static void update( - typename T::Fragment& frag_o, // output so far - typename T::Fragment& frag, - cutlass::Array& mi, - cutlass::Array& m_prime, - cutlass::Array& s_prime, - int8_t lane_id, - int8_t thread_id, - int8_t warp_id, - int16_t max_col, - typename T::TensorCoord const& tile_offset, - float scaling) { - // Convert to `accum_t` (rather than double) - constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E - if (!kIsFirst) { - if (thread_id < kQueriesPerBlock) { - m_prime[thread_id] = mi[thread_id]; - } - __syncthreads(); - } - - auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset); - - // First update `mi` to the max per-row - { - accum_t max; - BASE::iterateRows( - lane_offset, - [&](int accum_m) { - max = -cutlass::platform::numeric_limits::infinity(); - }, - [&](int accum_m, int accum_n, int idx) { - if (kFullColumns || accum_n < max_col) { - max = cutlass::fast_max(max, frag[idx]); - } - }, - [&](int accum_m) { - // Having 4x atomicMax seems faster than reduce within warp - // first... - atomicMaxFloat(&mi[accum_m], max * scaling); - }); - } - frag = cutlass::multiplies()(scaling * kLog2e, frag); - - // Make sure we all share the update values for `mi` - __syncthreads(); - - if (thread_id < kQueriesPerBlock) { - auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); - m_prime[thread_id] = m_prime_exp; - s_prime[thread_id] *= m_prime_exp; - } - __syncthreads(); // Update output fragments - if (kKeepOutputInRF && !kIsFirst) { - accum_t mp; - BASE::iterateRows( - lane_offset, - [&](int accum_m) { mp = m_prime[accum_m]; }, - [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, - [&](int accum_m) {}); - __syncthreads(); - } - // Update accum_m, accum_n, ... - { - accum_t mi_row, total_row; - BASE::iterateRows( - lane_offset, - [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, - [&](int accum_m, int accum_n, int idx) { - frag[idx] = (kFullColumns || accum_n < max_col) - ? exp2f(frag[idx] - mi_row) - : accum_t(0.0); - }, - [&](int accum_m) {}); - BASE::iterateRows( - lane_offset, - [&](int accum_m) { total_row = 0.0; }, - [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, - [&](int accum_m) { - if (BASE::reduceSameRow( - lane_id, total_row, [](accum_t a, accum_t b) { - return a + b; - })) { - atomicAdd(&s_prime[accum_m], total_row); - } - }); - } - } -}; - template -struct AttentionScalingCoefsUpdaterSm80 - : RegisterOps< - AttentionScalingCoefsUpdaterSm80, - T, - accum_t, - kWarpSize> { +struct AccumLambdaIteratorSm80 { static_assert( cutlass::platform:: is_same::value, @@ -208,12 +86,7 @@ struct AttentionScalingCoefsUpdaterSm80 }; template -struct AttentionScalingCoefsUpdaterVolta - : RegisterOps< - AttentionScalingCoefsUpdaterVolta, - T, - accum_t, - kWarpSize> { +struct AccumLambdaIteratorSm70 { static_assert( cutlass::platform:: is_same::value, @@ -326,12 +199,7 @@ struct AttentionScalingCoefsUpdaterVolta }; template -struct AttentionScalingCoefsUpdaterSimt - : RegisterOps< - AttentionScalingCoefsUpdaterSimt, - T, - accum_t, - kWarpSize> { +struct AccumLambdaIteratorSimt { using Policy = typename T::Policy; using Iterations = typename T::Iterations; using Element = typename T::Element; @@ -405,11 +273,11 @@ struct AttentionScalingCoefsUpdaterSimt }; template -struct DefaultAttentionScalingCoefsUpdater; +struct DefaultMmaAccumLambdaIterator; // Simt template -struct DefaultAttentionScalingCoefsUpdater< +struct DefaultMmaAccumLambdaIterator< cutlass::gemm::warp::MmaSimtTileIterator< S, cutlass::gemm::Operand::kC, @@ -420,7 +288,7 @@ struct DefaultAttentionScalingCoefsUpdater< 1>, accum_t, kWarpSize> { - using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< S, cutlass::gemm::Operand::kC, accum_t, @@ -428,13 +296,12 @@ struct DefaultAttentionScalingCoefsUpdater< P, 1, 1>; - using Updater = - AttentionScalingCoefsUpdaterSimt; + using Iterator = AccumLambdaIteratorSimt; }; // TensorOp - Volta template -struct DefaultAttentionScalingCoefsUpdater< +struct DefaultMmaAccumLambdaIterator< cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< S1, accum_t, @@ -443,15 +310,14 @@ struct DefaultAttentionScalingCoefsUpdater< cutlass::MatrixShape<1, 1>>, accum_t, kWarpSize> { - using Iterator = + using WarpIterator = typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< S1, accum_t, cutlass::layout::RowMajor, S2, cutlass::MatrixShape<1, 1>>; - using Updater = - AttentionScalingCoefsUpdaterVolta; + using Iterator = AccumLambdaIteratorSm70; }; // TensorOp - Sm75+ @@ -461,7 +327,7 @@ template < typename S3, typename accum_t, int kWarpSize> -struct DefaultAttentionScalingCoefsUpdater< +struct DefaultMmaAccumLambdaIterator< cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< S1, accum_t, @@ -470,13 +336,12 @@ struct DefaultAttentionScalingCoefsUpdater< S3>, accum_t, kWarpSize> { - using Iterator = + using WarpIterator = typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< S1, accum_t, cutlass::layout::RowMajor, S2, S3>; - using Updater = - AttentionScalingCoefsUpdaterSm80; + using Iterator = AccumLambdaIteratorSm80; }; diff --git a/xformers/csrc/attention/cuda/fmha/mma_from_smem.h b/xformers/csrc/attention/cuda/fmha/gemm/mma_from_smem.h similarity index 98% rename from xformers/csrc/attention/cuda/fmha/mma_from_smem.h rename to xformers/csrc/attention/cuda/fmha/gemm/mma_from_smem.h index b280d81476..5e14176436 100644 --- a/xformers/csrc/attention/cuda/fmha/mma_from_smem.h +++ b/xformers/csrc/attention/cuda/fmha/gemm/mma_from_smem.h @@ -52,17 +52,17 @@ #include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/vector_iterator.h" -#include "attention_scaling_coefs_updater.h" +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" #include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" #include "cutlass/gemm/threadblock/mma_base.h" #include "cutlass/gemm/threadblock/mma_multistage.h" #include "cutlass/gemm/threadblock/mma_pipelined.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" -#include "epilogue_thread_apply_logsumexp.h" -#include "gemm_kernel_utils.h" -#include "iterators/make_residual_last.h" -#include "iterators/transpose_warp_iterator.h" -#include "iterators/warp_iterator_from_smem.h" namespace cutlass { namespace gemm { @@ -1879,18 +1879,17 @@ struct B2bGemm< // NOTE: accum is attn.T // TODO: Optimize for each architecture static constexpr int WarpSize = 32; - using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< - IteratorC, - accum_t, - WarpSize>::Updater; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; auto lane_offset = - RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); cutlass::Array lse_prefetched; lse_prefetched.clear(); int rowIdx = 0; int colIdx = 0; - RegistersIter::iterateRows( + AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { ++rowIdx; @@ -2019,18 +2018,17 @@ struct B2bGemm< // NOTE: accum is attn.T // TODO: Optimize for each architecture static constexpr int WarpSize = 32; - using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< - IteratorC, - accum_t, - WarpSize>::Updater; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; auto lane_offset = - RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); cutlass::Array lse_prefetched; lse_prefetched.clear(); int rowIdx = 0; int colIdx = 0; - RegistersIter::iterateRows( + AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { ++rowIdx; diff --git a/xformers/csrc/attention/cuda/fmha/kernel_backward.h b/xformers/csrc/attention/cuda/fmha/kernel_backward.h index 9626fd03a9..e0c03e4aaf 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_backward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_backward.h @@ -7,10 +7,12 @@ #include #include +#ifdef HAS_PYTORCH #include #include #include #include +#endif #include "cutlass/cutlass.h" #include "cutlass/epilogue/thread/linear_combination.h" @@ -42,12 +44,13 @@ #include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/vector_iterator.h" -#include "epilogue_pipelined.h" +#include "epilogue/epilogue_pipelined.h" #include "iterators/epilogue_predicated_tile_iterator.h" -#include "find_default_mma.h" #include "gemm/custom_mma.h" -#include "mma_from_smem.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "gemm/mma_from_smem.h" #include "transform/tile_smem_loader.h" #include @@ -213,8 +216,10 @@ struct AttentionBackwardKernel { int32_t gB_strideM; int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise +#ifdef HAS_PYTORCH // dropout at::PhiloxCudaState rng_engine_inputs; +#endif // RNG sequence offset based on batch_id and head_id unsigned long long dropout_batch_head_rng_offset; float dropout_prob; @@ -480,10 +485,10 @@ struct AttentionBackwardKernel { scalar_t, WarpShape, ThreadblockShape>; - using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< typename Mma::Operator::IteratorC, accum_t, - kWarpSize>::Updater; + kWarpSize>::Iterator; using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; }; @@ -1065,6 +1070,7 @@ struct AttentionBackwardKernel { OutputFragments output_frags; curandStatePhilox4_32_10_t rng_state_init; +#ifdef HAS_PYTORCH if (kApplyDropout) { auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); // each element of the attention matrix P with shape @@ -1081,6 +1087,7 @@ struct AttentionBackwardKernel { std::get<1>(seeds) + p.dropout_batch_head_rng_offset, &rng_state_init); } +#endif int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; @@ -1307,9 +1314,9 @@ struct AttentionBackwardKernel { MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); // Pij += Bij, where Pij is in register fragment and Bij is in shmem - auto lane_offset = MatmulQK::ScalingCoefsUpdater::get_lane_offset( + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( lane_id, warp_id, output_tile_coords); - MatmulQK::ScalingCoefsUpdater::iterateRows( + MatmulQK::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_n) {}, [&](int accum_m, int accum_n, int idx) { @@ -1325,9 +1332,9 @@ struct AttentionBackwardKernel { // Apply mask if (p.causal) { - auto lane_offset = MatmulQK::ScalingCoefsUpdater::get_lane_offset( + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( lane_id, warp_id, output_tile_coords); - MatmulQK::ScalingCoefsUpdater::iterateRows( + MatmulQK::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, [&](int accum_m, int accum_n, int idx) { @@ -1545,11 +1552,11 @@ struct AttentionBackwardKernel { // attn_shared_storage [smem] <- tmp.T // tmp_shared_storage [smem] <- tmp { - using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< typename Mma::Operator::IteratorC, typename MatmulDOIVJ::ElementAccum, - kWarpSize>::Updater; - auto lane_offset = RegistersIter::get_lane_offset( + kWarpSize>::Iterator; + auto lane_offset = LambdaIterator::get_lane_offset( lane_id, warp_id, output_tile_coords); // if dropout was used, compute dPij = dPij_dropped * Zij @@ -1558,7 +1565,7 @@ struct AttentionBackwardKernel { if (kApplyDropout) { const auto zij = shared_storage.zij().accum_ref(); - RegistersIter::iterateRows( + LambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, [&](int accum_m, int accum_n, int idx) { @@ -1577,7 +1584,7 @@ struct AttentionBackwardKernel { auto attn_T = shared_storage.attn_shared_storage().accum_ref(); accum_t current_di; typename Mma::FragmentC fragment_attn, fragment_di; - RegistersIter::iterateRows( + LambdaIterator::iterateRows( lane_offset, [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, [&](int accum_m, int accum_n, int idx) { @@ -1628,7 +1635,7 @@ struct AttentionBackwardKernel { if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); // attn <- attn_T.T - RegistersIter::iterateRows( + LambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, [&](int accum_m, int accum_n, int idx) { @@ -2076,67 +2083,11 @@ struct AttentionBackwardKernel { template __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) - attention_kernel_backward_batched(typename AK::Params params); - -#define _ATTENTION_KERNEL_BACKWARD_BEGIN(...) \ - template <> \ - __global__ void __launch_bounds__( \ - __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ - attention_kernel_backward_batched<__VA_ARGS__>( \ - typename __VA_ARGS__::Params p) { \ - using Kernel = __VA_ARGS__; -#define _ATTENTION_KERNEL_BACKWARD_END() } - -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD(ARCH, ...) \ - _ATTENTION_KERNEL_BACKWARD_BEGIN( \ - AttentionBackwardKernel) \ - p.advance_to_block(); \ - Kernel::kernel(p); \ - _ATTENTION_KERNEL_BACKWARD_END(); - -#ifdef __CUDA_ARCH__ -#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ -#else -#define __CUDA_ARCH_OR_ZERO__ 0 -#endif + attention_kernel_backward_batched_impl(typename AK::Params p) { + p.advance_to_block(); + AK::attention_kernel(p); +} -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(ARCH, ...) \ - _ATTENTION_KERNEL_BACKWARD_BEGIN( \ - AttentionBackwardKernel) \ - printf( \ - "FATAL: this function is for sm%d, but was built with __CUDA_ARCH__=%d\n", \ - int(ARCH), \ - int(__CUDA_ARCH_OR_ZERO__)); \ - _ATTENTION_KERNEL_BACKWARD_END(); - -// All kernels are disabled by default -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(50, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(70, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(75, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(80, __VA_ARGS__) - -// Enable the right one based on __CUDA_ARCH__ -#ifndef __CUDA_ARCH__ -#elif __CUDA_ARCH__ < 500 -#error "Need cuda arch at least 5.0" -#elif __CUDA_ARCH__ < 700 -#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50 -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD(50, __VA_ARGS__) -#elif __CUDA_ARCH__ < 750 -#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70 -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD(70, __VA_ARGS__) -#elif __CUDA_ARCH__ < 800 -#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75 -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD(75, __VA_ARGS__) -#elif __CUDA_ARCH__ >= 800 -#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80 -#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_BACKWARD(80, __VA_ARGS__) -#endif +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); diff --git a/xformers/csrc/attention/cuda/fmha/kernel_forward.h b/xformers/csrc/attention/cuda/fmha/kernel_forward.h index 11f90fdaff..c7f82ff266 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_forward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_forward.h @@ -16,7 +16,6 @@ #include "cutlass/numeric_types.h" #include "cutlass/tensor_ref.h" -#include "attention_scaling_coefs_updater.h" #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" @@ -32,11 +31,11 @@ #include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "debug_utils.h" -#include "epilogue_pipelined.h" -#include "epilogue_rescale_output.h" -#include "find_default_mma.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" #include "gemm_kernel_utils.h" -#include "mma_from_smem.h" #include "transform/tile_smem_loader.h" #include @@ -52,6 +51,12 @@ constexpr int getWarpsPerSm() { ? 16 : 12); } +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} } // namespace template < @@ -297,10 +302,10 @@ struct AttentionKernel { using IteratorA = typename DefaultMma::IteratorA; using IteratorB = typename DefaultMma::IteratorB; using Mma = typename DefaultMma::ThreadblockMma; - using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< typename Mma::Operator::IteratorC, accum_t, - kWarpSize>::Updater; + kWarpSize>::Iterator; static_assert( MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * MmaCore::WarpCount::kK == @@ -670,9 +675,9 @@ struct AttentionKernel { MM0::BiasLoader::load(bias_iter, smem_tile_iter); // Pij += Bij, Pij is in register fragment and Bij is in shared memory - auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( lane_id(), warp_id(), iteratorC_tile_offset); - MM0::ScalingCoefsUpdater::iterateRows( + MM0::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, [&](int accum_m, int accum_n, int idx) { @@ -686,10 +691,10 @@ struct AttentionKernel { // Mask out last if causal if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { auto query_start = blockIdx.x * kQueriesPerBlock; - auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( lane_id(), warp_id(), iteratorC_tile_offset); int32_t last_col; - MM0::ScalingCoefsUpdater::iterateRows( + MM0::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { last_col = query_start + accum_m - iter_key_start; @@ -709,8 +714,8 @@ struct AttentionKernel { ([&] { // Update `mi` from accum stored in registers // Also does accum[i] <- exp(accum[i] - mi) - MM0::ScalingCoefsUpdater::update< - kQueriesPerBlock, + iterative_softmax< + typename MM0::Mma::Operator::IteratorC, kFullColumns, kIsFirst, kKeepOutputInRF>( @@ -974,6 +979,117 @@ struct AttentionKernel { } } + template < + typename WarpIteratorC, + bool kFullColumns, + bool kIsFirst, + bool kKeepOutputInRF> + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { + m_prime[thread_id] = mi[thread_id]; + } + __syncthreads(); + } + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) + ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } + static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; } @@ -997,89 +1113,3 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) template __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) attention_kernel_batched(typename AK::Params params); - -#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \ - template <> \ - __global__ void __launch_bounds__( \ - __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ - attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \ - using Kernel = __VA_ARGS__; -#define _ATTENTION_KERNEL_FORWARD_END() } - -#ifdef __CUDA_ARCH__ -#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ -#else -#define __CUDA_ARCH_OR_ZERO__ 0 -#endif - -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \ - ARCH, \ - SCALAR_T, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER) \ - _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ - SCALAR_T, \ - cutlass::arch::Sm##ARCH, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER>) \ - if (!p.advance_to_block()) { \ - return; \ - } \ - Kernel::attention_kernel(p); \ - _ATTENTION_KERNEL_FORWARD_END(); - -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \ - ARCH, \ - SCALAR_T, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER) \ - _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ - SCALAR_T, \ - cutlass::arch::Sm##ARCH, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER>) \ - printf( \ - "FATAL: this function is for sm%d, but was built for sm%d\n", \ - int(ARCH), \ - int(__CUDA_ARCH_OR_ZERO__)); \ - _ATTENTION_KERNEL_FORWARD_END(); - -// All kernels are disabled by default -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__) - -// Enable the right one based on __CUDA_ARCH__ -#ifndef __CUDA_ARCH__ -#elif __CUDA_ARCH__ < 500 -#error "Need cuda arch at least 5.0" -#elif __CUDA_ARCH__ < 700 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__) -#elif __CUDA_ARCH__ < 750 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__) -#elif __CUDA_ARCH__ < 800 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__) -#elif __CUDA_ARCH__ >= 800 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__) -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward.h b/xformers/csrc/attention/cuda/fmha/kernels/backward.h new file mode 100644 index 0000000000..be59536f57 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward.h @@ -0,0 +1,68 @@ +#pragma once + +// All kernels are disabled by default +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(50, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(70, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(75, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(80, __VA_ARGS__) + +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +#include "../kernel_backward.h" + +#define _ATTENTION_KERNEL_BACKWARD_BEGIN(...) \ + template <> \ + __global__ void __launch_bounds__( \ + __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ + attention_kernel_backward_batched<__VA_ARGS__>( \ + typename __VA_ARGS__::Params p) { \ + using Kernel = __VA_ARGS__; +#define _ATTENTION_KERNEL_BACKWARD_END() } + +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD(ARCH, ...) \ + _ATTENTION_KERNEL_BACKWARD_BEGIN( \ + AttentionBackwardKernel) \ + p.advance_to_block(); \ + Kernel::kernel(p); \ + _ATTENTION_KERNEL_BACKWARD_END(); + +#ifdef __CUDA_ARCH__ +#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ +#else +#define __CUDA_ARCH_OR_ZERO__ 0 +#endif + +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_DISABLED(ARCH, ...) \ + _ATTENTION_KERNEL_BACKWARD_BEGIN( \ + AttentionBackwardKernel) \ + printf( \ + "FATAL: this function is for sm%d, but was built with __CUDA_ARCH__=%d\n", \ + int(ARCH), \ + int(__CUDA_ARCH_OR_ZERO__)); \ + _ATTENTION_KERNEL_BACKWARD_END(); + +// Enable the right one based on __CUDA_ARCH__ +#ifndef __CUDA_ARCH__ +#elif __CUDA_ARCH__ < 500 +#error "Need cuda arch at least 5.0" +#elif __CUDA_ARCH__ < 700 +#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50 +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD(50, __VA_ARGS__) +#elif __CUDA_ARCH__ < 750 +#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70 +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD(70, __VA_ARGS__) +#elif __CUDA_ARCH__ < 800 +#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75 +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD(75, __VA_ARGS__) +#elif __CUDA_ARCH__ >= 800 +#undef INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80 +#define INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(...) \ + INSTANTIATE_ATTENTION_KERNEL_BACKWARD(80, __VA_ARGS__) +#endif +#endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu index c3d95b04b7..b3e2788ed9 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, false, false); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu index e6a55db9d6..c8e2146181 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, false); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout.cu index e1877331b3..d894c8ddcf 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k128.cu index 01f48d47a0..5d88becfd7 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k128.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( cutlass::bfloat16_t, true, @@ -21,4 +20,3 @@ INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( true, true, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k64.cu index 4483098da6..dc2eada660 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_dropout_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, true, true, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu index b1310420c1..753b410b23 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k128.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( cutlass::bfloat16_t, true, @@ -21,4 +20,3 @@ INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( true, false, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu index 444325a965..30dbf897d2 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_aligned_k64.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( cutlass::bfloat16_t, true, @@ -21,4 +20,3 @@ INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( true, false, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout.cu index 60cd1fd603..6846be1d3c 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::bfloat16_t, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::bfloat16_t, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::bfloat16_t, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::bfloat16_t, false, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k128.cu index 3c29ccd19e..51d8f507e9 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k128.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( cutlass::bfloat16_t, false, @@ -21,4 +20,3 @@ INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( false, true, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k64.cu index 3e0ee619bc..d8d6db9ccb 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_dropout_k64.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( cutlass::bfloat16_t, false, @@ -21,4 +20,3 @@ INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( false, true, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu index dc8a811b75..b7de6e20f8 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k128.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( cutlass::bfloat16_t, false, @@ -21,4 +20,3 @@ INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( false, false, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu index 6181148ebf..2e8a1e3586 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_bf16_k64.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50( cutlass::bfloat16_t, false, @@ -21,4 +20,3 @@ INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80( false, false, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu index 08e72e38a6..436d6e70ef 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, false); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu index 9764de537d..ec07a70b90 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, false); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout.cu index f68340c2b5..46a018ce7d 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k128.cu index 7992f7f21f..796b592cff 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, true, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k64.cu index 4d0867f23c..04348458d0 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_dropout_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, true, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu index 19d6115030..e7b6c0fedc 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, false, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu index b0e177d03f..dca706d51b 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_aligned_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, true, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, true, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, true, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, true, false, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout.cu index 1bef0d76ab..8212bc91e2 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k128.cu index 66b91286d0..55bf42cb96 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, true, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k64.cu index 661c4b97a1..6462ce1419 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_dropout_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, true, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu index f4485ff6c8..4527a8e8fe 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, false, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu index d6ff34ea13..8f94d9f935 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f16_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(cutlass::half_t, false, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(cutlass::half_t, false, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(cutlass::half_t, false, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(cutlass::half_t, false, false, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu index 4477366a6d..bf25e3127a 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, false); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu index b3bf2768f4..22c3412b62 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, false); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, false); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout.cu index 4dbdc6db48..3fbe4860cc 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k128.cu index 1ffc2806ea..0096a0a429 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, true, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k64.cu index 8ba17928d5..9c5f8e051f 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_dropout_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, true, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu index 6d2d903a85..8364faf2c2 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, false, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu index 7906cae500..8b06e82557 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_aligned_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, true, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, true, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, true, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, true, false, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout.cu index ce50ee776a..e41474c765 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, true); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k128.cu index e7b4da246c..f98faec8ea 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, true, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, true, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k64.cu index de6c8b0999..5246f6240a 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_dropout_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, true, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, true, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu index 8428f9ddbd..c743b1c1d0 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k128.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, false, 128); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, false, 128); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu index 476685bf41..f9e23eb47e 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/backward_f32_k64.cu @@ -1,8 +1,6 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM50(float, false, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM70(float, false, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM75(float, false, false, 64); INSTANTIATE_ATTENTION_KERNEL_BACKWARD_SM80(float, false, false, 64); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/forward.h b/xformers/csrc/attention/cuda/fmha/kernels/forward.h new file mode 100644 index 0000000000..be33d82867 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/forward.h @@ -0,0 +1,92 @@ +#pragma once + +// All kernels are disabled by default +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__) + +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD +#include "../kernel_forward.h" + +#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \ + template <> \ + __global__ void __launch_bounds__( \ + __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ + attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \ + using Kernel = __VA_ARGS__; +#define _ATTENTION_KERNEL_FORWARD_END() } + +#ifdef __CUDA_ARCH__ +#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ +#else +#define __CUDA_ARCH_OR_ZERO__ 0 +#endif + +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \ + ARCH, \ + SCALAR_T, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER) \ + _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ + SCALAR_T, \ + cutlass::arch::Sm##ARCH, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER>) \ + if (!p.advance_to_block()) { \ + return; \ + } \ + Kernel::attention_kernel(p); \ + _ATTENTION_KERNEL_FORWARD_END(); + +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \ + ARCH, \ + SCALAR_T, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER) \ + _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ + SCALAR_T, \ + cutlass::arch::Sm##ARCH, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER>) \ + printf( \ + "FATAL: this function is for sm%d, but was built for sm%d\n", \ + int(ARCH), \ + int(__CUDA_ARCH_OR_ZERO__)); \ + _ATTENTION_KERNEL_FORWARD_END(); + +// Enable the right one based on __CUDA_ARCH__ +#ifndef __CUDA_ARCH__ +#elif __CUDA_ARCH__ < 500 +#error "Need cuda arch at least 5.0" +#elif __CUDA_ARCH__ < 700 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__) +#elif __CUDA_ARCH__ < 750 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__) +#elif __CUDA_ARCH__ < 800 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__) +#elif __CUDA_ARCH__ >= 800 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__) +#endif + +#endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD diff --git a/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16.cu b/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16.cu index cef063c02b..b662137fc0 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD -#include "../kernel_forward.h" +#include "forward.h" INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50( cutlass::bfloat16_t, false, @@ -73,4 +72,3 @@ INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80( 64, 64, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16_aligned.cu index 4f5a9551af..0d0d24d3ec 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/forward_bf16_aligned.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD -#include "../kernel_forward.h" +#include "forward.h" INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50( cutlass::bfloat16_t, true, @@ -73,4 +72,3 @@ INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80( 64, 64, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/forward_f16.cu b/xformers/csrc/attention/cuda/fmha/kernels/forward_f16.cu index 0d1481d8a0..6059141d8a 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/forward_f16.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/forward_f16.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD -#include "../kernel_forward.h" +#include "forward.h" INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50( cutlass::half_t, false, @@ -53,4 +52,3 @@ INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80( 128, false); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(cutlass::half_t, false, 64, 64, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/forward_f16_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/forward_f16_aligned.cu index 1b7145cfac..1f2f795a2d 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/forward_f16_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/forward_f16_aligned.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD -#include "../kernel_forward.h" +#include "forward.h" INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(cutlass::half_t, true, 32, 128, true); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50( cutlass::half_t, @@ -33,4 +32,3 @@ INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80( 128, false); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(cutlass::half_t, true, 64, 64, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/forward_f32.cu b/xformers/csrc/attention/cuda/fmha/kernels/forward_f32.cu index 21fc692b97..5d11448d9d 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/forward_f32.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/forward_f32.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD -#include "../kernel_forward.h" +#include "forward.h" INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(float, false, 32, 128, true); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(float, false, 32, 128, false); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(float, false, 64, 64, true); @@ -13,4 +12,3 @@ INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(float, false, 64, 64, true); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(float, false, 32, 128, true); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(float, false, 32, 128, false); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(float, false, 64, 64, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/forward_f32_aligned.cu b/xformers/csrc/attention/cuda/fmha/kernels/forward_f32_aligned.cu index fad1be03f1..c97f125606 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/forward_f32_aligned.cu +++ b/xformers/csrc/attention/cuda/fmha/kernels/forward_f32_aligned.cu @@ -1,6 +1,5 @@ // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD -#include "../kernel_forward.h" +#include "forward.h" INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(float, true, 32, 128, true); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(float, true, 32, 128, false); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(float, true, 64, 64, true); @@ -13,4 +12,3 @@ INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(float, true, 64, 64, true); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(float, true, 32, 128, true); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(float, true, 32, 128, false); INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(float, true, 64, 64, true); -#endif diff --git a/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh b/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh index 08625b7654..3bffffc7de 100755 --- a/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh +++ b/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.sh @@ -1,6 +1,6 @@ #!/bin/bash set -ex -rm -f *.cu +rm -f forward_*.cu backward_*.cu IFS="," # BACKWARD @@ -23,15 +23,11 @@ for enable_dropout in "false" "true"; do echo $FNAME cat < $FNAME // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD -#include "../kernel_backward.h" +#include "backward.h" EOF for sm in 50 70 75 80; do echo "INSTANTIATE_ATTENTION_KERNEL_${kernel}_SM${sm}($dtype, $aligned, $enable_dropout$maxk_code);" >> $FNAME done; - cat <> $FNAME -#endif -EOF done; done; done; @@ -52,16 +48,12 @@ for aligned in "false" "true"; do echo $FNAME cat < $FNAME // This file is auto-generated. See "generate_kernels.sh" -#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD -#include "../kernel_forward.h" +#include "forward.h" EOF for sm in 50 70 75 80; do echo "INSTANTIATE_ATTENTION_KERNEL_${kernel}_SM${sm}($dtype, $aligned, 32, 128, true);" >> $FNAME echo "INSTANTIATE_ATTENTION_KERNEL_${kernel}_SM${sm}($dtype, $aligned, 32, 128, false);" >> $FNAME echo "INSTANTIATE_ATTENTION_KERNEL_${kernel}_SM${sm}($dtype, $aligned, 64, 64, true);" >> $FNAME done; - cat <> $FNAME -#endif -EOF done; done diff --git a/xformers/csrc/attention/cuda/fmha/attention.cu b/xformers/csrc/attention/cuda/fmha/small_k.cu similarity index 99% rename from xformers/csrc/attention/cuda/fmha/attention.cu rename to xformers/csrc/attention/cuda/fmha/small_k.cu index 558a9ab0f2..b9156c7f08 100644 --- a/xformers/csrc/attention/cuda/fmha/attention.cu +++ b/xformers/csrc/attention/cuda/fmha/small_k.cu @@ -1448,10 +1448,10 @@ at::Tensor _dropout_mask(at::Tensor output, double p) { TORCH_LIBRARY_IMPL(xformers, CUDA, m) { m.impl( - TORCH_SELECTIVE_NAME("xformers::efficient_attention"), + TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_small_k"), TORCH_FN(attention)); m.impl( - TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward"), + TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_small_k"), TORCH_FN(attention_backward)); m.impl( TORCH_SELECTIVE_NAME("xformers::_temp_dropout"), TORCH_FN(_dropout_mask)); diff --git a/xformers/ops/fmha/small_k.py b/xformers/ops/fmha/small_k.py index f655a3729f..218ae85de8 100644 --- a/xformers/ops/fmha/small_k.py +++ b/xformers/ops/fmha/small_k.py @@ -52,7 +52,7 @@ class FwOp(AttentionFwOpBase): This operator is deprecated and should not be used in new code """ - OPERATOR = get_xformers_operator("efficient_attention") + OPERATOR = get_xformers_operator("efficient_attention_forward_small_k") SUPPORTED_DEVICES = {"cuda", "cpu"} SUPPORTED_DTYPES = {torch.float} SUPPORTED_MAX_K: float = 32 @@ -113,7 +113,7 @@ def apply( @register_operator class BwOp(AttentionBwOpBase): - OPERATOR = get_xformers_operator("efficient_attention_backward") + OPERATOR = get_xformers_operator("efficient_attention_backward_small_k") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K