Skip to content

Commit

Permalink
Marlin 24 prefill performance improvement (about 25% better on averag…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic authored May 23, 2024
1 parent ee3eea0 commit 6066253
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 32 deletions.
74 changes: 62 additions & 12 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_quantize)
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)

Expand Down Expand Up @@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)

# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)

# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
Expand All @@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)

# Prepare
marlin_workspace = MarlinWorkspace(size_n)
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)

globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"is_k_full": is_k_full,
"a": a,
"a_tmp": a_tmp,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
"marlin_workspace": marlin_workspace,
}

min_run_time = 1
Expand Down Expand Up @@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time))

if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))

results.append(
benchmark.Timer(
stmt=
Expand Down Expand Up @@ -135,8 +170,20 @@ def main(args):
continue

for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue

for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue

for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue

for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
Expand All @@ -159,7 +206,7 @@ def main(args):


# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand All @@ -178,6 +225,9 @@ def main(args):
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])

args = parser.parse_args()
main(args)
55 changes: 40 additions & 15 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ namespace marlin_24 {
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static constexpr int THREADS = 256;
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
static constexpr int STAGES = 4;

static constexpr int min_thread_n = 128;

static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int max_par = 64;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800

Expand Down Expand Up @@ -736,10 +736,10 @@ __global__ void Marlin_24(
for (int pipe = 0; pipe < stages;) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
matmul(pipe);
wait_for_stage();

fetch_to_registers(pipe + 1, (pipe + 1) % stages);
matmul(pipe);

pipe++;
slice_iters--;
Expand Down Expand Up @@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// than better compute utilization
thread_k = 128;
thread_m = 128;
} else {
} else if (prob_n <= 256) {
thread_k = 64;
thread_m = 256;
} else {
thread_k = 32;
thread_m = 512;
}
}

Expand All @@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s;

constexpr int max_m_blocks = 4;

int* locks = (int*)workspace;
for (int i = 0; i < tot_n_blocks; i += 4) {
for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
int thread_n_blocks = tot_n_blocks - i;
prob_n = tot_n - 16 * i;
int par = 1;
if (thread_n_blocks > 4) {
if (thread_n_blocks > max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without any
// padding
par = (16 * thread_n_blocks - pad) / 64;
par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
if (par > max_par) par = max_par;
prob_n = 64 * par;
i += 4 * (par - 1);
thread_n_blocks = 4;
prob_n = (max_m_blocks * 16) * par;
i += max_m_blocks * (par - 1);
thread_n_blocks = max_m_blocks;
}

// For compilation speed, we only define the kernel configurations that have
Expand All @@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
if (false) {
} // BMxBNxBK, group
// 4-bit
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64

CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
Expand All @@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4(4, 16, 4, 2, -1)
CALL_IF_2_4(4, 16, 4, 2, 4)

CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(4, 32, 2, 1, 4)
CALL_IF_2_4(4, 32, 3, 1, -1)
CALL_IF_2_4(4, 32, 3, 1, 4)
CALL_IF_2_4(4, 32, 4, 1, -1)
CALL_IF_2_4(4, 32, 4, 1, 4)

// 8-bit
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64

CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
Expand All @@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4(8, 16, 3, 2, 4)
CALL_IF_2_4(8, 16, 4, 2, -1)
CALL_IF_2_4(8, 16, 4, 2, 4)

CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(8, 32, 2, 1, 4)
CALL_IF_2_4(8, 32, 3, 1, -1)
CALL_IF_2_4(8, 32, 3, 1, 4)
CALL_IF_2_4(8, 32, 4, 1, -1)
CALL_IF_2_4(8, 32, 4, 1, 4)
else {
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
", " + str(prob_k) + ", " + str(prob_n) + "]" +
Expand Down Expand Up @@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int thread_k = -1;
int thread_m = -1;
int sms = -1;
int max_par = 16;
int max_par = marlin_24::max_par;

int groupsize = -1;
if (b_scales.size(0) > 1) {
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
MARLIN_N_CHUNKS = [64, 128, 256]

MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [256]
MARLIN_24_N_CHUNKS = [512]

MNK_FACTORS = [
(1, 1, 1),
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 16
GPTQ_MARLIN_24_MAX_PARALLEL = 64

GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
Expand Down Expand Up @@ -53,14 +53,14 @@ def __init__(
self.tile_size = 16

# Min out_features dim
self.min_n_threads = 128
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N

# Min in_features dim
self.min_k_threads = 128
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K

# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL

# Permutation length used by the marlin kernels.
self.perm_len = 1024
Expand Down

0 comments on commit 6066253

Please sign in to comment.