Skip to content

Commit 41da12f

Browse files
alexm-redhatjoerunde
authored andcommitted
[Kernel] Add marlin_24 unit tests (vllm-project#4901)
1 parent fc3cc45 commit 41da12f

File tree

6 files changed

+649
-103
lines changed

6 files changed

+649
-103
lines changed

tests/kernels/test_marlin_gemm.py

+74-13
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,46 @@
77

88
from vllm import _custom_ops as ops
99
from vllm.model_executor.layers.quantization.gptq_marlin import (
10+
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
1011
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
12+
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
13+
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
14+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
15+
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
16+
marlin_perm)
1117
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
12-
MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights)
18+
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
19+
marlin_quantize, marlin_weights)
1320
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1421
gptq_pack, quantize_weights, sort_weights)
1522

1623
ACT_ORDER_OPTS = [False, True]
1724
K_FULL_OPTS = [False, True]
1825

19-
K_CHUNKS = [128, 256]
20-
N_CHUNKS = [64, 128, 256]
26+
MARLIN_K_CHUNKS = [128]
27+
MARLIN_N_CHUNKS = [64, 128, 256]
28+
29+
MARLIN_24_K_CHUNKS = [128]
30+
MARLIN_24_N_CHUNKS = [256]
2131

2232
MNK_FACTORS = [
2333
(1, 1, 1),
2434
(1, 4, 8),
2535
(1, 7, 5),
26-
(1, 7 * 4, 5 * 1),
2736
(13, 17, 67),
2837
(26, 37, 13),
2938
(67, 13, 11),
3039
]
3140

3241

3342
def rand_data(shape):
34-
data = torch.rand(shape).to(torch.half).cuda()
35-
return data
43+
return torch.randn(shape, dtype=torch.half, device="cuda")
3644

3745

3846
@pytest.mark.skipif(not is_marlin_supported(),
3947
reason="Marlin is not supported on this GPU type.")
40-
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
41-
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
48+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
49+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
4250
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
4351
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
4452
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
8290
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
8391

8492
# Pack to Marlin format
85-
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits)
93+
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits,
94+
marlin_perm[num_bits])
8695

8796
# Run Marlin repack GPU kernel
8897
marlin_q_w_2 = ops.gptq_marlin_repack(
@@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
99108

100109
@pytest.mark.skipif(not is_marlin_supported(),
101110
reason="Marlin is not supported on this GPU type.")
102-
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
103-
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
111+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
112+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
104113
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
105114
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
106115
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@@ -136,7 +145,8 @@ def test_marlin_gemm(
136145
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
137146
b_weight, num_bits, group_size, act_order)
138147

139-
workspace = MarlinWorkspace(size_n)
148+
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
149+
GPTQ_MARLIN_MAX_PARALLEL)
140150

141151
output = ops.gptq_marlin_gemm(
142152
a_input,
@@ -155,4 +165,55 @@ def test_marlin_gemm(
155165

156166
torch.cuda.synchronize()
157167

158-
assert torch.allclose(output, output_ref, rtol=1e-2)
168+
max_diff = compute_max_diff(output, output_ref)
169+
print("max_diff = {}".format(max_diff))
170+
171+
assert max_diff < 0.04
172+
173+
174+
@pytest.mark.skipif(not is_marlin_supported(),
175+
reason="Marlin is not supported on this GPU type.")
176+
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
177+
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
178+
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
179+
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
180+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
181+
def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
182+
m_factor, n_factor, k_factor = mnk_factors
183+
184+
size_m = m_factor
185+
size_k = k_chunk * k_factor
186+
size_n = n_chunk * n_factor
187+
188+
print(f"MNK = {size_m} {size_n} {size_k}")
189+
print(f"groupsize = {group_size}")
190+
191+
a_input = rand_data((size_m, size_k))
192+
b_weight = rand_data((size_k, size_n))
193+
194+
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
195+
marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size)
196+
197+
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
198+
GPTQ_MARLIN_24_MAX_PARALLEL)
199+
200+
output_ref = torch.matmul(a_input, w_24_ref)
201+
202+
output = ops.gptq_marlin_24_gemm(
203+
a_input,
204+
marlin_24_q_w_comp,
205+
marlin_24_meta,
206+
marlin_24_s,
207+
workspace_24.scratch,
208+
num_bits,
209+
a_input.shape[0],
210+
b_weight.shape[1],
211+
a_input.shape[1],
212+
)
213+
214+
torch.cuda.synchronize()
215+
216+
max_diff = compute_max_diff(output, output_ref)
217+
print("max_diff = {}".format(max_diff))
218+
219+
assert max_diff < 0.04

vllm/model_executor/layers/quantization/gptq_marlin_24.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212

1313
logger = init_logger(__name__)
1414

15+
GPTQ_MARLIN_24_TILE = 16
16+
GPTQ_MARLIN_24_MIN_THREAD_N = 128
17+
GPTQ_MARLIN_24_MIN_THREAD_K = 128
18+
GPTQ_MARLIN_24_MAX_PARALLEL = 16
19+
20+
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
21+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
22+
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
23+
1524

1625
class GPTQMarlin24Config(QuantizationConfig):
1726
"""Config class for Marlin24.
@@ -25,15 +34,17 @@ def __init__(
2534
self.weight_bits = weight_bits
2635
self.group_size = group_size
2736

28-
if self.weight_bits != 4 and self.weight_bits != 8:
29-
raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
30-
self.weight_bits))
31-
32-
if self.group_size != 128 and self.group_size != -1:
37+
# Verify
38+
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
39+
raise ValueError(
40+
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
41+
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
42+
"are supported.")
43+
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
3344
raise ValueError(
34-
"Currently, only group size 128 and -1 (channelwise) "
35-
"is supported for Marlin24, but got group_size of "
36-
f"{self.group_size}")
45+
f"Marlin_24 does not support group_size = {self.group_size}. "
46+
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
47+
"are supported.")
3748

3849
# 4 Bits packed into 32 bit datatype.
3950
self.pack_factor = 32 // self.weight_bits

0 commit comments

Comments
 (0)