Skip to content

Commit

Permalink
[Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (vllm-…
Browse files Browse the repository at this point in the history
…project#3922)

Co-authored-by: alexm <alexm@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
  • Loading branch information
3 people authored Apr 29, 2024
1 parent a754596 commit ac22cdd
Show file tree
Hide file tree
Showing 14 changed files with 2,626 additions and 104 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu")
endif()

Expand Down
18 changes: 18 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ torch::Tensor marlin_gemm(
int64_t size_m,
int64_t size_n,
int64_t size_k);

torch::Tensor gptq_marlin_gemm(
torch::Tensor &a,
torch::Tensor &b_q_weight,
torch::Tensor &b_scales,
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);

torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n);
#endif

void squeezellm_gemm(
Expand Down
2 changes: 2 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif

Expand Down
Loading

0 comments on commit ac22cdd

Please sign in to comment.