Skip to content

Commit

Permalink
Michael's feedback, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Oct 1, 2024
1 parent 091a4bb commit e0e5a74
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 43 deletions.
6 changes: 3 additions & 3 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ torch::Tensor marlin_gemm_moe(
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts,
int64_t topk, int64_t moe_block_size, bool replicate_input,
bool apply_weights) {
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
Expand Down
5 changes: 2 additions & 3 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ torch::Tensor marlin_gemm_moe(
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts,
int64_t topk, int64_t moe_block_size, bool replicate_input,
bool apply_weights);
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights);
5 changes: 2 additions & 3 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, bool has_zp, int num_experts, "
"int topk, int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, int "
"moe_block_size, bool replicate_input, bool apply_weights) -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
2 changes: 0 additions & 2 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def test_fused_marlin_moe_awq(
score,
topk_weights,
topk_ids,
has_zero_point=True,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
Expand Down Expand Up @@ -155,7 +154,6 @@ def test_single_marlin_moe_multiply_awq(
score,
topk,
renormalize=False,
has_zero_point=True,
w_zeros=zp,
num_bits=num_bits)

Expand Down
8 changes: 5 additions & 3 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,14 @@ def test_fused_marlin_moe(
device="cuda",
requires_grad=False)

zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False)

zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, False, e, topk, block_size_m, True, False))
2 * n, k, True, e, topk, block_size_m, True, False))


@pytest.mark.skip("This test is here for the sake of debugging, "
Expand Down
6 changes: 3 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,9 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool,
has_zero_point: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
size_k: int, is_k_full: bool, num_experts: int,
topk: int, moe_block_size: int,
replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
Expand Down
49 changes: 23 additions & 26 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def single_marlin_moe(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
has_zero_point: bool = False,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -93,11 +92,9 @@ def single_marlin_moe(
device=hidden_states.device,
requires_grad=False)

if has_zero_point:
assert w_zeros is not None and w_zeros.nelement() > 0

has_zero_point = w_zeros is not None
if w_zeros is None:
w_zeros = torch.empty((0),
w_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
Expand All @@ -119,7 +116,7 @@ def single_marlin_moe(
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
is_k_full, has_zero_point, E, topk, block_size_m, True, False)
is_k_full, E, topk, block_size_m, True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)

Expand All @@ -133,7 +130,6 @@ def fused_marlin_moe(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
has_zero_point: bool = False,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -187,6 +183,20 @@ def fused_marlin_moe(
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]

has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")

has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")

M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
Expand All @@ -213,47 +223,36 @@ def fused_marlin_moe(
device="cuda",
requires_grad=False)

if has_zero_point:
assert w1_zeros is not None and w1_zeros.nelement() > 0
assert w2_zeros is not None and w2_zeros.nelement() > 0

if w1_zeros is None:
w1_zeros = torch.empty((0),
if has_no_zp:
w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if w2_zeros is None:
w2_zeros = torch.empty((0),
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)

if g_idx1 is None:
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)

if g_idx2 is None:
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)

if sort_indices1 is None:
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)

if sort_indices2 is None:
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)

scalar_type1 = get_scalar_type(num_bits, has_zero_point)
scalar_type2 = get_scalar_type(num_bits, has_zero_point)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)

intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
Expand All @@ -277,7 +276,6 @@ def fused_marlin_moe(
2 * N,
K,
is_k_full,
has_zero_point,
E,
topk,
block_size_m,
Expand All @@ -303,7 +301,6 @@ def fused_marlin_moe(
K,
N,
is_k_full,
has_zero_point,
E,
topk,
block_size_m,
Expand Down

0 comments on commit e0e5a74

Please sign in to comment.