From 0abac6f6f01ff4c1fc275c13702e3269b5a9dfe3 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 30 Aug 2024 09:07:15 -0400 Subject: [PATCH 01/49] Enable 8-bit weights in Fused Marlin MoE --- csrc/moe/marlin_moe_ops.cu | 301 ++++++++++++------ csrc/moe/marlin_moe_ops.h | 9 +- csrc/moe/torch_bindings.cpp | 11 +- tests/kernels/test_moe.py | 225 ++++++++++++- vllm/_custom_ops.py | 2 +- .../layers/fused_moe/__init__.py | 16 +- .../layers/fused_moe/fused_moe.py | 133 ++------ .../layers/fused_moe/fused_moe_marlin.py | 245 ++++++++++++++ .../compressed_tensors_moe.py | 33 +- .../layers/quantization/utils/marlin_utils.py | 17 + .../quantization/utils/marlin_utils_test.py | 11 +- .../layers/quantization/utils/quant_utils.py | 19 +- 12 files changed, 775 insertions(+), 247 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_marlin.py diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 1e170e80d2f7..e3c18ce5a50b 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,8 @@ #include +#include "core/scalar_type.hpp" + template inline std::string str(T x) { return std::to_string(x); @@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) { return frag_b; } +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; @@ -840,10 +893,19 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - FragB frag_b0 = dequant(b_quant); + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -855,8 +917,6 @@ __device__ inline void MarlinMoESingle( } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -881,13 +941,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1035,8 +1095,10 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1169,25 +1231,67 @@ __device__ inline void MarlinMoESingle( // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } } } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1227,7 +1331,8 @@ __device__ inline void MarlinMoESingle( } } -template ( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1342,7 +1447,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -1494,42 +1601,43 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1611,10 +1719,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } + int pack_factor = 32 / q_type.size_bits(); + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1645,10 +1756,14 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) + CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1670,9 +1785,15 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + + int pack_factor = 32 / b_q_type->size_bits(); + int max_par = 4; int dev = a.get_device(); @@ -1733,8 +1854,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; -} \ No newline at end of file +} diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 01ba8ff69850..adee8399a4d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,11 +2,14 @@ #include +#include "core/scalar_type.hpp" + torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); \ No newline at end of file + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d4d43e2c601b..d2352375de33 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -9,16 +9,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); -#ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); - + "g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); -#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f526c381b333..f7642bf02b05 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,6 +2,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +from typing import List + import pytest import torch from transformers import MixtralConfig @@ -9,7 +11,12 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.scalar_type import scalar_types def torch_moe(a, w1, w2, score, topk): @@ -29,6 +36,20 @@ def torch_moe(a, w1, w2, score, topk): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) + + @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -43,11 +64,11 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) @@ -99,3 +120,199 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_fused_marlin_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, + num_bits: int, +): + torch.manual_seed(7) + + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size in (k, n): + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + for i in range(w2.shape[0]): + w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) + + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref1_l.append(w_ref1) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) + sort_indices1 = stack_and_dev(sort_indices1_l) + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref2_l.append(w_ref2) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) + sort_indices2 = stack_and_dev(sort_indices2_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_moe_marlin( + a, + qweight1, + qweight2, + score, + g_idx1, + g_idx2, + sort_indices1, + sort_indices2, + topk, + renormalize=False, + w1_scale=scales1, + w2_scale=scales2, + num_bits=num_bits, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 + + +@pytest.mark.skip("This test is here for the sake of debugging, " + "don't run it in automated tests.") +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_marlin_moe_mmm( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, + num_bits: int, +): + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == k: + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + + w_ref_l = [] + qweights_l = [] + scales_l = [] + g_idx_l = [] + sort_indices_l = [] + + for i in range(w.shape[0]): + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) + w_ref_l.append(w_ref) + qweights_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweights_l).contiguous() + scales = stack_and_dev(scales_l) + g_idx = stack_and_dev(g_idx_l) + sort_indices = stack_and_dev(sort_indices_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + marlin_output = single_moe_marlin(a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False, + num_bits=num_bits) + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fe254732e730..51db8b34e291 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -314,7 +314,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index fd6f41b90042..65a9b78a118c 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,17 +1,23 @@ +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", + "fused_moe_marlin", + "single_moe_marlin", +] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ += [ - "fused_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d2b152320e11..613d67e64bff 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config(M: int, E: int, N: int, K: int, topk: int, - dtype: Optional[str], - is_marlin: bool) -> Dict[str, int]: +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } + # A heuristic: fused marlin works faster with this config for small M if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, @@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int, return config -def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, - Any]] = None, - is_marlin: bool = False): +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): if override_config: config = override_config else: @@ -391,6 +399,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) + ops.topk_softmax( topk_weights, topk_ids, @@ -437,108 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - - #TODO fp8 is not implemented yet - assert not use_fp8 - - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py new file mode 100644 index 000000000000..40f9f66f1706 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -0,0 +1,245 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Any, Dict, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.scalar_type import scalar_types + +from .fused_moe import (fused_topk, moe_align_block_size, + try_get_optimal_moe_config) + + +def single_moe_marlin( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + rand_perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + num_bits: int = 8, +) -> torch.Tensor: + """ + This function computes a Marlin MoE MMM using weights w + and top-k gating mechanism. It is meant for testing and debugging. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w (torch.Tensor): The first set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + product for w. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert num_bits in [4, 8] + # TODO support this + assert not use_fp8 + + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // (num_bits // 2) + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + g_idx, rand_perm, workspace, scalar_type, M, N, K, True, E, topk, + block_size_m, True, False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def fused_moe_marlin( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert num_bits in [4, 8] + # TODO support this + assert not use_fp8 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + rand_perm1, + workspace, + scalar_type, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + rand_perm2, + workspace, + scalar_type, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0e0ab9ce9169..ba4f719a3f97 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -266,18 +266,21 @@ def apply(self, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_marlin_moe) - - return fused_marlin_moe(x, - layer.w13_weight_packed, - layer.w2_weight_packed, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale) + from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin) + + return fused_moe_marlin( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, + ) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f2..699d5f184414 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 7d08ac6f8746..4a06c5d63d52 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -1,6 +1,6 @@ """Utility functions used for tests and benchmarks""" -from typing import List +from typing import List, Optional import numpy as np import torch @@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, - act_order: bool): +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order) + w, quant_type, group_size, act_order, test_perm) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 33f24ff5d54d..bdfda31de852 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,5 +1,5 @@ """This file is used for /tests and /benchmarks""" -from typing import List +from typing import List, Optional import numpy import torch @@ -53,7 +53,10 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): assert q_w.shape == w_ref.shape orig_device = q_w.device @@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K - rand_perm = torch.randperm(k_size) + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() @@ -164,8 +167,11 @@ def reshape_w(w): ) -def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, act_order: bool): +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" @@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) return w_ref, w_q, w_s, g_idx, rand_perm From fdf69c2f5e6c4a3f5604d7a088abefd57a0a5508 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 30 Aug 2024 09:36:33 -0400 Subject: [PATCH 02/49] fix rocm --- csrc/moe/torch_bindings.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d2352375de33..e4fce091d24a 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -9,6 +9,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); +#ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " @@ -19,5 +20,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); } +#endif REGISTER_EXTENSION(TORCH_EXTENSION_NAME) From 4da163b45096fb24ec62f30a26b7ecd4750bea67 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 30 Aug 2024 09:45:52 -0400 Subject: [PATCH 03/49] bad paste --- csrc/moe/torch_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index e4fce091d24a..cd65a8ee92b9 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -19,7 +19,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); -} #endif +} REGISTER_EXTENSION(TORCH_EXTENSION_NAME) From 21d2337a42e11fd16d9891b6bd959209b220aa16 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 30 Aug 2024 17:29:42 +0000 Subject: [PATCH 04/49] add test case; fix imports for tests --- tests/weight_loading/models.txt | 1 + vllm/model_executor/layers/fused_moe/__init__.py | 8 ++++---- vllm/model_executor/layers/fused_moe/fused_moe_marlin.py | 5 ++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index cbe30305c14f..7deb2880145c 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -15,6 +15,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 65a9b78a118c..06bd2706d7e4 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,3 @@ -from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON @@ -8,16 +6,18 @@ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", - "fused_moe_marlin", - "single_moe_marlin", ] if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) + from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) __all__ += [ + "fused_moe_marlin", + "single_moe_marlin", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 40f9f66f1706..40b409ebeb34 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -5,11 +5,10 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size, try_get_optimal_moe_config) from vllm.scalar_type import scalar_types -from .fused_moe import (fused_topk, moe_align_block_size, - try_get_optimal_moe_config) - def single_moe_marlin( hidden_states: torch.Tensor, From 638777a35922dfecbce7866547f5096539187603 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 30 Aug 2024 20:12:47 +0000 Subject: [PATCH 05/49] fix to adapt custom_routin_function --- .../layers/fused_moe/fused_moe_marlin.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 40b409ebeb34..8c49333f7c84 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -1,6 +1,6 @@ """Fused MoE utilities for GPTQ.""" import functools -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional import torch @@ -106,7 +106,8 @@ def fused_moe_marlin( rand_perm1: torch.Tensor, rand_perm2: torch.Tensor, topk: int, - renormalize: bool, + custom_routing_function: Optional[Callable] = None, + renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -161,8 +162,12 @@ def fused_moe_marlin( E = w1.shape[0] N = w2.shape[1] * 16 - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + if custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) get_config_func = functools.partial( try_get_optimal_moe_config, From bd4b84d92bfb33c3456a73b8dd951490a2ce11b0 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 2 Sep 2024 03:04:07 -0400 Subject: [PATCH 06/49] Use select_experts to compute top_k tensors in fused moe --- tests/kernels/test_moe.py | 7 ++++++- .../layers/fused_moe/fused_moe_marlin.py | 11 +++-------- .../compressed_tensors_moe.py | 18 ++++++++++++++---- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f7642bf02b05..2cfd76d1c780 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -218,6 +219,9 @@ def test_fused_marlin_moe( sort_indices2 = stack_and_dev(sort_indices2_l) score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, False) + triton_output = fused_moe( a, w_ref1.transpose(1, 2).contiguous(), @@ -235,7 +239,8 @@ def test_fused_marlin_moe( g_idx2, sort_indices1, sort_indices2, - topk, + topk_weights, + topk_ids, renormalize=False, w1_scale=scales1, w2_scale=scales2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 8c49333f7c84..45dead9740f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -105,7 +105,8 @@ def fused_moe_marlin( g_idx2: torch.Tensor, rand_perm1: torch.Tensor, rand_perm2: torch.Tensor, - topk: int, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, custom_routing_function: Optional[Callable] = None, renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, @@ -161,13 +162,7 @@ def fused_moe_marlin( M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 - - if custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize) + topk = topk_ids.shape[1] get_config_func = functools.partial( try_get_optimal_moe_config, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 9632dbbae395..53769cb73153 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -5,7 +5,7 @@ import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -272,6 +272,16 @@ def apply( from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( fused_moe_marlin) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + return fused_moe_marlin( x, layer.w13_weight_packed, @@ -281,10 +291,10 @@ def apply( layer.w2_g_idx, layer.w13_g_idx_sort_indices, layer.w2_g_idx_sort_indices, - top_k, - custom_routing_function=custom_routing_function, + topk_weights, + topk_ids, renormalize=renormalize, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, num_bits=self.num_bits, - ) \ No newline at end of file + ) From bef6b53fc2043f6e7de262f90b381797ee0574ad Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 3 Sep 2024 10:42:10 -0400 Subject: [PATCH 07/49] bring back fused_moe_marlin -> fused_marlin_moe --- tests/kernels/test_moe.py | 8 ++++---- vllm/model_executor/layers/fused_moe/__init__.py | 8 ++++---- .../{fused_moe_marlin.py => fused_marlin_moe.py} | 4 ++-- .../compressed_tensors/compressed_tensors_moe.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) rename vllm/model_executor/layers/fused_moe/{fused_moe_marlin.py => fused_marlin_moe.py} (99%) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2cfd76d1c780..606997843982 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -11,9 +11,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -230,7 +230,7 @@ def test_fused_marlin_moe( topk, renormalize=False, ) - marlin_output = fused_moe_marlin( + marlin_output = fused_marlin_moe( a, qweight1, qweight2, @@ -309,7 +309,7 @@ def test_marlin_moe_mmm( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_moe_marlin(a, + marlin_output = single_marlin_moe(a, qweight, scales, score, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 06bd2706d7e4..e9b5703ca28b 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -9,15 +9,15 @@ ] if HAS_TRITON: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) - from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin, single_moe_marlin) __all__ += [ - "fused_moe_marlin", - "single_moe_marlin", + "fused_marlin_moe", + "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py similarity index 99% rename from vllm/model_executor/layers/fused_moe/fused_moe_marlin.py rename to vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 45dead9740f4..5866c83cd9c8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -10,7 +10,7 @@ from vllm.scalar_type import scalar_types -def single_moe_marlin( +def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, @@ -96,7 +96,7 @@ def single_moe_marlin( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) -def fused_moe_marlin( +def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 53769cb73153..b14ef433d539 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -269,8 +269,8 @@ def apply( custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin) + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -282,7 +282,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) - return fused_moe_marlin( + return fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, From b45594ccfc87097933850e553244dcad2645a3dc Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 4 Sep 2024 15:28:23 +0000 Subject: [PATCH 08/49] remove large model --- tests/weight_loading/models.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 5eee2cc53444..1dc529037a98 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -21,7 +21,6 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main From effd2cd5cd96dd5737d605941e7bdb6066ee2816 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 13:10:02 -0400 Subject: [PATCH 09/49] Cleanup, comments --- csrc/moe/marlin_moe_ops.cu | 4 +- tests/kernels/test_moe.py | 1 - .../layers/fused_moe/__init__.py | 8 +-- .../layers/fused_moe/fused_marlin_moe.py | 50 ++++++++----------- .../compressed_tensors_moe.py | 1 - 5 files changed, 28 insertions(+), 36 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e3c18ce5a50b..f6d475a56851 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1228,8 +1228,6 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { @@ -1237,6 +1235,8 @@ __device__ inline void MarlinMoESingle( } cp_async_fence(); } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out if (last) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 606997843982..7e359ff08088 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -241,7 +241,6 @@ def test_fused_marlin_moe( sort_indices2, topk_weights, topk_ids, - renormalize=False, w1_scale=scales1, w2_scale=scales2, num_bits=num_bits, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e9b5703ca28b..dea4a32aec4f 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,3 +1,5 @@ +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON @@ -6,18 +8,16 @@ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "fused_marlin_moe", + "single_marlin_moe", ] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) __all__ += [ - "fused_marlin_moe", - "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 5866c83cd9c8..c7906205760f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,6 +1,6 @@ """Fused MoE utilities for GPTQ.""" import functools -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict, Optional import torch @@ -16,11 +16,10 @@ def single_marlin_moe( scales: torch.Tensor, gating_output: torch.Tensor, g_idx: torch.Tensor, - rand_perm: torch.Tensor, + perm: torch.Tensor, topk: int, renormalize: bool, override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, num_bits: int = 8, ) -> torch.Tensor: """ @@ -28,18 +27,18 @@ def single_marlin_moe( and top-k gating mechanism. It is meant for testing and debugging. Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w (torch.Tensor): The first set of expert weights. + - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. + - w (torch.Tensor): The set of expert weights. + - scales (torch.Tensor): The quantization scales. - gating_output (torch.Tensor): The output of the gating operation (before softmax). + - g_idx (torch.Tensor): The act_order indices. + - perm (torch.Tensor): The act_order input permutation. - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - product for w. Defaults to False. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -55,8 +54,6 @@ def single_marlin_moe( torch.float32, torch.float16, torch.bfloat16 ] assert num_bits in [4, 8] - # TODO support this - assert not use_fp8 M, K = hidden_states.shape E = w.shape[0] @@ -70,7 +67,7 @@ def single_marlin_moe( w.shape, w.shape, topk_ids.shape[1], - "float8" if use_fp8 else None, + None, override_config=override_config, is_marlin=True) config = get_config_func(M) @@ -90,7 +87,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, rand_perm, workspace, scalar_type, M, N, K, True, E, topk, + g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -103,14 +100,11 @@ def fused_marlin_moe( gating_output: torch.Tensor, g_idx1: torch.Tensor, g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, + perm1: torch.Tensor, + perm2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - custom_routing_function: Optional[Callable] = None, - renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, num_bits: int = 8, @@ -125,18 +119,20 @@ def fused_marlin_moe( - w2 (torch.Tensor): The second set of expert weights. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - topk (int): The number of top-k experts to select. + - g_idx1 (torch.Tensor): The fist set of act_order indices. + - g_idx2 (torch.Tensor): The second set of act_order indices. + - perm1 (torch.Tensor): The first act_order input permutation. + - perm2 (torch.Tensor): The second act_order input permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -156,8 +152,6 @@ def fused_marlin_moe( torch.float32, torch.float16, torch.bfloat16 ] assert num_bits in [4, 8] - # TODO support this - assert not use_fp8 M, K = hidden_states.shape E = w1.shape[0] @@ -169,7 +163,7 @@ def fused_marlin_moe( w1.shape, w2.shape, topk_ids.shape[1], - "float8" if use_fp8 else None, + None, override_config=override_config, is_marlin=True, ) @@ -202,7 +196,7 @@ def fused_marlin_moe( topk_ids, w1_scale, g_idx1, - rand_perm1, + perm1, workspace, scalar_type, M, @@ -226,7 +220,7 @@ def fused_marlin_moe( topk_ids, w2_scale, g_idx2, - rand_perm2, + perm2, workspace, scalar_type, M, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b14ef433d539..7dee2fca8115 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -293,7 +293,6 @@ def apply( layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - renormalize=renormalize, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, num_bits=self.num_bits, From 52c33539a4b38ba97223157374fb2243d1272988 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 13:28:01 -0400 Subject: [PATCH 10/49] fix moe init --- vllm/model_executor/layers/fused_moe/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index dea4a32aec4f..e9b5703ca28b 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,3 @@ -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON @@ -8,16 +6,18 @@ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", - "fused_marlin_moe", - "single_marlin_moe", ] if HAS_TRITON: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) __all__ += [ + "fused_marlin_moe", + "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", From 882fd9c38e52163ac8db4a29c459e166ceea6816 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 4 Sep 2024 21:14:02 +0000 Subject: [PATCH 11/49] move larger models to an options larger test --- .buildkite/test-pipeline.yaml | 11 ++++++++++- tests/weight_loading/models-large.txt | 3 +++ tests/weight_loading/models.txt | 2 -- .../run_model_weight_loading_test.sh | 14 +++++++++++++- 4 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 tests/weight_loading/models-large.txt diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 86eddb576c42..bb71d4f4b9ac 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -375,7 +375,16 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt new file mode 100644 index 000000000000..f997220554f3 --- /dev/null +++ b/tests/weight_loading/models-large.txt @@ -0,0 +1,3 @@ +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1dc529037a98..a3e382acf56b 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh index 0cb45d1780c2..a099ce56bcaf 100644 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -1,7 +1,19 @@ #!/bin/bash SUCCESS=0 -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" +while getopts "c:" OPT; do + case ${OPT} in + c ) + CONFIG="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do From 973d914721912568641de9f39b298c4adac1a2b4 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 4 Sep 2024 21:51:04 +0000 Subject: [PATCH 12/49] add optional flag --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bb71d4f4b9ac..54fa3ba535af 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -380,6 +380,7 @@ steps: - label: Weight Loading Multiple GPU Test - Large Models # optional working_dir: "/vllm-workspace/tests" num_gpus: 2 + optional: true source_file_dependencies: - vllm/ - tests/weight_loading From 72bc8997fa4b32d42996255629038d6da33155c0 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 5 Sep 2024 02:39:08 +0000 Subject: [PATCH 13/49] swap gpu --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 54fa3ba535af..900dc72e7446 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -380,6 +380,7 @@ steps: - label: Weight Loading Multiple GPU Test - Large Models # optional working_dir: "/vllm-workspace/tests" num_gpus: 2 + gpu: a100 optional: true source_file_dependencies: - vllm/ From eea2bc3f38b366c4077ea273d72b47bfacf13330 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 5 Sep 2024 00:51:59 -0400 Subject: [PATCH 14/49] Temp disable part of moe tests to see what's breaking --- tests/kernels/test_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 7e359ff08088..daa9a2235b5b 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -140,7 +140,8 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("num_bits", [4]) +# @pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, From 9c29dc2733f29ac39c1614a94fa65a74947c76e2 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 5 Sep 2024 02:47:21 -0400 Subject: [PATCH 15/49] Fixes to act_order, make unit tests more robust --- csrc/moe/marlin_moe_ops.cu | 31 ++++++++++++++++++++----------- tests/kernels/test_moe.py | 5 +---- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index f6d475a56851..98306cb4707b 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -481,9 +481,10 @@ __device__ inline void MarlinMoESingle( // Scale sizes/strides without act_order int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; // Scale size/strides with act_order @@ -527,11 +528,13 @@ __device__ inline void MarlinMoESingle( // No act_order int s_gl_rd; - if constexpr (group_blocks == -1 || group_blocks == 0) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } } int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -776,6 +779,12 @@ __device__ inline void MarlinMoESingle( int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -1150,9 +1159,9 @@ __device__ inline void MarlinMoESingle( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - __syncthreads(); + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); #pragma unroll for (int i = 0; i < stages - 1; i++) { diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index daa9a2235b5b..9f2cc8693d43 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -140,8 +140,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4]) -# @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, @@ -170,8 +169,6 @@ def test_fused_marlin_moe( a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - for i in range(w2.shape[0]): - w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) w_ref1_l = [] qweight1_l = [] From 6d04dcdf3be4c65e49ec5c42a8ca7e7b733f98ba Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 5 Sep 2024 08:51:24 -0400 Subject: [PATCH 16/49] try to narrow down cuda error --- tests/kernels/test_moe.py | 16 ++++++++++++---- .../layers/fused_moe/fused_marlin_moe.py | 4 +--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 9f2cc8693d43..92d512a41f60 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -133,14 +133,22 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) +# @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +# @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +# @pytest.mark.parametrize("k", [128, 1024, 512]) +# @pytest.mark.parametrize("e", [4, 8, 64]) +# @pytest.mark.parametrize("topk", [2, 6]) +# @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +# @pytest.mark.parametrize("act_order", [True, False]) +# @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) @pytest.mark.parametrize("e", [4, 8, 64]) -@pytest.mark.parametrize("topk", [2, 6]) -@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("group_size", [32, 64, 128]) +@pytest.mark.parametrize("act_order", [False]) +@pytest.mark.parametrize("num_bits", [8]) def test_fused_marlin_moe( m: int, n: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index c7906205760f..4c82d59c0a79 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -148,9 +148,7 @@ def fused_marlin_moe( assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype == torch.float16 assert num_bits in [4, 8] M, K = hidden_states.shape From 83e799913a9d6b7eccffb629841eb5617c5153bb Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 6 Sep 2024 08:08:18 -0400 Subject: [PATCH 17/49] Try different subset of test params --- tests/kernels/test_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 92d512a41f60..7343cf28d9fd 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -144,9 +144,9 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("e", [4]) @pytest.mark.parametrize("topk", [2]) -@pytest.mark.parametrize("group_size", [32, 64, 128]) +@pytest.mark.parametrize("group_size", [-1]) @pytest.mark.parametrize("act_order", [False]) @pytest.mark.parametrize("num_bits", [8]) def test_fused_marlin_moe( From 6a42eaf5d0eb49e35a013ae9aaf5bba4d0260248 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 6 Sep 2024 09:11:58 -0400 Subject: [PATCH 18/49] . --- tests/kernels/test_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 7343cf28d9fd..381a42d3a368 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -141,9 +141,9 @@ def compute_max_diff(output, output_ref): # @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) # @pytest.mark.parametrize("act_order", [True, False]) # @pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("m", [1]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) @pytest.mark.parametrize("e", [4]) @pytest.mark.parametrize("topk", [2]) @pytest.mark.parametrize("group_size", [-1]) From 3288842c23aed83cc1d9acf0e29648b466cd7f1d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 6 Sep 2024 10:24:28 -0400 Subject: [PATCH 19/49] . --- tests/kernels/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 381a42d3a368..283eb49a5c08 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -141,7 +141,7 @@ def compute_max_diff(output, output_ref): # @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) # @pytest.mark.parametrize("act_order", [True, False]) # @pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("m", [1]) +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128]) @pytest.mark.parametrize("k", [128]) @pytest.mark.parametrize("e", [4]) From 667d23e3f9d18babf731cf1cba0fb3f9c1f26ba5 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 10 Sep 2024 03:08:11 -0400 Subject: [PATCH 20/49] fix and cleanup after merge --- tests/kernels/test_moe.py | 2 +- vllm/model_executor/layers/quantization/gptq_marlin.py | 1 + vllm/model_executor/model_loader/utils.py | 8 +------- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 9f2cc8693d43..8072cf09e5b6 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -257,7 +257,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) -def test_marlin_moe_mmm( +def test_single_marlin_moe_multiply( m: int, n: int, k: int, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3617a32f80fc..cc699f5b4554 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -611,4 +611,5 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, ).to(orig_dtype) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0052489d99dc..2bfe6ea09bd6 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,13 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] - # for gptq_marlin, only run fused MoE for int4 - if model_config.quantization == "gptq_marlin": - hf_quant_config = getattr(model_config.hf_config, - "quantization_config", None) - if hf_quant_config and hf_quant_config.get("bits") == 4: - mixtral_supported.append("gptq_marlin") + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From b16838e17002eaad85406fc6f208b73509350e8d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 10 Sep 2024 03:13:42 -0400 Subject: [PATCH 21/49] cleanup --- .../run_model_weight_loading_test.sh | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) mode change 100644 => 100755 tests/weight_loading/run_model_weight_loading_test.sh diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh old mode 100644 new mode 100755 index a099ce56bcaf..0cb45d1780c2 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -1,19 +1,7 @@ #!/bin/bash SUCCESS=0 -while getopts "c:" OPT; do - case ${OPT} in - c ) - CONFIG="$OPTARG" - ;; - \? ) - usage - exit 1 - ;; - esac -done - -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do From e53abb908dcbea747231338142c83cc8b9a0b2e5 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 10 Sep 2024 12:31:43 -0400 Subject: [PATCH 22/49] validate cache for the kernel code --- csrc/moe/marlin_moe_ops.cu | 205 ++++++++++++++++++++++++++++--------- 1 file changed, 159 insertions(+), 46 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 98306cb4707b..be14166f235d 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1375,7 +1375,8 @@ __global__ void MarlinMoE( bool replicate_input, // do we use the same input for each expert? bool apply_weights, // apply weights to output int current_m_block, // current m block to start kernel computation from - int max_par // maximum parallelism + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks ) { int m_block_ctr = current_m_block; @@ -1396,14 +1397,15 @@ __global__ void MarlinMoE( prob_m = tot_its - 16 * m_block_ctr; int par = 1; - if (max_block > 4) { + if (max_block > cfg_max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * max_block - pad) / 64; - par = min((16 * max_block - pad) / 64, max_par); - prob_m = 64 * par; - m_block_ctr += 4 * (par - 1); - max_block = 4; + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + // par = min((16 * max_block - pad) / 64, max_par); + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; } if (max_block == 1) { @@ -1491,7 +1493,9 @@ __global__ void MarlinMoE( bool replicate_input, // do we use the same input for each expert? bool apply_weights, // apply weights to output int current_m_block, // current m block to start kernel computation from - int max_par // maximum parallelism + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks + ) { // Marlin is not implemented yet for SM < 8.0 assert(false); @@ -1530,7 +1534,8 @@ static constexpr int min_thread_k = 64; A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par); \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1539,6 +1544,11 @@ typedef struct { int num_threads; } thread_config_t; +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + thread_config_t small_batch_thread_configs[] = { // Ordered by priority @@ -1559,8 +1569,78 @@ thread_config_t large_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = ceildiv(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * STAGES; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = ceildiv(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + // TORCH_CHECK(false, "m blocks failed = ", m_blocks); + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * STAGES; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1588,26 +1668,49 @@ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } #define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ @@ -1654,26 +1757,42 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int num_bits = q_type.size_bits(); + // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1707,11 +1826,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, } } - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - int tot_m = prob_m; const int* topk_ids_ptr = (const int*)topk_ids; @@ -1756,12 +1870,11 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, A_ptr = a_tmp_ptr; } - int max_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { - // Define kernel configurations - + int tot_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < tot_m_blocks; + m_block += 4 * exec_cfg.max_m_blocks) { // make it max possible value - int thread_m_blocks = 4; + int thread_m_blocks = exec_cfg.max_m_blocks; if (false) { } From 2f82715a20ec12bf11aa2c4c8bde1e915a68216e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 11 Sep 2024 00:31:44 -0400 Subject: [PATCH 23/49] cleanup commented out code --- csrc/moe/marlin_moe_ops.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index be14166f235d..666d87eb9259 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1402,7 +1402,6 @@ __global__ void MarlinMoE( // padding par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); if (par > max_par) par = max_par; - // par = min((16 * max_block - pad) / 64, max_par); prob_m = (16 * cfg_max_m_blocks) * par; m_block_ctr += cfg_max_m_blocks * (par - 1); max_block = cfg_max_m_blocks; @@ -1621,7 +1620,6 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, break; } - // TORCH_CHECK(false, "m blocks failed = ", m_blocks); max_m_blocks--; if (max_m_blocks == 0) { TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); From 2cc7dcc003f31212d9294f03e9cc7b244c827244 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 12 Sep 2024 12:32:12 -0400 Subject: [PATCH 24/49] Zero point in fused Marlin MoE kernel --- csrc/moe/marlin_moe_ops.cu | 421 ++++++++++++++++++++++++++++-------- csrc/moe/marlin_moe_ops.h | 11 +- csrc/moe/torch_bindings.cpp | 6 +- 3 files changed, 343 insertions(+), 95 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 666d87eb9259..9108cea47428 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -58,6 +58,7 @@ using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; // quantization scales +using FragZP = Vec; // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. @@ -195,6 +196,46 @@ __device__ inline FragB dequant(int q) { return frag_b; } +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -203,6 +244,12 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { frag_b[1] = __hmul2(frag_b[1], s); } +__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { + half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + // Given 2 floats multiply by 2 scales (halves) __device__ inline void scale_float(float* c, FragS& s) { __half* s_ptr = reinterpret_cast<__half*>(&s); @@ -345,6 +392,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -356,6 +404,8 @@ __device__ inline void MarlinMoESingle( const float* __restrict__ topk_weights, // float topk weights const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, int num_groups, // number of scale groups per output channel @@ -497,8 +547,12 @@ __device__ inline void MarlinMoESingle( int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - constexpr int sorted_sh_stride = threads; - constexpr int sorted_gl_stride = threads; + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + @@ -539,6 +593,19 @@ __device__ inline void MarlinMoESingle( int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. @@ -550,6 +617,18 @@ __device__ inline void MarlinMoESingle( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + int sh_first_group_id = -1; int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; @@ -565,8 +644,8 @@ __device__ inline void MarlinMoESingle( int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - int* sh_sorted = (int*)(sh_s + shs_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or @@ -622,8 +701,10 @@ __device__ inline void MarlinMoESingle( FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 // Zero accumulators. auto zero_accums = [&]() { @@ -730,6 +811,28 @@ __device__ inline void MarlinMoESingle( } } } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } } } // Insert a fence even when we are winding down the pipeline to ensure that @@ -737,15 +840,9 @@ __device__ inline void MarlinMoESingle( cp_async_fence(); }; - // TODO we are currently hitting illegal memory accesses when fetching - // sorted_ids to shared data: fix this - auto fetch_sorted_ids_to_shared = [&]() { - const int mpt = ceildiv(prob_m, threads); - for (int i = 0; i < mpt; i++) { - if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { - sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = - sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; - } + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); } }; @@ -896,8 +993,83 @@ __device__ inline void MarlinMoESingle( } }; + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -926,6 +1098,11 @@ __device__ inline void MarlinMoESingle( } } + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -1159,9 +1336,6 @@ __device__ inline void MarlinMoESingle( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - // __syncthreads(); #pragma unroll for (int i = 0; i < stages - 1; i++) { @@ -1172,6 +1346,12 @@ __device__ inline void MarlinMoESingle( } fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } fetch_to_shared(i, i, i < slice_iters); } @@ -1180,6 +1360,7 @@ __device__ inline void MarlinMoESingle( init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; @@ -1199,6 +1380,7 @@ __device__ inline void MarlinMoESingle( for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -1333,6 +1515,7 @@ __device__ inline void MarlinMoESingle( } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } start_pipes(); } @@ -1350,6 +1533,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1361,6 +1545,8 @@ __global__ void MarlinMoE( const float* __restrict__ topk_weights, // float topk weights const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, int num_groups, // number of scale groups per output channel @@ -1409,29 +1595,29 @@ __global__ void MarlinMoE( if (max_block == 1) { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); @@ -1467,6 +1653,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1478,6 +1665,8 @@ __global__ void MarlinMoE( const float* __restrict__ topk_weights, // float topk weights const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, int num_groups, // number of scale groups per output channel @@ -1516,22 +1705,23 @@ static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; #define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ NUM_THREADS) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ MarlinMoE \ + THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ replicate_input, apply_weights, m_block, max_par, \ exec_cfg.max_m_blocks); \ @@ -1711,43 +1901,65 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, + const void* topk_ids, const void* s, void* zp, + const void* g_idx, const void* perm, void* a_tmp, + void* expert_offsets, int prob_m, int prob_n, + int prob_k, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { + bool is_k_full, bool has_zp, int num_groups, + int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1855,6 +2067,11 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx; + const int4* zp_ptr = + (const int4*)zp + + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * + prob_n / 8) * + expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; @@ -1876,14 +2093,23 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } - CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + + AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) + AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) + AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) + AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) + AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1904,13 +2130,21 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights) { - TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + torch::Tensor& b_zeros, const torch::Tensor& g_idx, + const torch::Tensor& perm, torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, + int64_t topk, int64_t moe_block_size, bool replicate_input, + bool apply_weights) { + if (has_zp) { + TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", + b_q_type->str()); + } else { + TORCH_CHECK( + *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + } int pack_factor = 32 / b_q_type->size_bits(); @@ -1969,13 +2203,26 @@ torch::Tensor marlin_gemm_moe( } } + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); + TORCH_CHECK(b_zeros.size(0) == num_groups, + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", b_scales.size(1), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + marlin_moe::marlin_mm_moe_f16i4( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, - topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, max_par, replicate_input, apply_weights); + *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, + num_experts, topk, moe_block_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, + replicate_input, apply_weights); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index adee8399a4d6..0a54d93cedeb 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -8,8 +8,9 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); + torch::Tensor& b_zeros, const torch::Tensor& g_idx, + const torch::Tensor& perm, torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, + int64_t topk, int64_t moe_block_size, bool replicate_input, + bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index cd65a8ee92b9..85098df34b2d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,10 +13,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, " + "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " - "int moe_block_size, bool replicate_input, bool apply_weights)" + "int size_n, int size_k, bool is_k_full, bool has_zp, int num_experts, " + "int topk, int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif From 50cc76659c6ab9f14237564a870cc04d8d03ccbe Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 13 Sep 2024 11:23:40 -0400 Subject: [PATCH 25/49] Split into multiple files for faster compilation (work in progress) --- csrc/moe/marlin_moe_kernel.cuh | 1664 ++++++++++++ csrc/moe/marlin_moe_kernel_ku4.cu | 75 + csrc/moe/marlin_moe_kernel_ku4b8.cu | 79 + csrc/moe/marlin_moe_kernel_ku8.cu | 75 + csrc/moe/marlin_moe_kernel_ku8b128.cu | 79 + csrc/moe/marlin_moe_ops.cu | 3381 +++++++++++++------------ 6 files changed, 3705 insertions(+), 1648 deletions(-) create mode 100644 csrc/moe/marlin_moe_kernel.cuh create mode 100644 csrc/moe/marlin_moe_kernel_ku4.cu create mode 100644 csrc/moe/marlin_moe_kernel_ku4b8.cu create mode 100644 csrc/moe/marlin_moe_kernel_ku8.cu create mode 100644 csrc/moe/marlin_moe_kernel_ku8b128.cu diff --git a/csrc/moe/marlin_moe_kernel.cuh b/csrc/moe/marlin_moe_kernel.cuh new file mode 100644 index 000000000000..5330cfca9751 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel.cuh @@ -0,0 +1,1664 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/scalar_type.hpp" + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales +using FragZP = Vec; + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { + half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks + +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +// #define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ +// THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ +// NUM_THREADS) \ +// else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ +// thread_n_blocks == THREAD_N_BLOCKS && \ +// thread_k_blocks == THREAD_K_BLOCKS && \ +// has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ +// group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ +// cudaFuncSetAttribute(MarlinMoE, \ +// cudaFuncAttributeMaxDynamicSharedMemorySize, \ +// max_shared_mem); \ +// MarlinMoE \ +// <<>>( \ +// A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ +// zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ +// num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ +// replicate_input, apply_weights, m_block, max_par, \ +// cfg_max_m_blocks); \ +// } + +// #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +// #define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_moe_kernel_ku4.cu new file mode 100644 index 000000000000..d50d8f14d785 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku4.cu @@ -0,0 +1,75 @@ +#include "marlin_moe_kernel.cuh" + +namespace marlin_moe { + +#define __CALL_IF_MOE_4(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + + +#define AWQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_m_blocks, + int thread_n_blocks, int thread_k_blocks, bool has_act_order, + bool has_zp, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + AWQ_CALL_IF_MOE_4(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF_MOE_4(vllm::kU4, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_moe_kernel_ku4b8.cu new file mode 100644 index 000000000000..f5832b550a5d --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku4b8.cu @@ -0,0 +1,79 @@ +#include "marlin_moe_kernel.cuh" + +namespace marlin_moe { + +#define __CALL_IF_MOE_4_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_m_blocks, + int thread_n_blocks, int thread_k_blocks, bool has_act_order, + bool has_zp, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_moe_kernel_ku8.cu new file mode 100644 index 000000000000..b07491910002 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku8.cu @@ -0,0 +1,75 @@ +#include "marlin_moe_kernel.cuh" + +namespace marlin_moe { + +#define __CALL_IF_MOE_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + + +#define AWQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8( + vllm::ScalarType const& q_type, int thread_m_blocks, + int thread_n_blocks, int thread_k_blocks, bool has_act_order, + bool has_zp, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + AWQ_CALL_IF_MOE_8(vllm::kU8, 16, 4, 256) + AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 8, 256) + AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 4, 128) + AWQ_CALL_IF_MOE_8(vllm::kU8, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_moe_kernel_ku8b128.cu new file mode 100644 index 000000000000..22f042f0d43a --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku8b128.cu @@ -0,0 +1,79 @@ +#include "marlin_moe_kernel.cuh" + +namespace marlin_moe { + +#define __CALL_IF_MOE_8_128(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_m_blocks, + int thread_n_blocks, int thread_k_blocks, bool has_act_order, + bool has_zp, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 9108cea47428..dba94bde9fc1 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -26,6 +26,10 @@ #include #include "core/scalar_type.hpp" +#include "marlin_moe_kernel_ku4b8.cu" +#include "marlin_moe_kernel_ku8b128.cu" +#include "marlin_moe_kernel_ku4.cu" +#include "marlin_moe_kernel_ku8.cu" template inline std::string str(T x) { @@ -34,276 +38,290 @@ inline std::string str(T x) { namespace marlin_moe { -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } +// constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: +// // Instances of `Vec` are used to organize groups of >>registers<<, as needed +// // for instance as inputs to tensor core operations. Consequently, all +// // corresponding index accesses must be compile-time constants, which is why +// we +// // extensively use `#pragma unroll` throughout the kernel code to guarantee +// // this. +// template +// struct Vec { +// T elems[n]; +// __device__ T& operator[](int i) { return elems[i]; } +// }; + +// using I4 = Vec; + +// // Matrix fragments for tensor core instructions; their precise layout is +// // documented here: +// // // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales -using FragZP = Vec; - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: +// using FragA = Vec; +// using FragB = Vec; +// using FragC = Vec; +// using FragS = Vec; // quantization scales +// using FragZP = Vec; + +// // Predicated asynchronous global->shared copy; used for inputs A where we +// apply +// // predication to handle batchsizes that are not multiples of 16. +// __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, +// bool pred = true) { +// const int BYTES = 16; +// uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); +// asm volatile( +// "{\n" +// " .reg .pred p;\n" +// " setp.ne.b32 p, %0, 0;\n" +// " @p cp.async.cg.shared.global [%1], [%2], %3;\n" +// "}\n" ::"r"((int)pred), +// "r"(smem), "l"(glob_ptr), "n"(BYTES)); +// } + +// // Asynchronous global->shared copy +// __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { +// const int BYTES = 16; +// uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); +// asm volatile( +// "{\n" +// " cp.async.cg.shared.global [%0], [%1], %2;\n" +// "}\n" ::"r"(smem), +// "l"(glob_ptr), "n"(BYTES)); +// } + +// // Async copy fence. +// __device__ inline void cp_async_fence() { +// asm volatile("cp.async.commit_group;\n" ::); +// } + +// // Wait until at most `n` async copy stages are still pending. +// template +// __device__ inline void cp_async_wait() { +// asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +// } + +// // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// // output/accumulation. +// __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, +// FragC& frag_c) { +// const uint32_t* a = reinterpret_cast(&a_frag); +// const uint32_t* b = reinterpret_cast(&frag_b); +// float* c = reinterpret_cast(&frag_c); +// asm volatile( +// "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " +// "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" +// : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) +// : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), +// "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +// } + +// // Instruction for loading a full 16x16 matrix fragment of operand A from +// shared +// // memory, directly in tensor core layout. +// __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { +// uint32_t* a = reinterpret_cast(&frag_a); +// uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); +// asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, +// [%4];\n" +// : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) +// : "r"(smem)); +// } + +// // Lookup-table based 3-input logical operation; explicitly used for +// // dequantization as the compiler does not seem to automatically recognize it +// in +// // all cases. +// template +// __device__ inline int lop3(int a, int b, int c) { +// int res; +// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" +// : "=r"(res) +// : "r"(a), "r"(b), "r"(c), "n"(lut)); +// return res; +// } + +// // Constructs destination register by taking bytes from 2 sources (based on +// // mask) +// template +// __device__ inline uint32_t prmt(uint32_t a) { +// uint32_t res; +// asm volatile("prmt.b32 %0, %1, %2, %3;\n" +// : "=r"(res) +// : "r"(a), "n"(start_byte), "n"(mask)); +// return res; +// } + +// template +// __device__ inline FragB dequant(int q); + +// // Efficiently dequantize 4bit values packed in an int32 value into a full +// // B-fragment of 4 fp16 values. We mostly follow the strategy in the link +// below, +// // with some small changes: +// // // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: +// template <> +// __device__ inline FragB dequant(int q) { +// const int LO = 0x000f000f; +// const int HI = 0x00f000f0; +// const int EX = 0x64006400; +// // Guarantee that the `(a & b) | c` operations are LOP3s. +// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); +// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); +// // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point +// // directly into `SUB` and `ADD`. +// const int SUB = 0x64086408; +// const int MUL = 0x2c002c00; +// const int ADD = 0xd480d480; +// FragB frag_b; +// frag_b[0] = __hsub2(*reinterpret_cast(&lo), +// *reinterpret_cast(&SUB)); +// frag_b[1] = __hfma2(*reinterpret_cast(&hi), +// *reinterpret_cast(&MUL), +// *reinterpret_cast(&ADD)); +// return frag_b; +// } + +// // Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// // Reference: +// // // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { - half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} +// template <> +// __device__ inline FragB dequant(int q) { +// static constexpr uint32_t mask_for_elt_01 = 0x5250; +// static constexpr uint32_t mask_for_elt_23 = 0x5351; +// static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + +// uint32_t lo = prmt(q); +// uint32_t hi = prmt(q); + +// static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + +// FragB frag_b; +// frag_b[0] = __hsub2(*reinterpret_cast(&lo), +// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +// frag_b[1] = __hsub2(*reinterpret_cast(&hi), +// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +// return frag_b; +// } + +// template <> +// __device__ inline FragB dequant(int q) { +// const int LO = 0x000f000f; +// const int HI = 0x00f000f0; +// const int EX = 0x64006400; +// // Guarantee that the `(a & b) | c` operations are LOP3s. +// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); +// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + +// const int SUB = 0x64006400; +// const int MUL = 0x2c002c00; +// const int ADD = 0xd400d400; +// FragB frag_b; +// frag_b[0] = __hsub2(*reinterpret_cast(&lo), +// *reinterpret_cast(&SUB)); +// frag_b[1] = __hfma2(*reinterpret_cast(&hi), +// *reinterpret_cast(&MUL), +// *reinterpret_cast(&ADD)); +// return frag_b; +// } + +// template <> +// __device__ inline FragB dequant(int q) { +// static constexpr uint32_t mask_for_elt_01 = 0x5250; +// static constexpr uint32_t mask_for_elt_23 = 0x5351; +// static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + +// uint32_t lo = prmt(q); +// uint32_t hi = prmt(q); + +// static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + +// FragB frag_b; +// frag_b[0] = __hsub2(*reinterpret_cast(&lo), +// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +// frag_b[1] = __hsub2(*reinterpret_cast(&hi), +// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +// return frag_b; +// } + +// // Multiply dequantized values by the corresponding quantization scale; used +// // only for grouped quantization. +// __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { +// half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); +// frag_b[0] = __hmul2(frag_b[0], s); +// frag_b[1] = __hmul2(frag_b[1], s); +// } + +// __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { +// half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); +// frag_b[0] = __hsub2(frag_b[0], zp); +// frag_b[1] = __hsub2(frag_b[1], zp); +// } + +// // Given 2 floats multiply by 2 scales (halves) +// __device__ inline void scale_float(float* c, FragS& s) { +// __half* s_ptr = reinterpret_cast<__half*>(&s); +// c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); +// c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +// } + +// // Same as above, but for act_order (each K is multiplied individually) +// __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& +// frag_s_2, +// FragS& frag_s_3, FragS& frag_s_4, int i) { +// __half2 s_val_1_2; +// s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; +// s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + +// __half2 s_val_3_4; +// s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; +// s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + +// frag_b[0] = __hmul2(frag_b[0], s_val_1_2); +// frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +// } + +// // Wait until barrier reaches `count`, then lock for current threadblock. +// __device__ inline void barrier_acquire(int* lock, int count) { +// if (threadIdx.x == 0) { +// int state = -1; +// do +// // Guarantee that subsequent writes by this threadblock will be visible +// // globally. +// asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" +// : "=r"(state) +// : "l"(lock)); +// while (state != count); +// } +// __syncthreads(); +// } + +// // Release barrier and increment visitation count. +// __device__ inline void barrier_release(int* lock, bool reset = false) { +// __syncthreads(); +// if (threadIdx.x == 0) { +// if (reset) { +// lock[0] = 0; +// return; +// } +// int val = 1; +// // Make sure that all writes since acquiring this barrier are visible +// // globally, while releasing the barrier. +// asm volatile("fence.acq_rel.gpu;\n"); +// asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" +// : +// : "l"(lock), "r"(val)); +// } +// } // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. @@ -382,1247 +400,1284 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ inline void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} +// template shared +// // fetch pipeline +// const bool has_act_order, // whether act_order is enabled +// const bool has_zp, // whether zero-points are enabled +// const int group_blocks = -1 // number of consecutive 16x16 blocks +// // with a separate quantization scale +// > +// __device__ inline void MarlinMoESingle( +// const int4* __restrict__ A, // fp16 input matrix of shape mxk +// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn +// int4* __restrict__ C, // fp16 output buffer of shape mxn +// const int* __restrict__ sorted_ids, // int32 sorted ids of experts +// const float* __restrict__ topk_weights, // float topk weights +// const int4* __restrict__ scales_ptr, // fp16 quantization scales of +// shape +// // (k/groupsize)xn +// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape +// // (k/groupsize)x(n/pack_factor) +// const int* __restrict__ g_idx, // int32 group indices of shape k +// const int* __restrict__ expert_offsets, +// int num_groups, // number of scale groups per output channel +// int expert_idx, // idx of current expert +// int num_experts, // number of experts +// int topk, // topk parameter of moe +// int prob_m, // batch dimension m +// int prob_n, // output dimension n +// int prob_k, // reduction dimension k +// int tot_m, // total number of rows in A and C +// int* locks, // extra global storage for barrier +// synchronization bool replicate_input, // do we use the same input for +// each expert? bool apply_weights, // apply weights to output int +// current_m_block // current m block to start kernel computation from +// ) { +// static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); +// constexpr int pack_factor = 32 / w_type.size_bits(); + +// // For larger GEMMs we run multiple batchsize 64 versions in parallel for a +// // better partitioning with less reductions +// int parallel = 1; +// if (prob_m > 16 * thread_m_blocks) { +// parallel = prob_m / (16 * thread_m_blocks); +// prob_m = 16 * thread_m_blocks; +// } + +// int k_tiles = prob_k / 16 / thread_k_blocks; +// int n_tiles = prob_n / 16 / thread_n_blocks; +// int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + +// if constexpr (!has_act_order && group_blocks != -1) { +// if (group_blocks >= thread_k_blocks) { +// // Ensure that the number of tiles in each stripe is a multiple of the +// // groupsize; this avoids an annoying special case where a stripe +// starts +// // in the middle of group. +// iters = (group_blocks / thread_k_blocks) * +// ceildiv(iters, (group_blocks / thread_k_blocks)); +// } +// } + +// int slice_row = (iters * blockIdx.x) % k_tiles; +// int slice_col_par = (iters * blockIdx.x) / k_tiles; +// int slice_col = slice_col_par; +// int slice_iters; // number of threadblock tiles in the current slice +// int slice_count = +// 0; // total number of active threadblocks in the current slice +// int slice_idx; // index of threadblock in current slice; numbered bottom +// to +// // top + +// // We can easily implement parallel problem execution by just remapping +// // indices and advancing global pointers +// if (slice_col_par >= n_tiles) { +// locks += (slice_col_par / n_tiles) * n_tiles; +// slice_col = slice_col_par % n_tiles; +// sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; +// } + +// // Compute all information about the current slice which is required for +// // synchronization. +// auto init_slice = [&]() { +// slice_iters = +// iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); +// if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = +// 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) +// slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int +// col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if +// (col_first <= k_tiles * (slice_col_par + 1)) { +// int col_off = col_first - k_tiles * slice_col_par; +// slice_count = ceildiv(k_tiles - col_off, iters); +// if (col_off > 0) slice_count++; +// int delta_first = iters * blockIdx.x - col_first; +// if (delta_first < 0 || (col_off == 0 && delta_first == 0)) +// slice_idx = slice_count - 1; +// else { +// slice_idx = slice_count - 1 - delta_first / iters; +// if (col_off > 0) slice_idx--; +// } +// } +// if (slice_col == n_tiles) { +// sorted_ids += 16 * thread_m_blocks; +// locks += n_tiles; +// slice_col = 0; +// } +// }; +// init_slice(); + +// // A sizes/strides + +// // stride of the A matrix in global memory +// int a_gl_stride = prob_k / 8; +// // stride of an A matrix tile in shared memory +// constexpr int a_sh_stride = 16 * thread_k_blocks / 8; +// // delta between subsequent A tiles in global memory +// constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; +// // between subsequent accesses within a tile +// int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); +// // between shared memory writes +// constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); +// // between shared memory tile reads +// constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / +// 4)); +// // within a shared memory tile +// constexpr int a_sh_rd_delta_i = a_sh_stride * 16; +// // overall size of a tile +// constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); +// // number of shared write iterations for a tile +// constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + +// // B sizes/strides +// int b_gl_stride = 16 * prob_n / (pack_factor * 4); +// constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / +// 4; constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; constexpr +// int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + +// int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; +// int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); +// constexpr int b_sh_wr_delta = threads * b_thread_vecs; +// constexpr int b_sh_rd_delta = threads * b_thread_vecs; +// constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; +// constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + +// // Scale sizes/strides without act_order +// int s_gl_stride = prob_n / 8; +// constexpr int s_sh_stride = 16 * thread_n_blocks / 8; +// constexpr int s_tb_groups = +// !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks +// ? thread_k_blocks / group_blocks +// : 1; +// constexpr int s_sh_stage = s_tb_groups * s_sh_stride; +// int s_gl_rd_delta = s_gl_stride; +// // Scale size/strides with act_order +// constexpr int tb_k = 16 * thread_k_blocks; +// constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; +// // constexpr int act_s_row_stride = 1; +// // int act_s_col_stride = act_s_row_stride * num_groups; +// int act_s_col_stride = 1; +// int act_s_col_warp_stride = act_s_col_stride * 8; +// int tb_n_warps = thread_n_blocks / 4; +// int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + +// // Zero-points sizes/strides +// int zp_gl_stride = (prob_n / pack_factor) / 4; +// constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; +// constexpr int zp_tb_groups = s_tb_groups; +// constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; +// int zp_gl_rd_delta = zp_gl_stride; + +// // Global A read index of current thread. +// int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + +// (threadIdx.x % a_gl_rd_delta_o); +// a_gl_rd += a_gl_rd_delta_o * slice_row; +// // Shared write index of current thread. +// int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + +// (threadIdx.x % a_gl_rd_delta_o); +// // Shared read index. +// int a_sh_rd = +// a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; +// a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + +// int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + +// (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; +// b_gl_rd += b_sh_stride * slice_col; +// b_gl_rd += b_gl_rd_delta_o * slice_row; +// int b_sh_wr = threadIdx.x * b_thread_vecs; +// int b_sh_rd = threadIdx.x * b_thread_vecs; + +// // For act_order +// constexpr int k_iter_size = tb_k / b_sh_wr_iters; +// int slice_k_start = tb_k * slice_row; +// int slice_k_finish = slice_k_start + tb_k * slice_iters; +// int slice_k_start_shared_fetch = slice_k_start; +// int slice_n_offset = act_s_col_tb_stride * slice_col; + +// // No act_order +// int s_gl_rd; +// if constexpr (!has_act_order) { +// if constexpr (group_blocks == -1) { +// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; +// } else { +// s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +// + +// s_sh_stride * slice_col + threadIdx.x; +// } +// } +// int s_sh_wr = threadIdx.x; +// bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + +// // Zero-points +// int zp_gl_rd; +// if constexpr (has_zp) { +// if constexpr (group_blocks == -1) { +// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; +// } else { +// zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / +// group_blocks) + +// zp_sh_stride * slice_col + threadIdx.x; +// } +// } +// int zp_sh_wr = threadIdx.x; +// bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + +// // We use a different scale layout for grouped and column-wise quantization +// as +// // we scale a `half2` tile in column-major layout in the former and in +// // row-major in the latter case. +// int s_sh_rd; +// if constexpr (group_blocks != -1) +// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + +// (threadIdx.x % 32) / 4; +// else +// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + +// (threadIdx.x % 32) % 4; + +// // Zero-points have the same read layout as the scales +// // (without column-wise case) +// constexpr int num_col_threads = 8; +// constexpr int num_row_threads = 4; +// constexpr int num_ints_per_thread = 8 / pack_factor; +// int zp_sh_rd; +// if constexpr (has_zp) { +// zp_sh_rd = num_ints_per_thread * num_col_threads * +// ((threadIdx.x / 32) % (thread_n_blocks / 4)) + +// num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); +// } + +// int sh_first_group_id = -1; +// int sh_num_groups = -1; +// constexpr int sh_max_num_groups = 32; + +// int shs_size; +// if constexpr (has_act_order) +// shs_size = sh_max_num_groups * s_sh_stride + threads; +// else +// shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + +// extern __shared__ int4 sh[]; +// // Shared memory storage for global fetch pipelines. +// int4* sh_a = sh; +// int4* sh_b = sh_a + (stages * a_sh_stage); +// int4* sh_g_idx = sh_b + (stages * b_sh_stage); +// int4* sh_zp = sh_g_idx + (stages * g_idx_stage); +// int4* sh_s = sh_zp + (stages * zp_sh_stage); + +// // Precompute which thread should not read memory in which iterations; this +// is +// // needed if there are more threads than required for a certain tilesize or +// // when the batchsize is not a multiple of 16. +// bool a_sh_wr_pred[a_sh_wr_iters]; +// #pragma unroll +// for (int i = 0; i < a_sh_wr_iters; i++) { +// int a_idx = a_sh_wr_delta * i + a_sh_wr; +// int row = a_idx / a_gl_rd_delta_o; +// if (row >= prob_m) { +// a_sh_wr_pred[i] = false; +// } else { +// a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; +// } +// } + +// // To ensure that writing and reading A tiles to/from shared memory, the +// // latter in fragment format, is fully bank conflict free, we need to use a +// // rather fancy XOR-based layout. The key here is that neither reads nor +// // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the +// // same shared memory banks. Further, it seems (based on NSight-Compute) +// that +// // each warp must also write a consecutive memory segment? +// auto transform_a = [&](int i) { +// int row = i / a_gl_rd_delta_o; +// return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; +// }; +// // Since the computation of this remapping is non-trivial and, due to our +// main +// // loop unrolls, all shared memory accesses are static, we simply +// precompute +// // both transformed reads and writes. +// int a_sh_wr_trans[a_sh_wr_iters]; +// #pragma unroll +// for (int i = 0; i < a_sh_wr_iters; i++) +// a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); +// int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) { +// #pragma unroll +// for (int j = 0; j < thread_m_blocks; j++) +// a_sh_rd_trans[i][j] = +// transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); +// } + +// // Since B-accesses have non-constant stride they have to be computed at +// // runtime; we break dependencies between subsequent accesses with a tile +// by +// // maintining multiple pointers (we have enough registers), a tiny +// // optimization. +// const int4* B_ptr[b_sh_wr_iters]; +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) +// B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + +// // Register storage for double buffer of shared memory reads. +// FragA frag_a[2][thread_m_blocks]; +// I4 frag_b_quant[2][b_thread_vecs]; +// FragC frag_c[thread_m_blocks][4][2]; +// FragS frag_s[2][4]; // No act-order +// FragS act_frag_s[2][4][4]; // For act-order +// int frag_qzp[2][num_ints_per_thread]; // Zero-points +// FragZP frag_zp; // Zero-points in fp16 + +// // Zero accumulators. +// auto zero_accums = [&]() { +// #pragma unroll +// for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) +// reinterpret_cast(frag_c)[i] = 0; +// }; + +// auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, +// int last_group_id) { +// sh_first_group_id = first_group_id; +// sh_num_groups = last_group_id - first_group_id + 1; + +// if (sh_num_groups < sh_max_num_groups) { +// sh_num_groups = sh_max_num_groups; +// } + +// if (sh_first_group_id + sh_num_groups > num_groups) { +// sh_num_groups = num_groups - sh_first_group_id; +// } + +// int row_offset = first_group_id * s_gl_stride; + +// if (is_async) { +// for (int i = 0; i < sh_num_groups; i++) { +// if (threadIdx.x < s_sh_stride) { +// cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], +// &scales_ptr[row_offset + (i * s_gl_stride) + +// slice_n_offset + threadIdx.x]); +// } +// } +// } else { +// for (int i = 0; i < sh_num_groups; i++) { +// if (threadIdx.x < s_sh_stride) { +// sh_s[(i * s_sh_stride) + threadIdx.x] = +// scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + +// threadIdx.x]; +// } +// } +// } +// }; +// // Asynchronously fetch the next A, B and s tile from global to the next +// // shared memory pipeline location. +// auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { +// if (pred) { +// int4* sh_a_stage = sh_a + a_sh_stage * pipe; +// #pragma unroll +// for (int i = 0; i < a_sh_wr_iters; i++) { +// int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; +// int row = a_idx / a_gl_stride; +// int sorted_row = +// replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; +// int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; +// if (sorted_row < tot_m * (replicate_input ? 1 : topk) && +// new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { +// cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], +// a_sh_wr_pred[i]); +// } +// } +// int4* sh_b_stage = sh_b + b_sh_stage * pipe; +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) { +// #pragma unroll +// for (int j = 0; j < b_thread_vecs; j++) { +// cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + +// j); +// } +// B_ptr[i] += b_gl_rd_delta_o; +// } + +// if constexpr (has_act_order) { +// // Fetch g_idx thread-block portion +// int full_pipe = a_off; +// int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; +// if (cur_k < prob_k && cur_k < slice_k_finish) { +// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + +// int4 const* cur_g_idx_stage_ptr = +// reinterpret_cast(&g_idx[cur_k]); + +// if (threadIdx.x < g_idx_stage) { +// cp_async4_pred(&sh_g_idx_stage[threadIdx.x], +// &cur_g_idx_stage_ptr[threadIdx.x]); +// } +// } +// } else { +// if constexpr (group_blocks != -1) { +// int4* sh_s_stage = sh_s + s_sh_stage * pipe; + +// if constexpr (group_blocks >= thread_k_blocks) { +// // Only fetch scales if this tile starts a new group +// if (pipe % (group_blocks / thread_k_blocks) == 0) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); +// } +// s_gl_rd += s_gl_rd_delta; +// } +// } else { +// for (int i = 0; i < s_tb_groups; i++) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], +// &scales_ptr[s_gl_rd]); +// } +// s_gl_rd += s_gl_rd_delta; +// } +// } +// } + +// if constexpr (has_zp && group_blocks != -1) { +// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + +// if constexpr (group_blocks >= thread_k_blocks) { +// // Only fetch zero-points if this tile starts a new group +// if (pipe % (group_blocks / thread_k_blocks) == 0) { +// if (zp_sh_wr_pred) { +// cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); +// } +// zp_gl_rd += zp_gl_rd_delta; +// } +// } else { +// for (int i = 0; i < zp_tb_groups; i++) { +// if (zp_sh_wr_pred) { +// cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], +// &zp_ptr[zp_gl_rd]); +// } +// zp_gl_rd += zp_gl_rd_delta; +// } +// } +// } +// } +// } +// // Insert a fence even when we are winding down the pipeline to ensure +// that +// // waiting is also correct at this point. +// cp_async_fence(); +// }; + +// auto fetch_zp_to_shared = [&]() { +// if (zp_sh_wr_pred) { +// cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); +// } +// }; + +// // Wait until the next thread tile has been loaded to shared memory. +// auto wait_for_stage = [&]() { +// // We only have `stages - 2` active fetches since we are double buffering +// // and can only issue the next fetch when it is guaranteed that the +// previous +// // shared memory load is fully complete (as it may otherwise be +// // overwritten). +// cp_async_wait(); +// __syncthreads(); +// }; + +// // Load the next sub-tile from the current location in the shared memory +// pipe +// // into the current register buffer. +// auto fetch_to_registers = [&](int k, int pipe) { +// int4* sh_a_stage = sh_a + a_sh_stage * pipe; +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) +// ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % +// b_sh_wr_iters][i]]); +// int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +// #pragma unroll +// for (int i = 0; i < b_thread_vecs; i++) { +// frag_b_quant[k % 2][i] = *reinterpret_cast( +// &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); +// } +// }; + +// bool is_same_group[stages]; +// int same_group_id[stages]; + +// auto init_same_group = [&](int pipe) { +// if constexpr (!has_act_order) { +// is_same_group[pipe] = false; +// same_group_id[pipe] = 0; +// return; +// } + +// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; +// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + +// int group_id_1 = sh_g_idx_int_ptr[0]; +// int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + +// is_same_group[pipe] = group_id_1 == group_id_2; +// same_group_id[pipe] = group_id_1; +// }; + +// auto fetch_scales_to_registers = [&](int k, int full_pipe) { +// int pipe = full_pipe % stages; + +// if constexpr (!has_act_order) { +// // No act-order case +// if constexpr (group_blocks != -1) { +// if constexpr (group_blocks >= thread_k_blocks) { +// int4* sh_s_stage = +// sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * +// (pipe / (group_blocks / +// thread_k_blocks))); +// reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; +// } else { +// int warp_id = threadIdx.x / 32; +// int n_warps = thread_n_blocks / 4; + +// int warp_row = warp_id / n_warps; + +// int cur_k = warp_row * 16; +// cur_k += k_iter_size * (k % b_sh_wr_iters); + +// int k_blocks = cur_k / 16; +// int cur_group_id = k_blocks / group_blocks; + +// int4* sh_s_stage = sh_s + s_sh_stage * pipe; + +// reinterpret_cast(&frag_s[k % 2])[0] = +// sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; +// } +// } + +// return; +// } + +// // Act-order case + +// // Determine K of the "current" thread-block +// int cur_k = slice_k_start + tb_k * full_pipe; +// if (cur_k >= prob_k || cur_k >= slice_k_finish) { +// return; +// } + +// // Reset (to current thread-block) since we read g_idx portion from the +// // shared memory +// cur_k = 0; + +// // Progress to current iteration +// cur_k += k_iter_size * (k % b_sh_wr_iters); + +// // Determine "position" inside the thread-block (based on warp and +// // thread-id) +// int warp_id = threadIdx.x / 32; +// int n_warps = +// thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + +// int warp_row = warp_id / n_warps; +// int warp_col = warp_id % n_warps; + +// cur_k += warp_row * 16; + +// int th_id = threadIdx.x % 32; +// cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + +// int s_col_shift = +// /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + +// (th_id / 4) * act_s_col_stride; + +// if (is_same_group[pipe]) { +// if (k % 2 == 0) { +// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = +// sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + +// s_col_shift]; +// } else { +// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = +// *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); +// } + +// for (int i = 1; i < 4; i++) { +// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = +// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); +// } +// return; +// } + +// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; +// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + +// constexpr int k_frag_offsets[4] = {0, 1, 8, +// 9}; // Tensor core offsets per thread + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// int actual_k = cur_k + k_frag_offsets[i]; + +// int group_id = sh_g_idx_int_ptr[actual_k]; +// int rel_group_id = group_id - sh_first_group_id; + +// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = +// sh_s[rel_group_id * s_sh_stride + s_col_shift]; +// } +// }; + +// auto fetch_zp_to_registers = [&](int k, int full_pipe) { +// // This code does not handle group_blocks == 0, +// // which signifies act_order. +// // has_zp implies AWQ, which doesn't have act_order, +// static_assert(!has_zp || group_blocks != 0); + +// if constexpr (has_zp) { +// int pipe = full_pipe % stages; + +// if constexpr (group_blocks == -1) { +// for (int i = 0; i < num_ints_per_thread; i++) { +// frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; +// } + +// } else if constexpr (group_blocks >= thread_k_blocks) { +// int4* sh_zp_stage = +// sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * +// (pipe / (group_blocks / +// thread_k_blocks))); +// for (int i = 0; i < num_ints_per_thread; i++) { +// frag_qzp[k % 2][i] = +// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; +// } +// } else { +// int warp_id = threadIdx.x / 32; +// int n_warps = thread_n_blocks / 4; + +// int warp_row = warp_id / n_warps; + +// int cur_k = warp_row * 16; +// cur_k += k_iter_size * (k % b_sh_wr_iters); + +// int k_blocks = cur_k / 16; +// int cur_group_id = 0; + +// // Suppress bogus and persistent divide-by-zero warning +// #pragma nv_diagnostic push +// #pragma nv_diag_suppress divide_by_zero +// cur_group_id = k_blocks / group_blocks; +// #pragma nv_diagnostic pop + +// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + +// sh_zp_stage += cur_group_id * zp_sh_stride; + +// for (int i = 0; i < num_ints_per_thread; i++) { +// frag_qzp[k % 2][i] = +// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; +// } +// } +// } +// }; + +// // Execute the actual tensor core matmul of a sub-tile. +// auto matmul = [&](int k) { +// if constexpr (has_zp) { +// FragB frag_zp_0; +// FragB frag_zp_1; +// int zp_quant_0, zp_quant_1; + +// if constexpr (w_type.size_bits() == 4) { +// zp_quant_0 = frag_qzp[k % 2][0]; +// zp_quant_1 = zp_quant_0 >> 8; +// } else { +// static_assert(w_type.size_bits() == 8); +// zp_quant_0 = frag_qzp[k % 2][0]; +// zp_quant_1 = frag_qzp[k % 2][1]; +// } + +// frag_zp_0 = dequant(zp_quant_0); +// frag_zp_1 = dequant(zp_quant_1); + +// frag_zp[0] = frag_zp_0[0]; +// frag_zp[1] = frag_zp_0[1]; +// frag_zp[2] = frag_zp_1[0]; +// frag_zp[3] = frag_zp_1[1]; +// } + +// // We have the m dimension as the inner loop in order to encourage +// overlapping +// // dequantization and matmul operations. +// #pragma unroll +// for (int j = 0; j < 4; j++) { +// int b_quant_0, b_quant_1; +// if constexpr (w_type.size_bits() == 4) { +// b_quant_0 = frag_b_quant[k % 2][0][j]; +// b_quant_1 = b_quant_0 >> 8; +// } else { +// static_assert(w_type.size_bits() == 8); +// int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); +// b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; +// b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; +// } + +// FragB frag_b0 = dequant(b_quant_0); +// FragB frag_b1 = dequant(b_quant_1); + +// // Apply scale to frag_b0 +// if constexpr (has_act_order) { +// scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], +// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); +// } else { +// if constexpr (group_blocks != -1) { +// scale(frag_b0, frag_s[k % 2][j], 0); +// } +// } + +// // Apply zero-point to frag_b1 +// if constexpr (has_zp) { +// sub_zp(frag_b1, frag_zp[j], 1); +// } + +// // Apply scale to frag_b1 +// if constexpr (has_act_order) { +// scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], +// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + +// } else { +// if constexpr (group_blocks != -1) { +// scale(frag_b1, frag_s[k % 2][j], 1); +// } +// } + +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) { +// mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); +// mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); +// } +// } +// }; + +// // Since we slice across the k dimension of a tile in order to increase the +// // number of warps while keeping the n dimension of a tile reasonable, we +// have +// // multiple warps that accumulate their partial sums of the same output +// // location; which we have to reduce over in the end. We do in shared +// memory. auto thread_block_reduce = [&]() { +// constexpr int red_off = threads / b_sh_stride_threads / 2; +// if (red_off >= 1) { +// int red_idx = threadIdx.x / b_sh_stride_threads; +// constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; +// constexpr int red_sh_delta = b_sh_stride_threads; +// int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + +// (threadIdx.x % b_sh_stride_threads); + +// // Parallel logarithmic shared memory reduction. We make sure to avoid +// any +// // unnecessary read or write iterations, e.g., for two warps we write +// only +// // once by warp 1 and read only once by warp 0. + +// #pragma unroll +// for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +// #pragma unroll +// for (int i = red_off; i > 0; i /= 2) { +// if (i <= red_idx && red_idx < 2 * i) { +// #pragma unroll +// for (int j = 0; j < 4 * 2; j++) { +// int red_sh_wr = +// red_sh_delta * j + (red_sh_rd - red_sh_stride * i); +// if (i < red_off) { +// float* c_rd = +// reinterpret_cast(&sh[red_sh_delta * j + +// red_sh_rd]); +// float* c_wr = reinterpret_cast(&sh[red_sh_wr]); +// #pragma unroll +// for (int k = 0; k < 4; k++) +// reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += +// c_rd[k] + c_wr[k]; +// } +// sh[red_sh_wr] = +// reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; +// } +// } +// __syncthreads(); +// } +// if (red_idx == 0) { +// #pragma unroll +// for (int i = 0; i < 4 * 2; i++) { +// float* c_rd = +// reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +// #pragma unroll +// for (int j = 0; j < 4; j++) +// reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += +// c_rd[j]; +// } +// } +// __syncthreads(); +// } +// } +// }; + +// // Since multiple threadblocks may process parts of the same column slice, +// we +// // finally have to globally reduce over the results. As the striped +// // partitioning minimizes the number of such reductions and our outputs are +// // usually rather small, we perform this reduction serially in L2 cache. +// auto global_reduce = [&](bool first = false, bool last = false) { +// // We are very careful here to reduce directly in the output buffer to +// // maximize L2 cache utilization in this step. To do this, we write out +// // results in FP16 (but still reduce with FP32 compute). +// constexpr int active_threads = 32 * thread_n_blocks / 4; +// if (threadIdx.x < active_threads) { +// int c_gl_stride = prob_n / 8; +// int c_gl_wr_delta_o = 8 * c_gl_stride; +// int c_gl_wr_delta_i = 4 * (active_threads / 32); +// int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + +// 4 * (threadIdx.x / 32) + threadIdx.x % 4; +// c_gl_wr += (2 * thread_n_blocks) * slice_col; +// constexpr int c_sh_wr_delta = active_threads; +// int c_sh_wr = threadIdx.x; + +// int row = (threadIdx.x % 32) / 4; + +// if (!first) { +// // Interestingly, doing direct global accesses here really seems to mess up +// // the compiler and lead to slowdowns, hence we also use async-copies even +// // though these fetches are not actually asynchronous. +// #pragma unroll +// for (int i = 0; i < thread_m_blocks * 4; i++) { +// int c_idx = +// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % +// 2); +// int sorted_row = sorted_ids[c_idx / c_gl_stride]; +// int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; +// cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], +// sorted_row < tot_m * topk && +// (8 * (i / 2) + row < prob_m && +// (i < (thread_m_blocks - 1) * 4 || +// sorted_ids[8 * (i / 2) + row] < tot_m * +// topk))); +// } +// cp_async_fence(); +// cp_async_wait<0>(); +// } + +// #pragma unroll +// for (int i = 0; i < thread_m_blocks * 4; i++) { +// if (8 * (i / 2) + row < prob_m && +// (i < (thread_m_blocks - 1) * 4 || +// sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { +// if (!first) { +// int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +// #pragma unroll +// for (int j = 0; j < 2 * 4; j++) { +// reinterpret_cast( +// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += +// __half2float(reinterpret_cast<__half*>(&c_red)[j]); +// } +// } +// if (!last) { +// int4 c; +// #pragma unroll +// for (int j = 0; j < 2 * 4; j++) { +// reinterpret_cast<__half*>(&c)[j] = +// __float2half(reinterpret_cast( +// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); +// } +// int c_idx = +// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % +// 2); +// int row = sorted_ids[c_idx / c_gl_stride]; +// if (row < tot_m * topk) { +// int new_idx = row * c_gl_stride + c_idx % c_gl_stride; +// C[new_idx] = c; +// } +// } +// } +// } +// } +// }; + +// // Write out the reduce final result in the correct layout. We only +// actually +// // reshuffle matrix fragments in this step, the reduction above is +// performed +// // in fragment layout. +// auto write_result = [&]() { +// int c_gl_stride = prob_n / 8; +// constexpr int c_sh_stride = 2 * thread_n_blocks + 1; +// int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); +// constexpr int c_sh_rd_delta = +// c_sh_stride * (threads / (2 * thread_n_blocks)); + +// int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + +// (threadIdx.x % (2 * thread_n_blocks)); +// c_gl_wr += (2 * thread_n_blocks) * slice_col; +// int c_sh_wr = +// (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % +// 4; +// c_sh_wr += 32 * (threadIdx.x / 32); +// int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + +// (threadIdx.x % (2 * thread_n_blocks)); + +// int c_gl_wr_end = c_gl_stride * prob_m; + +// // We first reorder in shared memory to guarantee the most efficient +// final +// // global write patterns +// auto write = [&](int idx, float c0, float c1, FragS& s) { +// half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + +// // For per-column quantization we finally apply the scale here (only +// for +// // 4-bit) +// if constexpr (!has_act_order && group_blocks == -1 && +// w_type.size_bits() == 4) { +// res = __hmul2(res, s[0]); +// } + +// ((half2*)sh)[idx] = res; +// }; +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) { +// #pragma unroll +// for (int j = 0; j < 4; j++) { +// int wr = c_sh_wr + 8 * j; +// write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], +// frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); +// write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], +// frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); +// write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], +// frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); +// write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], +// frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); +// } +// c_sh_wr += 16 * (4 * c_sh_stride); +// } +// } +// __syncthreads(); + +// #pragma unroll +// for (int i = 0; +// i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); +// i++) { +// if (c_gl_wr < c_gl_wr_end) { +// int row = sorted_ids[c_gl_wr / c_gl_stride]; +// if (row < tot_m * topk) { +// int off = row * c_gl_stride + c_gl_wr % c_gl_stride; +// if (!apply_weights) { +// C[off] = sh[c_sh_rd]; +// } else { +// __half* ctrg = reinterpret_cast<__half*>(&C[off]); +// __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); +// for (int j = 0; j < 8; ++j) { +// ctrg[j] = __float2half(topk_weights[row] * +// __half2float(csrc[j])); +// } +// } +// c_gl_wr += c_gl_wr_delta; +// c_sh_rd += c_sh_rd_delta; +// } +// } +// } +// }; + +// // Start global fetch and register load pipelines. +// auto start_pipes = [&]() { + +// #pragma unroll +// for (int i = 0; i < stages - 1; i++) { +// if (has_act_order && i == 0) { +// int last_g_idx = slice_k_start + stages * tb_k * 2; +// if (last_g_idx >= prob_k) { +// last_g_idx = prob_k - 1; +// } +// fetch_scales_to_shared(true, g_idx[slice_k_start], +// g_idx[last_g_idx]); +// } + +// if constexpr (has_zp && group_blocks == -1) { +// if (i == 0) { +// fetch_zp_to_shared(); +// } +// } +// fetch_to_shared(i, i, i < slice_iters); +// } + +// zero_accums(); +// wait_for_stage(); +// init_same_group(0); +// fetch_to_registers(0, 0); +// fetch_scales_to_registers(0, 0); +// fetch_zp_to_registers(0, 0); +// a_gl_rd += a_gl_rd_delta_o * (stages - 1); +// slice_k_start_shared_fetch += tb_k * (stages - 1); +// }; +// if (slice_iters) { +// start_pipes(); +// } + +// // Main loop. +// while (slice_iters) { +// // We unroll over both the global fetch and the register load pipeline to +// // ensure all shared memory accesses are static. Note that both pipelines +// // have even length meaning that the next iteration will always start at +// // index 0. +// #pragma unroll +// for (int pipe = 0; pipe < stages;) { +// #pragma unroll +// for (int k = 0; k < b_sh_wr_iters; k++) { +// fetch_to_registers(k + 1, pipe % stages); +// fetch_scales_to_registers(k + 1, pipe); +// fetch_zp_to_registers(k + 1, pipe); +// if (k == b_sh_wr_iters - 2) { +// fetch_to_shared((pipe + stages - 1) % stages, pipe, +// slice_iters >= stages); +// pipe++; +// wait_for_stage(); +// init_same_group(pipe % stages); +// } +// matmul(k); +// } +// slice_iters--; +// if (slice_iters == 0) { +// break; +// } +// } + +// a_gl_rd += a_gl_rd_delta_o * stages; +// slice_k_start += tb_k * stages; +// slice_k_start_shared_fetch += tb_k * stages; + +// if constexpr (has_act_order) { +// int first_group_id = g_idx[slice_k_start]; +// int last_g_idx = slice_k_start + stages * tb_k * 2; +// if (last_g_idx >= prob_k) { +// last_g_idx = prob_k - 1; +// } +// int last_group_id = g_idx[last_g_idx]; +// if (last_group_id >= sh_first_group_id + sh_num_groups) { +// fetch_scales_to_shared(false, first_group_id, last_group_id); +// __syncthreads(); +// } +// } + +// // Process results and, if necessary, proceed to the next column slice. +// // While this pattern may not be the most readable, other ways of writing +// // the loop seemed to noticeably worse performance after compilation. +// if (slice_iters == 0) { +// cp_async_wait<0>(); +// bool last = slice_idx == slice_count - 1; +// if constexpr (!has_act_order && group_blocks == -1) { +// if constexpr (w_type.size_bits() == 8) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); +// } +// cp_async_fence(); +// } else { +// // For 4-bit per-column scales, we only fetch them here in the +// // final step before write-out +// if (last) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); +// } +// cp_async_fence(); +// } +// } +// } + +// thread_block_reduce(); +// if constexpr (!has_act_order && group_blocks == -1) { +// if constexpr (w_type.size_bits() == 8) { +// cp_async_wait<0>(); +// __syncthreads(); +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; +// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; +// } + +// } else { +// if (last) { +// cp_async_wait<0>(); +// __syncthreads(); +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; +// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; +// } +// } +// } +// } + +// // For 8-bit channelwise, we apply the scale before the global +// reduction +// // that converts the fp32 results to fp16 (so that we avoid possible +// // overflow in fp16) +// if constexpr (!has_act_order && group_blocks == -1 && +// w_type.size_bits() == 8) { +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) { +// #pragma unroll +// for (int j = 0; j < 4; j++) { +// scale_float(reinterpret_cast(&frag_c[i][j][0][0]), +// frag_s[j / 2][2 * (j % 2) + 0]); +// scale_float(reinterpret_cast(&frag_c[i][j][0][2]), +// frag_s[j / 2][2 * (j % 2) + 0]); + +// scale_float(reinterpret_cast(&frag_c[i][j][1][0]), +// frag_s[j / 2][2 * (j % 2) + 1]); +// scale_float(reinterpret_cast(&frag_c[i][j][1][2]), +// frag_s[j / 2][2 * (j % 2) + 1]); +// } +// } +// } +// } + +// if (slice_count > 1) { // only globally reduce if there is more than +// one +// // block in a slice +// barrier_acquire(&locks[slice_col], slice_idx); +// global_reduce(slice_idx == 0, last); +// barrier_release(&locks[slice_col], last); +// } +// if (last) // only the last block in a slice actually writes the result +// write_result(); +// slice_row = 0; +// slice_col_par++; +// slice_col++; +// init_slice(); +// if (slice_iters) { +// a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + +// (threadIdx.x % a_gl_rd_delta_o); +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) +// B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; +// if (slice_col == 0) { +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; +// } + +// // Update slice k/n for scales loading +// if constexpr (has_act_order) { +// slice_k_start = tb_k * slice_row; +// slice_k_finish = slice_k_start + tb_k * slice_iters; +// slice_k_start_shared_fetch = slice_k_start; +// slice_n_offset = act_s_col_tb_stride * slice_col; + +// } else { +// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; +// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; +// } +// start_pipes(); +// } +// } +// } +// } + +// template shared +// // fetch pipeline +// const bool has_act_order, // whether act_order is enabled +// const bool has_zp, // whether zero-points are enabled +// const int group_blocks = -1 // number of consecutive 16x16 blocks +// // with a separate quantization scale +// > +// __global__ void MarlinMoE( +// const int4* __restrict__ A, // fp16 input matrix of shape mxk +// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn +// int4* __restrict__ C, // fp16 output buffer of shape mxn +// const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts +// const float* __restrict__ topk_weights, // float topk weights +// const int4* __restrict__ scales_ptr, // fp16 quantization scales of +// shape +// // (k/groupsize)xn +// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape +// // (k/groupsize)x(n/pack_factor) +// const int* __restrict__ g_idx, // int32 group indices of shape k +// const int* __restrict__ expert_offsets, +// int num_groups, // number of scale groups per output channel +// int expert_idx, // idx of current expert +// int num_experts, // number of experts +// int topk, // topk parameter of moe +// int prob_m, // batch dimension m +// int prob_n, // output dimension n +// int prob_k, // reduction dimension k +// int tot_m, // total number of rows in A and C +// int* locks, // extra global storage for barrier +// synchronization bool replicate_input, // do we use the same input for +// each expert? bool apply_weights, // apply weights to output int +// current_m_block, // current m block to start kernel computation from +// int max_par, // maximum parallelism +// int cfg_max_m_blocks // upper bound on m blocks +// ) { +// int m_block_ctr = current_m_block; + +// const int* sorted_ids_expert = +// sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * +// max_par; +// int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; +// if (tot_its == 0) { +// return; +// } +// int tot_m_blocks = ceildiv(tot_its, 16); +// int pad = 16 * tot_m_blocks - tot_its; + +// if (m_block_ctr >= tot_m_blocks) { +// return; +// } + +// int max_block = tot_m_blocks - m_block_ctr; +// prob_m = tot_its - 16 * m_block_ctr; + +// int par = 1; +// if (max_block > cfg_max_m_blocks) { +// // Note that parallel > 1 currently only works for inputs without any +// // padding +// par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); +// if (par > max_par) par = max_par; +// prob_m = (16 * cfg_max_m_blocks) * par; +// m_block_ctr += cfg_max_m_blocks * (par - 1); +// max_block = cfg_max_m_blocks; +// } + +// if (max_block == 1) { +// MarlinMoESingle( +// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, +// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, +// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, +// current_m_block); +// } else if (max_block == 2) { +// MarlinMoESingle( +// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, +// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, +// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, +// current_m_block); +// } else if (max_block == 3) { +// MarlinMoESingle( +// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, +// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, +// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, +// current_m_block); +// } else { +// MarlinMoESingle( +// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, +// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, +// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, +// current_m_block); +// } +// } #else @@ -1643,89 +1698,92 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks - -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} +// template shared +// // fetch pipeline +// const bool has_act_order, // whether act_order is enabled +// const bool has_zp, // whether zero-points are enabled +// const int group_blocks = -1 // number of consecutive 16x16 blocks +// // with a separate quantization scale +// > +// __global__ void MarlinMoE( +// const int4* __restrict__ A, // fp16 input matrix of shape mxk +// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn +// int4* __restrict__ C, // fp16 output buffer of shape mxn +// const int* __restrict__ sorted_ids, // int32 sorted ids of experts +// const float* __restrict__ topk_weights, // float topk weights +// const int4* __restrict__ scales_ptr, // fp16 quantization scales of +// shape +// // (k/groupsize)xn +// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape +// // (k/groupsize)x(n/pack_factor) +// const int* __restrict__ g_idx, // int32 group indices of shape k +// const int* __restrict__ expert_offsets, +// int num_groups, // number of scale groups per output channel +// int expert_idx, // idx of current expert +// int num_experts, // number of experts +// int topk, // topk parameter of moe +// int prob_m, // batch dimension m +// int prob_n, // output dimension n +// int prob_k, // reduction dimension k +// int tot_m, // total number of rows in A and C +// int* locks, // extra global storage for barrier +// synchronization bool replicate_input, // do we use the same input for +// each expert? bool apply_weights, // apply weights to output int +// current_m_block, // current m block to start kernel computation from +// int max_par, // maximum parallelism +// int cfg_max_m_blocks // upper bound on m blocks + +// ) { +// // Marlin is not implemented yet for SM < 8.0 +// assert(false); +// return; +// } #endif -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -// const int SHARED_MEM = -// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ - } +// // 8 warps are a good choice since every SM has 4 schedulers and having more +// // than 1 warp per schedule allows some more latency hiding. At the same +// time, +// // we want relatively few warps to have many registers per warp and small +// tiles. const int USER_THREADS = +// 256; // Note: This is only used with user-provided +// thread_k/n +// const int STAGES = 4; // 4 pipeline stages fit into shared memory +// // const int SHARED_MEM = +// // 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +// static constexpr int min_thread_n = 64; +// static constexpr int min_thread_k = 64; + +// #define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ +// THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ +// NUM_THREADS) \ +// else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ +// thread_n_blocks == THREAD_N_BLOCKS && \ +// thread_k_blocks == THREAD_K_BLOCKS && \ +// has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ +// group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ +// cudaFuncSetAttribute(MarlinMoE, \ +// cudaFuncAttributeMaxDynamicSharedMemorySize, \ +// max_shared_mem); \ +// MarlinMoE \ +// <<>>( \ +// A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ +// zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ +// num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ +// replicate_input, apply_weights, m_block, max_par, \ +// exec_cfg.max_m_blocks); \ +// } typedef struct { int thread_k; @@ -1901,52 +1959,63 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - -#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) +// #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +// #define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ +// \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ +// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION( \ + q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, \ + has_act_order, has_zp, group_blocks, num_threads, blocks, \ + max_shared_mem, stream, A_ptr, B_ptr, C_ptr, sorted_ids_ptr, \ + topk_weights_ptr, s_ptr, zp_ptr, g_idx_ptr, expert_offsets_ptr, \ + num_groups, expert_idx, num_experts, topk, prob_m, prob_n, \ + prob_k, tot_m, locks, replicate_input, apply_weights, m_block, \ + max_par, exec_cfg.max_m_blocks)) { \ + } void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, @@ -2091,25 +2160,41 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, // make it max possible value int thread_m_blocks = exec_cfg.max_m_blocks; + int cfg_max_m_blocks = exec_cfg.max_m_blocks; + if (false) { } - GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) - - AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) - AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) + // else if(call_marlin_moe_kernel_ku4b8( + // q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, + // has_act_order, has_zp, group_blocks, num_threads, blocks, + // max_shared_mem, stream, A_ptr, B_ptr, C_ptr, sorted_ids_ptr, + // topk_weights_ptr, s_ptr, zp_ptr, g_idx_ptr, expert_offsets_ptr, + // num_groups, expert_idx, num_experts, topk, prob_m, prob_n, + // prob_k, tot_m, locks, replicate_input, apply_weights, m_block, + // max_par, exec_cfg.max_m_blocks)) { + // } + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) + + // GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + // GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + // GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + // GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + // GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + // GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + // GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + // GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + + // AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) + // AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) + // AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) + // AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) + // AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) + // AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) + // AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) + // AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + From 4b11a7de528f82a68a530b6288ddfc9abef0af3d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 16 Sep 2024 03:29:57 -0400 Subject: [PATCH 26/49] it compiles --- csrc/moe/marlin_moe_kernel.cu | 1251 +++++++++++++++++ csrc/moe/marlin_moe_kernel.cuh | 1164 +-------------- .../layers/fused_moe/fused_marlin_moe.py | 30 +- 3 files changed, 1280 insertions(+), 1165 deletions(-) create mode 100644 csrc/moe/marlin_moe_kernel.cu diff --git a/csrc/moe/marlin_moe_kernel.cu b/csrc/moe/marlin_moe_kernel.cu new file mode 100644 index 000000000000..021641ee1ba6 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel.cu @@ -0,0 +1,1251 @@ +#include "marlin_moe_kernel.cuh" + +namespace marlin_moe { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#endif + +} diff --git a/csrc/moe/marlin_moe_kernel.cuh b/csrc/moe/marlin_moe_kernel.cuh index 5330cfca9751..815d9561089e 100644 --- a/csrc/moe/marlin_moe_kernel.cuh +++ b/csrc/moe/marlin_moe_kernel.cuh @@ -323,1108 +323,7 @@ __device__ inline void MarlinMoESingle( bool replicate_input, // do we use the same input for each expert? bool apply_weights, // apply weights to output int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - start_pipes(); - } - } - } -} +); template = tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} +); #else diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 866b18d725a8..fafd74493be0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -78,16 +78,21 @@ def single_marlin_moe( max_workspace_size = (N // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, - device="cuda", + device=hidden_states.device, requires_grad=False) + w_zeros = torch.empty((0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + scalar_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, - block_size_m, True, False) + w_zeros, g_idx, perm, workspace, scalar_type, M, N, K, True, False, E, + topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -106,6 +111,8 @@ def fused_marlin_moe( override_config: Optional[Dict[str, Any]] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, num_bits: int = 8, ) -> torch.Tensor: """ @@ -176,6 +183,19 @@ def fused_marlin_moe( device="cuda", requires_grad=False) + has_zp1 = w1_zeros is not None + has_zp2 = w2_zeros is not None + if w1_zeros is None: + w1_zeros = torch.empty((0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + if w2_zeros is None: + w2_zeros = torch.empty((0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + scalar_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) @@ -192,6 +212,7 @@ def fused_marlin_moe( topk_weights, topk_ids, w1_scale, + w1_zeros, g_idx1, perm1, workspace, @@ -200,6 +221,7 @@ def fused_marlin_moe( 2 * N, K, True, + has_zp1, E, topk, block_size_m, @@ -216,6 +238,7 @@ def fused_marlin_moe( topk_weights, topk_ids, w2_scale, + w2_zeros, g_idx2, perm2, workspace, @@ -224,6 +247,7 @@ def fused_marlin_moe( K, N, True, + has_zp2, E, topk, block_size_m, From 507af0cd639157d3715bbe6d32aa5b6e61ec3a4f Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 17 Sep 2024 09:54:53 -0400 Subject: [PATCH 27/49] try to compile the kernel code --- CMakeLists.txt | 2 + csrc/moe/marlin_moe_kernel.cu | 2412 +++++++++++++------------ csrc/moe/marlin_moe_kernel.cuh | 121 +- csrc/moe/marlin_moe_kernel_ku4.cu | 52 +- csrc/moe/marlin_moe_kernel_ku4b8.cu | 143 +- csrc/moe/marlin_moe_kernel_ku8.cu | 52 +- csrc/moe/marlin_moe_kernel_ku8b128.cu | 115 +- csrc/moe/marlin_moe_ops.cu | 1734 +----------------- 8 files changed, 1427 insertions(+), 3204 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c88c31c83da..bd3322ad4cd1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -298,6 +298,8 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/marlin_moe_kernel.cu") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_moe_kernel.cu b/csrc/moe/marlin_moe_kernel.cu index 021641ee1ba6..2090eb848a16 100644 --- a/csrc/moe/marlin_moe_kernel.cu +++ b/csrc/moe/marlin_moe_kernel.cu @@ -1,1149 +1,1155 @@ +// #include +// #include +// #include +// #include +// #include + #include "marlin_moe_kernel.cuh" namespace marlin_moe { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ inline void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - start_pipes(); - } - } - } -} +// template shared +// // fetch pipeline +// const bool has_act_order, // whether act_order is enabled +// const bool has_zp, // whether zero-points are enabled +// const int group_blocks = -1 // number of consecutive 16x16 blocks +// // with a separate quantization scale +// > +// __device__ inline void MarlinMoESingle( +// const int4* __restrict__ A, // fp16 input matrix of shape mxk +// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn +// int4* __restrict__ C, // fp16 output buffer of shape mxn +// const int* __restrict__ sorted_ids, // int32 sorted ids of experts +// const float* __restrict__ topk_weights, // float topk weights +// const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape +// // (k/groupsize)xn +// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape +// // (k/groupsize)x(n/pack_factor) +// const int* __restrict__ g_idx, // int32 group indices of shape k +// const int* __restrict__ expert_offsets, +// int num_groups, // number of scale groups per output channel +// int expert_idx, // idx of current expert +// int num_experts, // number of experts +// int topk, // topk parameter of moe +// int prob_m, // batch dimension m +// int prob_n, // output dimension n +// int prob_k, // reduction dimension k +// int tot_m, // total number of rows in A and C +// int* locks, // extra global storage for barrier synchronization +// bool replicate_input, // do we use the same input for each expert? +// bool apply_weights, // apply weights to output +// int current_m_block // current m block to start kernel computation from +// ) { +// static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); +// constexpr int pack_factor = 32 / w_type.size_bits(); + +// // For larger GEMMs we run multiple batchsize 64 versions in parallel for a +// // better partitioning with less reductions +// int parallel = 1; +// if (prob_m > 16 * thread_m_blocks) { +// parallel = prob_m / (16 * thread_m_blocks); +// prob_m = 16 * thread_m_blocks; +// } + +// int k_tiles = prob_k / 16 / thread_k_blocks; +// int n_tiles = prob_n / 16 / thread_n_blocks; +// int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + +// if constexpr (!has_act_order && group_blocks != -1) { +// if (group_blocks >= thread_k_blocks) { +// // Ensure that the number of tiles in each stripe is a multiple of the +// // groupsize; this avoids an annoying special case where a stripe starts +// // in the middle of group. +// iters = (group_blocks / thread_k_blocks) * +// ceildiv(iters, (group_blocks / thread_k_blocks)); +// } +// } + +// int slice_row = (iters * blockIdx.x) % k_tiles; +// int slice_col_par = (iters * blockIdx.x) / k_tiles; +// int slice_col = slice_col_par; +// int slice_iters; // number of threadblock tiles in the current slice +// int slice_count = +// 0; // total number of active threadblocks in the current slice +// int slice_idx; // index of threadblock in current slice; numbered bottom to +// // top + +// // We can easily implement parallel problem execution by just remapping +// // indices and advancing global pointers +// if (slice_col_par >= n_tiles) { +// locks += (slice_col_par / n_tiles) * n_tiles; +// slice_col = slice_col_par % n_tiles; +// sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; +// } + +// // Compute all information about the current slice which is required for +// // synchronization. +// auto init_slice = [&]() { +// slice_iters = +// iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); +// if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; +// if (slice_iters == 0) return; +// if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; +// slice_count = 1; +// slice_idx = 0; +// int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); +// if (col_first <= k_tiles * (slice_col_par + 1)) { +// int col_off = col_first - k_tiles * slice_col_par; +// slice_count = ceildiv(k_tiles - col_off, iters); +// if (col_off > 0) slice_count++; +// int delta_first = iters * blockIdx.x - col_first; +// if (delta_first < 0 || (col_off == 0 && delta_first == 0)) +// slice_idx = slice_count - 1; +// else { +// slice_idx = slice_count - 1 - delta_first / iters; +// if (col_off > 0) slice_idx--; +// } +// } +// if (slice_col == n_tiles) { +// sorted_ids += 16 * thread_m_blocks; +// locks += n_tiles; +// slice_col = 0; +// } +// }; +// init_slice(); + +// // A sizes/strides + +// // stride of the A matrix in global memory +// int a_gl_stride = prob_k / 8; +// // stride of an A matrix tile in shared memory +// constexpr int a_sh_stride = 16 * thread_k_blocks / 8; +// // delta between subsequent A tiles in global memory +// constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; +// // between subsequent accesses within a tile +// int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); +// // between shared memory writes +// constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); +// // between shared memory tile reads +// constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); +// // within a shared memory tile +// constexpr int a_sh_rd_delta_i = a_sh_stride * 16; +// // overall size of a tile +// constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); +// // number of shared write iterations for a tile +// constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + +// // B sizes/strides +// int b_gl_stride = 16 * prob_n / (pack_factor * 4); +// constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; +// constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; +// constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + +// int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; +// int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); +// constexpr int b_sh_wr_delta = threads * b_thread_vecs; +// constexpr int b_sh_rd_delta = threads * b_thread_vecs; +// constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; +// constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + +// // Scale sizes/strides without act_order +// int s_gl_stride = prob_n / 8; +// constexpr int s_sh_stride = 16 * thread_n_blocks / 8; +// constexpr int s_tb_groups = +// !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks +// ? thread_k_blocks / group_blocks +// : 1; +// constexpr int s_sh_stage = s_tb_groups * s_sh_stride; +// int s_gl_rd_delta = s_gl_stride; +// // Scale size/strides with act_order +// constexpr int tb_k = 16 * thread_k_blocks; +// constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; +// // constexpr int act_s_row_stride = 1; +// // int act_s_col_stride = act_s_row_stride * num_groups; +// int act_s_col_stride = 1; +// int act_s_col_warp_stride = act_s_col_stride * 8; +// int tb_n_warps = thread_n_blocks / 4; +// int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + +// // Zero-points sizes/strides +// int zp_gl_stride = (prob_n / pack_factor) / 4; +// constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; +// constexpr int zp_tb_groups = s_tb_groups; +// constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; +// int zp_gl_rd_delta = zp_gl_stride; + +// // Global A read index of current thread. +// int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + +// (threadIdx.x % a_gl_rd_delta_o); +// a_gl_rd += a_gl_rd_delta_o * slice_row; +// // Shared write index of current thread. +// int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + +// (threadIdx.x % a_gl_rd_delta_o); +// // Shared read index. +// int a_sh_rd = +// a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; +// a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + +// int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + +// (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; +// b_gl_rd += b_sh_stride * slice_col; +// b_gl_rd += b_gl_rd_delta_o * slice_row; +// int b_sh_wr = threadIdx.x * b_thread_vecs; +// int b_sh_rd = threadIdx.x * b_thread_vecs; + +// // For act_order +// constexpr int k_iter_size = tb_k / b_sh_wr_iters; +// int slice_k_start = tb_k * slice_row; +// int slice_k_finish = slice_k_start + tb_k * slice_iters; +// int slice_k_start_shared_fetch = slice_k_start; +// int slice_n_offset = act_s_col_tb_stride * slice_col; + +// // No act_order +// int s_gl_rd; +// if constexpr (!has_act_order) { +// if constexpr (group_blocks == -1) { +// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; +// } else { +// s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + +// s_sh_stride * slice_col + threadIdx.x; +// } +// } +// int s_sh_wr = threadIdx.x; +// bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + +// // Zero-points +// int zp_gl_rd; +// if constexpr (has_zp) { +// if constexpr (group_blocks == -1) { +// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; +// } else { +// zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + +// zp_sh_stride * slice_col + threadIdx.x; +// } +// } +// int zp_sh_wr = threadIdx.x; +// bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + +// // We use a different scale layout for grouped and column-wise quantization as +// // we scale a `half2` tile in column-major layout in the former and in +// // row-major in the latter case. +// int s_sh_rd; +// if constexpr (group_blocks != -1) +// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + +// (threadIdx.x % 32) / 4; +// else +// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + +// (threadIdx.x % 32) % 4; + +// // Zero-points have the same read layout as the scales +// // (without column-wise case) +// constexpr int num_col_threads = 8; +// constexpr int num_row_threads = 4; +// constexpr int num_ints_per_thread = 8 / pack_factor; +// int zp_sh_rd; +// if constexpr (has_zp) { +// zp_sh_rd = num_ints_per_thread * num_col_threads * +// ((threadIdx.x / 32) % (thread_n_blocks / 4)) + +// num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); +// } + +// int sh_first_group_id = -1; +// int sh_num_groups = -1; +// constexpr int sh_max_num_groups = 32; + +// int shs_size; +// if constexpr (has_act_order) +// shs_size = sh_max_num_groups * s_sh_stride + threads; +// else +// shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + +// extern __shared__ int4 sh[]; +// // Shared memory storage for global fetch pipelines. +// int4* sh_a = sh; +// int4* sh_b = sh_a + (stages * a_sh_stage); +// int4* sh_g_idx = sh_b + (stages * b_sh_stage); +// int4* sh_zp = sh_g_idx + (stages * g_idx_stage); +// int4* sh_s = sh_zp + (stages * zp_sh_stage); + +// // Precompute which thread should not read memory in which iterations; this is +// // needed if there are more threads than required for a certain tilesize or +// // when the batchsize is not a multiple of 16. +// bool a_sh_wr_pred[a_sh_wr_iters]; +// #pragma unroll +// for (int i = 0; i < a_sh_wr_iters; i++) { +// int a_idx = a_sh_wr_delta * i + a_sh_wr; +// int row = a_idx / a_gl_rd_delta_o; +// if (row >= prob_m) { +// a_sh_wr_pred[i] = false; +// } else { +// a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; +// } +// } + +// // To ensure that writing and reading A tiles to/from shared memory, the +// // latter in fragment format, is fully bank conflict free, we need to use a +// // rather fancy XOR-based layout. The key here is that neither reads nor +// // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the +// // same shared memory banks. Further, it seems (based on NSight-Compute) that +// // each warp must also write a consecutive memory segment? +// auto transform_a = [&](int i) { +// int row = i / a_gl_rd_delta_o; +// return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; +// }; +// // Since the computation of this remapping is non-trivial and, due to our main +// // loop unrolls, all shared memory accesses are static, we simply precompute +// // both transformed reads and writes. +// int a_sh_wr_trans[a_sh_wr_iters]; +// #pragma unroll +// for (int i = 0; i < a_sh_wr_iters; i++) +// a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); +// int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) { +// #pragma unroll +// for (int j = 0; j < thread_m_blocks; j++) +// a_sh_rd_trans[i][j] = +// transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); +// } + +// // Since B-accesses have non-constant stride they have to be computed at +// // runtime; we break dependencies between subsequent accesses with a tile by +// // maintining multiple pointers (we have enough registers), a tiny +// // optimization. +// const int4* B_ptr[b_sh_wr_iters]; +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) +// B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + +// // Register storage for double buffer of shared memory reads. +// FragA frag_a[2][thread_m_blocks]; +// I4 frag_b_quant[2][b_thread_vecs]; +// FragC frag_c[thread_m_blocks][4][2]; +// FragS frag_s[2][4]; // No act-order +// FragS act_frag_s[2][4][4]; // For act-order +// int frag_qzp[2][num_ints_per_thread]; // Zero-points +// FragZP frag_zp; // Zero-points in fp16 + +// // Zero accumulators. +// auto zero_accums = [&]() { +// #pragma unroll +// for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) +// reinterpret_cast(frag_c)[i] = 0; +// }; + +// auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, +// int last_group_id) { +// sh_first_group_id = first_group_id; +// sh_num_groups = last_group_id - first_group_id + 1; + +// if (sh_num_groups < sh_max_num_groups) { +// sh_num_groups = sh_max_num_groups; +// } + +// if (sh_first_group_id + sh_num_groups > num_groups) { +// sh_num_groups = num_groups - sh_first_group_id; +// } + +// int row_offset = first_group_id * s_gl_stride; + +// if (is_async) { +// for (int i = 0; i < sh_num_groups; i++) { +// if (threadIdx.x < s_sh_stride) { +// cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], +// &scales_ptr[row_offset + (i * s_gl_stride) + +// slice_n_offset + threadIdx.x]); +// } +// } +// } else { +// for (int i = 0; i < sh_num_groups; i++) { +// if (threadIdx.x < s_sh_stride) { +// sh_s[(i * s_sh_stride) + threadIdx.x] = +// scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + +// threadIdx.x]; +// } +// } +// } +// }; +// // Asynchronously fetch the next A, B and s tile from global to the next +// // shared memory pipeline location. +// auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { +// if (pred) { +// int4* sh_a_stage = sh_a + a_sh_stage * pipe; +// #pragma unroll +// for (int i = 0; i < a_sh_wr_iters; i++) { +// int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; +// int row = a_idx / a_gl_stride; +// int sorted_row = +// replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; +// int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; +// if (sorted_row < tot_m * (replicate_input ? 1 : topk) && +// new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { +// cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], +// a_sh_wr_pred[i]); +// } +// } +// int4* sh_b_stage = sh_b + b_sh_stage * pipe; +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) { +// #pragma unroll +// for (int j = 0; j < b_thread_vecs; j++) { +// cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); +// } +// B_ptr[i] += b_gl_rd_delta_o; +// } + +// if constexpr (has_act_order) { +// // Fetch g_idx thread-block portion +// int full_pipe = a_off; +// int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; +// if (cur_k < prob_k && cur_k < slice_k_finish) { +// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + +// int4 const* cur_g_idx_stage_ptr = +// reinterpret_cast(&g_idx[cur_k]); + +// if (threadIdx.x < g_idx_stage) { +// cp_async4_pred(&sh_g_idx_stage[threadIdx.x], +// &cur_g_idx_stage_ptr[threadIdx.x]); +// } +// } +// } else { +// if constexpr (group_blocks != -1) { +// int4* sh_s_stage = sh_s + s_sh_stage * pipe; + +// if constexpr (group_blocks >= thread_k_blocks) { +// // Only fetch scales if this tile starts a new group +// if (pipe % (group_blocks / thread_k_blocks) == 0) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); +// } +// s_gl_rd += s_gl_rd_delta; +// } +// } else { +// for (int i = 0; i < s_tb_groups; i++) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], +// &scales_ptr[s_gl_rd]); +// } +// s_gl_rd += s_gl_rd_delta; +// } +// } +// } + +// if constexpr (has_zp && group_blocks != -1) { +// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + +// if constexpr (group_blocks >= thread_k_blocks) { +// // Only fetch zero-points if this tile starts a new group +// if (pipe % (group_blocks / thread_k_blocks) == 0) { +// if (zp_sh_wr_pred) { +// cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); +// } +// zp_gl_rd += zp_gl_rd_delta; +// } +// } else { +// for (int i = 0; i < zp_tb_groups; i++) { +// if (zp_sh_wr_pred) { +// cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], +// &zp_ptr[zp_gl_rd]); +// } +// zp_gl_rd += zp_gl_rd_delta; +// } +// } +// } +// } +// } +// // Insert a fence even when we are winding down the pipeline to ensure that +// // waiting is also correct at this point. +// cp_async_fence(); +// }; + +// auto fetch_zp_to_shared = [&]() { +// if (zp_sh_wr_pred) { +// cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); +// } +// }; + +// // Wait until the next thread tile has been loaded to shared memory. +// auto wait_for_stage = [&]() { +// // We only have `stages - 2` active fetches since we are double buffering +// // and can only issue the next fetch when it is guaranteed that the previous +// // shared memory load is fully complete (as it may otherwise be +// // overwritten). +// cp_async_wait(); +// __syncthreads(); +// }; + +// // Load the next sub-tile from the current location in the shared memory pipe +// // into the current register buffer. +// auto fetch_to_registers = [&](int k, int pipe) { +// int4* sh_a_stage = sh_a + a_sh_stage * pipe; +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) +// ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); +// int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +// #pragma unroll +// for (int i = 0; i < b_thread_vecs; i++) { +// frag_b_quant[k % 2][i] = *reinterpret_cast( +// &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); +// } +// }; + +// bool is_same_group[stages]; +// int same_group_id[stages]; + +// auto init_same_group = [&](int pipe) { +// if constexpr (!has_act_order) { +// is_same_group[pipe] = false; +// same_group_id[pipe] = 0; +// return; +// } + +// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; +// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + +// int group_id_1 = sh_g_idx_int_ptr[0]; +// int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + +// is_same_group[pipe] = group_id_1 == group_id_2; +// same_group_id[pipe] = group_id_1; +// }; + +// auto fetch_scales_to_registers = [&](int k, int full_pipe) { +// int pipe = full_pipe % stages; + +// if constexpr (!has_act_order) { +// // No act-order case +// if constexpr (group_blocks != -1) { +// if constexpr (group_blocks >= thread_k_blocks) { +// int4* sh_s_stage = +// sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * +// (pipe / (group_blocks / thread_k_blocks))); +// reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; +// } else { +// int warp_id = threadIdx.x / 32; +// int n_warps = thread_n_blocks / 4; + +// int warp_row = warp_id / n_warps; + +// int cur_k = warp_row * 16; +// cur_k += k_iter_size * (k % b_sh_wr_iters); + +// int k_blocks = cur_k / 16; +// int cur_group_id = k_blocks / group_blocks; + +// int4* sh_s_stage = sh_s + s_sh_stage * pipe; + +// reinterpret_cast(&frag_s[k % 2])[0] = +// sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; +// } +// } + +// return; +// } + +// // Act-order case + +// // Determine K of the "current" thread-block +// int cur_k = slice_k_start + tb_k * full_pipe; +// if (cur_k >= prob_k || cur_k >= slice_k_finish) { +// return; +// } + +// // Reset (to current thread-block) since we read g_idx portion from the +// // shared memory +// cur_k = 0; + +// // Progress to current iteration +// cur_k += k_iter_size * (k % b_sh_wr_iters); + +// // Determine "position" inside the thread-block (based on warp and +// // thread-id) +// int warp_id = threadIdx.x / 32; +// int n_warps = +// thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + +// int warp_row = warp_id / n_warps; +// int warp_col = warp_id % n_warps; + +// cur_k += warp_row * 16; + +// int th_id = threadIdx.x % 32; +// cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + +// int s_col_shift = +// /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + +// (th_id / 4) * act_s_col_stride; + +// if (is_same_group[pipe]) { +// if (k % 2 == 0) { +// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = +// sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + +// s_col_shift]; +// } else { +// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = +// *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); +// } + +// for (int i = 1; i < 4; i++) { +// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = +// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); +// } +// return; +// } + +// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; +// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + +// constexpr int k_frag_offsets[4] = {0, 1, 8, +// 9}; // Tensor core offsets per thread + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// int actual_k = cur_k + k_frag_offsets[i]; + +// int group_id = sh_g_idx_int_ptr[actual_k]; +// int rel_group_id = group_id - sh_first_group_id; + +// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = +// sh_s[rel_group_id * s_sh_stride + s_col_shift]; +// } +// }; + +// auto fetch_zp_to_registers = [&](int k, int full_pipe) { +// // This code does not handle group_blocks == 0, +// // which signifies act_order. +// // has_zp implies AWQ, which doesn't have act_order, +// static_assert(!has_zp || group_blocks != 0); + +// if constexpr (has_zp) { +// int pipe = full_pipe % stages; + +// if constexpr (group_blocks == -1) { +// for (int i = 0; i < num_ints_per_thread; i++) { +// frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; +// } + +// } else if constexpr (group_blocks >= thread_k_blocks) { +// int4* sh_zp_stage = +// sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * +// (pipe / (group_blocks / thread_k_blocks))); +// for (int i = 0; i < num_ints_per_thread; i++) { +// frag_qzp[k % 2][i] = +// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; +// } +// } else { +// int warp_id = threadIdx.x / 32; +// int n_warps = thread_n_blocks / 4; + +// int warp_row = warp_id / n_warps; + +// int cur_k = warp_row * 16; +// cur_k += k_iter_size * (k % b_sh_wr_iters); + +// int k_blocks = cur_k / 16; +// int cur_group_id = 0; + +// // Suppress bogus and persistent divide-by-zero warning +// #pragma nv_diagnostic push +// #pragma nv_diag_suppress divide_by_zero +// cur_group_id = k_blocks / group_blocks; +// #pragma nv_diagnostic pop + +// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + +// sh_zp_stage += cur_group_id * zp_sh_stride; + +// for (int i = 0; i < num_ints_per_thread; i++) { +// frag_qzp[k % 2][i] = +// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; +// } +// } +// } +// }; + +// // Execute the actual tensor core matmul of a sub-tile. +// auto matmul = [&](int k) { +// if constexpr (has_zp) { +// FragB frag_zp_0; +// FragB frag_zp_1; +// int zp_quant_0, zp_quant_1; + +// if constexpr (w_type.size_bits() == 4) { +// zp_quant_0 = frag_qzp[k % 2][0]; +// zp_quant_1 = zp_quant_0 >> 8; +// } else { +// static_assert(w_type.size_bits() == 8); +// zp_quant_0 = frag_qzp[k % 2][0]; +// zp_quant_1 = frag_qzp[k % 2][1]; +// } + +// frag_zp_0 = dequant(zp_quant_0); +// frag_zp_1 = dequant(zp_quant_1); + +// frag_zp[0] = frag_zp_0[0]; +// frag_zp[1] = frag_zp_0[1]; +// frag_zp[2] = frag_zp_1[0]; +// frag_zp[3] = frag_zp_1[1]; +// } + +// // We have the m dimension as the inner loop in order to encourage overlapping +// // dequantization and matmul operations. +// #pragma unroll +// for (int j = 0; j < 4; j++) { +// int b_quant_0, b_quant_1; +// if constexpr (w_type.size_bits() == 4) { +// b_quant_0 = frag_b_quant[k % 2][0][j]; +// b_quant_1 = b_quant_0 >> 8; +// } else { +// static_assert(w_type.size_bits() == 8); +// int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); +// b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; +// b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; +// } + +// FragB frag_b0 = dequant(b_quant_0); +// FragB frag_b1 = dequant(b_quant_1); + +// // Apply scale to frag_b0 +// if constexpr (has_act_order) { +// scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], +// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); +// } else { +// if constexpr (group_blocks != -1) { +// scale(frag_b0, frag_s[k % 2][j], 0); +// } +// } + +// // Apply zero-point to frag_b1 +// if constexpr (has_zp) { +// sub_zp(frag_b1, frag_zp[j], 1); +// } + +// // Apply scale to frag_b1 +// if constexpr (has_act_order) { +// scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], +// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + +// } else { +// if constexpr (group_blocks != -1) { +// scale(frag_b1, frag_s[k % 2][j], 1); +// } +// } + +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) { +// mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); +// mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); +// } +// } +// }; + +// // Since we slice across the k dimension of a tile in order to increase the +// // number of warps while keeping the n dimension of a tile reasonable, we have +// // multiple warps that accumulate their partial sums of the same output +// // location; which we have to reduce over in the end. We do in shared memory. +// auto thread_block_reduce = [&]() { +// constexpr int red_off = threads / b_sh_stride_threads / 2; +// if (red_off >= 1) { +// int red_idx = threadIdx.x / b_sh_stride_threads; +// constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; +// constexpr int red_sh_delta = b_sh_stride_threads; +// int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + +// (threadIdx.x % b_sh_stride_threads); + +// // Parallel logarithmic shared memory reduction. We make sure to avoid any +// // unnecessary read or write iterations, e.g., for two warps we write only +// // once by warp 1 and read only once by warp 0. + +// #pragma unroll +// for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +// #pragma unroll +// for (int i = red_off; i > 0; i /= 2) { +// if (i <= red_idx && red_idx < 2 * i) { +// #pragma unroll +// for (int j = 0; j < 4 * 2; j++) { +// int red_sh_wr = +// red_sh_delta * j + (red_sh_rd - red_sh_stride * i); +// if (i < red_off) { +// float* c_rd = +// reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); +// float* c_wr = reinterpret_cast(&sh[red_sh_wr]); +// #pragma unroll +// for (int k = 0; k < 4; k++) +// reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += +// c_rd[k] + c_wr[k]; +// } +// sh[red_sh_wr] = +// reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; +// } +// } +// __syncthreads(); +// } +// if (red_idx == 0) { +// #pragma unroll +// for (int i = 0; i < 4 * 2; i++) { +// float* c_rd = +// reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +// #pragma unroll +// for (int j = 0; j < 4; j++) +// reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += +// c_rd[j]; +// } +// } +// __syncthreads(); +// } +// } +// }; + +// // Since multiple threadblocks may process parts of the same column slice, we +// // finally have to globally reduce over the results. As the striped +// // partitioning minimizes the number of such reductions and our outputs are +// // usually rather small, we perform this reduction serially in L2 cache. +// auto global_reduce = [&](bool first = false, bool last = false) { +// // We are very careful here to reduce directly in the output buffer to +// // maximize L2 cache utilization in this step. To do this, we write out +// // results in FP16 (but still reduce with FP32 compute). +// constexpr int active_threads = 32 * thread_n_blocks / 4; +// if (threadIdx.x < active_threads) { +// int c_gl_stride = prob_n / 8; +// int c_gl_wr_delta_o = 8 * c_gl_stride; +// int c_gl_wr_delta_i = 4 * (active_threads / 32); +// int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + +// 4 * (threadIdx.x / 32) + threadIdx.x % 4; +// c_gl_wr += (2 * thread_n_blocks) * slice_col; +// constexpr int c_sh_wr_delta = active_threads; +// int c_sh_wr = threadIdx.x; + +// int row = (threadIdx.x % 32) / 4; + +// if (!first) { +// // Interestingly, doing direct global accesses here really seems to mess up +// // the compiler and lead to slowdowns, hence we also use async-copies even +// // though these fetches are not actually asynchronous. +// #pragma unroll +// for (int i = 0; i < thread_m_blocks * 4; i++) { +// int c_idx = +// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); +// int sorted_row = sorted_ids[c_idx / c_gl_stride]; +// int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; +// cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], +// sorted_row < tot_m * topk && +// (8 * (i / 2) + row < prob_m && +// (i < (thread_m_blocks - 1) * 4 || +// sorted_ids[8 * (i / 2) + row] < tot_m * topk))); +// } +// cp_async_fence(); +// cp_async_wait<0>(); +// } + +// #pragma unroll +// for (int i = 0; i < thread_m_blocks * 4; i++) { +// if (8 * (i / 2) + row < prob_m && +// (i < (thread_m_blocks - 1) * 4 || +// sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { +// if (!first) { +// int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +// #pragma unroll +// for (int j = 0; j < 2 * 4; j++) { +// reinterpret_cast( +// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += +// __half2float(reinterpret_cast<__half*>(&c_red)[j]); +// } +// } +// if (!last) { +// int4 c; +// #pragma unroll +// for (int j = 0; j < 2 * 4; j++) { +// reinterpret_cast<__half*>(&c)[j] = +// __float2half(reinterpret_cast( +// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); +// } +// int c_idx = +// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); +// int row = sorted_ids[c_idx / c_gl_stride]; +// if (row < tot_m * topk) { +// int new_idx = row * c_gl_stride + c_idx % c_gl_stride; +// C[new_idx] = c; +// } +// } +// } +// } +// } +// }; + +// // Write out the reduce final result in the correct layout. We only actually +// // reshuffle matrix fragments in this step, the reduction above is performed +// // in fragment layout. +// auto write_result = [&]() { +// int c_gl_stride = prob_n / 8; +// constexpr int c_sh_stride = 2 * thread_n_blocks + 1; +// int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); +// constexpr int c_sh_rd_delta = +// c_sh_stride * (threads / (2 * thread_n_blocks)); + +// int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + +// (threadIdx.x % (2 * thread_n_blocks)); +// c_gl_wr += (2 * thread_n_blocks) * slice_col; +// int c_sh_wr = +// (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; +// c_sh_wr += 32 * (threadIdx.x / 32); +// int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + +// (threadIdx.x % (2 * thread_n_blocks)); + +// int c_gl_wr_end = c_gl_stride * prob_m; + +// // We first reorder in shared memory to guarantee the most efficient final +// // global write patterns +// auto write = [&](int idx, float c0, float c1, FragS& s) { +// half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + +// // For per-column quantization we finally apply the scale here (only for +// // 4-bit) +// if constexpr (!has_act_order && group_blocks == -1 && +// w_type.size_bits() == 4) { +// res = __hmul2(res, s[0]); +// } + +// ((half2*)sh)[idx] = res; +// }; +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) { +// #pragma unroll +// for (int j = 0; j < 4; j++) { +// int wr = c_sh_wr + 8 * j; +// write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], +// frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); +// write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], +// frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); +// write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], +// frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); +// write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], +// frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); +// } +// c_sh_wr += 16 * (4 * c_sh_stride); +// } +// } +// __syncthreads(); + +// #pragma unroll +// for (int i = 0; +// i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); +// i++) { +// if (c_gl_wr < c_gl_wr_end) { +// int row = sorted_ids[c_gl_wr / c_gl_stride]; +// if (row < tot_m * topk) { +// int off = row * c_gl_stride + c_gl_wr % c_gl_stride; +// if (!apply_weights) { +// C[off] = sh[c_sh_rd]; +// } else { +// __half* ctrg = reinterpret_cast<__half*>(&C[off]); +// __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); +// for (int j = 0; j < 8; ++j) { +// ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); +// } +// } +// c_gl_wr += c_gl_wr_delta; +// c_sh_rd += c_sh_rd_delta; +// } +// } +// } +// }; + +// // Start global fetch and register load pipelines. +// auto start_pipes = [&]() { + +// #pragma unroll +// for (int i = 0; i < stages - 1; i++) { +// if (has_act_order && i == 0) { +// int last_g_idx = slice_k_start + stages * tb_k * 2; +// if (last_g_idx >= prob_k) { +// last_g_idx = prob_k - 1; +// } +// fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); +// } + +// if constexpr (has_zp && group_blocks == -1) { +// if (i == 0) { +// fetch_zp_to_shared(); +// } +// } +// fetch_to_shared(i, i, i < slice_iters); +// } + +// zero_accums(); +// wait_for_stage(); +// init_same_group(0); +// fetch_to_registers(0, 0); +// fetch_scales_to_registers(0, 0); +// fetch_zp_to_registers(0, 0); +// a_gl_rd += a_gl_rd_delta_o * (stages - 1); +// slice_k_start_shared_fetch += tb_k * (stages - 1); +// }; +// if (slice_iters) { +// start_pipes(); +// } + +// // Main loop. +// while (slice_iters) { +// // We unroll over both the global fetch and the register load pipeline to +// // ensure all shared memory accesses are static. Note that both pipelines +// // have even length meaning that the next iteration will always start at +// // index 0. +// #pragma unroll +// for (int pipe = 0; pipe < stages;) { +// #pragma unroll +// for (int k = 0; k < b_sh_wr_iters; k++) { +// fetch_to_registers(k + 1, pipe % stages); +// fetch_scales_to_registers(k + 1, pipe); +// fetch_zp_to_registers(k + 1, pipe); +// if (k == b_sh_wr_iters - 2) { +// fetch_to_shared((pipe + stages - 1) % stages, pipe, +// slice_iters >= stages); +// pipe++; +// wait_for_stage(); +// init_same_group(pipe % stages); +// } +// matmul(k); +// } +// slice_iters--; +// if (slice_iters == 0) { +// break; +// } +// } + +// a_gl_rd += a_gl_rd_delta_o * stages; +// slice_k_start += tb_k * stages; +// slice_k_start_shared_fetch += tb_k * stages; + +// if constexpr (has_act_order) { +// int first_group_id = g_idx[slice_k_start]; +// int last_g_idx = slice_k_start + stages * tb_k * 2; +// if (last_g_idx >= prob_k) { +// last_g_idx = prob_k - 1; +// } +// int last_group_id = g_idx[last_g_idx]; +// if (last_group_id >= sh_first_group_id + sh_num_groups) { +// fetch_scales_to_shared(false, first_group_id, last_group_id); +// __syncthreads(); +// } +// } + +// // Process results and, if necessary, proceed to the next column slice. +// // While this pattern may not be the most readable, other ways of writing +// // the loop seemed to noticeably worse performance after compilation. +// if (slice_iters == 0) { +// cp_async_wait<0>(); +// bool last = slice_idx == slice_count - 1; +// if constexpr (!has_act_order && group_blocks == -1) { +// if constexpr (w_type.size_bits() == 8) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); +// } +// cp_async_fence(); +// } else { +// // For 4-bit per-column scales, we only fetch them here in the +// // final step before write-out +// if (last) { +// if (s_sh_wr_pred) { +// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); +// } +// cp_async_fence(); +// } +// } +// } + +// thread_block_reduce(); +// if constexpr (!has_act_order && group_blocks == -1) { +// if constexpr (w_type.size_bits() == 8) { +// cp_async_wait<0>(); +// __syncthreads(); +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; +// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; +// } + +// } else { +// if (last) { +// cp_async_wait<0>(); +// __syncthreads(); +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; +// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; +// } +// } +// } +// } + +// // For 8-bit channelwise, we apply the scale before the global reduction +// // that converts the fp32 results to fp16 (so that we avoid possible +// // overflow in fp16) +// if constexpr (!has_act_order && group_blocks == -1 && +// w_type.size_bits() == 8) { +// if (threadIdx.x / 32 < thread_n_blocks / 4) { +// #pragma unroll +// for (int i = 0; i < thread_m_blocks; i++) { +// #pragma unroll +// for (int j = 0; j < 4; j++) { +// scale_float(reinterpret_cast(&frag_c[i][j][0][0]), +// frag_s[j / 2][2 * (j % 2) + 0]); +// scale_float(reinterpret_cast(&frag_c[i][j][0][2]), +// frag_s[j / 2][2 * (j % 2) + 0]); + +// scale_float(reinterpret_cast(&frag_c[i][j][1][0]), +// frag_s[j / 2][2 * (j % 2) + 1]); +// scale_float(reinterpret_cast(&frag_c[i][j][1][2]), +// frag_s[j / 2][2 * (j % 2) + 1]); +// } +// } +// } +// } + +// if (slice_count > 1) { // only globally reduce if there is more than one +// // block in a slice +// barrier_acquire(&locks[slice_col], slice_idx); +// global_reduce(slice_idx == 0, last); +// barrier_release(&locks[slice_col], last); +// } +// if (last) // only the last block in a slice actually writes the result +// write_result(); +// slice_row = 0; +// slice_col_par++; +// slice_col++; +// init_slice(); +// if (slice_iters) { +// a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + +// (threadIdx.x % a_gl_rd_delta_o); +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) +// B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; +// if (slice_col == 0) { +// #pragma unroll +// for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; +// } + +// // Update slice k/n for scales loading +// if constexpr (has_act_order) { +// slice_k_start = tb_k * slice_row; +// slice_k_finish = slice_k_start + tb_k * slice_iters; +// slice_k_start_shared_fetch = slice_k_start; +// slice_n_offset = act_s_col_tb_stride * slice_col; + +// } else { +// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; +// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; +// } +// start_pipes(); +// } +// } +// } +// } template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1186,66 +1192,66 @@ __global__ void MarlinMoE( int max_par, // maximum parallelism int cfg_max_m_blocks // upper bound on m blocks ) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } + // int m_block_ctr = current_m_block; + + // const int* sorted_ids_expert = + // sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + // int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + // if (tot_its == 0) { + // return; + // } + // int tot_m_blocks = ceildiv(tot_its, 16); + // int pad = 16 * tot_m_blocks - tot_its; + + // if (m_block_ctr >= tot_m_blocks) { + // return; + // } + + // int max_block = tot_m_blocks - m_block_ctr; + // prob_m = tot_its - 16 * m_block_ctr; + + // int par = 1; + // if (max_block > cfg_max_m_blocks) { + // // Note that parallel > 1 currently only works for inputs without any + // // padding + // par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + // if (par > max_par) par = max_par; + // prob_m = (16 * cfg_max_m_blocks) * par; + // m_block_ctr += cfg_max_m_blocks * (par - 1); + // max_block = cfg_max_m_blocks; + // } + + // if (max_block == 1) { + // MarlinMoESingle( + // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + // current_m_block); + // } else if (max_block == 2) { + // MarlinMoESingle( + // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + // current_m_block); + // } else if (max_block == 3) { + // MarlinMoESingle( + // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + // current_m_block); + // } else { + // MarlinMoESingle( + // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + // current_m_block); + // } } #endif -} +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel.cuh b/csrc/moe/marlin_moe_kernel.cuh index 815d9561089e..7dad50ed481e 100644 --- a/csrc/moe/marlin_moe_kernel.cuh +++ b/csrc/moe/marlin_moe_kernel.cuh @@ -285,46 +285,6 @@ __device__ inline void barrier_release(int* lock, bool reset = false) { } } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ inline void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -); - template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -424,81 +384,8 @@ __global__ void MarlinMoE( const int USER_THREADS = 256; // Note: This is only used with user-provided thread_k/n const int STAGES = 4; // 4 pipeline stages fit into shared memory -// const int SHARED_MEM = -// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; -// #define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ -// THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ -// NUM_THREADS) \ -// else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ -// thread_n_blocks == THREAD_N_BLOCKS && \ -// thread_k_blocks == THREAD_K_BLOCKS && \ -// has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ -// group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ -// cudaFuncSetAttribute(MarlinMoE, \ -// cudaFuncAttributeMaxDynamicSharedMemorySize, \ -// max_shared_mem); \ -// MarlinMoE \ -// <<>>( \ -// A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ -// zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ -// num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ -// replicate_input, apply_weights, m_block, max_par, \ -// cfg_max_m_blocks); \ -// } - -// #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - -// #define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - - } // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_moe_kernel_ku4.cu index d50d8f14d785..e84d4ad8de1e 100644 --- a/csrc/moe/marlin_moe_kernel_ku4.cu +++ b/csrc/moe/marlin_moe_kernel_ku4.cu @@ -2,9 +2,9 @@ namespace marlin_moe { -#define __CALL_IF_MOE_4(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ +#define __CALL_IF_MOE_4(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ @@ -25,23 +25,22 @@ namespace marlin_moe { cfg_max_m_blocks); \ } - #define AWQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ + \ __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ + \ __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ + \ __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ @@ -50,26 +49,25 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_m_blocks, - int thread_n_blocks, int thread_k_blocks, bool has_act_order, - bool has_zp, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - if (false) { - } - AWQ_CALL_IF_MOE_4(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF_MOE_4(vllm::kU4, 4, 8, 128) - else { - return false; - } - return true; + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + AWQ_CALL_IF_MOE_4(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF_MOE_4(vllm::kU4, 4, 8, 128) + else { + return false; + } + return true; } } // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_moe_kernel_ku4b8.cu index f5832b550a5d..de437454df77 100644 --- a/csrc/moe/marlin_moe_kernel_ku4b8.cu +++ b/csrc/moe/marlin_moe_kernel_ku4b8.cu @@ -2,9 +2,9 @@ namespace marlin_moe { -#define __CALL_IF_MOE_4_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ +#define __CALL_IF_MOE_4_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, \ + GROUP_BLOCKS, NUM_THREADS) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ @@ -15,65 +15,98 @@ namespace marlin_moe { HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ } -#define GPTQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ +// #define __CALL_IF_MOE_4_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ +// THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, \ +// GROUP_BLOCKS, NUM_THREADS) \ +// else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ +// thread_n_blocks == THREAD_N_BLOCKS && \ +// thread_k_blocks == THREAD_K_BLOCKS && \ +// has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ +// group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ +// cudaFuncSetAttribute(MarlinMoE, \ +// cudaFuncAttributeMaxDynamicSharedMemorySize, \ +// max_shared_mem); \ +// MarlinMoE \ +// <<>>( \ +// A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ +// zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ +// num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ +// replicate_input, apply_weights, m_block, max_par, \ +// cfg_max_m_blocks); \ +// } + +#define GPTQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_m_blocks, - int thread_n_blocks, int thread_k_blocks, bool has_act_order, - bool has_zp, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - if (false) { - } - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 4, 8, 128) - else { - return false; - } - return true; + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 4, 8, 128) + else { + return false; + } + return true; } } // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_moe_kernel_ku8.cu index b07491910002..931e074351dc 100644 --- a/csrc/moe/marlin_moe_kernel_ku8.cu +++ b/csrc/moe/marlin_moe_kernel_ku8.cu @@ -2,9 +2,9 @@ namespace marlin_moe { -#define __CALL_IF_MOE_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ +#define __CALL_IF_MOE_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ @@ -25,23 +25,22 @@ namespace marlin_moe { cfg_max_m_blocks); \ } - #define AWQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ + \ __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ + \ __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ + \ __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ @@ -50,26 +49,25 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_m_blocks, - int thread_n_blocks, int thread_k_blocks, bool has_act_order, - bool has_zp, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - if (false) { - } - AWQ_CALL_IF_MOE_8(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF_MOE_8(vllm::kU8, 4, 8, 128) - else { - return false; - } - return true; + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + AWQ_CALL_IF_MOE_8(vllm::kU8, 16, 4, 256) + AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 8, 256) + AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 4, 128) + AWQ_CALL_IF_MOE_8(vllm::kU8, 4, 8, 128) + else { + return false; + } + return true; } } // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_moe_kernel_ku8b128.cu index 22f042f0d43a..671466ae26c9 100644 --- a/csrc/moe/marlin_moe_kernel_ku8b128.cu +++ b/csrc/moe/marlin_moe_kernel_ku8b128.cu @@ -2,9 +2,9 @@ namespace marlin_moe { -#define __CALL_IF_MOE_8_128(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ +#define __CALL_IF_MOE_8_128(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, \ + GROUP_BLOCKS, NUM_THREADS) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ @@ -25,55 +25,74 @@ namespace marlin_moe { cfg_max_m_blocks); \ } -#define GPTQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) +#define GPTQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, \ + NUM_THREADS) \ + \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, \ + NUM_THREADS) // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_m_blocks, - int thread_n_blocks, int thread_k_blocks, bool has_act_order, - bool has_zp, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - if (false) { - } - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 4, 8, 128) - else { - return false; - } - return true; + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 4, 8, 128) + else { + return false; + } + return true; } } // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index dba94bde9fc1..cb01faeeb11b 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -27,9 +27,10 @@ #include "core/scalar_type.hpp" #include "marlin_moe_kernel_ku4b8.cu" -#include "marlin_moe_kernel_ku8b128.cu" -#include "marlin_moe_kernel_ku4.cu" -#include "marlin_moe_kernel_ku8.cu" +// #include "marlin_moe_kernel_ku8b128.cu" +// #include "marlin_moe_kernel_ku4.cu" +// #include "marlin_moe_kernel_ku8.cu" +// #include "marlin_moe_kernel.cuh" template inline std::string str(T x) { @@ -38,291 +39,8 @@ inline std::string str(T x) { namespace marlin_moe { -// constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// // Instances of `Vec` are used to organize groups of >>registers<<, as needed -// // for instance as inputs to tensor core operations. Consequently, all -// // corresponding index accesses must be compile-time constants, which is why -// we -// // extensively use `#pragma unroll` throughout the kernel code to guarantee -// // this. -// template -// struct Vec { -// T elems[n]; -// __device__ T& operator[](int i) { return elems[i]; } -// }; - -// using I4 = Vec; - -// // Matrix fragments for tensor core instructions; their precise layout is -// // documented here: -// // -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -// using FragA = Vec; -// using FragB = Vec; -// using FragC = Vec; -// using FragS = Vec; // quantization scales -// using FragZP = Vec; - -// // Predicated asynchronous global->shared copy; used for inputs A where we -// apply -// // predication to handle batchsizes that are not multiples of 16. -// __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, -// bool pred = true) { -// const int BYTES = 16; -// uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); -// asm volatile( -// "{\n" -// " .reg .pred p;\n" -// " setp.ne.b32 p, %0, 0;\n" -// " @p cp.async.cg.shared.global [%1], [%2], %3;\n" -// "}\n" ::"r"((int)pred), -// "r"(smem), "l"(glob_ptr), "n"(BYTES)); -// } - -// // Asynchronous global->shared copy -// __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { -// const int BYTES = 16; -// uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); -// asm volatile( -// "{\n" -// " cp.async.cg.shared.global [%0], [%1], %2;\n" -// "}\n" ::"r"(smem), -// "l"(glob_ptr), "n"(BYTES)); -// } - -// // Async copy fence. -// __device__ inline void cp_async_fence() { -// asm volatile("cp.async.commit_group;\n" ::); -// } - -// // Wait until at most `n` async copy stages are still pending. -// template -// __device__ inline void cp_async_wait() { -// asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -// } - -// // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// // output/accumulation. -// __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, -// FragC& frag_c) { -// const uint32_t* a = reinterpret_cast(&a_frag); -// const uint32_t* b = reinterpret_cast(&frag_b); -// float* c = reinterpret_cast(&frag_c); -// asm volatile( -// "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " -// "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -// : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) -// : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), -// "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -// } - -// // Instruction for loading a full 16x16 matrix fragment of operand A from -// shared -// // memory, directly in tensor core layout. -// __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { -// uint32_t* a = reinterpret_cast(&frag_a); -// uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); -// asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, -// [%4];\n" -// : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) -// : "r"(smem)); -// } - -// // Lookup-table based 3-input logical operation; explicitly used for -// // dequantization as the compiler does not seem to automatically recognize it -// in -// // all cases. -// template -// __device__ inline int lop3(int a, int b, int c) { -// int res; -// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" -// : "=r"(res) -// : "r"(a), "r"(b), "r"(c), "n"(lut)); -// return res; -// } - -// // Constructs destination register by taking bytes from 2 sources (based on -// // mask) -// template -// __device__ inline uint32_t prmt(uint32_t a) { -// uint32_t res; -// asm volatile("prmt.b32 %0, %1, %2, %3;\n" -// : "=r"(res) -// : "r"(a), "n"(start_byte), "n"(mask)); -// return res; -// } - -// template -// __device__ inline FragB dequant(int q); - -// // Efficiently dequantize 4bit values packed in an int32 value into a full -// // B-fragment of 4 fp16 values. We mostly follow the strategy in the link -// below, -// // with some small changes: -// // -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// template <> -// __device__ inline FragB dequant(int q) { -// const int LO = 0x000f000f; -// const int HI = 0x00f000f0; -// const int EX = 0x64006400; -// // Guarantee that the `(a & b) | c` operations are LOP3s. -// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); -// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); -// // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point -// // directly into `SUB` and `ADD`. -// const int SUB = 0x64086408; -// const int MUL = 0x2c002c00; -// const int ADD = 0xd480d480; -// FragB frag_b; -// frag_b[0] = __hsub2(*reinterpret_cast(&lo), -// *reinterpret_cast(&SUB)); -// frag_b[1] = __hfma2(*reinterpret_cast(&hi), -// *reinterpret_cast(&MUL), -// *reinterpret_cast(&ADD)); -// return frag_b; -// } - -// // Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// // Reference: -// // -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// template <> -// __device__ inline FragB dequant(int q) { -// static constexpr uint32_t mask_for_elt_01 = 0x5250; -// static constexpr uint32_t mask_for_elt_23 = 0x5351; -// static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - -// uint32_t lo = prmt(q); -// uint32_t hi = prmt(q); - -// static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - -// FragB frag_b; -// frag_b[0] = __hsub2(*reinterpret_cast(&lo), -// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); -// frag_b[1] = __hsub2(*reinterpret_cast(&hi), -// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); -// return frag_b; -// } - -// template <> -// __device__ inline FragB dequant(int q) { -// const int LO = 0x000f000f; -// const int HI = 0x00f000f0; -// const int EX = 0x64006400; -// // Guarantee that the `(a & b) | c` operations are LOP3s. -// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); -// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - -// const int SUB = 0x64006400; -// const int MUL = 0x2c002c00; -// const int ADD = 0xd400d400; -// FragB frag_b; -// frag_b[0] = __hsub2(*reinterpret_cast(&lo), -// *reinterpret_cast(&SUB)); -// frag_b[1] = __hfma2(*reinterpret_cast(&hi), -// *reinterpret_cast(&MUL), -// *reinterpret_cast(&ADD)); -// return frag_b; -// } - -// template <> -// __device__ inline FragB dequant(int q) { -// static constexpr uint32_t mask_for_elt_01 = 0x5250; -// static constexpr uint32_t mask_for_elt_23 = 0x5351; -// static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - -// uint32_t lo = prmt(q); -// uint32_t hi = prmt(q); - -// static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - -// FragB frag_b; -// frag_b[0] = __hsub2(*reinterpret_cast(&lo), -// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); -// frag_b[1] = __hsub2(*reinterpret_cast(&hi), -// *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); -// return frag_b; -// } - -// // Multiply dequantized values by the corresponding quantization scale; used -// // only for grouped quantization. -// __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { -// half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); -// frag_b[0] = __hmul2(frag_b[0], s); -// frag_b[1] = __hmul2(frag_b[1], s); -// } - -// __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { -// half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); -// frag_b[0] = __hsub2(frag_b[0], zp); -// frag_b[1] = __hsub2(frag_b[1], zp); -// } - -// // Given 2 floats multiply by 2 scales (halves) -// __device__ inline void scale_float(float* c, FragS& s) { -// __half* s_ptr = reinterpret_cast<__half*>(&s); -// c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); -// c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -// } - -// // Same as above, but for act_order (each K is multiplied individually) -// __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& -// frag_s_2, -// FragS& frag_s_3, FragS& frag_s_4, int i) { -// __half2 s_val_1_2; -// s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; -// s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - -// __half2 s_val_3_4; -// s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; -// s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - -// frag_b[0] = __hmul2(frag_b[0], s_val_1_2); -// frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -// } - -// // Wait until barrier reaches `count`, then lock for current threadblock. -// __device__ inline void barrier_acquire(int* lock, int count) { -// if (threadIdx.x == 0) { -// int state = -1; -// do -// // Guarantee that subsequent writes by this threadblock will be visible -// // globally. -// asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" -// : "=r"(state) -// : "l"(lock)); -// while (state != count); -// } -// __syncthreads(); -// } - -// // Release barrier and increment visitation count. -// __device__ inline void barrier_release(int* lock, bool reset = false) { -// __syncthreads(); -// if (threadIdx.x == 0) { -// if (reset) { -// lock[0] = 0; -// return; -// } -// int val = 1; -// // Make sure that all writes since acquiring this barrier are visible -// // globally, while releasing the barrier. -// asm volatile("fence.acq_rel.gpu;\n"); -// asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" -// : -// : "l"(lock), "r"(val)); -// } -// } - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -400,1285 +118,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -// template shared -// // fetch pipeline -// const bool has_act_order, // whether act_order is enabled -// const bool has_zp, // whether zero-points are enabled -// const int group_blocks = -1 // number of consecutive 16x16 blocks -// // with a separate quantization scale -// > -// __device__ inline void MarlinMoESingle( -// const int4* __restrict__ A, // fp16 input matrix of shape mxk -// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn -// int4* __restrict__ C, // fp16 output buffer of shape mxn -// const int* __restrict__ sorted_ids, // int32 sorted ids of experts -// const float* __restrict__ topk_weights, // float topk weights -// const int4* __restrict__ scales_ptr, // fp16 quantization scales of -// shape -// // (k/groupsize)xn -// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape -// // (k/groupsize)x(n/pack_factor) -// const int* __restrict__ g_idx, // int32 group indices of shape k -// const int* __restrict__ expert_offsets, -// int num_groups, // number of scale groups per output channel -// int expert_idx, // idx of current expert -// int num_experts, // number of experts -// int topk, // topk parameter of moe -// int prob_m, // batch dimension m -// int prob_n, // output dimension n -// int prob_k, // reduction dimension k -// int tot_m, // total number of rows in A and C -// int* locks, // extra global storage for barrier -// synchronization bool replicate_input, // do we use the same input for -// each expert? bool apply_weights, // apply weights to output int -// current_m_block // current m block to start kernel computation from -// ) { -// static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); -// constexpr int pack_factor = 32 / w_type.size_bits(); - -// // For larger GEMMs we run multiple batchsize 64 versions in parallel for a -// // better partitioning with less reductions -// int parallel = 1; -// if (prob_m > 16 * thread_m_blocks) { -// parallel = prob_m / (16 * thread_m_blocks); -// prob_m = 16 * thread_m_blocks; -// } - -// int k_tiles = prob_k / 16 / thread_k_blocks; -// int n_tiles = prob_n / 16 / thread_n_blocks; -// int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - -// if constexpr (!has_act_order && group_blocks != -1) { -// if (group_blocks >= thread_k_blocks) { -// // Ensure that the number of tiles in each stripe is a multiple of the -// // groupsize; this avoids an annoying special case where a stripe -// starts -// // in the middle of group. -// iters = (group_blocks / thread_k_blocks) * -// ceildiv(iters, (group_blocks / thread_k_blocks)); -// } -// } - -// int slice_row = (iters * blockIdx.x) % k_tiles; -// int slice_col_par = (iters * blockIdx.x) / k_tiles; -// int slice_col = slice_col_par; -// int slice_iters; // number of threadblock tiles in the current slice -// int slice_count = -// 0; // total number of active threadblocks in the current slice -// int slice_idx; // index of threadblock in current slice; numbered bottom -// to -// // top - -// // We can easily implement parallel problem execution by just remapping -// // indices and advancing global pointers -// if (slice_col_par >= n_tiles) { -// locks += (slice_col_par / n_tiles) * n_tiles; -// slice_col = slice_col_par % n_tiles; -// sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; -// } - -// // Compute all information about the current slice which is required for -// // synchronization. -// auto init_slice = [&]() { -// slice_iters = -// iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); -// if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = -// 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) -// slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int -// col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if -// (col_first <= k_tiles * (slice_col_par + 1)) { -// int col_off = col_first - k_tiles * slice_col_par; -// slice_count = ceildiv(k_tiles - col_off, iters); -// if (col_off > 0) slice_count++; -// int delta_first = iters * blockIdx.x - col_first; -// if (delta_first < 0 || (col_off == 0 && delta_first == 0)) -// slice_idx = slice_count - 1; -// else { -// slice_idx = slice_count - 1 - delta_first / iters; -// if (col_off > 0) slice_idx--; -// } -// } -// if (slice_col == n_tiles) { -// sorted_ids += 16 * thread_m_blocks; -// locks += n_tiles; -// slice_col = 0; -// } -// }; -// init_slice(); - -// // A sizes/strides - -// // stride of the A matrix in global memory -// int a_gl_stride = prob_k / 8; -// // stride of an A matrix tile in shared memory -// constexpr int a_sh_stride = 16 * thread_k_blocks / 8; -// // delta between subsequent A tiles in global memory -// constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; -// // between subsequent accesses within a tile -// int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); -// // between shared memory writes -// constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); -// // between shared memory tile reads -// constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / -// 4)); -// // within a shared memory tile -// constexpr int a_sh_rd_delta_i = a_sh_stride * 16; -// // overall size of a tile -// constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); -// // number of shared write iterations for a tile -// constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - -// // B sizes/strides -// int b_gl_stride = 16 * prob_n / (pack_factor * 4); -// constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / -// 4; constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; constexpr -// int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - -// int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; -// int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); -// constexpr int b_sh_wr_delta = threads * b_thread_vecs; -// constexpr int b_sh_rd_delta = threads * b_thread_vecs; -// constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; -// constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - -// // Scale sizes/strides without act_order -// int s_gl_stride = prob_n / 8; -// constexpr int s_sh_stride = 16 * thread_n_blocks / 8; -// constexpr int s_tb_groups = -// !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks -// ? thread_k_blocks / group_blocks -// : 1; -// constexpr int s_sh_stage = s_tb_groups * s_sh_stride; -// int s_gl_rd_delta = s_gl_stride; -// // Scale size/strides with act_order -// constexpr int tb_k = 16 * thread_k_blocks; -// constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; -// // constexpr int act_s_row_stride = 1; -// // int act_s_col_stride = act_s_row_stride * num_groups; -// int act_s_col_stride = 1; -// int act_s_col_warp_stride = act_s_col_stride * 8; -// int tb_n_warps = thread_n_blocks / 4; -// int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - -// // Zero-points sizes/strides -// int zp_gl_stride = (prob_n / pack_factor) / 4; -// constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; -// constexpr int zp_tb_groups = s_tb_groups; -// constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; -// int zp_gl_rd_delta = zp_gl_stride; - -// // Global A read index of current thread. -// int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + -// (threadIdx.x % a_gl_rd_delta_o); -// a_gl_rd += a_gl_rd_delta_o * slice_row; -// // Shared write index of current thread. -// int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + -// (threadIdx.x % a_gl_rd_delta_o); -// // Shared read index. -// int a_sh_rd = -// a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; -// a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - -// int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + -// (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; -// b_gl_rd += b_sh_stride * slice_col; -// b_gl_rd += b_gl_rd_delta_o * slice_row; -// int b_sh_wr = threadIdx.x * b_thread_vecs; -// int b_sh_rd = threadIdx.x * b_thread_vecs; - -// // For act_order -// constexpr int k_iter_size = tb_k / b_sh_wr_iters; -// int slice_k_start = tb_k * slice_row; -// int slice_k_finish = slice_k_start + tb_k * slice_iters; -// int slice_k_start_shared_fetch = slice_k_start; -// int slice_n_offset = act_s_col_tb_stride * slice_col; - -// // No act_order -// int s_gl_rd; -// if constexpr (!has_act_order) { -// if constexpr (group_blocks == -1) { -// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; -// } else { -// s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) -// + -// s_sh_stride * slice_col + threadIdx.x; -// } -// } -// int s_sh_wr = threadIdx.x; -// bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - -// // Zero-points -// int zp_gl_rd; -// if constexpr (has_zp) { -// if constexpr (group_blocks == -1) { -// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; -// } else { -// zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / -// group_blocks) + -// zp_sh_stride * slice_col + threadIdx.x; -// } -// } -// int zp_sh_wr = threadIdx.x; -// bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - -// // We use a different scale layout for grouped and column-wise quantization -// as -// // we scale a `half2` tile in column-major layout in the former and in -// // row-major in the latter case. -// int s_sh_rd; -// if constexpr (group_blocks != -1) -// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + -// (threadIdx.x % 32) / 4; -// else -// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + -// (threadIdx.x % 32) % 4; - -// // Zero-points have the same read layout as the scales -// // (without column-wise case) -// constexpr int num_col_threads = 8; -// constexpr int num_row_threads = 4; -// constexpr int num_ints_per_thread = 8 / pack_factor; -// int zp_sh_rd; -// if constexpr (has_zp) { -// zp_sh_rd = num_ints_per_thread * num_col_threads * -// ((threadIdx.x / 32) % (thread_n_blocks / 4)) + -// num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); -// } - -// int sh_first_group_id = -1; -// int sh_num_groups = -1; -// constexpr int sh_max_num_groups = 32; - -// int shs_size; -// if constexpr (has_act_order) -// shs_size = sh_max_num_groups * s_sh_stride + threads; -// else -// shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - -// extern __shared__ int4 sh[]; -// // Shared memory storage for global fetch pipelines. -// int4* sh_a = sh; -// int4* sh_b = sh_a + (stages * a_sh_stage); -// int4* sh_g_idx = sh_b + (stages * b_sh_stage); -// int4* sh_zp = sh_g_idx + (stages * g_idx_stage); -// int4* sh_s = sh_zp + (stages * zp_sh_stage); - -// // Precompute which thread should not read memory in which iterations; this -// is -// // needed if there are more threads than required for a certain tilesize or -// // when the batchsize is not a multiple of 16. -// bool a_sh_wr_pred[a_sh_wr_iters]; -// #pragma unroll -// for (int i = 0; i < a_sh_wr_iters; i++) { -// int a_idx = a_sh_wr_delta * i + a_sh_wr; -// int row = a_idx / a_gl_rd_delta_o; -// if (row >= prob_m) { -// a_sh_wr_pred[i] = false; -// } else { -// a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; -// } -// } - -// // To ensure that writing and reading A tiles to/from shared memory, the -// // latter in fragment format, is fully bank conflict free, we need to use a -// // rather fancy XOR-based layout. The key here is that neither reads nor -// // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the -// // same shared memory banks. Further, it seems (based on NSight-Compute) -// that -// // each warp must also write a consecutive memory segment? -// auto transform_a = [&](int i) { -// int row = i / a_gl_rd_delta_o; -// return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; -// }; -// // Since the computation of this remapping is non-trivial and, due to our -// main -// // loop unrolls, all shared memory accesses are static, we simply -// precompute -// // both transformed reads and writes. -// int a_sh_wr_trans[a_sh_wr_iters]; -// #pragma unroll -// for (int i = 0; i < a_sh_wr_iters; i++) -// a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); -// int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) { -// #pragma unroll -// for (int j = 0; j < thread_m_blocks; j++) -// a_sh_rd_trans[i][j] = -// transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); -// } - -// // Since B-accesses have non-constant stride they have to be computed at -// // runtime; we break dependencies between subsequent accesses with a tile -// by -// // maintining multiple pointers (we have enough registers), a tiny -// // optimization. -// const int4* B_ptr[b_sh_wr_iters]; -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) -// B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - -// // Register storage for double buffer of shared memory reads. -// FragA frag_a[2][thread_m_blocks]; -// I4 frag_b_quant[2][b_thread_vecs]; -// FragC frag_c[thread_m_blocks][4][2]; -// FragS frag_s[2][4]; // No act-order -// FragS act_frag_s[2][4][4]; // For act-order -// int frag_qzp[2][num_ints_per_thread]; // Zero-points -// FragZP frag_zp; // Zero-points in fp16 - -// // Zero accumulators. -// auto zero_accums = [&]() { -// #pragma unroll -// for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) -// reinterpret_cast(frag_c)[i] = 0; -// }; - -// auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, -// int last_group_id) { -// sh_first_group_id = first_group_id; -// sh_num_groups = last_group_id - first_group_id + 1; - -// if (sh_num_groups < sh_max_num_groups) { -// sh_num_groups = sh_max_num_groups; -// } - -// if (sh_first_group_id + sh_num_groups > num_groups) { -// sh_num_groups = num_groups - sh_first_group_id; -// } - -// int row_offset = first_group_id * s_gl_stride; - -// if (is_async) { -// for (int i = 0; i < sh_num_groups; i++) { -// if (threadIdx.x < s_sh_stride) { -// cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], -// &scales_ptr[row_offset + (i * s_gl_stride) + -// slice_n_offset + threadIdx.x]); -// } -// } -// } else { -// for (int i = 0; i < sh_num_groups; i++) { -// if (threadIdx.x < s_sh_stride) { -// sh_s[(i * s_sh_stride) + threadIdx.x] = -// scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + -// threadIdx.x]; -// } -// } -// } -// }; -// // Asynchronously fetch the next A, B and s tile from global to the next -// // shared memory pipeline location. -// auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { -// if (pred) { -// int4* sh_a_stage = sh_a + a_sh_stage * pipe; -// #pragma unroll -// for (int i = 0; i < a_sh_wr_iters; i++) { -// int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; -// int row = a_idx / a_gl_stride; -// int sorted_row = -// replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; -// int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; -// if (sorted_row < tot_m * (replicate_input ? 1 : topk) && -// new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { -// cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], -// a_sh_wr_pred[i]); -// } -// } -// int4* sh_b_stage = sh_b + b_sh_stage * pipe; -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) { -// #pragma unroll -// for (int j = 0; j < b_thread_vecs; j++) { -// cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + -// j); -// } -// B_ptr[i] += b_gl_rd_delta_o; -// } - -// if constexpr (has_act_order) { -// // Fetch g_idx thread-block portion -// int full_pipe = a_off; -// int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; -// if (cur_k < prob_k && cur_k < slice_k_finish) { -// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - -// int4 const* cur_g_idx_stage_ptr = -// reinterpret_cast(&g_idx[cur_k]); - -// if (threadIdx.x < g_idx_stage) { -// cp_async4_pred(&sh_g_idx_stage[threadIdx.x], -// &cur_g_idx_stage_ptr[threadIdx.x]); -// } -// } -// } else { -// if constexpr (group_blocks != -1) { -// int4* sh_s_stage = sh_s + s_sh_stage * pipe; - -// if constexpr (group_blocks >= thread_k_blocks) { -// // Only fetch scales if this tile starts a new group -// if (pipe % (group_blocks / thread_k_blocks) == 0) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); -// } -// s_gl_rd += s_gl_rd_delta; -// } -// } else { -// for (int i = 0; i < s_tb_groups; i++) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], -// &scales_ptr[s_gl_rd]); -// } -// s_gl_rd += s_gl_rd_delta; -// } -// } -// } - -// if constexpr (has_zp && group_blocks != -1) { -// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - -// if constexpr (group_blocks >= thread_k_blocks) { -// // Only fetch zero-points if this tile starts a new group -// if (pipe % (group_blocks / thread_k_blocks) == 0) { -// if (zp_sh_wr_pred) { -// cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); -// } -// zp_gl_rd += zp_gl_rd_delta; -// } -// } else { -// for (int i = 0; i < zp_tb_groups; i++) { -// if (zp_sh_wr_pred) { -// cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], -// &zp_ptr[zp_gl_rd]); -// } -// zp_gl_rd += zp_gl_rd_delta; -// } -// } -// } -// } -// } -// // Insert a fence even when we are winding down the pipeline to ensure -// that -// // waiting is also correct at this point. -// cp_async_fence(); -// }; - -// auto fetch_zp_to_shared = [&]() { -// if (zp_sh_wr_pred) { -// cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); -// } -// }; - -// // Wait until the next thread tile has been loaded to shared memory. -// auto wait_for_stage = [&]() { -// // We only have `stages - 2` active fetches since we are double buffering -// // and can only issue the next fetch when it is guaranteed that the -// previous -// // shared memory load is fully complete (as it may otherwise be -// // overwritten). -// cp_async_wait(); -// __syncthreads(); -// }; - -// // Load the next sub-tile from the current location in the shared memory -// pipe -// // into the current register buffer. -// auto fetch_to_registers = [&](int k, int pipe) { -// int4* sh_a_stage = sh_a + a_sh_stage * pipe; -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) -// ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % -// b_sh_wr_iters][i]]); -// int4* sh_b_stage = sh_b + b_sh_stage * pipe; - -// #pragma unroll -// for (int i = 0; i < b_thread_vecs; i++) { -// frag_b_quant[k % 2][i] = *reinterpret_cast( -// &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); -// } -// }; - -// bool is_same_group[stages]; -// int same_group_id[stages]; - -// auto init_same_group = [&](int pipe) { -// if constexpr (!has_act_order) { -// is_same_group[pipe] = false; -// same_group_id[pipe] = 0; -// return; -// } - -// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; -// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - -// int group_id_1 = sh_g_idx_int_ptr[0]; -// int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - -// is_same_group[pipe] = group_id_1 == group_id_2; -// same_group_id[pipe] = group_id_1; -// }; - -// auto fetch_scales_to_registers = [&](int k, int full_pipe) { -// int pipe = full_pipe % stages; - -// if constexpr (!has_act_order) { -// // No act-order case -// if constexpr (group_blocks != -1) { -// if constexpr (group_blocks >= thread_k_blocks) { -// int4* sh_s_stage = -// sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * -// (pipe / (group_blocks / -// thread_k_blocks))); -// reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; -// } else { -// int warp_id = threadIdx.x / 32; -// int n_warps = thread_n_blocks / 4; - -// int warp_row = warp_id / n_warps; - -// int cur_k = warp_row * 16; -// cur_k += k_iter_size * (k % b_sh_wr_iters); - -// int k_blocks = cur_k / 16; -// int cur_group_id = k_blocks / group_blocks; - -// int4* sh_s_stage = sh_s + s_sh_stage * pipe; - -// reinterpret_cast(&frag_s[k % 2])[0] = -// sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; -// } -// } - -// return; -// } - -// // Act-order case - -// // Determine K of the "current" thread-block -// int cur_k = slice_k_start + tb_k * full_pipe; -// if (cur_k >= prob_k || cur_k >= slice_k_finish) { -// return; -// } - -// // Reset (to current thread-block) since we read g_idx portion from the -// // shared memory -// cur_k = 0; - -// // Progress to current iteration -// cur_k += k_iter_size * (k % b_sh_wr_iters); - -// // Determine "position" inside the thread-block (based on warp and -// // thread-id) -// int warp_id = threadIdx.x / 32; -// int n_warps = -// thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - -// int warp_row = warp_id / n_warps; -// int warp_col = warp_id % n_warps; - -// cur_k += warp_row * 16; - -// int th_id = threadIdx.x % 32; -// cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - -// int s_col_shift = -// /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + -// (th_id / 4) * act_s_col_stride; - -// if (is_same_group[pipe]) { -// if (k % 2 == 0) { -// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = -// sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + -// s_col_shift]; -// } else { -// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = -// *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); -// } - -// for (int i = 1; i < 4; i++) { -// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = -// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); -// } -// return; -// } - -// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; -// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - -// constexpr int k_frag_offsets[4] = {0, 1, 8, -// 9}; // Tensor core offsets per thread - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// int actual_k = cur_k + k_frag_offsets[i]; - -// int group_id = sh_g_idx_int_ptr[actual_k]; -// int rel_group_id = group_id - sh_first_group_id; - -// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = -// sh_s[rel_group_id * s_sh_stride + s_col_shift]; -// } -// }; - -// auto fetch_zp_to_registers = [&](int k, int full_pipe) { -// // This code does not handle group_blocks == 0, -// // which signifies act_order. -// // has_zp implies AWQ, which doesn't have act_order, -// static_assert(!has_zp || group_blocks != 0); - -// if constexpr (has_zp) { -// int pipe = full_pipe % stages; - -// if constexpr (group_blocks == -1) { -// for (int i = 0; i < num_ints_per_thread; i++) { -// frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; -// } - -// } else if constexpr (group_blocks >= thread_k_blocks) { -// int4* sh_zp_stage = -// sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * -// (pipe / (group_blocks / -// thread_k_blocks))); -// for (int i = 0; i < num_ints_per_thread; i++) { -// frag_qzp[k % 2][i] = -// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; -// } -// } else { -// int warp_id = threadIdx.x / 32; -// int n_warps = thread_n_blocks / 4; - -// int warp_row = warp_id / n_warps; - -// int cur_k = warp_row * 16; -// cur_k += k_iter_size * (k % b_sh_wr_iters); - -// int k_blocks = cur_k / 16; -// int cur_group_id = 0; - -// // Suppress bogus and persistent divide-by-zero warning -// #pragma nv_diagnostic push -// #pragma nv_diag_suppress divide_by_zero -// cur_group_id = k_blocks / group_blocks; -// #pragma nv_diagnostic pop - -// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - -// sh_zp_stage += cur_group_id * zp_sh_stride; - -// for (int i = 0; i < num_ints_per_thread; i++) { -// frag_qzp[k % 2][i] = -// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; -// } -// } -// } -// }; - -// // Execute the actual tensor core matmul of a sub-tile. -// auto matmul = [&](int k) { -// if constexpr (has_zp) { -// FragB frag_zp_0; -// FragB frag_zp_1; -// int zp_quant_0, zp_quant_1; - -// if constexpr (w_type.size_bits() == 4) { -// zp_quant_0 = frag_qzp[k % 2][0]; -// zp_quant_1 = zp_quant_0 >> 8; -// } else { -// static_assert(w_type.size_bits() == 8); -// zp_quant_0 = frag_qzp[k % 2][0]; -// zp_quant_1 = frag_qzp[k % 2][1]; -// } - -// frag_zp_0 = dequant(zp_quant_0); -// frag_zp_1 = dequant(zp_quant_1); - -// frag_zp[0] = frag_zp_0[0]; -// frag_zp[1] = frag_zp_0[1]; -// frag_zp[2] = frag_zp_1[0]; -// frag_zp[3] = frag_zp_1[1]; -// } - -// // We have the m dimension as the inner loop in order to encourage -// overlapping -// // dequantization and matmul operations. -// #pragma unroll -// for (int j = 0; j < 4; j++) { -// int b_quant_0, b_quant_1; -// if constexpr (w_type.size_bits() == 4) { -// b_quant_0 = frag_b_quant[k % 2][0][j]; -// b_quant_1 = b_quant_0 >> 8; -// } else { -// static_assert(w_type.size_bits() == 8); -// int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); -// b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; -// b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; -// } - -// FragB frag_b0 = dequant(b_quant_0); -// FragB frag_b1 = dequant(b_quant_1); - -// // Apply scale to frag_b0 -// if constexpr (has_act_order) { -// scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], -// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); -// } else { -// if constexpr (group_blocks != -1) { -// scale(frag_b0, frag_s[k % 2][j], 0); -// } -// } - -// // Apply zero-point to frag_b1 -// if constexpr (has_zp) { -// sub_zp(frag_b1, frag_zp[j], 1); -// } - -// // Apply scale to frag_b1 -// if constexpr (has_act_order) { -// scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], -// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - -// } else { -// if constexpr (group_blocks != -1) { -// scale(frag_b1, frag_s[k % 2][j], 1); -// } -// } - -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) { -// mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); -// mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); -// } -// } -// }; - -// // Since we slice across the k dimension of a tile in order to increase the -// // number of warps while keeping the n dimension of a tile reasonable, we -// have -// // multiple warps that accumulate their partial sums of the same output -// // location; which we have to reduce over in the end. We do in shared -// memory. auto thread_block_reduce = [&]() { -// constexpr int red_off = threads / b_sh_stride_threads / 2; -// if (red_off >= 1) { -// int red_idx = threadIdx.x / b_sh_stride_threads; -// constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; -// constexpr int red_sh_delta = b_sh_stride_threads; -// int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + -// (threadIdx.x % b_sh_stride_threads); - -// // Parallel logarithmic shared memory reduction. We make sure to avoid -// any -// // unnecessary read or write iterations, e.g., for two warps we write -// only -// // once by warp 1 and read only once by warp 0. - -// #pragma unroll -// for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -// #pragma unroll -// for (int i = red_off; i > 0; i /= 2) { -// if (i <= red_idx && red_idx < 2 * i) { -// #pragma unroll -// for (int j = 0; j < 4 * 2; j++) { -// int red_sh_wr = -// red_sh_delta * j + (red_sh_rd - red_sh_stride * i); -// if (i < red_off) { -// float* c_rd = -// reinterpret_cast(&sh[red_sh_delta * j + -// red_sh_rd]); -// float* c_wr = reinterpret_cast(&sh[red_sh_wr]); -// #pragma unroll -// for (int k = 0; k < 4; k++) -// reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += -// c_rd[k] + c_wr[k]; -// } -// sh[red_sh_wr] = -// reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; -// } -// } -// __syncthreads(); -// } -// if (red_idx == 0) { -// #pragma unroll -// for (int i = 0; i < 4 * 2; i++) { -// float* c_rd = -// reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -// #pragma unroll -// for (int j = 0; j < 4; j++) -// reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += -// c_rd[j]; -// } -// } -// __syncthreads(); -// } -// } -// }; - -// // Since multiple threadblocks may process parts of the same column slice, -// we -// // finally have to globally reduce over the results. As the striped -// // partitioning minimizes the number of such reductions and our outputs are -// // usually rather small, we perform this reduction serially in L2 cache. -// auto global_reduce = [&](bool first = false, bool last = false) { -// // We are very careful here to reduce directly in the output buffer to -// // maximize L2 cache utilization in this step. To do this, we write out -// // results in FP16 (but still reduce with FP32 compute). -// constexpr int active_threads = 32 * thread_n_blocks / 4; -// if (threadIdx.x < active_threads) { -// int c_gl_stride = prob_n / 8; -// int c_gl_wr_delta_o = 8 * c_gl_stride; -// int c_gl_wr_delta_i = 4 * (active_threads / 32); -// int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + -// 4 * (threadIdx.x / 32) + threadIdx.x % 4; -// c_gl_wr += (2 * thread_n_blocks) * slice_col; -// constexpr int c_sh_wr_delta = active_threads; -// int c_sh_wr = threadIdx.x; - -// int row = (threadIdx.x % 32) / 4; - -// if (!first) { -// // Interestingly, doing direct global accesses here really seems to mess up -// // the compiler and lead to slowdowns, hence we also use async-copies even -// // though these fetches are not actually asynchronous. -// #pragma unroll -// for (int i = 0; i < thread_m_blocks * 4; i++) { -// int c_idx = -// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % -// 2); -// int sorted_row = sorted_ids[c_idx / c_gl_stride]; -// int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; -// cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], -// sorted_row < tot_m * topk && -// (8 * (i / 2) + row < prob_m && -// (i < (thread_m_blocks - 1) * 4 || -// sorted_ids[8 * (i / 2) + row] < tot_m * -// topk))); -// } -// cp_async_fence(); -// cp_async_wait<0>(); -// } - -// #pragma unroll -// for (int i = 0; i < thread_m_blocks * 4; i++) { -// if (8 * (i / 2) + row < prob_m && -// (i < (thread_m_blocks - 1) * 4 || -// sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { -// if (!first) { -// int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -// #pragma unroll -// for (int j = 0; j < 2 * 4; j++) { -// reinterpret_cast( -// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += -// __half2float(reinterpret_cast<__half*>(&c_red)[j]); -// } -// } -// if (!last) { -// int4 c; -// #pragma unroll -// for (int j = 0; j < 2 * 4; j++) { -// reinterpret_cast<__half*>(&c)[j] = -// __float2half(reinterpret_cast( -// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); -// } -// int c_idx = -// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % -// 2); -// int row = sorted_ids[c_idx / c_gl_stride]; -// if (row < tot_m * topk) { -// int new_idx = row * c_gl_stride + c_idx % c_gl_stride; -// C[new_idx] = c; -// } -// } -// } -// } -// } -// }; - -// // Write out the reduce final result in the correct layout. We only -// actually -// // reshuffle matrix fragments in this step, the reduction above is -// performed -// // in fragment layout. -// auto write_result = [&]() { -// int c_gl_stride = prob_n / 8; -// constexpr int c_sh_stride = 2 * thread_n_blocks + 1; -// int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); -// constexpr int c_sh_rd_delta = -// c_sh_stride * (threads / (2 * thread_n_blocks)); - -// int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + -// (threadIdx.x % (2 * thread_n_blocks)); -// c_gl_wr += (2 * thread_n_blocks) * slice_col; -// int c_sh_wr = -// (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % -// 4; -// c_sh_wr += 32 * (threadIdx.x / 32); -// int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + -// (threadIdx.x % (2 * thread_n_blocks)); - -// int c_gl_wr_end = c_gl_stride * prob_m; - -// // We first reorder in shared memory to guarantee the most efficient -// final -// // global write patterns -// auto write = [&](int idx, float c0, float c1, FragS& s) { -// half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - -// // For per-column quantization we finally apply the scale here (only -// for -// // 4-bit) -// if constexpr (!has_act_order && group_blocks == -1 && -// w_type.size_bits() == 4) { -// res = __hmul2(res, s[0]); -// } - -// ((half2*)sh)[idx] = res; -// }; -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) { -// #pragma unroll -// for (int j = 0; j < 4; j++) { -// int wr = c_sh_wr + 8 * j; -// write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], -// frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); -// write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], -// frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); -// write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], -// frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); -// write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], -// frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); -// } -// c_sh_wr += 16 * (4 * c_sh_stride); -// } -// } -// __syncthreads(); - -// #pragma unroll -// for (int i = 0; -// i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); -// i++) { -// if (c_gl_wr < c_gl_wr_end) { -// int row = sorted_ids[c_gl_wr / c_gl_stride]; -// if (row < tot_m * topk) { -// int off = row * c_gl_stride + c_gl_wr % c_gl_stride; -// if (!apply_weights) { -// C[off] = sh[c_sh_rd]; -// } else { -// __half* ctrg = reinterpret_cast<__half*>(&C[off]); -// __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); -// for (int j = 0; j < 8; ++j) { -// ctrg[j] = __float2half(topk_weights[row] * -// __half2float(csrc[j])); -// } -// } -// c_gl_wr += c_gl_wr_delta; -// c_sh_rd += c_sh_rd_delta; -// } -// } -// } -// }; - -// // Start global fetch and register load pipelines. -// auto start_pipes = [&]() { - -// #pragma unroll -// for (int i = 0; i < stages - 1; i++) { -// if (has_act_order && i == 0) { -// int last_g_idx = slice_k_start + stages * tb_k * 2; -// if (last_g_idx >= prob_k) { -// last_g_idx = prob_k - 1; -// } -// fetch_scales_to_shared(true, g_idx[slice_k_start], -// g_idx[last_g_idx]); -// } - -// if constexpr (has_zp && group_blocks == -1) { -// if (i == 0) { -// fetch_zp_to_shared(); -// } -// } -// fetch_to_shared(i, i, i < slice_iters); -// } - -// zero_accums(); -// wait_for_stage(); -// init_same_group(0); -// fetch_to_registers(0, 0); -// fetch_scales_to_registers(0, 0); -// fetch_zp_to_registers(0, 0); -// a_gl_rd += a_gl_rd_delta_o * (stages - 1); -// slice_k_start_shared_fetch += tb_k * (stages - 1); -// }; -// if (slice_iters) { -// start_pipes(); -// } - -// // Main loop. -// while (slice_iters) { -// // We unroll over both the global fetch and the register load pipeline to -// // ensure all shared memory accesses are static. Note that both pipelines -// // have even length meaning that the next iteration will always start at -// // index 0. -// #pragma unroll -// for (int pipe = 0; pipe < stages;) { -// #pragma unroll -// for (int k = 0; k < b_sh_wr_iters; k++) { -// fetch_to_registers(k + 1, pipe % stages); -// fetch_scales_to_registers(k + 1, pipe); -// fetch_zp_to_registers(k + 1, pipe); -// if (k == b_sh_wr_iters - 2) { -// fetch_to_shared((pipe + stages - 1) % stages, pipe, -// slice_iters >= stages); -// pipe++; -// wait_for_stage(); -// init_same_group(pipe % stages); -// } -// matmul(k); -// } -// slice_iters--; -// if (slice_iters == 0) { -// break; -// } -// } - -// a_gl_rd += a_gl_rd_delta_o * stages; -// slice_k_start += tb_k * stages; -// slice_k_start_shared_fetch += tb_k * stages; - -// if constexpr (has_act_order) { -// int first_group_id = g_idx[slice_k_start]; -// int last_g_idx = slice_k_start + stages * tb_k * 2; -// if (last_g_idx >= prob_k) { -// last_g_idx = prob_k - 1; -// } -// int last_group_id = g_idx[last_g_idx]; -// if (last_group_id >= sh_first_group_id + sh_num_groups) { -// fetch_scales_to_shared(false, first_group_id, last_group_id); -// __syncthreads(); -// } -// } - -// // Process results and, if necessary, proceed to the next column slice. -// // While this pattern may not be the most readable, other ways of writing -// // the loop seemed to noticeably worse performance after compilation. -// if (slice_iters == 0) { -// cp_async_wait<0>(); -// bool last = slice_idx == slice_count - 1; -// if constexpr (!has_act_order && group_blocks == -1) { -// if constexpr (w_type.size_bits() == 8) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); -// } -// cp_async_fence(); -// } else { -// // For 4-bit per-column scales, we only fetch them here in the -// // final step before write-out -// if (last) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); -// } -// cp_async_fence(); -// } -// } -// } - -// thread_block_reduce(); -// if constexpr (!has_act_order && group_blocks == -1) { -// if constexpr (w_type.size_bits() == 8) { -// cp_async_wait<0>(); -// __syncthreads(); -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; -// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; -// } - -// } else { -// if (last) { -// cp_async_wait<0>(); -// __syncthreads(); -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; -// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; -// } -// } -// } -// } - -// // For 8-bit channelwise, we apply the scale before the global -// reduction -// // that converts the fp32 results to fp16 (so that we avoid possible -// // overflow in fp16) -// if constexpr (!has_act_order && group_blocks == -1 && -// w_type.size_bits() == 8) { -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) { -// #pragma unroll -// for (int j = 0; j < 4; j++) { -// scale_float(reinterpret_cast(&frag_c[i][j][0][0]), -// frag_s[j / 2][2 * (j % 2) + 0]); -// scale_float(reinterpret_cast(&frag_c[i][j][0][2]), -// frag_s[j / 2][2 * (j % 2) + 0]); - -// scale_float(reinterpret_cast(&frag_c[i][j][1][0]), -// frag_s[j / 2][2 * (j % 2) + 1]); -// scale_float(reinterpret_cast(&frag_c[i][j][1][2]), -// frag_s[j / 2][2 * (j % 2) + 1]); -// } -// } -// } -// } - -// if (slice_count > 1) { // only globally reduce if there is more than -// one -// // block in a slice -// barrier_acquire(&locks[slice_col], slice_idx); -// global_reduce(slice_idx == 0, last); -// barrier_release(&locks[slice_col], last); -// } -// if (last) // only the last block in a slice actually writes the result -// write_result(); -// slice_row = 0; -// slice_col_par++; -// slice_col++; -// init_slice(); -// if (slice_iters) { -// a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + -// (threadIdx.x % a_gl_rd_delta_o); -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) -// B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; -// if (slice_col == 0) { -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; -// } - -// // Update slice k/n for scales loading -// if constexpr (has_act_order) { -// slice_k_start = tb_k * slice_row; -// slice_k_finish = slice_k_start + tb_k * slice_iters; -// slice_k_start_shared_fetch = slice_k_start; -// slice_n_offset = act_s_col_tb_stride * slice_col; - -// } else { -// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; -// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; -// } -// start_pipes(); -// } -// } -// } -// } - -// template shared -// // fetch pipeline -// const bool has_act_order, // whether act_order is enabled -// const bool has_zp, // whether zero-points are enabled -// const int group_blocks = -1 // number of consecutive 16x16 blocks -// // with a separate quantization scale -// > -// __global__ void MarlinMoE( -// const int4* __restrict__ A, // fp16 input matrix of shape mxk -// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn -// int4* __restrict__ C, // fp16 output buffer of shape mxn -// const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts -// const float* __restrict__ topk_weights, // float topk weights -// const int4* __restrict__ scales_ptr, // fp16 quantization scales of -// shape -// // (k/groupsize)xn -// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape -// // (k/groupsize)x(n/pack_factor) -// const int* __restrict__ g_idx, // int32 group indices of shape k -// const int* __restrict__ expert_offsets, -// int num_groups, // number of scale groups per output channel -// int expert_idx, // idx of current expert -// int num_experts, // number of experts -// int topk, // topk parameter of moe -// int prob_m, // batch dimension m -// int prob_n, // output dimension n -// int prob_k, // reduction dimension k -// int tot_m, // total number of rows in A and C -// int* locks, // extra global storage for barrier -// synchronization bool replicate_input, // do we use the same input for -// each expert? bool apply_weights, // apply weights to output int -// current_m_block, // current m block to start kernel computation from -// int max_par, // maximum parallelism -// int cfg_max_m_blocks // upper bound on m blocks -// ) { -// int m_block_ctr = current_m_block; - -// const int* sorted_ids_expert = -// sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * -// max_par; -// int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; -// if (tot_its == 0) { -// return; -// } -// int tot_m_blocks = ceildiv(tot_its, 16); -// int pad = 16 * tot_m_blocks - tot_its; - -// if (m_block_ctr >= tot_m_blocks) { -// return; -// } - -// int max_block = tot_m_blocks - m_block_ctr; -// prob_m = tot_its - 16 * m_block_ctr; - -// int par = 1; -// if (max_block > cfg_max_m_blocks) { -// // Note that parallel > 1 currently only works for inputs without any -// // padding -// par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); -// if (par > max_par) par = max_par; -// prob_m = (16 * cfg_max_m_blocks) * par; -// m_block_ctr += cfg_max_m_blocks * (par - 1); -// max_block = cfg_max_m_blocks; -// } - -// if (max_block == 1) { -// MarlinMoESingle( -// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, -// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, -// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, -// current_m_block); -// } else if (max_block == 2) { -// MarlinMoESingle( -// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, -// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, -// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, -// current_m_block); -// } else if (max_block == 3) { -// MarlinMoESingle( -// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, -// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, -// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, -// current_m_block); -// } else { -// MarlinMoESingle( -// A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, -// expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, -// prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, -// current_m_block); -// } -// } - #else __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -1698,93 +137,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -// template shared -// // fetch pipeline -// const bool has_act_order, // whether act_order is enabled -// const bool has_zp, // whether zero-points are enabled -// const int group_blocks = -1 // number of consecutive 16x16 blocks -// // with a separate quantization scale -// > -// __global__ void MarlinMoE( -// const int4* __restrict__ A, // fp16 input matrix of shape mxk -// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn -// int4* __restrict__ C, // fp16 output buffer of shape mxn -// const int* __restrict__ sorted_ids, // int32 sorted ids of experts -// const float* __restrict__ topk_weights, // float topk weights -// const int4* __restrict__ scales_ptr, // fp16 quantization scales of -// shape -// // (k/groupsize)xn -// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape -// // (k/groupsize)x(n/pack_factor) -// const int* __restrict__ g_idx, // int32 group indices of shape k -// const int* __restrict__ expert_offsets, -// int num_groups, // number of scale groups per output channel -// int expert_idx, // idx of current expert -// int num_experts, // number of experts -// int topk, // topk parameter of moe -// int prob_m, // batch dimension m -// int prob_n, // output dimension n -// int prob_k, // reduction dimension k -// int tot_m, // total number of rows in A and C -// int* locks, // extra global storage for barrier -// synchronization bool replicate_input, // do we use the same input for -// each expert? bool apply_weights, // apply weights to output int -// current_m_block, // current m block to start kernel computation from -// int max_par, // maximum parallelism -// int cfg_max_m_blocks // upper bound on m blocks - -// ) { -// // Marlin is not implemented yet for SM < 8.0 -// assert(false); -// return; -// } - #endif -// // 8 warps are a good choice since every SM has 4 schedulers and having more -// // than 1 warp per schedule allows some more latency hiding. At the same -// time, -// // we want relatively few warps to have many registers per warp and small -// tiles. const int USER_THREADS = -// 256; // Note: This is only used with user-provided -// thread_k/n -// const int STAGES = 4; // 4 pipeline stages fit into shared memory -// // const int SHARED_MEM = -// // 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -// static constexpr int min_thread_n = 64; -// static constexpr int min_thread_k = 64; - -// #define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ -// THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ -// NUM_THREADS) \ -// else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ -// thread_n_blocks == THREAD_N_BLOCKS && \ -// thread_k_blocks == THREAD_K_BLOCKS && \ -// has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ -// group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ -// cudaFuncSetAttribute(MarlinMoE, \ -// cudaFuncAttributeMaxDynamicSharedMemorySize, \ -// max_shared_mem); \ -// MarlinMoE \ -// <<>>( \ -// A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ -// zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ -// num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ -// replicate_input, apply_weights, m_block, max_par, \ -// exec_cfg.max_m_blocks); \ -// } - typedef struct { int thread_k; int thread_n; @@ -1959,53 +313,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -// #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - -// #define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ -// \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ -// __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - #define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ else if (KERNEL_FUNCTION( \ q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, \ @@ -2164,37 +471,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } - // else if(call_marlin_moe_kernel_ku4b8( - // q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, - // has_act_order, has_zp, group_blocks, num_threads, blocks, - // max_shared_mem, stream, A_ptr, B_ptr, C_ptr, sorted_ids_ptr, - // topk_weights_ptr, s_ptr, zp_ptr, g_idx_ptr, expert_offsets_ptr, - // num_groups, expert_idx, num_experts, topk, prob_m, prob_n, - // prob_k, tot_m, locks, replicate_input, apply_weights, m_block, - // max_par, exec_cfg.max_m_blocks)) { - // } CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) - - // GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - // GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - // GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - // GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - // GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - // GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - // GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - // GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) - - // AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) - // AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) - // AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) - // AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) - // AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) - // AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) - // AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) - // AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) + // CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) + // CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) + // CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + From 1b76e45aab62817f53c16f7bdbb2a0a3961f8505 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 19 Sep 2024 11:13:47 -0400 Subject: [PATCH 28/49] Compilation works --- CMakeLists.txt | 11 +- csrc/moe/marlin_moe_kernel.cu | 1257 -------------------- csrc/moe/marlin_moe_kernel.cuh | 391 ------ csrc/moe/marlin_moe_kernel.h | 1584 +++++++++++++++++++++++++ csrc/moe/marlin_moe_kernel_ku4.cu | 48 +- csrc/moe/marlin_moe_kernel_ku4.h | 20 + csrc/moe/marlin_moe_kernel_ku4b8.cu | 100 +- csrc/moe/marlin_moe_kernel_ku4b8.h | 20 + csrc/moe/marlin_moe_kernel_ku8.cu | 48 +- csrc/moe/marlin_moe_kernel_ku8.h | 20 + csrc/moe/marlin_moe_kernel_ku8b128.cu | 77 +- csrc/moe/marlin_moe_kernel_ku8b128.h | 18 + csrc/moe/marlin_moe_ops.cu | 15 +- 13 files changed, 1749 insertions(+), 1860 deletions(-) delete mode 100644 csrc/moe/marlin_moe_kernel.cu delete mode 100644 csrc/moe/marlin_moe_kernel.cuh create mode 100644 csrc/moe/marlin_moe_kernel.h create mode 100644 csrc/moe/marlin_moe_kernel_ku4.h create mode 100644 csrc/moe/marlin_moe_kernel_ku4b8.h create mode 100644 csrc/moe/marlin_moe_kernel_ku8.h create mode 100644 csrc/moe/marlin_moe_kernel_ku8b128.h diff --git a/CMakeLists.txt b/CMakeLists.txt index bd3322ad4cd1..c5934c59f36a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -295,11 +295,18 @@ define_gpu_extension_target( set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" + "csrc/moe/marlin_moe_kernel.h" + "csrc/moe/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_moe_kernel_ku8b128.cu" + "csrc/moe/marlin_moe_kernel_ku4.h" + "csrc/moe/marlin_moe_kernel_ku4.cu" + "csrc/moe/marlin_moe_kernel_ku8.h" + "csrc/moe/marlin_moe_kernel_ku8.cu" "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC - "csrc/moe/marlin_moe_kernel.cu") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_moe_kernel.cu b/csrc/moe/marlin_moe_kernel.cu deleted file mode 100644 index 2090eb848a16..000000000000 --- a/csrc/moe/marlin_moe_kernel.cu +++ /dev/null @@ -1,1257 +0,0 @@ -// #include -// #include -// #include -// #include -// #include - -#include "marlin_moe_kernel.cuh" - -namespace marlin_moe { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// template shared -// // fetch pipeline -// const bool has_act_order, // whether act_order is enabled -// const bool has_zp, // whether zero-points are enabled -// const int group_blocks = -1 // number of consecutive 16x16 blocks -// // with a separate quantization scale -// > -// __device__ inline void MarlinMoESingle( -// const int4* __restrict__ A, // fp16 input matrix of shape mxk -// const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn -// int4* __restrict__ C, // fp16 output buffer of shape mxn -// const int* __restrict__ sorted_ids, // int32 sorted ids of experts -// const float* __restrict__ topk_weights, // float topk weights -// const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape -// // (k/groupsize)xn -// const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape -// // (k/groupsize)x(n/pack_factor) -// const int* __restrict__ g_idx, // int32 group indices of shape k -// const int* __restrict__ expert_offsets, -// int num_groups, // number of scale groups per output channel -// int expert_idx, // idx of current expert -// int num_experts, // number of experts -// int topk, // topk parameter of moe -// int prob_m, // batch dimension m -// int prob_n, // output dimension n -// int prob_k, // reduction dimension k -// int tot_m, // total number of rows in A and C -// int* locks, // extra global storage for barrier synchronization -// bool replicate_input, // do we use the same input for each expert? -// bool apply_weights, // apply weights to output -// int current_m_block // current m block to start kernel computation from -// ) { -// static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); -// constexpr int pack_factor = 32 / w_type.size_bits(); - -// // For larger GEMMs we run multiple batchsize 64 versions in parallel for a -// // better partitioning with less reductions -// int parallel = 1; -// if (prob_m > 16 * thread_m_blocks) { -// parallel = prob_m / (16 * thread_m_blocks); -// prob_m = 16 * thread_m_blocks; -// } - -// int k_tiles = prob_k / 16 / thread_k_blocks; -// int n_tiles = prob_n / 16 / thread_n_blocks; -// int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - -// if constexpr (!has_act_order && group_blocks != -1) { -// if (group_blocks >= thread_k_blocks) { -// // Ensure that the number of tiles in each stripe is a multiple of the -// // groupsize; this avoids an annoying special case where a stripe starts -// // in the middle of group. -// iters = (group_blocks / thread_k_blocks) * -// ceildiv(iters, (group_blocks / thread_k_blocks)); -// } -// } - -// int slice_row = (iters * blockIdx.x) % k_tiles; -// int slice_col_par = (iters * blockIdx.x) / k_tiles; -// int slice_col = slice_col_par; -// int slice_iters; // number of threadblock tiles in the current slice -// int slice_count = -// 0; // total number of active threadblocks in the current slice -// int slice_idx; // index of threadblock in current slice; numbered bottom to -// // top - -// // We can easily implement parallel problem execution by just remapping -// // indices and advancing global pointers -// if (slice_col_par >= n_tiles) { -// locks += (slice_col_par / n_tiles) * n_tiles; -// slice_col = slice_col_par % n_tiles; -// sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; -// } - -// // Compute all information about the current slice which is required for -// // synchronization. -// auto init_slice = [&]() { -// slice_iters = -// iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); -// if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; -// if (slice_iters == 0) return; -// if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; -// slice_count = 1; -// slice_idx = 0; -// int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); -// if (col_first <= k_tiles * (slice_col_par + 1)) { -// int col_off = col_first - k_tiles * slice_col_par; -// slice_count = ceildiv(k_tiles - col_off, iters); -// if (col_off > 0) slice_count++; -// int delta_first = iters * blockIdx.x - col_first; -// if (delta_first < 0 || (col_off == 0 && delta_first == 0)) -// slice_idx = slice_count - 1; -// else { -// slice_idx = slice_count - 1 - delta_first / iters; -// if (col_off > 0) slice_idx--; -// } -// } -// if (slice_col == n_tiles) { -// sorted_ids += 16 * thread_m_blocks; -// locks += n_tiles; -// slice_col = 0; -// } -// }; -// init_slice(); - -// // A sizes/strides - -// // stride of the A matrix in global memory -// int a_gl_stride = prob_k / 8; -// // stride of an A matrix tile in shared memory -// constexpr int a_sh_stride = 16 * thread_k_blocks / 8; -// // delta between subsequent A tiles in global memory -// constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; -// // between subsequent accesses within a tile -// int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); -// // between shared memory writes -// constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); -// // between shared memory tile reads -// constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); -// // within a shared memory tile -// constexpr int a_sh_rd_delta_i = a_sh_stride * 16; -// // overall size of a tile -// constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); -// // number of shared write iterations for a tile -// constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - -// // B sizes/strides -// int b_gl_stride = 16 * prob_n / (pack_factor * 4); -// constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; -// constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; -// constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - -// int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; -// int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); -// constexpr int b_sh_wr_delta = threads * b_thread_vecs; -// constexpr int b_sh_rd_delta = threads * b_thread_vecs; -// constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; -// constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - -// // Scale sizes/strides without act_order -// int s_gl_stride = prob_n / 8; -// constexpr int s_sh_stride = 16 * thread_n_blocks / 8; -// constexpr int s_tb_groups = -// !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks -// ? thread_k_blocks / group_blocks -// : 1; -// constexpr int s_sh_stage = s_tb_groups * s_sh_stride; -// int s_gl_rd_delta = s_gl_stride; -// // Scale size/strides with act_order -// constexpr int tb_k = 16 * thread_k_blocks; -// constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; -// // constexpr int act_s_row_stride = 1; -// // int act_s_col_stride = act_s_row_stride * num_groups; -// int act_s_col_stride = 1; -// int act_s_col_warp_stride = act_s_col_stride * 8; -// int tb_n_warps = thread_n_blocks / 4; -// int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - -// // Zero-points sizes/strides -// int zp_gl_stride = (prob_n / pack_factor) / 4; -// constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; -// constexpr int zp_tb_groups = s_tb_groups; -// constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; -// int zp_gl_rd_delta = zp_gl_stride; - -// // Global A read index of current thread. -// int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + -// (threadIdx.x % a_gl_rd_delta_o); -// a_gl_rd += a_gl_rd_delta_o * slice_row; -// // Shared write index of current thread. -// int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + -// (threadIdx.x % a_gl_rd_delta_o); -// // Shared read index. -// int a_sh_rd = -// a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; -// a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - -// int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + -// (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; -// b_gl_rd += b_sh_stride * slice_col; -// b_gl_rd += b_gl_rd_delta_o * slice_row; -// int b_sh_wr = threadIdx.x * b_thread_vecs; -// int b_sh_rd = threadIdx.x * b_thread_vecs; - -// // For act_order -// constexpr int k_iter_size = tb_k / b_sh_wr_iters; -// int slice_k_start = tb_k * slice_row; -// int slice_k_finish = slice_k_start + tb_k * slice_iters; -// int slice_k_start_shared_fetch = slice_k_start; -// int slice_n_offset = act_s_col_tb_stride * slice_col; - -// // No act_order -// int s_gl_rd; -// if constexpr (!has_act_order) { -// if constexpr (group_blocks == -1) { -// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; -// } else { -// s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + -// s_sh_stride * slice_col + threadIdx.x; -// } -// } -// int s_sh_wr = threadIdx.x; -// bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - -// // Zero-points -// int zp_gl_rd; -// if constexpr (has_zp) { -// if constexpr (group_blocks == -1) { -// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; -// } else { -// zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + -// zp_sh_stride * slice_col + threadIdx.x; -// } -// } -// int zp_sh_wr = threadIdx.x; -// bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - -// // We use a different scale layout for grouped and column-wise quantization as -// // we scale a `half2` tile in column-major layout in the former and in -// // row-major in the latter case. -// int s_sh_rd; -// if constexpr (group_blocks != -1) -// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + -// (threadIdx.x % 32) / 4; -// else -// s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + -// (threadIdx.x % 32) % 4; - -// // Zero-points have the same read layout as the scales -// // (without column-wise case) -// constexpr int num_col_threads = 8; -// constexpr int num_row_threads = 4; -// constexpr int num_ints_per_thread = 8 / pack_factor; -// int zp_sh_rd; -// if constexpr (has_zp) { -// zp_sh_rd = num_ints_per_thread * num_col_threads * -// ((threadIdx.x / 32) % (thread_n_blocks / 4)) + -// num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); -// } - -// int sh_first_group_id = -1; -// int sh_num_groups = -1; -// constexpr int sh_max_num_groups = 32; - -// int shs_size; -// if constexpr (has_act_order) -// shs_size = sh_max_num_groups * s_sh_stride + threads; -// else -// shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - -// extern __shared__ int4 sh[]; -// // Shared memory storage for global fetch pipelines. -// int4* sh_a = sh; -// int4* sh_b = sh_a + (stages * a_sh_stage); -// int4* sh_g_idx = sh_b + (stages * b_sh_stage); -// int4* sh_zp = sh_g_idx + (stages * g_idx_stage); -// int4* sh_s = sh_zp + (stages * zp_sh_stage); - -// // Precompute which thread should not read memory in which iterations; this is -// // needed if there are more threads than required for a certain tilesize or -// // when the batchsize is not a multiple of 16. -// bool a_sh_wr_pred[a_sh_wr_iters]; -// #pragma unroll -// for (int i = 0; i < a_sh_wr_iters; i++) { -// int a_idx = a_sh_wr_delta * i + a_sh_wr; -// int row = a_idx / a_gl_rd_delta_o; -// if (row >= prob_m) { -// a_sh_wr_pred[i] = false; -// } else { -// a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; -// } -// } - -// // To ensure that writing and reading A tiles to/from shared memory, the -// // latter in fragment format, is fully bank conflict free, we need to use a -// // rather fancy XOR-based layout. The key here is that neither reads nor -// // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the -// // same shared memory banks. Further, it seems (based on NSight-Compute) that -// // each warp must also write a consecutive memory segment? -// auto transform_a = [&](int i) { -// int row = i / a_gl_rd_delta_o; -// return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; -// }; -// // Since the computation of this remapping is non-trivial and, due to our main -// // loop unrolls, all shared memory accesses are static, we simply precompute -// // both transformed reads and writes. -// int a_sh_wr_trans[a_sh_wr_iters]; -// #pragma unroll -// for (int i = 0; i < a_sh_wr_iters; i++) -// a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); -// int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) { -// #pragma unroll -// for (int j = 0; j < thread_m_blocks; j++) -// a_sh_rd_trans[i][j] = -// transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); -// } - -// // Since B-accesses have non-constant stride they have to be computed at -// // runtime; we break dependencies between subsequent accesses with a tile by -// // maintining multiple pointers (we have enough registers), a tiny -// // optimization. -// const int4* B_ptr[b_sh_wr_iters]; -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) -// B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - -// // Register storage for double buffer of shared memory reads. -// FragA frag_a[2][thread_m_blocks]; -// I4 frag_b_quant[2][b_thread_vecs]; -// FragC frag_c[thread_m_blocks][4][2]; -// FragS frag_s[2][4]; // No act-order -// FragS act_frag_s[2][4][4]; // For act-order -// int frag_qzp[2][num_ints_per_thread]; // Zero-points -// FragZP frag_zp; // Zero-points in fp16 - -// // Zero accumulators. -// auto zero_accums = [&]() { -// #pragma unroll -// for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) -// reinterpret_cast(frag_c)[i] = 0; -// }; - -// auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, -// int last_group_id) { -// sh_first_group_id = first_group_id; -// sh_num_groups = last_group_id - first_group_id + 1; - -// if (sh_num_groups < sh_max_num_groups) { -// sh_num_groups = sh_max_num_groups; -// } - -// if (sh_first_group_id + sh_num_groups > num_groups) { -// sh_num_groups = num_groups - sh_first_group_id; -// } - -// int row_offset = first_group_id * s_gl_stride; - -// if (is_async) { -// for (int i = 0; i < sh_num_groups; i++) { -// if (threadIdx.x < s_sh_stride) { -// cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], -// &scales_ptr[row_offset + (i * s_gl_stride) + -// slice_n_offset + threadIdx.x]); -// } -// } -// } else { -// for (int i = 0; i < sh_num_groups; i++) { -// if (threadIdx.x < s_sh_stride) { -// sh_s[(i * s_sh_stride) + threadIdx.x] = -// scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + -// threadIdx.x]; -// } -// } -// } -// }; -// // Asynchronously fetch the next A, B and s tile from global to the next -// // shared memory pipeline location. -// auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { -// if (pred) { -// int4* sh_a_stage = sh_a + a_sh_stage * pipe; -// #pragma unroll -// for (int i = 0; i < a_sh_wr_iters; i++) { -// int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; -// int row = a_idx / a_gl_stride; -// int sorted_row = -// replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; -// int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; -// if (sorted_row < tot_m * (replicate_input ? 1 : topk) && -// new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { -// cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], -// a_sh_wr_pred[i]); -// } -// } -// int4* sh_b_stage = sh_b + b_sh_stage * pipe; -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) { -// #pragma unroll -// for (int j = 0; j < b_thread_vecs; j++) { -// cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); -// } -// B_ptr[i] += b_gl_rd_delta_o; -// } - -// if constexpr (has_act_order) { -// // Fetch g_idx thread-block portion -// int full_pipe = a_off; -// int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; -// if (cur_k < prob_k && cur_k < slice_k_finish) { -// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - -// int4 const* cur_g_idx_stage_ptr = -// reinterpret_cast(&g_idx[cur_k]); - -// if (threadIdx.x < g_idx_stage) { -// cp_async4_pred(&sh_g_idx_stage[threadIdx.x], -// &cur_g_idx_stage_ptr[threadIdx.x]); -// } -// } -// } else { -// if constexpr (group_blocks != -1) { -// int4* sh_s_stage = sh_s + s_sh_stage * pipe; - -// if constexpr (group_blocks >= thread_k_blocks) { -// // Only fetch scales if this tile starts a new group -// if (pipe % (group_blocks / thread_k_blocks) == 0) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); -// } -// s_gl_rd += s_gl_rd_delta; -// } -// } else { -// for (int i = 0; i < s_tb_groups; i++) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], -// &scales_ptr[s_gl_rd]); -// } -// s_gl_rd += s_gl_rd_delta; -// } -// } -// } - -// if constexpr (has_zp && group_blocks != -1) { -// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - -// if constexpr (group_blocks >= thread_k_blocks) { -// // Only fetch zero-points if this tile starts a new group -// if (pipe % (group_blocks / thread_k_blocks) == 0) { -// if (zp_sh_wr_pred) { -// cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); -// } -// zp_gl_rd += zp_gl_rd_delta; -// } -// } else { -// for (int i = 0; i < zp_tb_groups; i++) { -// if (zp_sh_wr_pred) { -// cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], -// &zp_ptr[zp_gl_rd]); -// } -// zp_gl_rd += zp_gl_rd_delta; -// } -// } -// } -// } -// } -// // Insert a fence even when we are winding down the pipeline to ensure that -// // waiting is also correct at this point. -// cp_async_fence(); -// }; - -// auto fetch_zp_to_shared = [&]() { -// if (zp_sh_wr_pred) { -// cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); -// } -// }; - -// // Wait until the next thread tile has been loaded to shared memory. -// auto wait_for_stage = [&]() { -// // We only have `stages - 2` active fetches since we are double buffering -// // and can only issue the next fetch when it is guaranteed that the previous -// // shared memory load is fully complete (as it may otherwise be -// // overwritten). -// cp_async_wait(); -// __syncthreads(); -// }; - -// // Load the next sub-tile from the current location in the shared memory pipe -// // into the current register buffer. -// auto fetch_to_registers = [&](int k, int pipe) { -// int4* sh_a_stage = sh_a + a_sh_stage * pipe; -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) -// ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); -// int4* sh_b_stage = sh_b + b_sh_stage * pipe; - -// #pragma unroll -// for (int i = 0; i < b_thread_vecs; i++) { -// frag_b_quant[k % 2][i] = *reinterpret_cast( -// &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); -// } -// }; - -// bool is_same_group[stages]; -// int same_group_id[stages]; - -// auto init_same_group = [&](int pipe) { -// if constexpr (!has_act_order) { -// is_same_group[pipe] = false; -// same_group_id[pipe] = 0; -// return; -// } - -// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; -// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - -// int group_id_1 = sh_g_idx_int_ptr[0]; -// int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - -// is_same_group[pipe] = group_id_1 == group_id_2; -// same_group_id[pipe] = group_id_1; -// }; - -// auto fetch_scales_to_registers = [&](int k, int full_pipe) { -// int pipe = full_pipe % stages; - -// if constexpr (!has_act_order) { -// // No act-order case -// if constexpr (group_blocks != -1) { -// if constexpr (group_blocks >= thread_k_blocks) { -// int4* sh_s_stage = -// sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * -// (pipe / (group_blocks / thread_k_blocks))); -// reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; -// } else { -// int warp_id = threadIdx.x / 32; -// int n_warps = thread_n_blocks / 4; - -// int warp_row = warp_id / n_warps; - -// int cur_k = warp_row * 16; -// cur_k += k_iter_size * (k % b_sh_wr_iters); - -// int k_blocks = cur_k / 16; -// int cur_group_id = k_blocks / group_blocks; - -// int4* sh_s_stage = sh_s + s_sh_stage * pipe; - -// reinterpret_cast(&frag_s[k % 2])[0] = -// sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; -// } -// } - -// return; -// } - -// // Act-order case - -// // Determine K of the "current" thread-block -// int cur_k = slice_k_start + tb_k * full_pipe; -// if (cur_k >= prob_k || cur_k >= slice_k_finish) { -// return; -// } - -// // Reset (to current thread-block) since we read g_idx portion from the -// // shared memory -// cur_k = 0; - -// // Progress to current iteration -// cur_k += k_iter_size * (k % b_sh_wr_iters); - -// // Determine "position" inside the thread-block (based on warp and -// // thread-id) -// int warp_id = threadIdx.x / 32; -// int n_warps = -// thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - -// int warp_row = warp_id / n_warps; -// int warp_col = warp_id % n_warps; - -// cur_k += warp_row * 16; - -// int th_id = threadIdx.x % 32; -// cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - -// int s_col_shift = -// /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + -// (th_id / 4) * act_s_col_stride; - -// if (is_same_group[pipe]) { -// if (k % 2 == 0) { -// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = -// sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + -// s_col_shift]; -// } else { -// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = -// *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); -// } - -// for (int i = 1; i < 4; i++) { -// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = -// *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); -// } -// return; -// } - -// int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; -// int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - -// constexpr int k_frag_offsets[4] = {0, 1, 8, -// 9}; // Tensor core offsets per thread - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// int actual_k = cur_k + k_frag_offsets[i]; - -// int group_id = sh_g_idx_int_ptr[actual_k]; -// int rel_group_id = group_id - sh_first_group_id; - -// *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = -// sh_s[rel_group_id * s_sh_stride + s_col_shift]; -// } -// }; - -// auto fetch_zp_to_registers = [&](int k, int full_pipe) { -// // This code does not handle group_blocks == 0, -// // which signifies act_order. -// // has_zp implies AWQ, which doesn't have act_order, -// static_assert(!has_zp || group_blocks != 0); - -// if constexpr (has_zp) { -// int pipe = full_pipe % stages; - -// if constexpr (group_blocks == -1) { -// for (int i = 0; i < num_ints_per_thread; i++) { -// frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; -// } - -// } else if constexpr (group_blocks >= thread_k_blocks) { -// int4* sh_zp_stage = -// sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * -// (pipe / (group_blocks / thread_k_blocks))); -// for (int i = 0; i < num_ints_per_thread; i++) { -// frag_qzp[k % 2][i] = -// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; -// } -// } else { -// int warp_id = threadIdx.x / 32; -// int n_warps = thread_n_blocks / 4; - -// int warp_row = warp_id / n_warps; - -// int cur_k = warp_row * 16; -// cur_k += k_iter_size * (k % b_sh_wr_iters); - -// int k_blocks = cur_k / 16; -// int cur_group_id = 0; - -// // Suppress bogus and persistent divide-by-zero warning -// #pragma nv_diagnostic push -// #pragma nv_diag_suppress divide_by_zero -// cur_group_id = k_blocks / group_blocks; -// #pragma nv_diagnostic pop - -// int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - -// sh_zp_stage += cur_group_id * zp_sh_stride; - -// for (int i = 0; i < num_ints_per_thread; i++) { -// frag_qzp[k % 2][i] = -// (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; -// } -// } -// } -// }; - -// // Execute the actual tensor core matmul of a sub-tile. -// auto matmul = [&](int k) { -// if constexpr (has_zp) { -// FragB frag_zp_0; -// FragB frag_zp_1; -// int zp_quant_0, zp_quant_1; - -// if constexpr (w_type.size_bits() == 4) { -// zp_quant_0 = frag_qzp[k % 2][0]; -// zp_quant_1 = zp_quant_0 >> 8; -// } else { -// static_assert(w_type.size_bits() == 8); -// zp_quant_0 = frag_qzp[k % 2][0]; -// zp_quant_1 = frag_qzp[k % 2][1]; -// } - -// frag_zp_0 = dequant(zp_quant_0); -// frag_zp_1 = dequant(zp_quant_1); - -// frag_zp[0] = frag_zp_0[0]; -// frag_zp[1] = frag_zp_0[1]; -// frag_zp[2] = frag_zp_1[0]; -// frag_zp[3] = frag_zp_1[1]; -// } - -// // We have the m dimension as the inner loop in order to encourage overlapping -// // dequantization and matmul operations. -// #pragma unroll -// for (int j = 0; j < 4; j++) { -// int b_quant_0, b_quant_1; -// if constexpr (w_type.size_bits() == 4) { -// b_quant_0 = frag_b_quant[k % 2][0][j]; -// b_quant_1 = b_quant_0 >> 8; -// } else { -// static_assert(w_type.size_bits() == 8); -// int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); -// b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; -// b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; -// } - -// FragB frag_b0 = dequant(b_quant_0); -// FragB frag_b1 = dequant(b_quant_1); - -// // Apply scale to frag_b0 -// if constexpr (has_act_order) { -// scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], -// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); -// } else { -// if constexpr (group_blocks != -1) { -// scale(frag_b0, frag_s[k % 2][j], 0); -// } -// } - -// // Apply zero-point to frag_b1 -// if constexpr (has_zp) { -// sub_zp(frag_b1, frag_zp[j], 1); -// } - -// // Apply scale to frag_b1 -// if constexpr (has_act_order) { -// scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], -// act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - -// } else { -// if constexpr (group_blocks != -1) { -// scale(frag_b1, frag_s[k % 2][j], 1); -// } -// } - -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) { -// mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); -// mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); -// } -// } -// }; - -// // Since we slice across the k dimension of a tile in order to increase the -// // number of warps while keeping the n dimension of a tile reasonable, we have -// // multiple warps that accumulate their partial sums of the same output -// // location; which we have to reduce over in the end. We do in shared memory. -// auto thread_block_reduce = [&]() { -// constexpr int red_off = threads / b_sh_stride_threads / 2; -// if (red_off >= 1) { -// int red_idx = threadIdx.x / b_sh_stride_threads; -// constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; -// constexpr int red_sh_delta = b_sh_stride_threads; -// int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + -// (threadIdx.x % b_sh_stride_threads); - -// // Parallel logarithmic shared memory reduction. We make sure to avoid any -// // unnecessary read or write iterations, e.g., for two warps we write only -// // once by warp 1 and read only once by warp 0. - -// #pragma unroll -// for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -// #pragma unroll -// for (int i = red_off; i > 0; i /= 2) { -// if (i <= red_idx && red_idx < 2 * i) { -// #pragma unroll -// for (int j = 0; j < 4 * 2; j++) { -// int red_sh_wr = -// red_sh_delta * j + (red_sh_rd - red_sh_stride * i); -// if (i < red_off) { -// float* c_rd = -// reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); -// float* c_wr = reinterpret_cast(&sh[red_sh_wr]); -// #pragma unroll -// for (int k = 0; k < 4; k++) -// reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += -// c_rd[k] + c_wr[k]; -// } -// sh[red_sh_wr] = -// reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; -// } -// } -// __syncthreads(); -// } -// if (red_idx == 0) { -// #pragma unroll -// for (int i = 0; i < 4 * 2; i++) { -// float* c_rd = -// reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -// #pragma unroll -// for (int j = 0; j < 4; j++) -// reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += -// c_rd[j]; -// } -// } -// __syncthreads(); -// } -// } -// }; - -// // Since multiple threadblocks may process parts of the same column slice, we -// // finally have to globally reduce over the results. As the striped -// // partitioning minimizes the number of such reductions and our outputs are -// // usually rather small, we perform this reduction serially in L2 cache. -// auto global_reduce = [&](bool first = false, bool last = false) { -// // We are very careful here to reduce directly in the output buffer to -// // maximize L2 cache utilization in this step. To do this, we write out -// // results in FP16 (but still reduce with FP32 compute). -// constexpr int active_threads = 32 * thread_n_blocks / 4; -// if (threadIdx.x < active_threads) { -// int c_gl_stride = prob_n / 8; -// int c_gl_wr_delta_o = 8 * c_gl_stride; -// int c_gl_wr_delta_i = 4 * (active_threads / 32); -// int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + -// 4 * (threadIdx.x / 32) + threadIdx.x % 4; -// c_gl_wr += (2 * thread_n_blocks) * slice_col; -// constexpr int c_sh_wr_delta = active_threads; -// int c_sh_wr = threadIdx.x; - -// int row = (threadIdx.x % 32) / 4; - -// if (!first) { -// // Interestingly, doing direct global accesses here really seems to mess up -// // the compiler and lead to slowdowns, hence we also use async-copies even -// // though these fetches are not actually asynchronous. -// #pragma unroll -// for (int i = 0; i < thread_m_blocks * 4; i++) { -// int c_idx = -// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); -// int sorted_row = sorted_ids[c_idx / c_gl_stride]; -// int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; -// cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], -// sorted_row < tot_m * topk && -// (8 * (i / 2) + row < prob_m && -// (i < (thread_m_blocks - 1) * 4 || -// sorted_ids[8 * (i / 2) + row] < tot_m * topk))); -// } -// cp_async_fence(); -// cp_async_wait<0>(); -// } - -// #pragma unroll -// for (int i = 0; i < thread_m_blocks * 4; i++) { -// if (8 * (i / 2) + row < prob_m && -// (i < (thread_m_blocks - 1) * 4 || -// sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { -// if (!first) { -// int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -// #pragma unroll -// for (int j = 0; j < 2 * 4; j++) { -// reinterpret_cast( -// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += -// __half2float(reinterpret_cast<__half*>(&c_red)[j]); -// } -// } -// if (!last) { -// int4 c; -// #pragma unroll -// for (int j = 0; j < 2 * 4; j++) { -// reinterpret_cast<__half*>(&c)[j] = -// __float2half(reinterpret_cast( -// &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); -// } -// int c_idx = -// c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); -// int row = sorted_ids[c_idx / c_gl_stride]; -// if (row < tot_m * topk) { -// int new_idx = row * c_gl_stride + c_idx % c_gl_stride; -// C[new_idx] = c; -// } -// } -// } -// } -// } -// }; - -// // Write out the reduce final result in the correct layout. We only actually -// // reshuffle matrix fragments in this step, the reduction above is performed -// // in fragment layout. -// auto write_result = [&]() { -// int c_gl_stride = prob_n / 8; -// constexpr int c_sh_stride = 2 * thread_n_blocks + 1; -// int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); -// constexpr int c_sh_rd_delta = -// c_sh_stride * (threads / (2 * thread_n_blocks)); - -// int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + -// (threadIdx.x % (2 * thread_n_blocks)); -// c_gl_wr += (2 * thread_n_blocks) * slice_col; -// int c_sh_wr = -// (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; -// c_sh_wr += 32 * (threadIdx.x / 32); -// int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + -// (threadIdx.x % (2 * thread_n_blocks)); - -// int c_gl_wr_end = c_gl_stride * prob_m; - -// // We first reorder in shared memory to guarantee the most efficient final -// // global write patterns -// auto write = [&](int idx, float c0, float c1, FragS& s) { -// half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - -// // For per-column quantization we finally apply the scale here (only for -// // 4-bit) -// if constexpr (!has_act_order && group_blocks == -1 && -// w_type.size_bits() == 4) { -// res = __hmul2(res, s[0]); -// } - -// ((half2*)sh)[idx] = res; -// }; -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) { -// #pragma unroll -// for (int j = 0; j < 4; j++) { -// int wr = c_sh_wr + 8 * j; -// write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], -// frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); -// write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], -// frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); -// write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], -// frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); -// write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], -// frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); -// } -// c_sh_wr += 16 * (4 * c_sh_stride); -// } -// } -// __syncthreads(); - -// #pragma unroll -// for (int i = 0; -// i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); -// i++) { -// if (c_gl_wr < c_gl_wr_end) { -// int row = sorted_ids[c_gl_wr / c_gl_stride]; -// if (row < tot_m * topk) { -// int off = row * c_gl_stride + c_gl_wr % c_gl_stride; -// if (!apply_weights) { -// C[off] = sh[c_sh_rd]; -// } else { -// __half* ctrg = reinterpret_cast<__half*>(&C[off]); -// __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); -// for (int j = 0; j < 8; ++j) { -// ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); -// } -// } -// c_gl_wr += c_gl_wr_delta; -// c_sh_rd += c_sh_rd_delta; -// } -// } -// } -// }; - -// // Start global fetch and register load pipelines. -// auto start_pipes = [&]() { - -// #pragma unroll -// for (int i = 0; i < stages - 1; i++) { -// if (has_act_order && i == 0) { -// int last_g_idx = slice_k_start + stages * tb_k * 2; -// if (last_g_idx >= prob_k) { -// last_g_idx = prob_k - 1; -// } -// fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); -// } - -// if constexpr (has_zp && group_blocks == -1) { -// if (i == 0) { -// fetch_zp_to_shared(); -// } -// } -// fetch_to_shared(i, i, i < slice_iters); -// } - -// zero_accums(); -// wait_for_stage(); -// init_same_group(0); -// fetch_to_registers(0, 0); -// fetch_scales_to_registers(0, 0); -// fetch_zp_to_registers(0, 0); -// a_gl_rd += a_gl_rd_delta_o * (stages - 1); -// slice_k_start_shared_fetch += tb_k * (stages - 1); -// }; -// if (slice_iters) { -// start_pipes(); -// } - -// // Main loop. -// while (slice_iters) { -// // We unroll over both the global fetch and the register load pipeline to -// // ensure all shared memory accesses are static. Note that both pipelines -// // have even length meaning that the next iteration will always start at -// // index 0. -// #pragma unroll -// for (int pipe = 0; pipe < stages;) { -// #pragma unroll -// for (int k = 0; k < b_sh_wr_iters; k++) { -// fetch_to_registers(k + 1, pipe % stages); -// fetch_scales_to_registers(k + 1, pipe); -// fetch_zp_to_registers(k + 1, pipe); -// if (k == b_sh_wr_iters - 2) { -// fetch_to_shared((pipe + stages - 1) % stages, pipe, -// slice_iters >= stages); -// pipe++; -// wait_for_stage(); -// init_same_group(pipe % stages); -// } -// matmul(k); -// } -// slice_iters--; -// if (slice_iters == 0) { -// break; -// } -// } - -// a_gl_rd += a_gl_rd_delta_o * stages; -// slice_k_start += tb_k * stages; -// slice_k_start_shared_fetch += tb_k * stages; - -// if constexpr (has_act_order) { -// int first_group_id = g_idx[slice_k_start]; -// int last_g_idx = slice_k_start + stages * tb_k * 2; -// if (last_g_idx >= prob_k) { -// last_g_idx = prob_k - 1; -// } -// int last_group_id = g_idx[last_g_idx]; -// if (last_group_id >= sh_first_group_id + sh_num_groups) { -// fetch_scales_to_shared(false, first_group_id, last_group_id); -// __syncthreads(); -// } -// } - -// // Process results and, if necessary, proceed to the next column slice. -// // While this pattern may not be the most readable, other ways of writing -// // the loop seemed to noticeably worse performance after compilation. -// if (slice_iters == 0) { -// cp_async_wait<0>(); -// bool last = slice_idx == slice_count - 1; -// if constexpr (!has_act_order && group_blocks == -1) { -// if constexpr (w_type.size_bits() == 8) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); -// } -// cp_async_fence(); -// } else { -// // For 4-bit per-column scales, we only fetch them here in the -// // final step before write-out -// if (last) { -// if (s_sh_wr_pred) { -// cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); -// } -// cp_async_fence(); -// } -// } -// } - -// thread_block_reduce(); -// if constexpr (!has_act_order && group_blocks == -1) { -// if constexpr (w_type.size_bits() == 8) { -// cp_async_wait<0>(); -// __syncthreads(); -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; -// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; -// } - -// } else { -// if (last) { -// cp_async_wait<0>(); -// __syncthreads(); -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; -// reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; -// } -// } -// } -// } - -// // For 8-bit channelwise, we apply the scale before the global reduction -// // that converts the fp32 results to fp16 (so that we avoid possible -// // overflow in fp16) -// if constexpr (!has_act_order && group_blocks == -1 && -// w_type.size_bits() == 8) { -// if (threadIdx.x / 32 < thread_n_blocks / 4) { -// #pragma unroll -// for (int i = 0; i < thread_m_blocks; i++) { -// #pragma unroll -// for (int j = 0; j < 4; j++) { -// scale_float(reinterpret_cast(&frag_c[i][j][0][0]), -// frag_s[j / 2][2 * (j % 2) + 0]); -// scale_float(reinterpret_cast(&frag_c[i][j][0][2]), -// frag_s[j / 2][2 * (j % 2) + 0]); - -// scale_float(reinterpret_cast(&frag_c[i][j][1][0]), -// frag_s[j / 2][2 * (j % 2) + 1]); -// scale_float(reinterpret_cast(&frag_c[i][j][1][2]), -// frag_s[j / 2][2 * (j % 2) + 1]); -// } -// } -// } -// } - -// if (slice_count > 1) { // only globally reduce if there is more than one -// // block in a slice -// barrier_acquire(&locks[slice_col], slice_idx); -// global_reduce(slice_idx == 0, last); -// barrier_release(&locks[slice_col], last); -// } -// if (last) // only the last block in a slice actually writes the result -// write_result(); -// slice_row = 0; -// slice_col_par++; -// slice_col++; -// init_slice(); -// if (slice_iters) { -// a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + -// (threadIdx.x % a_gl_rd_delta_o); -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) -// B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; -// if (slice_col == 0) { -// #pragma unroll -// for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; -// } - -// // Update slice k/n for scales loading -// if constexpr (has_act_order) { -// slice_k_start = tb_k * slice_row; -// slice_k_finish = slice_k_start + tb_k * slice_iters; -// slice_k_start_shared_fetch = slice_k_start; -// slice_n_offset = act_s_col_tb_stride * slice_col; - -// } else { -// s_gl_rd = s_sh_stride * slice_col + threadIdx.x; -// zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; -// } -// start_pipes(); -// } -// } -// } -// } - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - // int m_block_ctr = current_m_block; - - // const int* sorted_ids_expert = - // sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - // int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - // if (tot_its == 0) { - // return; - // } - // int tot_m_blocks = ceildiv(tot_its, 16); - // int pad = 16 * tot_m_blocks - tot_its; - - // if (m_block_ctr >= tot_m_blocks) { - // return; - // } - - // int max_block = tot_m_blocks - m_block_ctr; - // prob_m = tot_its - 16 * m_block_ctr; - - // int par = 1; - // if (max_block > cfg_max_m_blocks) { - // // Note that parallel > 1 currently only works for inputs without any - // // padding - // par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - // if (par > max_par) par = max_par; - // prob_m = (16 * cfg_max_m_blocks) * par; - // m_block_ctr += cfg_max_m_blocks * (par - 1); - // max_block = cfg_max_m_blocks; - // } - - // if (max_block == 1) { - // MarlinMoESingle( - // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - // current_m_block); - // } else if (max_block == 2) { - // MarlinMoESingle( - // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - // current_m_block); - // } else if (max_block == 3) { - // MarlinMoESingle( - // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - // current_m_block); - // } else { - // MarlinMoESingle( - // A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - // expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - // prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - // current_m_block); - // } -} - -#endif - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel.cuh b/csrc/moe/marlin_moe_kernel.cuh deleted file mode 100644 index 7dad50ed481e..000000000000 --- a/csrc/moe/marlin_moe_kernel.cuh +++ /dev/null @@ -1,391 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include - -#include - -#include "core/scalar_type.hpp" - -namespace marlin_moe { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales -using FragZP = Vec; - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { - half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -); - -#else - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks - -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel.h b/csrc/moe/marlin_moe_kernel.h new file mode 100644 index 000000000000..b2f612952344 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel.h @@ -0,0 +1,1584 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/scalar_type.hpp" + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales +using FragZP = Vec; + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { + half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_moe_kernel_ku4.cu index e84d4ad8de1e..d445be33fb16 100644 --- a/csrc/moe/marlin_moe_kernel_ku4.cu +++ b/csrc/moe/marlin_moe_kernel_ku4.cu @@ -1,22 +1,19 @@ -#include "marlin_moe_kernel.cuh" +#include "marlin_moe_kernel_ku4.h" namespace marlin_moe { -#define __CALL_IF_MOE_4(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ +#define __CALL_IF_MOE_4(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - MarlinMoE \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -25,26 +22,11 @@ namespace marlin_moe { cfg_max_m_blocks); \ } -#define AWQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) +#define AWQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. diff --git a/csrc/moe/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_moe_kernel_ku4.h new file mode 100644 index 000000000000..7524d65bc9c9 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku4.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_moe_kernel_ku4b8.cu index de437454df77..5f1b47999ba4 100644 --- a/csrc/moe/marlin_moe_kernel_ku4b8.cu +++ b/csrc/moe/marlin_moe_kernel_ku4b8.cu @@ -1,89 +1,33 @@ -#include "marlin_moe_kernel.cuh" +#include "marlin_moe_kernel_ku4b8.h" namespace marlin_moe { -#define __CALL_IF_MOE_4_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ +#define __CALL_IF_MOE_4_8(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ } -// #define __CALL_IF_MOE_4_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ -// THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, \ -// GROUP_BLOCKS, NUM_THREADS) \ -// else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ -// thread_n_blocks == THREAD_N_BLOCKS && \ -// thread_k_blocks == THREAD_K_BLOCKS && \ -// has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ -// group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ -// cudaFuncSetAttribute(MarlinMoE, \ -// cudaFuncAttributeMaxDynamicSharedMemorySize, \ -// max_shared_mem); \ -// MarlinMoE \ -// <<>>( \ -// A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ -// zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ -// num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ -// replicate_input, apply_weights, m_block, max_par, \ -// cfg_max_m_blocks); \ -// } - -#define GPTQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) +#define GPTQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_moe_kernel_ku4b8.h new file mode 100644 index 000000000000..01c67fd402cc --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku4b8.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_moe_kernel_ku8.cu index 931e074351dc..b8ed6f8ef50e 100644 --- a/csrc/moe/marlin_moe_kernel_ku8.cu +++ b/csrc/moe/marlin_moe_kernel_ku8.cu @@ -1,22 +1,19 @@ -#include "marlin_moe_kernel.cuh" +#include "marlin_moe_kernel_ku8.h" namespace marlin_moe { -#define __CALL_IF_MOE_8(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ +#define __CALL_IF_MOE_8(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - MarlinMoE \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -25,26 +22,11 @@ namespace marlin_moe { cfg_max_m_blocks); \ } -#define AWQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) +#define AWQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. diff --git a/csrc/moe/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_moe_kernel_ku8.h new file mode 100644 index 000000000000..53a712eb1968 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku8.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8( + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_moe_kernel_ku8b128.cu index 671466ae26c9..5b96e8773ed2 100644 --- a/csrc/moe/marlin_moe_kernel_ku8b128.cu +++ b/csrc/moe/marlin_moe_kernel_ku8b128.cu @@ -1,22 +1,19 @@ -#include "marlin_moe_kernel.cuh" +#include "marlin_moe_kernel_ku8b128.h" namespace marlin_moe { -#define __CALL_IF_MOE_8_128(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ +#define __CALL_IF_MOE_8_128(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - MarlinMoE \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -25,51 +22,15 @@ namespace marlin_moe { cfg_max_m_blocks); \ } -#define GPTQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, \ - NUM_THREADS) \ - \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, \ - NUM_THREADS) +#define GPTQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, \ + NUM_THREADS) \ + __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. diff --git a/csrc/moe/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_moe_kernel_ku8b128.h new file mode 100644 index 000000000000..b6d2f64f2a33 --- /dev/null +++ b/csrc/moe/marlin_moe_kernel_ku8b128.h @@ -0,0 +1,18 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, + const int4* A_ptr, const int4* B_ptr, int4* C_ptr, + const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, + const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, + int num_groups, int expert_idx, int num_experts, int topk, int prob_m, + int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, + bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + +} \ No newline at end of file diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index cb01faeeb11b..1f7da2387323 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -26,11 +26,10 @@ #include #include "core/scalar_type.hpp" -#include "marlin_moe_kernel_ku4b8.cu" -// #include "marlin_moe_kernel_ku8b128.cu" -// #include "marlin_moe_kernel_ku4.cu" -// #include "marlin_moe_kernel_ku8.cu" -// #include "marlin_moe_kernel.cuh" +#include "marlin_moe_kernel_ku4b8.h" +#include "marlin_moe_kernel_ku8b128.h" +#include "marlin_moe_kernel_ku4.h" +#include "marlin_moe_kernel_ku8.h" template inline std::string str(T x) { @@ -472,9 +471,9 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) - // CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) - // CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - // CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + From 0c7cbb5c435e133b2fb2bf88f7496fca4e0ecd0d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 20 Sep 2024 03:17:08 -0400 Subject: [PATCH 29/49] Cleanup --- CMakeLists.txt | 18 +++++----- csrc/moe/marlin_moe_kernel.h | 51 ++++++++++++++++++++++----- csrc/moe/marlin_moe_kernel_ku4.cu | 34 +++--------------- csrc/moe/marlin_moe_kernel_ku4.h | 2 +- csrc/moe/marlin_moe_kernel_ku4b8.cu | 35 +++--------------- csrc/moe/marlin_moe_kernel_ku4b8.h | 2 +- csrc/moe/marlin_moe_kernel_ku8.cu | 34 +++--------------- csrc/moe/marlin_moe_kernel_ku8.h | 2 +- csrc/moe/marlin_moe_kernel_ku8b128.cu | 38 +++----------------- 9 files changed, 70 insertions(+), 146 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c5934c59f36a..94027f9e74e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -295,19 +295,19 @@ define_gpu_extension_target( set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" - "csrc/moe/marlin_moe_kernel.h" - "csrc/moe/marlin_moe_kernel_ku4b8.h" - "csrc/moe/marlin_moe_kernel_ku4b8.cu" - "csrc/moe/marlin_moe_kernel_ku8b128.h" - "csrc/moe/marlin_moe_kernel_ku8b128.cu" - "csrc/moe/marlin_moe_kernel_ku4.h" - "csrc/moe/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_moe_kernel_ku8.h" - "csrc/moe/marlin_moe_kernel_ku8.cu" "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/marlin_moe_kernel.h" + "csrc/moe/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_moe_kernel_ku8b128.cu" + "csrc/moe/marlin_moe_kernel_ku4.h" + "csrc/moe/marlin_moe_kernel_ku4.cu" + "csrc/moe/marlin_moe_kernel_ku8.h" + "csrc/moe/marlin_moe_kernel_ku8.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_moe_kernel.h b/csrc/moe/marlin_moe_kernel.h index b2f612952344..d44294262d83 100644 --- a/csrc/moe/marlin_moe_kernel.h +++ b/csrc/moe/marlin_moe_kernel.h @@ -296,7 +296,7 @@ template __device__ void MarlinMoESingle( @@ -1432,10 +1432,10 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1532,10 +1532,10 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1581,4 +1581,37 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + } // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_moe_kernel_ku4.cu index d445be33fb16..7b4a396b5bb1 100644 --- a/csrc/moe/marlin_moe_kernel_ku4.cu +++ b/csrc/moe/marlin_moe_kernel_ku4.cu @@ -2,32 +2,6 @@ namespace marlin_moe { -#define __CALL_IF_MOE_4(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ - } - -#define AWQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4( @@ -42,10 +16,10 @@ bool call_marlin_moe_kernel_ku4( bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { if (false) { } - AWQ_CALL_IF_MOE_4(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF_MOE_4(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF_MOE_4(vllm::kU4, 4, 8, 128) + AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) else { return false; } diff --git a/csrc/moe/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_moe_kernel_ku4.h index 7524d65bc9c9..656d3a1c3b5d 100644 --- a/csrc/moe/marlin_moe_kernel_ku4.h +++ b/csrc/moe/marlin_moe_kernel_ku4.h @@ -2,7 +2,7 @@ #include "marlin_moe_kernel.h" -namespace marlin_moe { +namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_moe_kernel_ku4b8.cu index 5f1b47999ba4..fa6cc315eed7 100644 --- a/csrc/moe/marlin_moe_kernel_ku4b8.cu +++ b/csrc/moe/marlin_moe_kernel_ku4b8.cu @@ -2,33 +2,6 @@ namespace marlin_moe { -#define __CALL_IF_MOE_4_8(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ - } - -#define GPTQ_CALL_IF_MOE_4(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE_4_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4b8( @@ -43,10 +16,10 @@ bool call_marlin_moe_kernel_ku4b8( bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { if (false) { } - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF_MOE_4(vllm::kU4B8, 4, 8, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) else { return false; } diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_moe_kernel_ku4b8.h index 01c67fd402cc..3975792d8c49 100644 --- a/csrc/moe/marlin_moe_kernel_ku4b8.h +++ b/csrc/moe/marlin_moe_kernel_ku4b8.h @@ -2,7 +2,7 @@ #include "marlin_moe_kernel.h" -namespace marlin_moe { +namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. diff --git a/csrc/moe/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_moe_kernel_ku8.cu index b8ed6f8ef50e..2ab0d88d763d 100644 --- a/csrc/moe/marlin_moe_kernel_ku8.cu +++ b/csrc/moe/marlin_moe_kernel_ku8.cu @@ -2,32 +2,6 @@ namespace marlin_moe { -#define __CALL_IF_MOE_8(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ - } - -#define AWQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8( @@ -42,10 +16,10 @@ bool call_marlin_moe_kernel_ku8( bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { if (false) { } - AWQ_CALL_IF_MOE_8(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF_MOE_8(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF_MOE_8(vllm::kU8, 4, 8, 128) + AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) + AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) + AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) + AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) else { return false; } diff --git a/csrc/moe/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_moe_kernel_ku8.h index 53a712eb1968..25bae5f4875c 100644 --- a/csrc/moe/marlin_moe_kernel_ku8.h +++ b/csrc/moe/marlin_moe_kernel_ku8.h @@ -2,7 +2,7 @@ #include "marlin_moe_kernel.h" -namespace marlin_moe { +namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. diff --git a/csrc/moe/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_moe_kernel_ku8b128.cu index 5b96e8773ed2..d399fed939bd 100644 --- a/csrc/moe/marlin_moe_kernel_ku8b128.cu +++ b/csrc/moe/marlin_moe_kernel_ku8b128.cu @@ -2,36 +2,6 @@ namespace marlin_moe { -#define __CALL_IF_MOE_8_128(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ - } - -#define GPTQ_CALL_IF_MOE_8(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, \ - NUM_THREADS) \ - __CALL_IF_MOE_8_128(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8b128( @@ -46,10 +16,10 @@ bool call_marlin_moe_kernel_ku8b128( bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { if (false) { } - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF_MOE_8(vllm::kU8B128, 4, 8, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { return false; } From 8a8f92500fcb1851df835df199418df549f5da06 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 20 Sep 2024 03:27:42 -0400 Subject: [PATCH 30/49] function name --- csrc/moe/marlin_moe_ops.cu | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 1f7da2387323..682146cacead 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -323,18 +323,17 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, max_par, exec_cfg.max_m_blocks)) { \ } -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, void* zp, - const void* g_idx, const void* perm, void* a_tmp, - void* expert_offsets, int prob_m, int prob_n, - int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, bool has_zp, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { +void marlin_mm_moe(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, void* zp, + const void* g_idx, const void* perm, void* a_tmp, + void* expert_offsets, int prob_m, int prob_n, int prob_k, + void* workspace, vllm::ScalarType const& q_type, + bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -579,7 +578,7 @@ torch::Tensor marlin_gemm_moe( " is not size_n / pack_factor = ", size_n / pack_factor); } - marlin_moe::marlin_mm_moe_f16i4( + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), From 936f2b9b30aaed9cc3c02146632e07fc10e4663d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 20 Sep 2024 04:49:39 -0400 Subject: [PATCH 31/49] Move kernel files to a separate directory --- CMakeLists.txt | 18 +++++------ .../{ => marlin_kernels}/marlin_moe_kernel.h | 0 .../marlin_moe_kernel_ku4.cu | 7 +++-- .../marlin_moe_kernel_ku4.h | 4 +-- .../marlin_moe_kernel_ku4b8.cu | 7 +++-- .../marlin_moe_kernel_ku4b8.h | 4 +-- .../marlin_moe_kernel_ku8.cu | 7 +++-- .../marlin_moe_kernel_ku8.h | 4 +-- .../marlin_moe_kernel_ku8b128.cu | 7 +++-- .../marlin_moe_kernel_ku8b128.h | 6 ++-- csrc/moe/marlin_moe_ops.cu | 30 ++++++++----------- 11 files changed, 51 insertions(+), 43 deletions(-) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel.h (100%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku4.cu (86%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku4.h (82%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku4b8.cu (86%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku4b8.h (82%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku8.cu (86%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku8.h (82%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku8b128.cu (86%) rename csrc/moe/{ => marlin_kernels}/marlin_moe_kernel_ku8b128.h (79%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 94027f9e74e6..72100a109c82 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -299,15 +299,15 @@ set(VLLM_MOE_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC - "csrc/moe/marlin_moe_kernel.h" - "csrc/moe/marlin_moe_kernel_ku4b8.h" - "csrc/moe/marlin_moe_kernel_ku4b8.cu" - "csrc/moe/marlin_moe_kernel_ku8b128.h" - "csrc/moe/marlin_moe_kernel_ku8b128.cu" - "csrc/moe/marlin_moe_kernel_ku4.h" - "csrc/moe/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_moe_kernel_ku8.h" - "csrc/moe/marlin_moe_kernel_ku8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h similarity index 100% rename from csrc/moe/marlin_moe_kernel.h rename to csrc/moe/marlin_kernels/marlin_moe_kernel.h diff --git a/csrc/moe/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu similarity index 86% rename from csrc/moe/marlin_moe_kernel_ku4.cu rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu index 7b4a396b5bb1..79525182c0f3 100644 --- a/csrc/moe/marlin_moe_kernel_ku4.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu @@ -5,8 +5,8 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, @@ -14,6 +14,9 @@ bool call_marlin_moe_kernel_ku4( int num_groups, int expert_idx, int num_experts, int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + + bool has_zp = true; + if (false) { } AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) diff --git a/csrc/moe/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h similarity index 82% rename from csrc/moe/marlin_moe_kernel_ku4.h rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h index 656d3a1c3b5d..9d805e5c7837 100644 --- a/csrc/moe/marlin_moe_kernel_ku4.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h @@ -7,8 +7,8 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu similarity index 86% rename from csrc/moe/marlin_moe_kernel_ku4b8.cu rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu index fa6cc315eed7..aae4a43a1b8b 100644 --- a/csrc/moe/marlin_moe_kernel_ku4b8.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -5,8 +5,8 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, @@ -14,6 +14,9 @@ bool call_marlin_moe_kernel_ku4b8( int num_groups, int expert_idx, int num_experts, int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + + bool has_zp = false; + if (false) { } GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) diff --git a/csrc/moe/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h similarity index 82% rename from csrc/moe/marlin_moe_kernel_ku4b8.h rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h index 3975792d8c49..6cd187ae96a4 100644 --- a/csrc/moe/marlin_moe_kernel_ku4b8.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -7,8 +7,8 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, diff --git a/csrc/moe/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu similarity index 86% rename from csrc/moe/marlin_moe_kernel_ku8.cu rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu index 2ab0d88d763d..95e388e220a9 100644 --- a/csrc/moe/marlin_moe_kernel_ku8.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu @@ -5,8 +5,8 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, @@ -14,6 +14,9 @@ bool call_marlin_moe_kernel_ku8( int num_groups, int expert_idx, int num_experts, int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + + bool has_zp = true; + if (false) { } AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) diff --git a/csrc/moe/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h similarity index 82% rename from csrc/moe/marlin_moe_kernel_ku8.h rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h index 25bae5f4875c..85ad32d46d34 100644 --- a/csrc/moe/marlin_moe_kernel_ku8.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h @@ -7,8 +7,8 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, diff --git a/csrc/moe/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu similarity index 86% rename from csrc/moe/marlin_moe_kernel_ku8b128.cu rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu index d399fed939bd..9176268744a0 100644 --- a/csrc/moe/marlin_moe_kernel_ku8b128.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -5,8 +5,8 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, @@ -14,6 +14,9 @@ bool call_marlin_moe_kernel_ku8b128( int num_groups, int expert_idx, int num_experts, int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { + + bool has_zp = false; + if (false) { } GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) diff --git a/csrc/moe/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h similarity index 79% rename from csrc/moe/marlin_moe_kernel_ku8b128.h rename to csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h index b6d2f64f2a33..bad01604f12d 100644 --- a/csrc/moe/marlin_moe_kernel_ku8b128.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -5,8 +5,8 @@ namespace marlin_moe { bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, bool has_zp, int group_blocks, + vllm::ScalarType const& q_type, int thread_n_blocks, + int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, @@ -15,4 +15,4 @@ bool call_marlin_moe_kernel_ku8b128( int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); -} \ No newline at end of file +} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 682146cacead..bf9266578d70 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -26,10 +26,10 @@ #include #include "core/scalar_type.hpp" -#include "marlin_moe_kernel_ku4b8.h" -#include "marlin_moe_kernel_ku8b128.h" -#include "marlin_moe_kernel_ku4.h" -#include "marlin_moe_kernel_ku8.h" +#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" +#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" +#include "marlin_kernels/marlin_moe_kernel_ku4.h" +#include "marlin_kernels/marlin_moe_kernel_ku8.h" template inline std::string str(T x) { @@ -312,15 +312,15 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ - else if (KERNEL_FUNCTION( \ - q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, \ - has_act_order, has_zp, group_blocks, num_threads, blocks, \ - max_shared_mem, stream, A_ptr, B_ptr, C_ptr, sorted_ids_ptr, \ - topk_weights_ptr, s_ptr, zp_ptr, g_idx_ptr, expert_offsets_ptr, \ - num_groups, expert_idx, num_experts, topk, prob_m, prob_n, \ - prob_k, tot_m, locks, replicate_input, apply_weights, m_block, \ - max_par, exec_cfg.max_m_blocks)) { \ +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION( \ + q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ + group_blocks, num_threads, blocks, max_shared_mem, stream, \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks)) { \ } void marlin_mm_moe(const void* A, const void* B, void* C, @@ -462,9 +462,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C, int tot_m_blocks = ceildiv(tot_m, 16); for (int m_block = 0; m_block < tot_m_blocks; m_block += 4 * exec_cfg.max_m_blocks) { - // make it max possible value - int thread_m_blocks = exec_cfg.max_m_blocks; - int cfg_max_m_blocks = exec_cfg.max_m_blocks; if (false) { @@ -479,7 +476,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C, ", has_act_order = " + str(has_act_order) + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + ", thread_n_blocks = " + str(thread_n_blocks) + ", thread_k_blocks = " + str(thread_k_blocks)); } From 98ec9b6f7ef64e29600eef15e26b5f2212279fdb Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 20 Sep 2024 11:30:11 -0400 Subject: [PATCH 32/49] Unit tests --- .../marlin_kernels/marlin_moe_kernel_ku4.cu | 19 +- .../marlin_kernels/marlin_moe_kernel_ku4.h | 18 +- .../marlin_kernels/marlin_moe_kernel_ku4b8.cu | 19 +- .../marlin_kernels/marlin_moe_kernel_ku4b8.h | 18 +- .../marlin_kernels/marlin_moe_kernel_ku8.cu | 19 +- .../marlin_kernels/marlin_moe_kernel_ku8.h | 18 +- .../marlin_moe_kernel_ku8b128.cu | 19 +- .../marlin_moe_kernel_ku8b128.h | 18 +- csrc/moe/marlin_moe_ops.cu | 12 +- tests/kernels/test_awq_marlin.py | 220 ++++++++++++++++++ tests/kernels/test_moe.py | 16 +- .../layers/fused_moe/fused_marlin_moe.py | 116 ++++++--- .../compressed_tensors_moe.py | 12 +- .../layers/quantization/gptq_marlin.py | 12 +- 14 files changed, 400 insertions(+), 136 deletions(-) create mode 100644 tests/kernels/test_awq_marlin.py diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu index 79525182c0f3..77bc0dd90edd 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu @@ -5,16 +5,15 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { bool has_zp = true; if (false) { diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h index 9d805e5c7837..833fadf37721 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h @@ -7,14 +7,14 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); } // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu index aae4a43a1b8b..f7e57b037594 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -5,16 +5,15 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { bool has_zp = false; if (false) { diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h index 6cd187ae96a4..494da8f10e26 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -7,14 +7,14 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); } // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu index 95e388e220a9..7abbc45440bf 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu @@ -5,16 +5,15 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { bool has_zp = true; if (false) { diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h index 85ad32d46d34..03a0132aa347 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h @@ -7,14 +7,14 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); } // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu index 9176268744a0..a901f0b11cd7 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -5,16 +5,15 @@ namespace marlin_moe { // We return bool so we can create these different kernel calls as a sequence // of if-elseif's. bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) { - + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { bool has_zp = false; if (false) { diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h index bad01604f12d..f3018aa0c1ab 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -5,14 +5,14 @@ namespace marlin_moe { bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, - int thread_k_blocks, bool has_act_order, int group_blocks, - int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, - const int4* A_ptr, const int4* B_ptr, int4* C_ptr, - const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr, - const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr, - int num_groups, int expert_idx, int num_experts, int topk, int prob_m, - int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input, - bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks); + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); } diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index bf9266578d70..7f7017deac40 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -444,7 +444,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C, const int4* zp_ptr = (const int4*)zp + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 8) * + prob_n / 4) * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; @@ -565,12 +565,12 @@ torch::Tensor marlin_gemm_moe( // Verify b_zeros if (has_zp) { int rank = b_zeros.sizes().size(); - TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); - TORCH_CHECK(b_zeros.size(0) == num_groups, - "b_zeros dim 0 = ", b_zeros.size(0), + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + TORCH_CHECK(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_scales.size(1), + TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_scales.size(2), " is not size_n / pack_factor = ", size_n / pack_factor); } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py new file mode 100644 index 000000000000..e60b7d976c8f --- /dev/null +++ b/tests/kernels/test_awq_marlin.py @@ -0,0 +1,220 @@ +"""Test AWQ with fused MoE Marlin kernels. + +Run `pytest tests/kernels/test_awq_marlin.py`. +""" +from typing import List + +import pytest +import torch + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + awq_marlin_quantize +) + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + +def torch_moe(a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) + +@pytest.mark.skip("TODO") +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_fused_marlin_moe_awq( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + num_bits: int, +): + torch.manual_seed(7) + + if topk > e: + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + w_ref1_l = [] + qweights1_l = [] + scales1_l = [] + zp1_l = [] + + for i in range(w1.shape[0]): + w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size) + w_ref1_l.append(w_ref1) + qweights1_l.append(qweight1) + scales1_l.append(scales1) + zp1_l.append(zp1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweights1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + zp1 = stack_and_dev(zp1_l) + + w_ref2_l = [] + qweights2_l = [] + scales2_l = [] + zp2_l = [] + + for i in range(w2.shape[0]): + w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size) + w_ref2_l.append(w_ref2) + qweights2_l.append(qweight2) + scales2_l.append(scales2) + zp2_l.append(zp2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweights2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + zp2 = stack_and_dev(zp2_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, False) + + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_marlin_moe( + a, + qweight1, + qweight2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + w1_zeros=zp1, + w2_zeros=zp2, + num_bits=num_bits, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 + + +# @pytest.mark.skip("This test is here for the sake of debugging, " +# "don't run it in automated tests.") +# @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +# @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +# @pytest.mark.parametrize("k", [128, 1024, 512]) +# @pytest.mark.parametrize("e", [4, 8, 64]) +# @pytest.mark.parametrize("topk", [2, 6]) +# @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +# @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("m", [1]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [4]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("group_size", [-1]) +@pytest.mark.parametrize("num_bits", [4]) +def test_single_marlin_moe_multiply_awq( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + num_bits: int, +): + if topk > e: + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + + w_ref_l = [] + qweights_l = [] + scales_l = [] + zp_l = [] + + for i in range(w.shape[0]): + w_ref, qweight, scales, zp = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size) + w_ref_l.append(w_ref) + qweights_l.append(qweight) + scales_l.append(scales) + zp_l.append(zp) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweights_l).contiguous() + scales = stack_and_dev(scales_l) + zp = stack_and_dev(zp_l) + + print(scales.dtype) + print(zp.dtype) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + marlin_output = single_marlin_moe(a, + qweight, + scales, + score, + topk, + renormalize=False, + w_zeros=zp, + num_bits=num_bits) + + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8072cf09e5b6..69c5548589d0 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -232,15 +232,15 @@ def test_fused_marlin_moe( a, qweight1, qweight2, + scales1, + scales2, score, - g_idx1, - g_idx2, - sort_indices1, - sort_indices2, topk_weights, topk_ids, - w1_scale=scales1, - w2_scale=scales2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, num_bits=num_bits, ) @@ -310,10 +310,10 @@ def test_single_marlin_moe_multiply( qweight, scales, score, - g_idx, - sort_indices, topk, renormalize=False, + g_idx=g_idx, + sort_indices=sort_indices, num_bits=num_bits) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index fafd74493be0..87e2330ca1f9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -10,15 +10,23 @@ from vllm.scalar_type import scalar_types +def get_scalar_type(num_bits: int, has_zp: bool): + if has_zp: + return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + + def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, topk: int, renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, ) -> torch.Tensor: @@ -33,10 +41,12 @@ def single_marlin_moe( - scales (torch.Tensor): The quantization scales. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - g_idx (torch.Tensor): The act_order indices. - - perm (torch.Tensor): The act_order input permutation. + - g_idx (Optional[torch.Tensor]): Optional act_order indices. + - sort_indices (Optional[torch.Tensor]): Optional act_order input + permutation. - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. @@ -81,18 +91,31 @@ def single_marlin_moe( device=hidden_states.device, requires_grad=False) - w_zeros = torch.empty((0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) + has_zp = w_zeros is not None + if w_zeros is None: + w_zeros = torch.empty((0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + + if g_idx is None: + g_idx = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + if sort_indices is None: + sort_indices = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type = get_scalar_type(num_bits, has_zp) intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - w_zeros, g_idx, perm, workspace, scalar_type, M, N, K, True, False, E, - topk, block_size_m, True, False) + w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, True, + has_zp, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -101,18 +124,18 @@ def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - perm1: torch.Tensor, - perm2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - override_config: Optional[Dict[str, Any]] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, + override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, ) -> torch.Tensor: """ @@ -123,21 +146,22 @@ def fused_marlin_moe( - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. + - w1_scale (torch.Tensor): Scale to be used for w1. + - w2_scale (torch.Tensor): Scale to be used for w2. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - g_idx1 (torch.Tensor): The first set of act_order indices. - - g_idx2 (torch.Tensor): The second set of act_order indices. - - perm1 (torch.Tensor): The first act_order input permutation. - - perm2 (torch.Tensor): The second act_order input permutation. + - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. + - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. + - sort_indices1 (Optional[torch.Tensor]): The first act_order input + permutation. + - sort_indices2 (Optional[torch.Tensor]): The second act_order input + permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. + - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. + - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -196,8 +220,32 @@ def fused_marlin_moe( device=hidden_states.device, requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + if g_idx1 is None: + g_idx1 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + if g_idx2 is None: + g_idx2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + if sort_indices1 is None: + sort_indices1 = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + if sort_indices2 is None: + sort_indices2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type1 = get_scalar_type(num_bits, has_zp1) + scalar_type2 = get_scalar_type(num_bits, has_zp2) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -214,9 +262,9 @@ def fused_marlin_moe( w1_scale, w1_zeros, g_idx1, - perm1, + sort_indices1, workspace, - scalar_type, + scalar_type1, M, 2 * N, K, @@ -240,9 +288,9 @@ def fused_marlin_moe( w2_scale, w2_zeros, g_idx2, - perm2, + sort_indices2, workspace, - scalar_type, + scalar_type2, M, K, N, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7dee2fca8115..80e7f48d647a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -286,14 +286,14 @@ def apply( x, layer.w13_weight_packed, layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index cc699f5b4554..dd46f0ce5a39 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -602,14 +602,14 @@ def apply( x, layer.w13_qweight, layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, ).to(orig_dtype) From fa23e51cec112fce860ac896072bd6732f558304 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 24 Sep 2024 14:50:46 -0400 Subject: [PATCH 33/49] working kernel --- csrc/moe/marlin_kernels/marlin_moe_kernel.h | 31 ++++----- csrc/moe/marlin_moe_ops.cu | 13 +--- csrc/quantization/gptq_marlin/gptq_marlin.cu | 2 +- tests/kernels/test_awq_marlin.py | 66 +++++++++----------- 4 files changed, 49 insertions(+), 63 deletions(-) diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h index d44294262d83..808fcdae7a7c 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -38,7 +38,7 @@ using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; // quantization scales -using FragZP = Vec; +using FragZP = Vec; // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. @@ -230,13 +230,6 @@ __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { frag_b[1] = __hsub2(frag_b[1], zp); } -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - // Same as above, but for act_order (each K is multiplied individually) __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, FragS& frag_s_3, FragS& frag_s_4, int i) { @@ -252,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { @@ -440,6 +440,7 @@ __device__ void MarlinMoESingle( : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; @@ -536,12 +537,6 @@ __device__ void MarlinMoESingle( int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; @@ -674,6 +669,7 @@ __device__ void MarlinMoESingle( for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } + B_ptr[i] += b_gl_rd_delta_o; } @@ -990,6 +986,10 @@ __device__ void MarlinMoESingle( FragB frag_b0 = dequant(b_quant_0); FragB frag_b1 = dequant(b_quant_1); + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); + } // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -1193,6 +1193,7 @@ __device__ void MarlinMoESingle( ((half2*)sh)[idx] = res; }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1277,6 +1278,7 @@ __device__ void MarlinMoESingle( // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll @@ -1420,6 +1422,7 @@ __device__ void MarlinMoESingle( s_gl_rd = s_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } + start_pipes(); } } diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 7f7017deac40..705251008628 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -436,16 +436,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C, int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = - (const int4*)s + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 8) * - expert_idx; + const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; const int4* zp_ptr = - (const int4*)zp + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 4) * - expert_idx; + (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; @@ -570,7 +563,7 @@ torch::Tensor marlin_gemm_moe( "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups); TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, - "b_zeros dim 2 = ", b_scales.size(2), + "b_zeros dim 2 = ", b_zeros.size(2), " is not size_n / pack_factor = ", size_n / pack_factor); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9b4a6a515107..f943185bab7f 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -2258,7 +2258,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_scales.size(1), + "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n / pack_factor = ", size_n / pack_factor); } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index e60b7d976c8f..e408636bdb2d 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -8,23 +8,24 @@ import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.scalar_type import scalar_types from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - awq_marlin_quantize -) + awq_marlin_quantize) +from vllm.scalar_type import scalar_types + def stack_and_dev(tensors: List[torch.Tensor]): dev = tensors[0].device return torch.stack(tensors, dim=0).to(dev) + def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref)) + def torch_moe(a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -55,7 +56,7 @@ def torch_moe_single(a, w, score, topk): out[mask] = a[mask] @ w[i].transpose(0, 1) return (out.view(B, -1, w.shape[1])).sum(dim=1) -@pytest.mark.skip("TODO") + @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) @@ -77,8 +78,7 @@ def test_fused_marlin_moe_awq( if topk > e: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -123,15 +123,6 @@ def test_fused_marlin_moe_awq( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, False) - - triton_output = fused_moe( - a, - w_ref1.transpose(1, 2).contiguous(), - w_ref2.transpose(1, 2).contiguous(), - score, - topk, - renormalize=False, - ) marlin_output = fused_marlin_moe( a, qweight1, @@ -146,25 +137,26 @@ def test_fused_marlin_moe_awq( num_bits=num_bits, ) - assert compute_max_diff(marlin_output, triton_output) < 4e-2 + torch_output = torch_moe( + a, + w_ref1.transpose(1, 2), + w_ref2.transpose(1, 2), + score, + topk, + ) + + assert compute_max_diff(marlin_output, torch_output) < 4e-2 # @pytest.mark.skip("This test is here for the sake of debugging, " # "don't run it in automated tests.") -# @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -# @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -# @pytest.mark.parametrize("k", [128, 1024, 512]) -# @pytest.mark.parametrize("e", [4, 8, 64]) -# @pytest.mark.parametrize("topk", [2, 6]) -# @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -# @pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("m", [1]) -@pytest.mark.parametrize("n", [128]) -@pytest.mark.parametrize("k", [128]) -@pytest.mark.parametrize("e", [4]) -@pytest.mark.parametrize("topk", [2]) -@pytest.mark.parametrize("group_size", [-1]) -@pytest.mark.parametrize("num_bits", [4]) +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("num_bits", [4, 8]) def test_single_marlin_moe_multiply_awq( m: int, n: int, @@ -174,11 +166,12 @@ def test_single_marlin_moe_multiply_awq( group_size: int, num_bits: int, ): + torch.manual_seed(7) + if topk > e: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -198,11 +191,8 @@ def test_single_marlin_moe_multiply_awq( w_ref = stack_and_dev(w_ref_l) qweight = stack_and_dev(qweights_l).contiguous() - scales = stack_and_dev(scales_l) - zp = stack_and_dev(zp_l) - - print(scales.dtype) - print(zp.dtype) + scales = stack_and_dev(scales_l).contiguous() + zp = stack_and_dev(zp_l).contiguous() score = torch.randn((m, e), device="cuda", dtype=dtype) From e98bc45f96df15f64fee44f04d2a962dfb488ff0 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 25 Sep 2024 02:43:33 -0400 Subject: [PATCH 34/49] clean up unit tests, disable single awq test --- tests/kernels/test_awq_marlin.py | 50 +++----------------------------- tests/kernels/test_moe.py | 46 ++--------------------------- tests/kernels/utils.py | 45 ++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 90 deletions(-) diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index e408636bdb2d..4481ebaf9c1f 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -2,12 +2,11 @@ Run `pytest tests/kernels/test_awq_marlin.py`. """ -from typing import List - import pytest import torch -from vllm.model_executor.layers.activation import SiluAndMul +from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, + torch_moe_single) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -16,47 +15,6 @@ from vllm.scalar_type import scalar_types -def stack_and_dev(tensors: List[torch.Tensor]): - dev = tensors[0].device - return torch.stack(tensors, dim=0).to(dev) - - -def compute_max_diff(output, output_ref): - return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) - - -def torch_moe(a, w1, w2, score, topk): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - -def torch_moe_single(a, w, score, topk): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - _, topk_ids = torch.topk(score, topk) - topk_ids = topk_ids.view(-1) - for i in range(w.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = a[mask] @ w[i].transpose(0, 1) - return (out.view(B, -1, w.shape[1])).sum(dim=1) - - @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) @@ -148,8 +106,8 @@ def test_fused_marlin_moe_awq( assert compute_max_diff(marlin_output, torch_output) < 4e-2 -# @pytest.mark.skip("This test is here for the sake of debugging, " -# "don't run it in automated tests.") +@pytest.mark.skip("This test is here for the sake of debugging, " + "don't run it in automated tests.") @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f5e8b3e270a9..5f03a65c62f9 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,14 +2,13 @@ Run `pytest tests/kernels/test_moe.py`. """ -from typing import List - import pytest import torch from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from vllm.model_executor.layers.activation import SiluAndMul +from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, + torch_moe_single) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) @@ -21,37 +20,6 @@ from vllm.utils import seed_everything -def torch_moe(a, w1, w2, score, topk): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - -def torch_moe_single(a, w, score, topk): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - _, topk_ids = torch.topk(score, topk) - topk_ids = topk_ids.view(-1) - for i in range(w.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = a[mask] @ w[i].transpose(0, 1) - return (out.view(B, -1, w.shape[1])).sum(dim=1) - - @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -124,16 +92,6 @@ def test_mixtral_moe(dtype: torch.dtype): atol=mixtral_moe_tol[dtype]) -def stack_and_dev(tensors: List[torch.Tensor]): - dev = tensors[0].device - return torch.stack(tensors, dim=0).to(dev) - - -def compute_max_diff(output, output_ref): - return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) - - @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5746932c30a4..41dfdcd08ff0 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -10,6 +10,7 @@ import torch from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType +from vllm.model_executor.layers.activation import SiluAndMul from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -960,3 +961,47 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, kwargs, test_utils=test_utils, raise_exception=raise_exception) if cond else {} + + +# Marlin MoE test utils + + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +def torch_moe(a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) From 2f09e582485bd84a22fe4849d603a33f9b8fd10b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 25 Sep 2024 08:34:50 -0400 Subject: [PATCH 35/49] make has_zero_point boolean explicitly passed to fused_marlin_moe --- tests/kernels/test_awq_marlin.py | 2 ++ .../layers/fused_moe/fused_marlin_moe.py | 24 ++++++++++++------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 4481ebaf9c1f..aeacc16e5396 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -90,6 +90,7 @@ def test_fused_marlin_moe_awq( score, topk_weights, topk_ids, + has_zero_point=True, w1_zeros=zp1, w2_zeros=zp2, num_bits=num_bits, @@ -160,6 +161,7 @@ def test_single_marlin_moe_multiply_awq( score, topk, renormalize=False, + has_zero_point=True, w_zeros=zp, num_bits=num_bits) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 87e2330ca1f9..60fcb6d100d8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -24,6 +24,7 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, + has_zero_point: bool = False, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, @@ -91,7 +92,9 @@ def single_marlin_moe( device=hidden_states.device, requires_grad=False) - has_zp = w_zeros is not None + if has_zero_point: + assert w_zeros is not None and w_zeros.nelement() > 0 + if w_zeros is None: w_zeros = torch.empty((0), dtype=hidden_states.dtype, @@ -110,12 +113,12 @@ def single_marlin_moe( device=hidden_states.device, requires_grad=False) - scalar_type = get_scalar_type(num_bits, has_zp) + scalar_type = get_scalar_type(num_bits, has_zero_point) intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, True, - has_zp, E, topk, block_size_m, True, False) + has_zero_point, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -129,6 +132,7 @@ def fused_marlin_moe( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + has_zero_point: bool = False, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -207,8 +211,10 @@ def fused_marlin_moe( device="cuda", requires_grad=False) - has_zp1 = w1_zeros is not None - has_zp2 = w2_zeros is not None + if has_zero_point: + assert w1_zeros is not None and w1_zeros.nelement() > 0 + assert w2_zeros is not None and w2_zeros.nelement() > 0 + if w1_zeros is None: w1_zeros = torch.empty((0), dtype=hidden_states.dtype, @@ -244,8 +250,8 @@ def fused_marlin_moe( device=hidden_states.device, requires_grad=False) - scalar_type1 = get_scalar_type(num_bits, has_zp1) - scalar_type2 = get_scalar_type(num_bits, has_zp2) + scalar_type1 = get_scalar_type(num_bits, has_zero_point) + scalar_type2 = get_scalar_type(num_bits, has_zero_point) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -269,7 +275,7 @@ def fused_marlin_moe( 2 * N, K, True, - has_zp1, + has_zero_point, E, topk, block_size_m, @@ -295,7 +301,7 @@ def fused_marlin_moe( K, N, True, - has_zp2, + has_zero_point, E, topk, block_size_m, From 000796acdf1e6184eeb36272c5ddd6ffbc41fac3 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 26 Sep 2024 18:03:11 +0000 Subject: [PATCH 36/49] add awq moe --- .../model_executor/layers/quantization/awq.py | 191 +++++++++++++++++- vllm/model_executor/model_loader/utils.py | 2 +- 2 files changed, 188 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 410b3cb5321c..e564b18e7d32 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,14 +1,22 @@ -from typing import Any, Dict, List, Optional +from typing import Callable, Any, Dict, List, Optional import torch - +from torch.nn import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -64,9 +72,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQLinearMethod"]: + prefix: str) -> Optional["QuantizedMethodBase"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -170,3 +180,176 @@ def apply(self, if bias is not None: out.add_(bias) return out.reshape(out_shape) + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + self.num_bits = self.quant_config.weight_bits + self.packed_factor = self.quant_config.pack_factor + self.group_size = self.quant_config.group_size + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": "group", + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1], + layer.w13_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1], + layer.w2_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.group_size + ) + + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] , + size_n=layer.w2_scales.shape[2] * self.packed_factor, + group_size=self.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits) \ No newline at end of file diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2bfe6ea09bd6..995bb253db8a 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From e8289ae95dbe7f100898935997906791b01adcc2 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 26 Sep 2024 19:39:56 +0000 Subject: [PATCH 37/49] update --- .../model_executor/layers/quantization/awq.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index e564b18e7d32..b82714bd5ba5 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -18,6 +18,7 @@ marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -181,13 +182,11 @@ def apply(self, out.add_(bias) return out.reshape(out_shape) + class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - self.num_bits = self.quant_config.weight_bits - self.packed_factor = self.quant_config.pack_factor - self.group_size = self.quant_config.group_size def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -255,61 +254,60 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1], - layer.w13_qweight.shape[2] * self.packed_factor, - self.num_bits, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1], - layer.w2_qweight.shape[2] * self.packed_factor, - self.num_bits, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) - # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], - group_size=self.group_size + group_size=self.quant_config.group_size, ) - + # for @eliza: why do we need to apply a pack factor to the scales? + # they're not in packed format? replace_tensor(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] , - size_n=layer.w2_scales.shape[2] * self.packed_factor, - group_size=self.group_size, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) @@ -352,4 +350,5 @@ def apply( g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.num_bits) \ No newline at end of file + num_bits=self.quant_config.weight_bits, + ) From 0385aa85eabc5005e706f112e96f370ae3e28326 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 27 Sep 2024 17:07:25 +0000 Subject: [PATCH 38/49] update awq --- vllm/_custom_ops.py | 14 ++++++++ .../layers/fused_moe/fused_moe.py | 2 +- .../model_executor/layers/quantization/awq.py | 36 +++++++++++++------ .../layers/quantization/gptq_marlin.py | 1 + .../layers/quantization/utils/marlin_utils.py | 15 ++++++++ 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 77c46584ef53..8ce01b2d8253 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -317,6 +317,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return output +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bd13d8fecbb9..1a98666204f9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -443,7 +443,7 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index b82714bd5ba5..ba912aa6552d 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -12,11 +12,10 @@ PackedvLLMParameter) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_moe_permute_scales, moe_awq_to_marlin_zero_points, + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) class AWQConfig(QuantizationConfig): @@ -276,7 +275,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) - marlin_w13_qweight = ops.gptq_marlin_moe_repack( + marlin_w13_qweight = ops.awq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, size_k=layer.w13_qweight.shape[1], @@ -285,7 +284,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( + marlin_w2_qweight = ops.awq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, size_k=layer.w2_qweight.shape[1], @@ -294,23 +293,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) - # for @eliza: why do we need to apply a pack factor to the scales? - # they're not in packed format? + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_k=layer.intermediate_size_per_partition, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + def apply( self, layer: torch.nn.Module, @@ -346,6 +360,8 @@ def apply( router_logits, topk_weights, topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index dd46f0ce5a39..04bea28ec463 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -554,6 +554,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) # Repack scales + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 699d5f184414..db8ec78f937e 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -188,6 +188,7 @@ def marlin_moe_permute_scales( device=s.device, dtype=s.dtype, ) + for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) return output @@ -238,6 +239,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) def replace_tensor(layer: torch.nn.Module, name: str, From 091a4bbb76385ffe35d6e8a8ef9ebe14da528792 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 30 Sep 2024 03:17:33 -0400 Subject: [PATCH 39/49] Post-merge fix, remove e=4 case from unit tests to speed them up a bit --- tests/kernels/test_awq_marlin.py | 10 ++-------- tests/kernels/test_moe.py | 15 ++++++--------- vllm/_custom_ops.py | 9 +++++---- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index aeacc16e5396..338f46cbe09f 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("num_bits", [4, 8]) @@ -33,9 +33,6 @@ def test_fused_marlin_moe_awq( ): torch.manual_seed(7) - if topk > e: - return - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 @@ -112,7 +109,7 @@ def test_fused_marlin_moe_awq( @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("num_bits", [4, 8]) @@ -127,9 +124,6 @@ def test_single_marlin_moe_multiply_awq( ): torch.manual_seed(7) - if topk > e: - return - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index eacd12b9f9ee..360ef1330bd6 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -97,7 +97,7 @@ def test_mixtral_moe(dtype: torch.dtype): @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @@ -116,9 +116,6 @@ def test_fused_marlin_moe( ): seed_everything(7) - if topk > e: - return - # Filter act_order if act_order: if group_size == -1: @@ -237,10 +234,12 @@ def test_fused_marlin_moe( device="cuda", requires_grad=False) + zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False) + opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, - scales1, g_idx1, sort_indices1, workspace, quant_type, m, - 2 * n, k, True, e, topk, block_size_m, True, False)) + scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, + 2 * n, k, True, False, e, topk, block_size_m, True, False)) @pytest.mark.skip("This test is here for the sake of debugging, " @@ -248,7 +247,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @@ -265,8 +264,6 @@ def test_single_marlin_moe_multiply( num_bits: int, is_k_full: bool, ): - if topk > e: - return # Filter act_order if act_order: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ebdb06ba7013..bc7b4293c119 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -819,10 +819,11 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor, - g_idx: torch.Tensor, perm: torch.Tensor, - workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, size_n: int, size_k: int, - is_k_full: bool, num_experts: int, topk: int, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, + has_zero_point: bool, num_experts: int, topk: int, moe_block_size: int, replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), From 3d125547c775e3048e4c327f2a5dbb272f490a8b Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 30 Sep 2024 15:16:38 +0000 Subject: [PATCH 40/49] move to marlin; clean-up --- .../model_executor/layers/quantization/awq.py | 208 +----------------- .../layers/quantization/awq_marlin.py | 206 ++++++++++++++++- vllm/model_executor/model_loader/utils.py | 4 +- 3 files changed, 204 insertions(+), 214 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ba912aa6552d..30380ec0407c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,22 +1,14 @@ -from typing import Callable, Any, Dict, List, Optional +from typing import Any, Dict, List, Optional import torch -from torch.nn import Parameter + from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_moe_permute_scales, moe_awq_to_marlin_zero_points, - apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) - class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -72,11 +64,9 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizedMethodBase"]: + prefix: str) -> Optional["AWQLinearMethod"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) - elif isinstance(layer, FusedMoE): - return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,192 +169,4 @@ def apply(self, pack_factor) if bias is not None: out.add_(bias) - return out.reshape(out_shape) - - -class AWQMoEMethod(FusedMoEMethodBase): - - def __init__(self, quant_config: AWQConfig): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": "group", - }) - - w13_qweight = Parameter(torch.empty(num_experts, - hidden_size, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qweight", w13_qweight) - set_weight_attrs(w13_qweight, extra_weight_attrs) - - w2_qweight = Parameter(torch.empty(num_experts, - intermediate_size, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qweight", w2_qweight) - set_weight_attrs(w2_qweight, extra_weight_attrs) - - num_groups_w13 = hidden_size // self.quant_config.group_size - num_groups_w2 = intermediate_size // self.quant_config.group_size - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_scales = Parameter(torch.empty(num_experts, - num_groups_w13, - intermediate_size * 2, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_scales", w13_scales) - set_weight_attrs(w13_scales, extra_weight_attrs) - - w2_scales = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_scales", w2_scales) - set_weight_attrs(w2_scales, extra_weight_attrs) - - # WEIGHT_ZERO_POINT - # Allocate 2 zero points for w1 and w3 respectively. - w13_qzeros = Parameter(torch.empty(num_experts, - num_groups_w13, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qzeros", w13_qzeros) - set_weight_attrs(w13_qzeros, extra_weight_attrs) - - w2_qzeros = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qzeros", w2_qzeros) - set_weight_attrs(w2_qzeros, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - num_experts = layer.w13_qweight.shape[0] - device = layer.w13_qweight.device - - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - - marlin_w13_qweight = ops.awq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - size_k=layer.w13_qweight.shape[1], - size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits, - ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) - - marlin_w2_qweight = ops.awq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - size_k=layer.w2_qweight.shape[1], - size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits, - ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) - - # Why does this take the intermediate size for size_k? - marlin_w13_scales = marlin_moe_permute_scales( - s=layer.w13_scales, - size_k=layer.intermediate_size_per_partition, - size_n=layer.w13_scales.shape[2], - group_size=self.quant_config.group_size, - ) - - replace_tensor(layer, "w13_scales", marlin_w13_scales) - - marlin_w2_scales = marlin_moe_permute_scales( - s=layer.w2_scales, - size_k=layer.intermediate_size_per_partition, - size_n=layer.w2_scales.shape[2], - group_size=self.quant_config.group_size, - ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) - - marlin_w13_zp = moe_awq_to_marlin_zero_points( - layer.w13_qzeros, - size_k=layer.w13_qzeros.shape[1], - size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w13_qzeros", marlin_w13_zp) - - marlin_w2_zp = moe_awq_to_marlin_zero_points( - layer.w2_qzeros, - size_k=layer.w2_qzeros.shape[1], - size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w2_qzeros", marlin_w2_zp) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) - - return fused_marlin_moe( - x, - layer.w13_qweight, - layer.w2_qweight, - layer.w13_scales, - layer.w2_scales, - router_logits, - topk_weights, - topk_ids, - w1_zeros=layer.w13_qzeros, - w2_zeros=layer.w2_qzeros, - g_idx1=layer.w13_g_idx, - g_idx2=layer.w2_g_idx, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.quant_config.weight_bits, - ) + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff4..9704b1adbce5 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,16 +1,21 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, replace_tensor, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -34,12 +39,13 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool, self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits - if weight_bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {weight_bits}. " + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " f"Supported num_bits = {self.TYPE_MAP.keys()}") - self.quant_type = self.TYPE_MAP[weight_bits] + self.quant_type = self.TYPE_MAP[self.weight_bits] verify_marlin_supported(self.quant_type, group_size=self.group_size, @@ -97,10 +103,12 @@ def override_quantization_method(cls, hf_quant_cfg, return None def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQMarlinLinearMethod"]: + prefix: str) -> Optional["QuantizeMethodBase"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -270,4 +278,182 @@ def apply( quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) \ No newline at end of file + bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP, + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_tensor(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 995bb253db8a..792c359a559a 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,9 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq"] + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq", "awq_marlin" + ] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From b54b633cbf21ae4a2b600b96be3f04603d9d5c9a Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 30 Sep 2024 16:35:23 +0000 Subject: [PATCH 41/49] fix typo; add test --- tests/weight_loading/models-large.txt | 1 + vllm/model_executor/layers/quantization/awq_marlin.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 2f5c6c5a117f..8ab7f05d7d1b 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -2,3 +2,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 9704b1adbce5..5c689f03925a 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -293,7 +293,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, "is_transposed": True, "quant_method": - FusedMoeWeightScaleSupported.GROUP, + FusedMoeWeightScaleSupported.GROUP.value, }) w13_qweight = Parameter(torch.empty(num_experts, From e0e5a749b7a41c4554ed02e659b4bf90bc8ac04a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 1 Oct 2024 03:19:05 -0400 Subject: [PATCH 42/49] Michael's feedback, cleanup --- csrc/moe/marlin_moe_ops.cu | 6 +-- csrc/moe/marlin_moe_ops.h | 5 +- csrc/moe/torch_bindings.cpp | 5 +- tests/kernels/test_awq_marlin.py | 2 - tests/kernels/test_moe.py | 8 +-- vllm/_custom_ops.py | 6 +-- .../layers/fused_moe/fused_marlin_moe.py | 49 +++++++++---------- 7 files changed, 38 insertions(+), 43 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e540f0723649..ec0836131ba8 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -484,9 +484,9 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights) { + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights) { + bool has_zp = b_zeros.size(1) != 0; if (has_zp) { TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 0a54d93cedeb..0013787a623d 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -11,6 +11,5 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights); + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 85098df34b2d..576305d48ae4 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -15,9 +15,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, bool has_zp, int num_experts, " - "int topk, int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, int " + "moe_block_size, bool replicate_input, bool apply_weights) -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 338f46cbe09f..f1a0b09e8e46 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -87,7 +87,6 @@ def test_fused_marlin_moe_awq( score, topk_weights, topk_ids, - has_zero_point=True, w1_zeros=zp1, w2_zeros=zp2, num_bits=num_bits, @@ -155,7 +154,6 @@ def test_single_marlin_moe_multiply_awq( score, topk, renormalize=False, - has_zero_point=True, w_zeros=zp, num_bits=num_bits) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 360ef1330bd6..b73c45b9cd19 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -234,12 +234,14 @@ def test_fused_marlin_moe( device="cuda", requires_grad=False) - zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False) - + zp = torch.empty((0, 0), + dtype=dtype, + device="cuda", + requires_grad=False) opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, - 2 * n, k, True, False, e, topk, block_size_m, True, False)) + 2 * n, k, True, e, topk, block_size_m, True, False)) @pytest.mark.skip("This test is here for the sake of debugging, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bc7b4293c119..6081fa674579 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -822,9 +822,9 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, - has_zero_point: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, + size_k: int, is_k_full: bool, num_experts: int, + topk: int, moe_block_size: int, + replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e57b15936aa8..466b0edd81fe 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -24,7 +24,6 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, - has_zero_point: bool = False, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, @@ -93,11 +92,9 @@ def single_marlin_moe( device=hidden_states.device, requires_grad=False) - if has_zero_point: - assert w_zeros is not None and w_zeros.nelement() > 0 - + has_zero_point = w_zeros is not None if w_zeros is None: - w_zeros = torch.empty((0), + w_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) @@ -119,7 +116,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, - is_k_full, has_zero_point, E, topk, block_size_m, True, False) + is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -133,7 +130,6 @@ def fused_marlin_moe( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - has_zero_point: bool = False, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -187,6 +183,20 @@ def fused_marlin_moe( assert hidden_states.dtype == torch.float16 assert num_bits in [4, 8] + has_no_act_order = (g_idx1 is None and g_idx2 is None + and sort_indices1 is None and sort_indices2 is None) + has_all_act_order = (g_idx1 is not None and g_idx2 is not None + and sort_indices1 is not None + and sort_indices2 is not None) + assert has_no_act_order or has_all_act_order, ( + "g_idx and sorted_indices " + "must be all not None or must be all None") + + has_no_zp = w1_zeros is None and w2_zeros is None + has_all_zp = w1_zeros is not None and w2_zeros is not None + assert has_no_zp or has_all_zp, ("zero points must be both not None or " + "must be both None") + M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -213,47 +223,36 @@ def fused_marlin_moe( device="cuda", requires_grad=False) - if has_zero_point: - assert w1_zeros is not None and w1_zeros.nelement() > 0 - assert w2_zeros is not None and w2_zeros.nelement() > 0 - - if w1_zeros is None: - w1_zeros = torch.empty((0), + if has_no_zp: + w1_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if w2_zeros is None: - w2_zeros = torch.empty((0), + w2_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if g_idx1 is None: + if has_no_act_order: g_idx1 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if g_idx2 is None: g_idx2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices1 is None: sort_indices1 = torch.empty((0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices2 is None: sort_indices2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - scalar_type1 = get_scalar_type(num_bits, has_zero_point) - scalar_type2 = get_scalar_type(num_bits, has_zero_point) + scalar_type1 = get_scalar_type(num_bits, has_all_zp) + scalar_type2 = get_scalar_type(num_bits, has_all_zp) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -277,7 +276,6 @@ def fused_marlin_moe( 2 * N, K, is_k_full, - has_zero_point, E, topk, block_size_m, @@ -303,7 +301,6 @@ def fused_marlin_moe( K, N, is_k_full, - has_zero_point, E, topk, block_size_m, From bbf575e2985b8476bf858866be9158bb0bf2a0e1 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 1 Oct 2024 13:49:52 +0000 Subject: [PATCH 43/49] use replace_parameters; clean-up --- .../layers/quantization/awq_marlin.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index cc98cbfb70ad..294fe11815c0 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -11,14 +11,11 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, moe_awq_to_marlin_zero_points) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, @@ -379,7 +376,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.awq_marlin_moe_repack( layer.w2_qweight, @@ -388,7 +385,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( @@ -398,7 +395,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w13_scales", marlin_w13_scales) + replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, @@ -406,21 +403,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) + replace_parameter(layer, "w2_scales", marlin_w2_scales) marlin_w13_zp = moe_awq_to_marlin_zero_points( layer.w13_qzeros, size_k=layer.w13_qzeros.shape[1], size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) marlin_w2_zp = moe_awq_to_marlin_zero_points( layer.w2_qzeros, size_k=layer.w2_qzeros.shape[1], size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) def apply( self, From 79126f906b5eafda1df6866f7328f6d78f2eeec3 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 1 Oct 2024 13:54:42 +0000 Subject: [PATCH 44/49] more clean-up --- vllm/model_executor/layers/quantization/awq.py | 2 +- .../model_executor/layers/quantization/gptq_marlin.py | 1 - .../layers/quantization/utils/marlin_utils.py | 11 ----------- vllm/model_executor/model_loader/utils.py | 2 +- 4 files changed, 2 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 30380ec0407c..410b3cb5321c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -169,4 +169,4 @@ def apply(self, pack_factor) if bias is not None: out.add_(bias) - return out.reshape(out_shape) \ No newline at end of file + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b9b43413b35d..e77191796bd7 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -509,7 +509,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales - # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 1275b4474a06..9a1defa40971 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -273,17 +273,6 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return output -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 792c359a559a..b95c0b7cd061 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,7 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ - "fp8", "compressed-tensors", "gptq_marlin", "awq", "awq_marlin" + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" ] if (model_config.quantization is not None From 87d46dc91021432b51c7b19730f3b04403c119ed Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 1 Oct 2024 11:33:57 -0400 Subject: [PATCH 45/49] Delete 8-bit zero point code --- CMakeLists.txt | 2 -- .../marlin_kernels/marlin_moe_kernel_ku8.cu | 31 ------------------- .../marlin_kernels/marlin_moe_kernel_ku8.h | 20 ------------ csrc/moe/marlin_moe_ops.cu | 8 ++--- tests/kernels/test_awq_marlin.py | 10 +++--- .../layers/fused_moe/fused_marlin_moe.py | 3 +- 6 files changed, 9 insertions(+), 65 deletions(-) delete mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu delete mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h diff --git a/CMakeLists.txt b/CMakeLists.txt index df22ce47e54b..8c66c31aa6ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,8 +332,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu deleted file mode 100644 index 7abbc45440bf..000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku8.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = true; - - if (false) { - } - AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h deleted file mode 100644 index 03a0132aa347..000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index ec0836131ba8..b3cccd4c566f 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -30,7 +30,6 @@ #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku4.h" -#include "marlin_kernels/marlin_moe_kernel_ku8.h" template inline std::string str(T x) { @@ -461,7 +460,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C, CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -488,9 +486,9 @@ torch::Tensor marlin_gemm_moe( int64_t moe_block_size, bool replicate_input, bool apply_weights) { bool has_zp = b_zeros.size(1) != 0; if (has_zp) { - TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", - b_q_type->str()); + TORCH_CHECK( + *b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); } else { TORCH_CHECK( *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index f1a0b09e8e46..0738ea9b97ed 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -21,7 +21,6 @@ @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe_awq( m: int, n: int, @@ -29,11 +28,11 @@ def test_fused_marlin_moe_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -111,7 +110,6 @@ def test_fused_marlin_moe_awq( @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_single_marlin_moe_multiply_awq( m: int, n: int, @@ -119,11 +117,11 @@ def test_single_marlin_moe_multiply_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 466b0edd81fe..66f589dba785 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -12,7 +12,8 @@ def get_scalar_type(num_bits: int, has_zp: bool): if has_zp: - return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + assert num_bits == 4 + return scalar_types.uint4 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 From 8fe6da46bf7aa2673a480c72558a80aeeedcc544 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 1 Oct 2024 17:57:00 +0000 Subject: [PATCH 46/49] fix file reverted from some commit hoopla --- .../run_model_weight_loading_test.sh | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh index 0cb45d1780c2..e80c1d6c5849 100755 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -1,7 +1,20 @@ #!/bin/bash SUCCESS=0 -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" +while getopts "c:" OPT; do + case ${OPT} in + c ) + CONFIG="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + + +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do From a966417fa4433c6bbb863f14db27d15c1547dbb4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Oct 2024 01:14:48 -0400 Subject: [PATCH 47/49] Make workspace smaller, add very small thread config --- csrc/moe/marlin_moe_ops.cu | 2 ++ vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index b3cccd4c566f..69d66b5d7101 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -157,6 +157,7 @@ thread_config_t small_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 2X, same K {64, 256, 256}, // Reduce K 2X, increase N 2X {64, 128, 128}, // Reduce K 2X, same N + {64, 64, 128}, // Reduce both 2X }; thread_config_t large_batch_thread_configs[] = { @@ -167,6 +168,7 @@ thread_config_t large_batch_thread_configs[] = { {128, 128, 256}, // Reduce N 2X, increase K 2X {64, 128, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 64, 128}, // Reduce N 4X, same K }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 66f589dba785..5964d5a5465f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -218,7 +218,7 @@ def fused_marlin_moe( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + max_workspace_size = (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda", From fa4d269f6e9479084e24717cf7b4c02bc981bc87 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Oct 2024 10:38:57 -0400 Subject: [PATCH 48/49] try to make required cache smaller --- csrc/moe/marlin_moe_ops.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e540f0723649..aec22281e1af 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -192,7 +192,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int load_groups = tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 4; + return load_groups * tb_n * 3; } else { int tb_scales = tb_groups * tb_n * 2; From fb8a1e744164d27cf8ed15f9b1895dbf11c020d6 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Oct 2024 12:34:27 -0400 Subject: [PATCH 49/49] revert --- csrc/moe/marlin_moe_ops.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index ce6af7fa0186..69d66b5d7101 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -193,7 +193,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int load_groups = tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 3; + return load_groups * tb_n * 4; } else { int tb_scales = tb_groups * tb_n * 2;