diff --git a/CMakeLists.txt b/CMakeLists.txt index df22ce47e54b..8c66c31aa6ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,8 +332,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu deleted file mode 100644 index 7abbc45440bf..000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku8.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = true; - - if (false) { - } - AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h deleted file mode 100644 index 03a0132aa347..000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index ec0836131ba8..b3cccd4c566f 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -30,7 +30,6 @@ #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku4.h" -#include "marlin_kernels/marlin_moe_kernel_ku8.h" template inline std::string str(T x) { @@ -461,7 +460,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C, CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -488,9 +486,9 @@ torch::Tensor marlin_gemm_moe( int64_t moe_block_size, bool replicate_input, bool apply_weights) { bool has_zp = b_zeros.size(1) != 0; if (has_zp) { - TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", - b_q_type->str()); + TORCH_CHECK( + *b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); } else { TORCH_CHECK( *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index f1a0b09e8e46..0738ea9b97ed 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -21,7 +21,6 @@ @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe_awq( m: int, n: int, @@ -29,11 +28,11 @@ def test_fused_marlin_moe_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -111,7 +110,6 @@ def test_fused_marlin_moe_awq( @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_single_marlin_moe_multiply_awq( m: int, n: int, @@ -119,11 +117,11 @@ def test_single_marlin_moe_multiply_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 466b0edd81fe..66f589dba785 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -12,7 +12,8 @@ def get_scalar_type(num_bits: int, has_zp: bool): if has_zp: - return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + assert num_bits == 4 + return scalar_types.uint4 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128