Skip to content

Commit

Permalink
fMHA: Kernels registry
Browse files Browse the repository at this point in the history
Kernels are ordered (see `sort_index`), and when dispatching,
we select the first kernel in the list that supports the inputs

This will allow to add more kernels in the future, without having
to support all the possible combinations necessarily.

**TEST PLAN**
Made sure performance didn't regress - we're still selecting the fastest kernel available

ghstack-source-id: fdbc72b8469a2791b98500d784a330e941989e01
Pull Request resolved: https://github.com/fairinternal/xformers/pull/448

__original_commit__ = fairinternal/xformers@7dae8135433e2f80d55321d7d92ce27c79a73a4e
  • Loading branch information
danthe3rd authored and xFormers Bot committed Feb 2, 2023
1 parent 5df1f0b commit 87dc3a7
Show file tree
Hide file tree
Showing 109 changed files with 5,594 additions and 963 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/linters_reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ jobs:
sudo apt-get install clang-format
clang-format --version
# apply to our files
./.circleci/run-clang-format.py -r xformers/csrc
# apply to our files - excluding autogenerated files
./.circleci/run-clang-format.py -e "*fmha/kernels" -r xformers/csrc
110 changes: 40 additions & 70 deletions xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,66 +12,9 @@

#include "gemm_kernel_utils.h"
#include "kernel_backward.h"
#include "kernels/cutlassB.h"
#include "pytorch_utils.h"

#define DISPATCH_MAXK(func) \
{ \
const auto maxK = std::max(query.size(3), value.size(3)); \
if (maxK <= 64) { \
constexpr int kMaxK = 64; \
func(); \
} else if (maxK <= 128) { \
constexpr int kMaxK = 128; \
func(); \
} else { \
constexpr int kMaxK = std::numeric_limits<int>::max(); \
func(); \
} \
}

#define DISPATCH_KERNEL(QUERY, KEY, VALUE, USE_DROPOUT, FUNC) \
{ \
cudaDeviceProp* properties = \
at::cuda::getDeviceProperties(QUERY.device().index()); \
const int computeCapability = properties->major * 10 + properties->minor; \
DISPATCH_MAXK(([&] { \
DISPATCH_TYPES( \
QUERY, ([&]() { \
DISPATCH_BOOL( \
USE_DROPOUT, kApplyDropout, ([&]() { \
DISPATCH_ARCHTAG( \
computeCapability, ([&]() { \
using AlignedAK = AttentionBackwardKernel< \
ArchTag, \
scalar_t, \
true, \
kApplyDropout, \
kMaxK>; \
bool isAligned = \
(QUERY.stride(2) % \
AlignedAK::kOptimalAlignement == \
0 && \
KEY.stride(2) % AlignedAK::kOptimalAlignement == \
0 && \
VALUE.stride(2) % \
AlignedAK::kOptimalAlignement == \
0); \
DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \
using Kernel = \
AttentionBackwardKernel< \
ArchTag, \
scalar_t, \
kIsAligned, \
kApplyDropout, \
kMaxK>; \
FUNC(); \
})) \
})) \
})) \
})) \
})); \
}

namespace {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mem_efficient_attention_backward_cutlass(
Expand Down Expand Up @@ -175,11 +118,32 @@ mem_efficient_attention_backward_cutlass(
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
at::PhiloxCudaState rng_engine_inputs(rng_seed, rng_offset);

auto launchKernel = [&](auto _k, int computeCapability) {
bool kernel_launched = false;
const auto maxK = std::max(query.size(3), value.size(3));

auto launchKernel = [&](auto _k, auto kernel_fn) {
using Kernel = decltype(_k);
using scalar_t = typename Kernel::scalar_t;
(void)_k;

if (kernel_launched) {
return;
}
// Check if this kernel is compatible
if (Kernel::kMaxK < maxK) {
return;
}
if (use_dropout && !Kernel::kApplyDropout) {
return;
}
// Alignment
if ((query.stride(2) % Kernel::kMinimumAlignment) ||
(key.stride(2) % Kernel::kMinimumAlignment) ||
(value.stride(2) % Kernel::kMinimumAlignment)) {
return;
}

kernel_launched = true;
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);

// TODO: Fuse this into a kernel?
Expand Down Expand Up @@ -290,14 +254,16 @@ mem_efficient_attention_backward_cutlass(
}
Kernel::check_supported(p);

constexpr auto kernel_fn = attention_kernel_backward_batched<Kernel>;

if (smem_bytes > 0xc000) {
TORCH_INTERNAL_ASSERT(
computeCapability >= 70,
"This kernel requires too much shared memory on this machine!");
AT_CUDA_CHECK(cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
auto err = cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
XFORMERS_CHECK(
err != cudaErrorInvalidValue,
"This GPU does not have enough shared-memory (kernel requires ",
smem_bytes / 1024,
" kb)");
AT_CUDA_CHECK(err);
}

// second syntax resulted in the error below on windows
Expand All @@ -323,13 +289,17 @@ mem_efficient_attention_backward_cutlass(
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
};

DISPATCH_KERNEL(query, key, value, use_dropout, ([&] {
launchKernel(Kernel{}, computeCapability);
}));
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;

DISPATCH_TYPES(query, ([&]() {
dispatch_cutlassB<scalar_t>(launchKernel, computeCapability);
}));
TORCH_CHECK(kernel_launched, "cutlassB: no kernel found to launch!");
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
#endif
} // namespace
}

} // namespace

Expand Down
111 changes: 40 additions & 71 deletions xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,69 +13,9 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh>

#include "kernel_forward.h"
#include "kernels/cutlassF.h"
#include "pytorch_utils.h"

#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \
{ \
if (VALUE_HEAD_DIM <= 64) { \
constexpr bool kIs64x64 = true; \
constexpr bool kSingleValueIteration = true; \
FN(); \
} else { \
constexpr bool kIs64x64 = false; \
if (VALUE_HEAD_DIM <= 128) { \
constexpr bool kSingleValueIteration = true; \
FN(); \
} else { \
constexpr bool kSingleValueIteration = false; \
FN(); \
} \
} \
}

#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \
{ \
cudaDeviceProp* properties = \
at::cuda::getDeviceProperties(QUERY.device().index()); \
const int computeCapability = properties->major * 10 + properties->minor; \
DISPATCH_BLOCKSIZE( \
VALUE.size(-1), ([&]() { \
static constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; \
static constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; \
DISPATCH_TYPES( \
QUERY, ([&]() { \
DISPATCH_ARCHTAG( \
computeCapability, ([&]() { \
using AlignedAK = AttentionKernel< \
scalar_t, \
ArchTag, \
true, \
kQueriesPerBlock, \
kKeysPerBlock, \
kSingleValueIteration>; \
/* Run a more efficient kernel (with `isAligned=True`) \
if memory is correctly aligned*/ \
bool isAligned = \
(QUERY.stride(2) % AlignedAK::kAlignmentQ == 0 && \
KEY.stride(2) % AlignedAK::kAlignmentK == 0 && \
VALUE.stride(2) % AlignedAK::kAlignmentV == 0); \
/* TODO: Should we warn or log somewhere when we use a \
less efficient kernel due to wrong alignment? */ \
DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \
using Kernel = AttentionKernel< \
scalar_t, \
ArchTag, \
kIsAligned, \
kQueriesPerBlock, \
kKeysPerBlock, \
kSingleValueIteration>; \
FUNC(); \
})) \
})) \
})); \
})); \
}

namespace {
template <typename scalar_t>
struct TypeTraits;
Expand Down Expand Up @@ -225,11 +165,34 @@ efficient_attention_forward_cutlass(
rng_engine_inputs = gen->philox_cuda_state(B * num_heads * M * N);
}

auto launchKernel = [&](auto _k, int computeCapability) {
bool kernel_launched = false;
auto launchKernel = [&](auto _k, auto kernel_fn) {
using Kernel = decltype(_k);
using scalar_t = typename Kernel::scalar_t;
(void)_k;

if (kernel_launched) {
return;
}
// Check if this kernel is compatible
if (!Kernel::kSupportsDropout && use_dropout) {
return;
}
if (!Kernel::kSupportsBias && bias.has_value()) {
return;
}
if (Kernel::kSingleValueIteration &&
Kernel::kKeysPerBlock < value.size(3)) {
return;
}
// Alignment
if ((query.stride(2) % Kernel::kAlignmentQ) ||
(key.stride(2) % Kernel::kAlignmentK) ||
(value.stride(2) % Kernel::kAlignmentV)) {
return;
}
kernel_launched = true;

res = at::empty(
{B, M, num_heads, Kv},
query.options().dtype(
Expand Down Expand Up @@ -311,23 +274,29 @@ efficient_attention_forward_cutlass(
p.dropout_prob = dropout_p;
}

constexpr auto kernel_fn = attention_kernel_batched<Kernel>;
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
if (smem_bytes > 0xc000) {
TORCH_INTERNAL_ASSERT(
computeCapability >= 70,
"This kernel requires too much shared memory on this machine!");
AT_CUDA_CHECK(cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
auto err = cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
XFORMERS_CHECK(
err != cudaErrorInvalidValue,
"This GPU does not have enough shared-memory (kernel requires ",
smem_bytes / 1024,
" kb)");
AT_CUDA_CHECK(err);
}
Kernel::check_supported(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
};

// Dispatch to the right kernel
DISPATCH_KERNEL(query, key, value, ([&]() {
launchKernel(Kernel{}, computeCapability);
}));
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;

DISPATCH_TYPES(query, ([&]() {
dispatch_cutlassF<scalar_t>(launchKernel, computeCapability);
}));
TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!");
AT_CUDA_CHECK(cudaGetLastError());

// uint64_t -> int64_t bitwise casting as PyTorch don't support uint64_t
Expand Down
2 changes: 2 additions & 0 deletions xformers/csrc/attention/cuda/fmha/gemm/find_default_mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
This is really only for the FastF32 case - aka using TensorCores with fp32.
*/

#pragma once

#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
Expand Down
16 changes: 11 additions & 5 deletions xformers/csrc/attention/cuda/fmha/kernel_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ template <
// run optimized kernel because memory accesses will be aligned
bool kIsAligned_,
// use dropout if enabled
bool kApplyDropout,
bool kApplyDropout_,
// upperbound on `max(value.shape[-1], query.shape[-1])`
int kMaxK = std::numeric_limits<int>::max()>
int kMaxK_ = std::numeric_limits<int>::max()>
struct AttentionBackwardKernel {
using scalar_t = scalar_t_;
using output_t = scalar_t;
Expand All @@ -171,6 +171,8 @@ struct AttentionBackwardKernel {
using accum_t = float;
using ArchTag = ArchTag_;
static constexpr bool kIsAligned = kIsAligned_;
static constexpr bool kApplyDropout = kApplyDropout_;
static constexpr int kMaxK = kMaxK_;

struct Params {
// Input tensors
Expand Down Expand Up @@ -263,7 +265,7 @@ struct AttentionBackwardKernel {
int64_t gV_strideH;
int64_t gB_strideH;

CUTLASS_DEVICE void advance_to_block() {
CUTLASS_DEVICE bool advance_to_block() {
int64_t batch_id = blockIdx.z;
int32_t head_id = blockIdx.y;

Expand Down Expand Up @@ -325,6 +327,8 @@ struct AttentionBackwardKernel {
} else {
workspace = nullptr;
}

return true;
}

__host__ dim3 getBlocksGrid() const {
Expand Down Expand Up @@ -1041,7 +1045,7 @@ struct AttentionBackwardKernel {
return true;
}

static CUTLASS_DEVICE void kernel(Params const& p) {
static CUTLASS_DEVICE void attention_kernel(Params const& p) {
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);

Expand Down Expand Up @@ -2084,7 +2088,9 @@ struct AttentionBackwardKernel {
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_backward_batched_impl(typename AK::Params p) {
p.advance_to_block();
if (!p.advance_to_block()) {
return;
}
AK::attention_kernel(p);
}

Expand Down
Loading

0 comments on commit 87dc3a7

Please sign in to comment.