diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index c69e87999ae71..26f7423fd7455 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" /* diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index fcc17c7727f94..c723adf126422 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" /* diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 75681f7f37820..f2fae4b66d651 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -28,9 +28,9 @@ using namespace cute; /* - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, + Epilogues defined in, + csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp + must contain a public type named EVTCompute of type Sm80EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 8190277997161..123f4359c0d1a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,385 +1,18 @@ -// clang-format will break include orders -// clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -#include + #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh" + #include "scaled_mm_c3x_sm90_int8_dispatch.cuh" -#include - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "core/math.hpp" -#include "cutlass_extensions/common.hpp" -// clang-format on - -using namespace cute; + #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. */ -namespace { - -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); - #endif - } -}; -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; - - struct GemmKernel : public KernelType {}; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; - - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_default { - // For M > 128 and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M128 { - // For M in (64, 128] and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M64 { - // For M in (32, 64] and any N - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NBig { - // For M in [1, 32] and N >= 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NSmall { - // For M in [1, 32] and N < 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - if (mp2 <= 64) { - // m in [1, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass3xGemmDefault = - typename sm90_int8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_int8_config_M128::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_int8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM32NBig = - typename sm90_int8_config_M32_NBig::Cutlass3xGemm; - using Cutlass3xGemmM32NSmall = - typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; - - uint32_t const n = out.size(1); - bool const is_small_n = n < 8192; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 - - if (mp2 <= 32) { - // m in [1, 32] - if (is_small_n) { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else if (mp2 <= 64) { - // m in (32, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - template