diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 666d87eb92595..49cc03f827f68 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1342,9 +1342,6 @@ __device__ inline void MarlinMoESingle( template shared @@ -1459,9 +1456,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, template shared @@ -1515,26 +1509,24 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; -#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1711,31 +1703,16 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights,