From abe53edfa44e9af514304d45e86d14b5f496917f Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 4 Oct 2024 20:34:44 +0200 Subject: [PATCH] [Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973) Co-authored-by: Dipika Co-authored-by: Dipika Sikka --- CMakeLists.txt | 2 + csrc/moe/marlin_kernels/marlin_moe_kernel.h | 297 ++++++++++++++---- .../marlin_kernels/marlin_moe_kernel_ku4.cu | 31 ++ .../marlin_kernels/marlin_moe_kernel_ku4.h | 20 ++ .../marlin_kernels/marlin_moe_kernel_ku4b8.cu | 12 +- .../marlin_kernels/marlin_moe_kernel_ku4b8.h | 10 +- .../marlin_moe_kernel_ku8b128.cu | 12 +- .../marlin_moe_kernel_ku8b128.h | 10 +- csrc/moe/marlin_moe_ops.cu | 84 +++-- csrc/moe/torch_bindings.cpp | 2 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 2 +- tests/kernels/test_awq_marlin.py | 160 ++++++++++ tests/kernels/test_moe.py | 79 ++--- tests/kernels/utils.py | 45 +++ tests/weight_loading/models-large.txt | 1 + .../run_model_weight_loading_test.sh | 15 +- vllm/_custom_ops.py | 25 +- .../layers/fused_moe/fused_marlin_moe.py | 138 ++++++-- .../layers/quantization/awq_marlin.py | 204 +++++++++++- .../compressed_tensors_moe.py | 12 +- .../layers/quantization/gptq_marlin.py | 12 +- .../layers/quantization/utils/marlin_utils.py | 15 + vllm/model_executor/model_loader/utils.py | 4 +- 23 files changed, 969 insertions(+), 223 deletions(-) create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h create mode 100644 tests/kernels/test_awq_marlin.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b24c4abc650e..4be524808a23a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -433,6 +433,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" "csrc/moe/marlin_moe_ops.cu") set_gencode_flags_for_srcs( diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h index 0bd3017226c94..a217401b3d7c2 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -38,6 +38,7 @@ using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; // quantization scales +using FragZP = Vec; // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. @@ -175,6 +176,46 @@ __device__ inline FragB dequant(int q) { return frag_b; } +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -183,11 +224,10 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { frag_b[1] = __hmul2(frag_b[1], s); } -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { + half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) @@ -205,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { @@ -248,10 +295,11 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > -__device__ inline void MarlinMoESingle( +__device__ void MarlinMoESingle( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn @@ -259,6 +307,8 @@ __device__ inline void MarlinMoESingle( const float* __restrict__ topk_weights, // float topk weights const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, int num_groups, // number of scale groups per output channel @@ -400,8 +450,12 @@ __device__ inline void MarlinMoESingle( int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - constexpr int sorted_sh_stride = threads; - constexpr int sorted_gl_stride = threads; + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + @@ -442,6 +496,19 @@ __device__ inline void MarlinMoESingle( int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. @@ -453,23 +520,29 @@ __device__ inline void MarlinMoESingle( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + int sh_first_group_id = -1; int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - int* sh_sorted = (int*)(sh_s + shs_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or @@ -525,8 +598,10 @@ __device__ inline void MarlinMoESingle( FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 // Zero accumulators. auto zero_accums = [&]() { @@ -633,6 +708,28 @@ __device__ inline void MarlinMoESingle( } } } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } } } // Insert a fence even when we are winding down the pipeline to ensure that @@ -640,15 +737,9 @@ __device__ inline void MarlinMoESingle( cp_async_fence(); }; - // TODO we are currently hitting illegal memory accesses when fetching - // sorted_ids to shared data: fix this - auto fetch_sorted_ids_to_shared = [&]() { - const int mpt = ceildiv(prob_m, threads); - for (int i = 0; i < mpt; i++) { - if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { - sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = - sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; - } + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); } }; @@ -799,8 +890,83 @@ __device__ inline void MarlinMoESingle( } }; + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -818,6 +984,10 @@ __device__ inline void MarlinMoESingle( FragB frag_b0 = dequant(b_quant_0); FragB frag_b1 = dequant(b_quant_1); + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); + } // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -829,6 +999,11 @@ __device__ inline void MarlinMoESingle( } } + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -1062,9 +1237,6 @@ __device__ inline void MarlinMoESingle( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - // __syncthreads(); #pragma unroll for (int i = 0; i < stages - 1; i++) { @@ -1075,6 +1247,12 @@ __device__ inline void MarlinMoESingle( } fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } fetch_to_shared(i, i, i < slice_iters); } @@ -1083,6 +1261,7 @@ __device__ inline void MarlinMoESingle( init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; @@ -1102,6 +1281,7 @@ __device__ inline void MarlinMoESingle( for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -1236,7 +1416,9 @@ __device__ inline void MarlinMoESingle( } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } + start_pipes(); } } @@ -1250,6 +1432,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1261,6 +1444,8 @@ __global__ void MarlinMoE( const float* __restrict__ topk_weights, // float topk weights const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, int num_groups, // number of scale groups per output channel @@ -1309,29 +1494,29 @@ __global__ void MarlinMoE( if (max_block == 1) { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + stages, has_act_order, has_zp, group_blocks>( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); @@ -1347,6 +1532,7 @@ template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1358,6 +1544,8 @@ __global__ void MarlinMoE( const float* __restrict__ topk_weights, // float topk weights const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, int num_groups, // number of scale groups per output channel @@ -1374,7 +1562,6 @@ __global__ void MarlinMoE( int current_m_block, // current m block to start kernel computation from int max_par, // maximum parallelism int cfg_max_m_blocks // upper bound on m blocks - ) { // Marlin is not implemented yet for SM < 8.0 assert(false); @@ -1389,37 +1576,41 @@ __global__ void MarlinMoE( const int USER_THREADS = 256; // Note: This is only used with user-provided thread_k/n const int STAGES = 4; // 4 pipeline stages fit into shared memory -// const int SHARED_MEM = -// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ + HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ MarlinMoE, \ + STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ MarlinMoE \ + STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ replicate_input, apply_weights, m_block, max_par, \ cfg_max_m_blocks); \ } -#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) } // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu new file mode 100644 index 0000000000000..77bc0dd90edde --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu @@ -0,0 +1,31 @@ +#include "marlin_moe_kernel_ku4.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = true; + + if (false) { + } + AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h new file mode 100644 index 0000000000000..833fadf37721f --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu index cbafd9ffe7474..f7e57b0375945 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8( bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, - int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, - int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, - bool replicate_input, bool apply_weights, int m_block, int max_par, - int cfg_max_m_blocks) { + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = false; + if (false) { } GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h index 9eacb42c115f0..494da8f10e262 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8( bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, - int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, - int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, - bool replicate_input, bool apply_weights, int m_block, int max_par, - int cfg_max_m_blocks); + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); } // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu index c46712474f715..a901f0b11cd78 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128( bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, - int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, - int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, - bool replicate_input, bool apply_weights, int m_block, int max_par, - int cfg_max_m_blocks) { + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = false; + if (false) { } GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h index 7cd9acafb3b80..f3018aa0c1ab7 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128( bool has_act_order, int group_blocks, int num_threads, int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, - int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, - int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, - bool replicate_input, bool apply_weights, int m_block, int max_par, - int cfg_max_m_blocks); + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); } diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 661490d95e791..e2db4e4196b6f 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -30,6 +30,7 @@ #include "core/registration.h" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" +#include "marlin_kernels/marlin_moe_kernel_ku4.h" template inline std::string str(T x) { @@ -157,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 2X, same K {64, 256, 256}, // Reduce K 2X, increase N 2X {64, 128, 128}, // Reduce K 2X, same N + {64, 64, 128}, // Reduce both 2X }; thread_config_t large_batch_thread_configs[] = { @@ -167,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = { {128, 128, 256}, // Reduce N 2X, increase K 2X {64, 128, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 64, 128}, // Reduce N 4X, same K }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, @@ -312,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ - else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \ - has_act_order, group_blocks, num_threads, blocks, \ - max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \ - sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \ - expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ - locks, replicate_input, apply_weights, m_block, \ - max_par, exec_cfg.max_m_blocks)) { \ +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION( \ + q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ + group_blocks, num_threads, blocks, max_shared_mem, stream, \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks)) { \ } void marlin_mm_moe(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par, bool replicate_input, bool apply_weights) { + const void* topk_ids, const void* s, void* zp, + const void* g_idx, const void* perm, void* a_tmp, + void* expert_offsets, int prob_m, int prob_n, int prob_k, + void* workspace, vllm::ScalarType const& q_type, + bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -436,6 +440,8 @@ void marlin_mm_moe(const void* A, const void* B, void* C, const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; + const int4* zp_ptr = + (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; @@ -456,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C, } CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -475,13 +482,21 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights) { - TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + torch::Tensor& b_zeros, const torch::Tensor& g_idx, + const torch::Tensor& perm, torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights) { + bool has_zp = b_zeros.size(1) != 0; + if (has_zp) { + TORCH_CHECK( + *b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); + } else { + TORCH_CHECK( + *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + } int pack_factor = 32 / b_q_type->size_bits(); @@ -543,14 +558,27 @@ torch::Tensor marlin_gemm_moe( } } + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + TORCH_CHECK(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, - topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, max_par, replicate_input, apply_weights); + *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, + num_experts, topk, moe_block_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, + replicate_input, apply_weights); return c; } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index cbc8754f7a5b2..18fbc57ac7834 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -12,7 +12,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, " + "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " "int moe_block_size, bool replicate_input, bool apply_weights)" diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 227bc19b914a0..5efe15d2b2f6b 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -2260,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_scales.size(1), + "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n / pack_factor = ", size_n / pack_factor); } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py new file mode 100644 index 0000000000000..0738ea9b97edb --- /dev/null +++ b/tests/kernels/test_awq_marlin.py @@ -0,0 +1,160 @@ +"""Test AWQ with fused MoE Marlin kernels. + +Run `pytest tests/kernels/test_awq_marlin.py`. +""" +import pytest +import torch + +from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, + torch_moe_single) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + awq_marlin_quantize) +from vllm.scalar_type import scalar_types + + +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +def test_fused_marlin_moe_awq( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, +): + torch.manual_seed(7) + + num_bits = 4 + quant_type = scalar_types.uint4 + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + w_ref1_l = [] + qweights1_l = [] + scales1_l = [] + zp1_l = [] + + for i in range(w1.shape[0]): + w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size) + w_ref1_l.append(w_ref1) + qweights1_l.append(qweight1) + scales1_l.append(scales1) + zp1_l.append(zp1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweights1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + zp1 = stack_and_dev(zp1_l) + + w_ref2_l = [] + qweights2_l = [] + scales2_l = [] + zp2_l = [] + + for i in range(w2.shape[0]): + w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size) + w_ref2_l.append(w_ref2) + qweights2_l.append(qweight2) + scales2_l.append(scales2) + zp2_l.append(zp2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweights2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + zp2 = stack_and_dev(zp2_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, False) + marlin_output = fused_marlin_moe( + a, + qweight1, + qweight2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + w1_zeros=zp1, + w2_zeros=zp2, + num_bits=num_bits, + ) + + torch_output = torch_moe( + a, + w_ref1.transpose(1, 2), + w_ref2.transpose(1, 2), + score, + topk, + ) + + assert compute_max_diff(marlin_output, torch_output) < 4e-2 + + +@pytest.mark.skip("This test is here for the sake of debugging, " + "don't run it in automated tests.") +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +def test_single_marlin_moe_multiply_awq( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, +): + torch.manual_seed(7) + + num_bits = 4 + quant_type = scalar_types.uint4 + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + + w_ref_l = [] + qweights_l = [] + scales_l = [] + zp_l = [] + + for i in range(w.shape[0]): + w_ref, qweight, scales, zp = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size) + w_ref_l.append(w_ref) + qweights_l.append(qweight) + scales_l.append(scales) + zp_l.append(zp) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweights_l).contiguous() + scales = stack_and_dev(scales_l).contiguous() + zp = stack_and_dev(zp_l).contiguous() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + marlin_output = single_marlin_moe(a, + qweight, + scales, + score, + topk, + renormalize=False, + w_zeros=zp, + num_bits=num_bits) + + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index cbbb5c9b79c42..b73c45b9cd198 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,16 +2,14 @@ Run `pytest tests/kernels/test_moe.py`. """ -from typing import List - import pytest import torch from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from tests.kernels.utils import opcheck +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) from vllm import _custom_ops as ops -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) @@ -24,37 +22,6 @@ from vllm.utils import seed_everything -def torch_moe(a, w1, w2, score, topk): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - -def torch_moe_single(a, w, score, topk): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - _, topk_ids = torch.topk(score, topk) - topk_ids = topk_ids.view(-1) - for i in range(w.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = a[mask] @ w[i].transpose(0, 1) - return (out.view(B, -1, w.shape[1])).sum(dim=1) - - @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -127,20 +94,10 @@ def test_mixtral_moe(dtype: torch.dtype): atol=mixtral_moe_tol[dtype]) -def stack_and_dev(tensors: List[torch.Tensor]): - dev = tensors[0].device - return torch.stack(tensors, dim=0).to(dev) - - -def compute_max_diff(output, output_ref): - return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) - - @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @@ -159,9 +116,6 @@ def test_fused_marlin_moe( ): seed_everything(7) - if topk > e: - return - # Filter act_order if act_order: if group_size == -1: @@ -241,15 +195,15 @@ def test_fused_marlin_moe( a, qweight1, qweight2, + scales1, + scales2, score, - g_idx1, - g_idx2, - sort_indices1, - sort_indices2, topk_weights, topk_ids, - w1_scale=scales1, - w2_scale=scales2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, num_bits=num_bits, is_k_full=is_k_full, ) @@ -280,9 +234,13 @@ def test_fused_marlin_moe( device="cuda", requires_grad=False) + zp = torch.empty((0, 0), + dtype=dtype, + device="cuda", + requires_grad=False) opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, - scales1, g_idx1, sort_indices1, workspace, quant_type, m, + scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, 2 * n, k, True, e, topk, block_size_m, True, False)) @@ -291,7 +249,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @@ -308,8 +266,6 @@ def test_single_marlin_moe_multiply( num_bits: int, is_k_full: bool, ): - if topk > e: - return # Filter act_order if act_order: @@ -355,13 +311,14 @@ def test_single_marlin_moe_multiply( qweight, scales, score, - g_idx, - sort_indices, topk, renormalize=False, + g_idx=g_idx, + sort_indices=sort_indices, num_bits=num_bits, is_k_full=is_k_full, ) + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 08004efe9e2f8..a2d414f636e13 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -12,6 +12,7 @@ from torch._prims_common import TensorLikeType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType +from vllm.model_executor.layers.activation import SiluAndMul from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -974,6 +975,50 @@ def fp8_allclose( equal_nan=equal_nan)).item()) +# Marlin MoE test utils + + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +def torch_moe(a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) + + # A special version of op check that has a restricted default set of test_utils # and a patched version of allclose that supports fp8 types. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 3e6eba04f1a87..5fda910fde084 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh index 0cb45d1780c2c..e80c1d6c5849c 100755 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -1,7 +1,20 @@ #!/bin/bash SUCCESS=0 -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" +while getopts "c:" OPT; do + case ${OPT} in + c ) + CONFIG="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + + +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 05f036af331f1..24e008dc38022 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -568,6 +568,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return output +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -828,11 +842,12 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor, - g_idx: torch.Tensor, perm: torch.Tensor, - workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, size_n: int, size_k: int, - is_k_full: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, num_experts: int, + topk: int, moe_block_size: int, + replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 8177e846127ee..5964d5a5465fd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -10,15 +10,24 @@ from vllm.scalar_type import scalar_types +def get_scalar_type(num_bits: int, has_zp: bool): + if has_zp: + assert num_bits == 4 + return scalar_types.uint4 + else: + return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + + def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, topk: int, renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, is_k_full: bool = True, @@ -34,10 +43,12 @@ def single_marlin_moe( - scales (torch.Tensor): The quantization scales. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - g_idx (torch.Tensor): The act_order indices. - - perm (torch.Tensor): The act_order input permutation. + - g_idx (Optional[torch.Tensor]): Optional act_order indices. + - sort_indices (Optional[torch.Tensor]): Optional act_order input + permutation. - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. @@ -79,16 +90,34 @@ def single_marlin_moe( max_workspace_size = (N // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, - device="cuda", + device=hidden_states.device, + requires_grad=False) + + has_zero_point = w_zeros is not None + if w_zeros is None: + w_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + + if g_idx is None: + g_idx = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + if sort_indices is None: + sort_indices = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type = get_scalar_type(num_bits, has_zero_point) intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk, - block_size_m, True, False) + w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, + is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -97,16 +126,18 @@ def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - perm1: torch.Tensor, - perm2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, override_config: Optional[Dict[str, Any]] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -118,21 +149,22 @@ def fused_marlin_moe( - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. + - w1_scale (torch.Tensor): Scale to be used for w1. + - w2_scale (torch.Tensor): Scale to be used for w2. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - g_idx1 (torch.Tensor): The first set of act_order indices. - - g_idx2 (torch.Tensor): The second set of act_order indices. - - perm1 (torch.Tensor): The first act_order input permutation. - - perm2 (torch.Tensor): The second act_order input permutation. + - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. + - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. + - sort_indices1 (Optional[torch.Tensor]): The first act_order input + permutation. + - sort_indices2 (Optional[torch.Tensor]): The second act_order input + permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. + - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. + - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -152,6 +184,20 @@ def fused_marlin_moe( assert hidden_states.dtype == torch.float16 assert num_bits in [4, 8] + has_no_act_order = (g_idx1 is None and g_idx2 is None + and sort_indices1 is None and sort_indices2 is None) + has_all_act_order = (g_idx1 is not None and g_idx2 is not None + and sort_indices1 is not None + and sort_indices2 is not None) + assert has_no_act_order or has_all_act_order, ( + "g_idx and sorted_indices " + "must be all not None or must be all None") + + has_no_zp = w1_zeros is None and w2_zeros is None + has_all_zp = w1_zeros is not None and w2_zeros is not None + assert has_no_zp or has_all_zp, ("zero points must be both not None or " + "must be both None") + M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -172,14 +218,42 @@ def fused_marlin_moe( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + max_workspace_size = (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + if has_no_zp: + w1_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + w2_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + + if has_no_act_order: + g_idx1 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + g_idx2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + sort_indices1 = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + sort_indices2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type1 = get_scalar_type(num_bits, has_all_zp) + scalar_type2 = get_scalar_type(num_bits, has_all_zp) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -194,10 +268,11 @@ def fused_marlin_moe( topk_weights, topk_ids, w1_scale, + w1_zeros, g_idx1, - perm1, + sort_indices1, workspace, - scalar_type, + scalar_type1, M, 2 * N, K, @@ -218,10 +293,11 @@ def fused_marlin_moe( topk_weights, topk_ids, w2_scale, + w2_zeros, g_idx2, - perm2, + sort_indices2, workspace, - scalar_type, + scalar_type2, M, K, N, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index fe33b7341fd38..294fe11815c0f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,16 +1,21 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, @@ -35,12 +40,13 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool, self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits - if weight_bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {weight_bits}. " + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " f"Supported num_bits = {self.TYPE_MAP.keys()}") - self.quant_type = self.TYPE_MAP[weight_bits] + self.quant_type = self.TYPE_MAP[self.weight_bits] verify_marlin_supported(self.quant_type, group_size=self.group_size, @@ -98,10 +104,12 @@ def override_quantization_method(cls, hf_quant_cfg, return None def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQMarlinLinearMethod"]: + prefix: str) -> Optional["QuantizeMethodBase"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -271,4 +279,182 @@ def apply( quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) \ No newline at end of file + bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP.value, + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_parameter(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6666a4bf1f26a..af04d725159f9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -498,14 +498,14 @@ def apply( x, layer.w13_weight_packed, layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3d3ce711e58b0..e77191796bd7e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -557,14 +557,14 @@ def apply( x, layer.w13_qweight, layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, ).to(orig_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 53762965732ce..9a1defa409714 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -208,6 +208,7 @@ def marlin_moe_permute_scales( device=s.device, dtype=s.dtype, ) + for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) return output @@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2bfe6ea09bd62..b95c0b7cd0612 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,9 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" + ] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported