diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index d22ba7fe4335..6fdf75863c99 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1261,7 +1261,7 @@ __device__ inline void MarlinMoESingle( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (!has_act_order) { if constexpr (w_type.size_bits() == 8) { if (group_blocks == -1) { cp_async_wait<0>(); @@ -1288,7 +1288,7 @@ __device__ inline void MarlinMoESingle( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && && w_type.size_bits() == 8) { + if constexpr (!has_act_order && w_type.size_bits() == 8) { if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1713,7 +1713,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, +void marlin_mm_moe(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, @@ -1888,6 +1888,8 @@ torch::Tensor marlin_gemm_moe( TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + TORCH_CHECK(is_k_full, "NYI: Marlin MoE kernel does not currently support !is_k_full case."); + int pack_factor = 32 / b_q_type->size_bits(); int max_par = 4; @@ -1945,7 +1947,7 @@ torch::Tensor marlin_gemm_moe( } } - marlin_moe::marlin_mm_moe_f16i4( + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),