Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 20, 2024
1 parent 72d1503 commit 00adeed
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order // whether act_order is enabled
const bool has_act_order // whether act_order is enabled
>
__device__ inline void MarlinMoESingle(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand Down Expand Up @@ -1353,7 +1353,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order // whether act_order is enabled
const bool has_act_order // whether act_order is enabled
>
__global__ void MarlinMoE(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand Down Expand Up @@ -1415,29 +1415,29 @@ __global__ void MarlinMoE(
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 2) {
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 3) {
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else {
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk,
prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
}
}
Expand Down Expand Up @@ -1467,7 +1467,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order // whether act_order is enabled
const bool has_act_order // whether act_order is enabled
>
__global__ void MarlinMoE(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand Down Expand Up @@ -1521,17 +1521,17 @@ static constexpr int min_thread_k = 64;
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
cudaFuncSetAttribute(MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER> \
STAGES, HAS_ACT_ORDER> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, \
expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, \
locks, replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks); \
}

Expand Down Expand Up @@ -1709,21 +1709,20 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}};
}

#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS)

void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, const void* g_idx,
const void* perm, void* a_tmp, void* expert_offsets,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, int num_groups, int group_size,
int num_experts, int topk, int moe_block_size, int dev,
cudaStream_t stream, int thread_k, int thread_n,
int sms, int max_par, bool replicate_input,
bool apply_weights) {
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, const void* g_idx,
const void* perm, void* a_tmp, void* expert_offsets,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, int num_groups, int group_size,
int num_experts, int topk, int moe_block_size, int dev,
cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par, bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");

Expand Down Expand Up @@ -1888,7 +1887,9 @@ torch::Tensor marlin_gemm_moe(
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());

TORCH_CHECK(is_k_full, "NYI: Marlin MoE kernel does not currently support !is_k_full case.");
TORCH_CHECK(
is_k_full,
"NYI: Marlin MoE kernel does not currently support !is_k_full case.");

int pack_factor = 32 / b_q_type->size_bits();

Expand Down

0 comments on commit 00adeed

Please sign in to comment.