From 071c841f723c8f08a672d80becc8d83f1493ff08 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 21 Jun 2024 12:40:07 +0000 Subject: [PATCH 1/7] Add fallback gemms based on SharedMemory requirements --- csrc/quantization/cutlass_w8a8/common.hpp | 8 +++++ .../cutlass_w8a8/scaled_mm_c2x.cu | 35 ++++++++++++++++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp index 23d0587bbdc5d..bd01104fe6689 100644 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -17,3 +17,11 @@ inline uint32_t next_pow_2(uint32_t const num) { return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + 0); + return max_shared_mem_per_block_opt_in; +} + diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 740b9fb64a754..e59319834a83c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -183,6 +183,9 @@ struct cutlass_2x_gemm { >::GemmKernel>; // clang-format on + static constexpr size_t kRequiredSharedMemSize = + sizeof(typename KernelType::SharedStorage); + using Op = cutlass::gemm::device::GemmUniversalAdapter; }; @@ -250,6 +253,26 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(status); } +template +void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + // In some cases, the GPU isn't able to accomodate the + // shared memory requirements of the Gemm. In such cases, use + // the FallbackGemm instead. + static const int max_shared_mem_per_block_opt_in = + get_cuda_max_shared_memory_per_block_opt_in(0); + if (Gemm::kRequiredSharedMemSize <= max_shared_mem_per_block_opt_in) { + return cutlass_gemm_caller(out, a, b, + std::forward(args)...); + } else { + TORCH_CHECK(FallbackGemm::kRequiredSharedMemSize <= + max_shared_mem_per_block_opt_in); + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + template typename Epilogue> struct sm80_config_default { @@ -336,25 +359,27 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 32) { // M in (16, 32] - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 64) { // M in (32, 64] - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // M in (64, 128] uint32_t const n = out.size(1); bool const small_n = n < 8192; if (small_n) { - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } } else { From 028eadfb53515b97bbe7009c3912b968c6ad957b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 21 Jun 2024 12:53:25 +0000 Subject: [PATCH 2/7] avoid compiler warnings --- csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index e59319834a83c..8871a14285925 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -183,9 +183,6 @@ struct cutlass_2x_gemm { >::GemmKernel>; // clang-format on - static constexpr size_t kRequiredSharedMemSize = - sizeof(typename KernelType::SharedStorage); - using Op = cutlass::gemm::device::GemmUniversalAdapter; }; @@ -262,12 +259,17 @@ void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, // the FallbackGemm instead. static const int max_shared_mem_per_block_opt_in = get_cuda_max_shared_memory_per_block_opt_in(0); - if (Gemm::kRequiredSharedMemSize <= max_shared_mem_per_block_opt_in) { + + size_t const gemm_shared_mem_size = + sizeof(typename Gemm::KernelType::SharedStorage); + size_t const fallback_gemm_shared_mem_size = + sizeof(typename FallbackGemm::KernelType::SharedStorage); + + if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { return cutlass_gemm_caller(out, a, b, std::forward(args)...); } else { - TORCH_CHECK(FallbackGemm::kRequiredSharedMemSize <= - max_shared_mem_per_block_opt_in); + TORCH_CHECK(fallback_gemm_shared_mem_size <= max_shared_mem_per_block_opt_in); return cutlass_gemm_caller( out, a, b, std::forward(args)...); } From aa4e5464cea4da10fd4d1e35a2def19cd673070c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 21 Jun 2024 12:57:15 +0000 Subject: [PATCH 3/7] format --- csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 8871a14285925..40965653ba1bf 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -262,14 +262,15 @@ void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, size_t const gemm_shared_mem_size = sizeof(typename Gemm::KernelType::SharedStorage); - size_t const fallback_gemm_shared_mem_size = + size_t const fallback_gemm_shared_mem_size = sizeof(typename FallbackGemm::KernelType::SharedStorage); if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { return cutlass_gemm_caller(out, a, b, std::forward(args)...); } else { - TORCH_CHECK(fallback_gemm_shared_mem_size <= max_shared_mem_per_block_opt_in); + TORCH_CHECK(fallback_gemm_shared_mem_size <= + max_shared_mem_per_block_opt_in); return cutlass_gemm_caller( out, a, b, std::forward(args)...); } From a1eaf19c20a06b927685382c0c45afdfd7ef073a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 21 Jun 2024 13:11:46 +0000 Subject: [PATCH 4/7] Add an explicit FallbackGemm --- .../cutlass_w8a8/scaled_mm_c2x.cu | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 40965653ba1bf..77f90fa88a611 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -282,6 +282,7 @@ struct sm80_config_default { // This config is used in 2 cases, // - M in (128, inf) // - M in (64, 128] and N >= 8192 + // Shared Memory required by this Gemm - 81920 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; @@ -297,6 +298,7 @@ struct sm80_config_M64 { // This config is used in 2 cases, // - M in (32, 64] // - M in (64, 128] and N < 8192 + // Shared Memory required by this Gemm - 122880 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; @@ -310,6 +312,7 @@ template typename Epilogue> struct sm80_config_M32 { // M in (16, 32] + // Shared Memory required by this Gemm - 61440 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; @@ -323,6 +326,7 @@ template typename Epilogue> struct sm80_config_M16 { // M in [1, 16] + // Shared Memory required by this Gemm - 51200 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; @@ -357,20 +361,29 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, using Cutlass2xGemmM16 = typename sm80_config_M16::Cutlass2xGemm; + // Due to shared memory requirements, some Gemms may fail to run on some + // GPUs. As the name indicates, the Fallback Gemm is used as an alternative + // in such cases. + // sm80_config_M16 has the least shared-memory requirement. However, + // based on some profiling, we select sm80_config_M32 as a better alternative + // performance wise. + using FallbackGemm = + typename sm80_config_M32::Cutlass2xGemm; + uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] - return fallback_cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 32) { // M in (16, 32] - return fallback_cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 64) { // M in (32, 64] - return fallback_cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // M in (64, 128] @@ -378,16 +391,16 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, bool const small_n = n < 8192; if (small_n) { return fallback_cutlass_gemm_caller( + FallbackGemm>( out, a, b, std::forward(args)...); } else { return fallback_cutlass_gemm_caller( + FallbackGemm>( out, a, b, std::forward(args)...); } } else { // M in (128, inf) - return cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } } From 833d0186ce4bafb565e7428d5bfdf6fbf26cc23e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 21 Jun 2024 13:12:05 +0000 Subject: [PATCH 5/7] format --- csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 77f90fa88a611..0c6d237bcd752 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -367,7 +367,7 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, // sm80_config_M16 has the least shared-memory requirement. However, // based on some profiling, we select sm80_config_M32 as a better alternative // performance wise. - using FallbackGemm = + using FallbackGemm = typename sm80_config_M32::Cutlass2xGemm; uint32_t const m = a.size(0); @@ -394,8 +394,7 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, FallbackGemm>( out, a, b, std::forward(args)...); } else { - return fallback_cutlass_gemm_caller( + return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } } else { From 8a20c16b0d554182610a73a86bee7adb4ddbac97 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 21 Jun 2024 13:33:48 +0000 Subject: [PATCH 6/7] codespell --- csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 0c6d237bcd752..38a20a1727d18 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -254,7 +254,7 @@ template void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { - // In some cases, the GPU isn't able to accomodate the + // In some cases, the GPU isn't able to accommodate the // shared memory requirements of the Gemm. In such cases, use // the FallbackGemm instead. static const int max_shared_mem_per_block_opt_in = From 796701b0f28b171d8bb3e9dd32ce602e1a5f7497 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 21 Jun 2024 13:46:27 +0000 Subject: [PATCH 7/7] fix : use the passed in device-id --- csrc/quantization/cutlass_w8a8/common.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp index bd01104fe6689..bf04bb400790f 100644 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -21,7 +21,7 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, cudaDevAttrMaxSharedMemoryPerBlockOptin, - 0); + device); return max_shared_mem_per_block_opt_in; }