diff --git a/CMakeLists.txt b/CMakeLists.txt index 801429096eaa..ede9192cd1db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.21) project(vllm_extensions LANGUAGES CXX) -option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") +# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) +set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") diff --git a/csrc/ops.h b/csrc/ops.h index 6f0a7143c916..ae04150eaf75 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -96,7 +96,8 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales); + torch::Tensor const& b_scales, + c10::optional const& bias); #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 38a20a1727d1..6ce25c5ac897 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -77,24 +77,12 @@ struct enable_sm89_to_sm90 : Kernel { }; /* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ + * This class provides the common ScaleA and ScaleB descriptors for the + * ScaledEpilogue and ScaledEpilogueBias classes. + */ template -struct ScaledEpilogue { - private: +struct ScaledEpilogueBase { + protected: using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< @@ -102,6 +90,32 @@ struct ScaledEpilogue { using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::ScaleA; + using ScaleB = typename SUPER::ScaleB; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< cutlass::multiplies, float, float, @@ -134,6 +148,53 @@ struct ScaledEpilogue { } }; +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::ScaleA; + using ScaleB = typename SUPER::ScaleB; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, ElementD, Stride, Int<1>, Int<0>>>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + using ScaleAArgs = typename ScaleA::Arguments; + using ScaleBArgs = typename ScaleB::Arguments; + using BiasArgs = typename Bias::Arguments; + + ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; + ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; + BiasArgs bias_args{static_cast(bias.data_ptr()), {}}; + + typename EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args, + bias_args}; + return evt_compute_args; + } +}; + template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, @@ -168,13 +229,13 @@ struct cutlass_2x_gemm { // clang-format off using RowMajor = typename cutlass::layout::RowMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor; - using KernelType = + using KernelType = ArchGuard typename Epilogue, + typename... EpilogueArgs> +void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_args) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; @@ -420,78 +480,130 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_caller>( - out, a, b, a_scales, b_scales); + Epilogue, TileShape, WarpShape, InstructionShape, 2>>( + out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_gemm_caller>( - out, a, b, a_scales, b_scales); + Epilogue, TileShape, WarpShape, InstructionShape, 2>>( + out, a, b, std::forward(epilogue_args)...); } } -void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, +void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); + torch::Tensor const& b_scales, + c10::optional const& bias) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm75_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm75_epilogue(out, a, b, a_scales, + b_scales); + } +} + +template