Skip to content

Commit

Permalink
Parametrize cutlassB kernel on block size + preloadMma.
Browse files Browse the repository at this point in the history
ghstack-source-id: fb283a9390624c5e49972e92ffbf4d788ff1c3c1
Pull Request resolved: https://github.com/fairinternal/xformers/pull/452

__original_commit__ = fairinternal/xformers@723c0c6e4562b0088faff77065d5230c516c5f46
  • Loading branch information
danthe3rd authored and xFormers Bot committed Feb 2, 2023
1 parent 7ba4c98 commit 8b82140
Show file tree
Hide file tree
Showing 55 changed files with 1,257 additions and 1,442 deletions.
15 changes: 11 additions & 4 deletions xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ mem_efficient_attention_backward_cutlass(
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
at::PhiloxCudaState rng_engine_inputs(rng_seed, rng_offset);

cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;

bool kernel_launched = false;
const auto maxK = std::max(query.size(3), value.size(3));
const auto maxShmem =
getMaximumSharedMemoryPerBlockKb(computeCapability) * 1024;

auto launchKernel = [&](auto _k, auto kernel_fn) {
using Kernel = decltype(_k);
Expand All @@ -133,6 +138,7 @@ mem_efficient_attention_backward_cutlass(
if (Kernel::kMaxK < maxK) {
return;
}
// Dropout must be supported if we need it
if (use_dropout && !Kernel::kApplyDropout) {
return;
}
Expand All @@ -142,9 +148,13 @@ mem_efficient_attention_backward_cutlass(
(value.stride(2) % Kernel::kMinimumAlignment)) {
return;
}
// Uses too much shmem
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
if (smem_bytes > maxShmem) {
return;
}

kernel_launched = true;
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);

// TODO: Fuse this into a kernel?
// This is a bottleneck for smaller sequences (M <= 128)
Expand Down Expand Up @@ -289,9 +299,6 @@ mem_efficient_attention_backward_cutlass(
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
};

cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;

DISPATCH_TYPES(query, ([&]() {
dispatch_cutlassB<scalar_t>(launchKernel, computeCapability);
}));
Expand Down
15 changes: 11 additions & 4 deletions xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,13 @@ efficient_attention_forward_cutlass(
rng_engine_inputs = gen->philox_cuda_state(B * num_heads * M * N);
}

cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;

bool kernel_launched = false;
const auto maxShmem =
getMaximumSharedMemoryPerBlockKb(computeCapability) * 1024;

auto launchKernel = [&](auto _k, auto kernel_fn) {
using Kernel = decltype(_k);
using scalar_t = typename Kernel::scalar_t;
Expand All @@ -191,6 +197,11 @@ efficient_attention_forward_cutlass(
(value.stride(2) % Kernel::kAlignmentV)) {
return;
}
// Uses too much shmem
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
if (smem_bytes > maxShmem) {
return;
}
kernel_launched = true;

res = at::empty(
Expand Down Expand Up @@ -274,7 +285,6 @@ efficient_attention_forward_cutlass(
p.dropout_prob = dropout_p;
}

size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
if (smem_bytes > 0xc000) {
auto err = cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
Expand All @@ -290,9 +300,6 @@ efficient_attention_forward_cutlass(
};

// Dispatch to the right kernel
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;

DISPATCH_TYPES(query, ([&]() {
dispatch_cutlassF<scalar_t>(launchKernel, computeCapability);
}));
Expand Down
31 changes: 31 additions & 0 deletions xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,37 @@ constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}

inline int32_t getMaximumSharedMemoryPerBlockKb(int cc) {
// from:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
switch (cc) {
case 50:
case 52:
case 53:
case 60:
case 61:
case 62:
return 64;
case 70:
case 72:
return 96;
case 75:
return 64;
case 80:
return 163;
case 86:
return 99;
case 87:
return 163;
case 89:
return 99;
case 90:
return 227;
default:
return 0;
}
}

////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
Expand Down
23 changes: 12 additions & 11 deletions xformers/csrc/attention/cuda/fmha/kernel_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ template <
bool kIsAligned_,
// use dropout if enabled
bool kApplyDropout_,
// when doing a GEMM, preload the next one (uses more shmem)
bool kPreloadMmas_,
// block dimensions
int kBlockSizeI_,
int kBlockSizeJ_,
// upperbound on `max(value.shape[-1], query.shape[-1])`
int kMaxK_ = std::numeric_limits<int>::max()>
struct AttentionBackwardKernel {
Expand All @@ -172,6 +177,9 @@ struct AttentionBackwardKernel {
using ArchTag = ArchTag_;
static constexpr bool kIsAligned = kIsAligned_;
static constexpr bool kApplyDropout = kApplyDropout_;
static constexpr bool kPreloadMmas = kPreloadMmas_;
static constexpr int kBlockSizeI = kBlockSizeI_;
static constexpr int kBlockSizeJ = kBlockSizeJ_;
static constexpr int kMaxK = kMaxK_;

struct Params {
Expand Down Expand Up @@ -374,29 +382,22 @@ struct AttentionBackwardKernel {
}
};

// Block I
static constexpr bool kSupports64x128 =
ArchTag::kMinComputeCapability >= 80 ||
(ArchTag::kMinComputeCapability >= 70 &&
cutlass::sizeof_bits<scalar_t>::value <= 16);
static constexpr int64_t kWarpSize = 32;
static constexpr int64_t kBlockSizeI =
kSupports64x128 && kMaxK > 64 ? 128 : 64;

// If this is true, we store and accumulate dK/dV in RF
// rather than going back to gmem everytime
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16;
static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
static constexpr bool kPreloadMmas =
kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF;
static_assert(
!kPreloadMmas ||
(kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF),
"preload MMA not supported");
static constexpr bool kPrologueQK = kPreloadMmas;
static constexpr bool kPrologueGV = kPreloadMmas;
static constexpr bool kPrologueDOV = kPreloadMmas;
static constexpr bool kPrologueGQ = kPreloadMmas;
static constexpr bool kPrologueGK = kPreloadMmas;

// Block J
static constexpr int64_t kBlockSizeJ = kPreloadMmas && kMaxK > 64 ? 128 : 64;
static constexpr int64_t kNumWarpsPerBlock =
(kBlockSizeI * kBlockSizeJ) / (32 * 32);

Expand Down
Loading

0 comments on commit 8b82140

Please sign in to comment.