7
7
8
8
from vllm import _custom_ops as ops
9
9
from vllm .model_executor .layers .quantization .gptq_marlin import (
10
+ GPTQ_MARLIN_MAX_PARALLEL , GPTQ_MARLIN_MIN_THREAD_N ,
10
11
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES , GPTQ_MARLIN_SUPPORTED_NUM_BITS )
12
+ from vllm .model_executor .layers .quantization .gptq_marlin_24 import (
13
+ GPTQ_MARLIN_24_MAX_PARALLEL , GPTQ_MARLIN_24_MIN_THREAD_N ,
14
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES , GPTQ_MARLIN_24_SUPPORTED_NUM_BITS )
15
+ from vllm .model_executor .layers .quantization .utils .marlin_perms import (
16
+ marlin_perm )
11
17
from vllm .model_executor .layers .quantization .utils .marlin_utils import (
12
- MarlinWorkspace , is_marlin_supported , marlin_quantize , marlin_weights )
18
+ MarlinWorkspace , compute_max_diff , is_marlin_supported , marlin_24_quantize ,
19
+ marlin_quantize , marlin_weights )
13
20
from vllm .model_executor .layers .quantization .utils .quant_utils import (
14
21
gptq_pack , quantize_weights , sort_weights )
15
22
16
23
ACT_ORDER_OPTS = [False , True ]
17
24
K_FULL_OPTS = [False , True ]
18
25
19
- K_CHUNKS = [128 , 256 ]
20
- N_CHUNKS = [64 , 128 , 256 ]
26
+ MARLIN_K_CHUNKS = [128 ]
27
+ MARLIN_N_CHUNKS = [64 , 128 , 256 ]
28
+
29
+ MARLIN_24_K_CHUNKS = [128 ]
30
+ MARLIN_24_N_CHUNKS = [256 ]
21
31
22
32
MNK_FACTORS = [
23
33
(1 , 1 , 1 ),
24
34
(1 , 4 , 8 ),
25
35
(1 , 7 , 5 ),
26
- (1 , 7 * 4 , 5 * 1 ),
27
36
(13 , 17 , 67 ),
28
37
(26 , 37 , 13 ),
29
38
(67 , 13 , 11 ),
30
39
]
31
40
32
41
33
42
def rand_data (shape ):
34
- data = torch .rand (shape ).to (torch .half ).cuda ()
35
- return data
43
+ return torch .randn (shape , dtype = torch .half , device = "cuda" )
36
44
37
45
38
46
@pytest .mark .skipif (not is_marlin_supported (),
39
47
reason = "Marlin is not supported on this GPU type." )
40
- @pytest .mark .parametrize ("k_chunk" , K_CHUNKS )
41
- @pytest .mark .parametrize ("n_chunk" , N_CHUNKS )
48
+ @pytest .mark .parametrize ("k_chunk" , MARLIN_K_CHUNKS )
49
+ @pytest .mark .parametrize ("n_chunk" , MARLIN_N_CHUNKS )
42
50
@pytest .mark .parametrize ("num_bits" , GPTQ_MARLIN_SUPPORTED_NUM_BITS )
43
51
@pytest .mark .parametrize ("group_size" , GPTQ_MARLIN_SUPPORTED_GROUP_SIZES )
44
52
@pytest .mark .parametrize ("act_order" , ACT_ORDER_OPTS )
@@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
82
90
q_w , g_idx , sort_indices = sort_weights (q_w , g_idx )
83
91
84
92
# Pack to Marlin format
85
- marlin_q_w_1 = marlin_weights (q_w , size_k , size_n , num_bits )
93
+ marlin_q_w_1 = marlin_weights (q_w , size_k , size_n , num_bits ,
94
+ marlin_perm [num_bits ])
86
95
87
96
# Run Marlin repack GPU kernel
88
97
marlin_q_w_2 = ops .gptq_marlin_repack (
@@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
99
108
100
109
@pytest .mark .skipif (not is_marlin_supported (),
101
110
reason = "Marlin is not supported on this GPU type." )
102
- @pytest .mark .parametrize ("k_chunk" , K_CHUNKS )
103
- @pytest .mark .parametrize ("n_chunk" , N_CHUNKS )
111
+ @pytest .mark .parametrize ("k_chunk" , MARLIN_K_CHUNKS )
112
+ @pytest .mark .parametrize ("n_chunk" , MARLIN_N_CHUNKS )
104
113
@pytest .mark .parametrize ("num_bits" , GPTQ_MARLIN_SUPPORTED_NUM_BITS )
105
114
@pytest .mark .parametrize ("group_size" , GPTQ_MARLIN_SUPPORTED_GROUP_SIZES )
106
115
@pytest .mark .parametrize ("mnk_factors" , MNK_FACTORS )
@@ -136,7 +145,8 @@ def test_marlin_gemm(
136
145
w_ref , marlin_q_w , marlin_s , g_idx , sort_indices , _ = marlin_quantize (
137
146
b_weight , num_bits , group_size , act_order )
138
147
139
- workspace = MarlinWorkspace (size_n )
148
+ workspace = MarlinWorkspace (size_n , GPTQ_MARLIN_MIN_THREAD_N ,
149
+ GPTQ_MARLIN_MAX_PARALLEL )
140
150
141
151
output = ops .gptq_marlin_gemm (
142
152
a_input ,
@@ -155,4 +165,55 @@ def test_marlin_gemm(
155
165
156
166
torch .cuda .synchronize ()
157
167
158
- assert torch .allclose (output , output_ref , rtol = 1e-2 )
168
+ max_diff = compute_max_diff (output , output_ref )
169
+ print ("max_diff = {}" .format (max_diff ))
170
+
171
+ assert max_diff < 0.04
172
+
173
+
174
+ @pytest .mark .skipif (not is_marlin_supported (),
175
+ reason = "Marlin is not supported on this GPU type." )
176
+ @pytest .mark .parametrize ("k_chunk" , MARLIN_24_K_CHUNKS )
177
+ @pytest .mark .parametrize ("n_chunk" , MARLIN_24_N_CHUNKS )
178
+ @pytest .mark .parametrize ("num_bits" , GPTQ_MARLIN_24_SUPPORTED_NUM_BITS )
179
+ @pytest .mark .parametrize ("group_size" , GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES )
180
+ @pytest .mark .parametrize ("mnk_factors" , MNK_FACTORS )
181
+ def test_marlin_24_gemm (k_chunk , n_chunk , num_bits , group_size , mnk_factors ):
182
+ m_factor , n_factor , k_factor = mnk_factors
183
+
184
+ size_m = m_factor
185
+ size_k = k_chunk * k_factor
186
+ size_n = n_chunk * n_factor
187
+
188
+ print (f"MNK = { size_m } { size_n } { size_k } " )
189
+ print (f"groupsize = { group_size } " )
190
+
191
+ a_input = rand_data ((size_m , size_k ))
192
+ b_weight = rand_data ((size_k , size_n ))
193
+
194
+ (w_24_ref , marlin_24_q_w_comp , marlin_24_meta ,
195
+ marlin_24_s ) = marlin_24_quantize (b_weight , num_bits , group_size )
196
+
197
+ workspace_24 = MarlinWorkspace (size_n , GPTQ_MARLIN_24_MIN_THREAD_N ,
198
+ GPTQ_MARLIN_24_MAX_PARALLEL )
199
+
200
+ output_ref = torch .matmul (a_input , w_24_ref )
201
+
202
+ output = ops .gptq_marlin_24_gemm (
203
+ a_input ,
204
+ marlin_24_q_w_comp ,
205
+ marlin_24_meta ,
206
+ marlin_24_s ,
207
+ workspace_24 .scratch ,
208
+ num_bits ,
209
+ a_input .shape [0 ],
210
+ b_weight .shape [1 ],
211
+ a_input .shape [1 ],
212
+ )
213
+
214
+ torch .cuda .synchronize ()
215
+
216
+ max_diff = compute_max_diff (output , output_ref )
217
+ print ("max_diff = {}" .format (max_diff ))
218
+
219
+ assert max_diff < 0.04
0 commit comments