Skip to content

Commit

Permalink
Delete 8-bit zero point code
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Oct 1, 2024
1 parent 3ff0ba1 commit 87d46dc
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 65 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu"
"csrc/moe/marlin_moe_ops.cu")
endif()

Expand Down
31 changes: 0 additions & 31 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu

This file was deleted.

20 changes: 0 additions & 20 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h

This file was deleted.

8 changes: 3 additions & 5 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
#include "marlin_kernels/marlin_moe_kernel_ku8.h"

template <typename T>
inline std::string str(T x) {
Expand Down Expand Up @@ -461,7 +460,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
Expand All @@ -488,9 +486,9 @@ torch::Tensor marlin_gemm_moe(
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
b_q_type->str());
TORCH_CHECK(
*b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
Expand Down
10 changes: 4 additions & 6 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
num_bits: int,
):
torch.manual_seed(7)

quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
Expand Down Expand Up @@ -111,19 +110,18 @@ def test_fused_marlin_moe_awq(
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
num_bits: int,
):
torch.manual_seed(7)

quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128

Expand Down

0 comments on commit 87d46dc

Please sign in to comment.