2
2
3
3
Run `pytest tests/kernels/test_cutlass.py`.
4
4
"""
5
- from typing import Type
5
+ from typing import Optional , Type
6
6
7
7
import pytest
8
8
import torch
@@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor):
27
27
return torch .round (tensor .clamp (min = - 128 , max = 127 )).to (dtype = torch .int8 )
28
28
29
29
30
+ def baseline_scaled_mm (a : torch .Tensor ,
31
+ b : torch .Tensor ,
32
+ scale_a : torch .Tensor ,
33
+ scale_b : torch .Tensor ,
34
+ out_dtype : Type [torch .dtype ],
35
+ bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
36
+
37
+ output = (scale_a * (scale_b * (torch .mm (
38
+ a .to (dtype = torch .float32 ), b .to (dtype = torch .float32 ))))).to (out_dtype )
39
+ if bias is not None :
40
+ output = output + bias
41
+
42
+ return output
43
+
44
+
30
45
def cutlass_fp8_gemm_helper (m : int ,
31
46
n : int ,
32
47
k : int ,
33
48
per_token_act_quant : bool ,
34
49
per_out_channel_weight_quant : bool ,
35
- bias : bool ,
50
+ use_bias : bool ,
36
51
out_dtype : Type [torch .dtype ] = torch .bfloat16 ,
37
52
device : str = "cuda" ):
38
53
# Test for a cutlass kernel with per-token activation quantization
@@ -43,31 +58,27 @@ def cutlass_fp8_gemm_helper(m: int,
43
58
m_a_scales = m if per_token_act_quant else 1
44
59
n_b_scales = n if per_out_channel_weight_quant else 1
45
60
46
- scale_a = (torch .randn (
47
- (m_a_scales , 1 ), device = device , dtype = torch .float32 ) / 10 )
48
- scale_b = (torch .randn (
49
- (1 , n_b_scales ), device = device , dtype = torch .float32 ) / 10 )
50
- if bias :
51
- # bias term should be > 1 so that the absolute tolerance can catch it
52
- bias_t = torch .rand ((n , ), device = device , dtype = out_dtype ) + 1.0
53
- out = ops .cutlass_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias_t )
61
+ scale_a = (torch .randn ((m_a_scales , 1 ), device = device ,
62
+ dtype = torch .float32 ))
63
+ scale_b = (torch .randn ((1 , n_b_scales ), device = device ,
64
+ dtype = torch .float32 ))
65
+ if use_bias :
66
+ bias = torch .rand ((n , ), device = device , dtype = out_dtype ) * 10
54
67
else :
55
- out = ops .cutlass_scaled_mm (a , b , scale_a , scale_b , out_dtype )
56
- bias_t = 0
68
+ bias = None
57
69
58
- baseline = (torch .mm (scale_a * a .to (dtype = torch .float32 ),
59
- scale_b * b .to (dtype = torch .float32 )) +
60
- bias_t ).to (out_dtype )
70
+ out = ops .cutlass_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias )
71
+ baseline = baseline_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias )
61
72
62
- assert torch .allclose (out , baseline , rtol = 1e-2 , atol = 1e-1 )
73
+ assert torch .allclose (out , baseline , rtol = 1e-2 , atol = 5e-2 )
63
74
64
75
65
76
def cutlass_int8_gemm_helper (m : int ,
66
77
n : int ,
67
78
k : int ,
68
79
per_token_act_quant : bool ,
69
80
per_out_channel_weight_quant : bool ,
70
- bias : bool ,
81
+ use_bias : bool ,
71
82
out_dtype : Type [torch .dtype ] = torch .bfloat16 ,
72
83
device : str = "cuda" ):
73
84
# Test for a cutlass kernel with per-token activation quantization
@@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int,
78
89
m_a_scales = m if per_token_act_quant else 1
79
90
n_b_scales = n if per_out_channel_weight_quant else 1
80
91
81
- scale_a = (torch .randn (
82
- ( m_a_scales , 1 ), device = device , dtype = torch .float32 ) / 10 )
83
- scale_b = (torch .randn (
84
- ( 1 , n_b_scales ), device = device , dtype = torch .float32 ) / 10 )
92
+ scale_a = (torch .randn (( m_a_scales , 1 ), device = device ,
93
+ dtype = torch .float32 ))
94
+ scale_b = (torch .randn (( 1 , n_b_scales ), device = device ,
95
+ dtype = torch .float32 ))
85
96
86
- if bias :
87
- # bias term should be > 1 so that the absolute tolerance can catch it
88
- bias_t = torch .rand ((n , ), device = device , dtype = out_dtype ) + 1.0
89
- out = ops .cutlass_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias_t )
97
+ if use_bias :
98
+ bias = torch .rand ((n , ), device = device , dtype = out_dtype ) * 10
90
99
else :
91
- out = ops .cutlass_scaled_mm (a , b , scale_a , scale_b , out_dtype )
92
- bias_t = 0
100
+ bias = None
101
+
102
+ out = ops .cutlass_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias )
103
+ baseline = baseline_scaled_mm (a , b , scale_a , scale_b , out_dtype , bias )
93
104
94
- baseline = (torch .mm (scale_a * a .to (dtype = torch .float32 ),
95
- scale_b * b .to (dtype = torch .float32 )) +
96
- bias_t ).to (dtype = out_dtype )
97
105
assert torch .allclose (out , baseline , rtol = 1e-1 , atol = 1e0 )
98
106
99
107
@@ -102,83 +110,83 @@ def cutlass_int8_gemm_helper(m: int,
102
110
@pytest .mark .parametrize ("k" , [128 , 496 , 1024 ])
103
111
@pytest .mark .parametrize ("per_act_token" , [True , False ])
104
112
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
105
- @pytest .mark .parametrize ("bias " , [True , False ])
113
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
106
114
@pytest .mark .skipif (capability < 89 ,
107
115
reason = "FP8 is not supported on this GPU type." )
108
116
def test_cutlass_fp8_gemm (m : int , n : int , k : int , per_act_token : bool ,
109
- per_out_ch : bool , bias : bool ):
110
- cutlass_fp8_gemm_helper (m , n , k , per_act_token , per_out_ch , bias )
117
+ per_out_ch : bool , use_bias : bool ):
118
+ cutlass_fp8_gemm_helper (m , n , k , per_act_token , per_out_ch , use_bias )
111
119
112
120
113
121
@pytest .mark .parametrize ("m" , [512 , 222 , 33 , 1 ])
114
122
@pytest .mark .parametrize ("n" , [2048 , 256 , 1024 ])
115
123
@pytest .mark .parametrize ("k" , [128 , 496 , 1024 ])
116
124
@pytest .mark .parametrize ("per_act_token" , [True , False ])
117
125
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
118
- @pytest .mark .parametrize ("bias " , [True , False ])
126
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
119
127
def test_cutlass_int8_gemm (m : int , n : int , k : int , per_act_token : bool ,
120
- per_out_ch : bool , bias : bool ):
121
- cutlass_int8_gemm_helper (m , n , k , per_act_token , per_out_ch , bias )
128
+ per_out_ch : bool , use_bias : bool ):
129
+ cutlass_int8_gemm_helper (m , n , k , per_act_token , per_out_ch , use_bias )
122
130
123
131
124
132
@pytest .mark .parametrize ("per_act_token" , [True , False ])
125
133
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
126
134
@pytest .mark .parametrize ("out_dtype" , [torch .bfloat16 , torch .float16 ])
127
- @pytest .mark .parametrize ("bias " , [True , False ])
135
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
128
136
def test_cutlass_int8_gemm_output_dtype (per_act_token : bool , per_out_ch : bool ,
129
137
out_dtype : Type [torch .dtype ],
130
- bias : bool ):
138
+ use_bias : bool ):
131
139
cutlass_int8_gemm_helper (512 ,
132
140
512 ,
133
141
512 ,
134
142
per_act_token ,
135
143
per_out_ch ,
136
- bias ,
144
+ use_bias ,
137
145
out_dtype = out_dtype )
138
146
139
147
140
148
@pytest .mark .parametrize ("per_act_token" , [True , False ])
141
149
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
142
150
@pytest .mark .parametrize ("out_dtype" , [torch .bfloat16 , torch .float16 ])
143
- @pytest .mark .parametrize ("bias " , [True , False ])
151
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
144
152
@pytest .mark .skipif (capability < 89 ,
145
153
reason = "FP8 is not supported on this GPU type." )
146
154
def test_cutlass_fp8_gemm_output_dtype (per_act_token : bool , per_out_ch : bool ,
147
155
out_dtype : Type [torch .dtype ],
148
- bias : bool ):
156
+ use_bias : bool ):
149
157
cutlass_fp8_gemm_helper (512 ,
150
158
512 ,
151
159
512 ,
152
160
per_act_token ,
153
161
per_out_ch ,
154
- bias ,
162
+ use_bias ,
155
163
out_dtype = out_dtype )
156
164
157
165
158
166
@pytest .mark .parametrize ("per_act_token" , [True , False ])
159
167
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
160
- @pytest .mark .parametrize ("bias " , [True , False ])
168
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
161
169
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
162
170
@pytest .mark .skipif (capability < 89 ,
163
171
reason = "FP8 is not supported on this GPU type." )
164
172
def test_cutlass_fp8_gemm_devices (per_act_token : bool , per_out_ch : bool ,
165
- bias : bool , device : str ):
166
- cutlass_fp8_gemm_helper (512 , 512 , 512 , per_act_token , per_out_ch , bias ,
173
+ use_bias : bool , device : str ):
174
+ cutlass_fp8_gemm_helper (512 , 512 , 512 , per_act_token , per_out_ch , use_bias ,
167
175
torch .bfloat16 , device )
168
176
169
177
170
178
@pytest .mark .parametrize ("per_act_token" , [True , False ])
171
179
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
172
- @pytest .mark .parametrize ("bias " , [True , False ])
180
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
173
181
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
174
182
def test_cutlass_int8_gemm_devices (per_act_token : bool , per_out_ch : bool ,
175
- bias : bool , device : str ):
183
+ use_bias : bool , device : str ):
176
184
cutlass_int8_gemm_helper (512 ,
177
185
512 ,
178
186
512 ,
179
187
per_act_token ,
180
188
per_out_ch ,
181
- bias ,
189
+ use_bias ,
182
190
out_dtype = torch .bfloat16 ,
183
191
device = device )
184
192
@@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
190
198
# kernel must handle any M thrown at it.
191
199
@pytest .mark .parametrize ("per_act_token" , [True , False ])
192
200
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
193
- @pytest .mark .parametrize ("bias " , [True , False ])
201
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
194
202
@pytest .mark .skipif (capability < 89 ,
195
203
reason = "FP8 is not supported on this GPU type." )
196
204
def test_cutlass_fp8_gemm_m_sweep (per_act_token : bool , per_out_ch : bool ,
197
- bias : bool ):
205
+ use_bias : bool ):
198
206
for nk in range (32 , 128 , 32 ):
199
207
for m in range (1 , 128 ):
200
- cutlass_fp8_gemm_helper (m , nk , nk , per_act_token , per_out_ch , bias )
208
+ cutlass_fp8_gemm_helper (m , nk , nk , per_act_token , per_out_ch ,
209
+ use_bias )
201
210
202
211
203
212
@pytest .mark .parametrize ("per_act_token" , [True , False ])
204
213
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
205
- @pytest .mark .parametrize ("bias " , [True , False ])
214
+ @pytest .mark .parametrize ("use_bias " , [True , False ])
206
215
def test_cutlass_int8_gemm_m_sweep (per_act_token : bool , per_out_ch : bool ,
207
- bias : bool ):
216
+ use_bias : bool ):
208
217
for nk in range (32 , 128 , 32 ):
209
218
for m in range (1 , 128 ):
210
219
cutlass_int8_gemm_helper (m , nk , nk , per_act_token , per_out_ch ,
211
- bias )
220
+ use_bias )
212
221
213
222
214
223
# Test working with a subset of A and B
@@ -229,9 +238,11 @@ def test_cutlass_subset():
229
238
scale_a ,
230
239
scale_b ,
231
240
out_dtype = torch .bfloat16 )
232
- baseline = torch .mm (scale_a * a .to (dtype = torch .float32 ),
233
- scale_b *
234
- b .to (dtype = torch .float32 )).to (dtype = torch .bfloat16 )
241
+ baseline = baseline_scaled_mm (a ,
242
+ b ,
243
+ scale_a ,
244
+ scale_b ,
245
+ out_dtype = torch .bfloat16 )
235
246
236
247
assert torch .allclose (out , baseline , rtol = 1e-1 , atol = 1e0 )
237
248
0 commit comments