Skip to content

Commit

Permalink
Fix build issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Aug 2, 2024
1 parent b39dba4 commit b0c4671
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
5 changes: 2 additions & 3 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
const torch::Tensor& expert_offsets, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
12 changes: 5 additions & 7 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
#include "moe_ops.h"
#include "marlin_moe_ops.h"

#include <torch/library.h>

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
ops.def(
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
ops.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);

ops.def(
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");

ops.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
12 changes: 3 additions & 9 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ def fused_topk(
topk,
dtype=torch.int32,
device=hidden_states.device)
from pprint import pprint
pprint(vars(ops))

ops.topk_softmax(
topk_weights,
topk_ids,
Expand Down Expand Up @@ -692,8 +691,7 @@ def single_marlin_moe(

block_size_m = config['BLOCK_SIZE_M']

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, E)
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)

max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size,
Expand Down Expand Up @@ -781,8 +779,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,

block_size_m = config['BLOCK_SIZE_M']

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, E)
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)

max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
Expand All @@ -806,8 +803,5 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
block_size_m, False, True)

# intermediate_cache3 = torch.zeros((M, topk, K),
# device=hidden_states.device,
# dtype=hidden_states.dtype)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)

0 comments on commit b0c4671

Please sign in to comment.