Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 19, 2024
1 parent c0c13ec commit 72d1503
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ __device__ inline void MarlinMoESingle(
}

thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (!has_act_order) {
if constexpr (w_type.size_bits() == 8) {
if (group_blocks == -1) {
cp_async_wait<0>();
Expand All @@ -1288,7 +1288,7 @@ __device__ inline void MarlinMoESingle(
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && && w_type.size_bits() == 8) {
if constexpr (!has_act_order && w_type.size_bits() == 8) {
if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
Expand Down Expand Up @@ -1713,7 +1713,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS)

void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, const void* g_idx,
const void* perm, void* a_tmp, void* expert_offsets,
Expand Down Expand Up @@ -1888,6 +1888,8 @@ torch::Tensor marlin_gemm_moe(
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());

TORCH_CHECK(is_k_full, "NYI: Marlin MoE kernel does not currently support !is_k_full case.");

int pack_factor = 32 / b_q_type->size_bits();

int max_par = 4;
Expand Down Expand Up @@ -1945,7 +1947,7 @@ torch::Tensor marlin_gemm_moe(
}
}

marlin_moe::marlin_mm_moe_f16i4(
marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
Expand Down

0 comments on commit 72d1503

Please sign in to comment.