Skip to content

Commit 914c581

Browse files
tlrmchlsmthAlvant
authored andcommitted
[Bugfix] Fix compute datatype for cutlass 3.x epilogues (vllm-project#5931)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 5a3ab2a commit 914c581

File tree

2 files changed

+70
-59
lines changed

2 files changed

+70
-59
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ struct ScaledEpilogueBias
144144
using ScaleB = typename SUPER::ScaleB;
145145

146146
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
147-
cutlass::multiplies, ElementD, ElementD,
147+
cutlass::multiplies, float, float,
148148
cutlass::FloatRoundStyle::round_to_nearest>;
149149

150150
using EVTCompute0 =
151151
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
152152

153153
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
154-
cutlass::multiply_add, ElementD, ElementD,
154+
cutlass::multiply_add, ElementD, float,
155155
cutlass::FloatRoundStyle::round_to_nearest>;
156156

157157
using BiasDescriptor =

tests/kernels/test_cutlass.py

+68-57
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
Run `pytest tests/kernels/test_cutlass.py`.
44
"""
5-
from typing import Type
5+
from typing import Optional, Type
66

77
import pytest
88
import torch
@@ -27,12 +27,27 @@ def to_int8(tensor: torch.Tensor):
2727
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
2828

2929

30+
def baseline_scaled_mm(a: torch.Tensor,
31+
b: torch.Tensor,
32+
scale_a: torch.Tensor,
33+
scale_b: torch.Tensor,
34+
out_dtype: Type[torch.dtype],
35+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
36+
37+
output = (scale_a * (scale_b * (torch.mm(
38+
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
39+
if bias is not None:
40+
output = output + bias
41+
42+
return output
43+
44+
3045
def cutlass_fp8_gemm_helper(m: int,
3146
n: int,
3247
k: int,
3348
per_token_act_quant: bool,
3449
per_out_channel_weight_quant: bool,
35-
bias: bool,
50+
use_bias: bool,
3651
out_dtype: Type[torch.dtype] = torch.bfloat16,
3752
device: str = "cuda"):
3853
# Test for a cutlass kernel with per-token activation quantization
@@ -43,31 +58,27 @@ def cutlass_fp8_gemm_helper(m: int,
4358
m_a_scales = m if per_token_act_quant else 1
4459
n_b_scales = n if per_out_channel_weight_quant else 1
4560

46-
scale_a = (torch.randn(
47-
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
48-
scale_b = (torch.randn(
49-
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
50-
if bias:
51-
# bias term should be > 1 so that the absolute tolerance can catch it
52-
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
53-
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
61+
scale_a = (torch.randn((m_a_scales, 1), device=device,
62+
dtype=torch.float32))
63+
scale_b = (torch.randn((1, n_b_scales), device=device,
64+
dtype=torch.float32))
65+
if use_bias:
66+
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
5467
else:
55-
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
56-
bias_t = 0
68+
bias = None
5769

58-
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
59-
scale_b * b.to(dtype=torch.float32)) +
60-
bias_t).to(out_dtype)
70+
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
71+
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
6172

62-
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
73+
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
6374

6475

6576
def cutlass_int8_gemm_helper(m: int,
6677
n: int,
6778
k: int,
6879
per_token_act_quant: bool,
6980
per_out_channel_weight_quant: bool,
70-
bias: bool,
81+
use_bias: bool,
7182
out_dtype: Type[torch.dtype] = torch.bfloat16,
7283
device: str = "cuda"):
7384
# Test for a cutlass kernel with per-token activation quantization
@@ -78,22 +89,19 @@ def cutlass_int8_gemm_helper(m: int,
7889
m_a_scales = m if per_token_act_quant else 1
7990
n_b_scales = n if per_out_channel_weight_quant else 1
8091

81-
scale_a = (torch.randn(
82-
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
83-
scale_b = (torch.randn(
84-
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
92+
scale_a = (torch.randn((m_a_scales, 1), device=device,
93+
dtype=torch.float32))
94+
scale_b = (torch.randn((1, n_b_scales), device=device,
95+
dtype=torch.float32))
8596

86-
if bias:
87-
# bias term should be > 1 so that the absolute tolerance can catch it
88-
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
89-
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
97+
if use_bias:
98+
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
9099
else:
91-
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
92-
bias_t = 0
100+
bias = None
101+
102+
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
103+
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
93104

94-
baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
95-
scale_b * b.to(dtype=torch.float32)) +
96-
bias_t).to(dtype=out_dtype)
97105
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
98106

99107

@@ -102,83 +110,83 @@ def cutlass_int8_gemm_helper(m: int,
102110
@pytest.mark.parametrize("k", [128, 496, 1024])
103111
@pytest.mark.parametrize("per_act_token", [True, False])
104112
@pytest.mark.parametrize("per_out_ch", [True, False])
105-
@pytest.mark.parametrize("bias", [True, False])
113+
@pytest.mark.parametrize("use_bias", [True, False])
106114
@pytest.mark.skipif(capability < 89,
107115
reason="FP8 is not supported on this GPU type.")
108116
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
109-
per_out_ch: bool, bias: bool):
110-
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
117+
per_out_ch: bool, use_bias: bool):
118+
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
111119

112120

113121
@pytest.mark.parametrize("m", [512, 222, 33, 1])
114122
@pytest.mark.parametrize("n", [2048, 256, 1024])
115123
@pytest.mark.parametrize("k", [128, 496, 1024])
116124
@pytest.mark.parametrize("per_act_token", [True, False])
117125
@pytest.mark.parametrize("per_out_ch", [True, False])
118-
@pytest.mark.parametrize("bias", [True, False])
126+
@pytest.mark.parametrize("use_bias", [True, False])
119127
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
120-
per_out_ch: bool, bias: bool):
121-
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
128+
per_out_ch: bool, use_bias: bool):
129+
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
122130

123131

124132
@pytest.mark.parametrize("per_act_token", [True, False])
125133
@pytest.mark.parametrize("per_out_ch", [True, False])
126134
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
127-
@pytest.mark.parametrize("bias", [True, False])
135+
@pytest.mark.parametrize("use_bias", [True, False])
128136
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
129137
out_dtype: Type[torch.dtype],
130-
bias: bool):
138+
use_bias: bool):
131139
cutlass_int8_gemm_helper(512,
132140
512,
133141
512,
134142
per_act_token,
135143
per_out_ch,
136-
bias,
144+
use_bias,
137145
out_dtype=out_dtype)
138146

139147

140148
@pytest.mark.parametrize("per_act_token", [True, False])
141149
@pytest.mark.parametrize("per_out_ch", [True, False])
142150
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
143-
@pytest.mark.parametrize("bias", [True, False])
151+
@pytest.mark.parametrize("use_bias", [True, False])
144152
@pytest.mark.skipif(capability < 89,
145153
reason="FP8 is not supported on this GPU type.")
146154
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
147155
out_dtype: Type[torch.dtype],
148-
bias: bool):
156+
use_bias: bool):
149157
cutlass_fp8_gemm_helper(512,
150158
512,
151159
512,
152160
per_act_token,
153161
per_out_ch,
154-
bias,
162+
use_bias,
155163
out_dtype=out_dtype)
156164

157165

158166
@pytest.mark.parametrize("per_act_token", [True, False])
159167
@pytest.mark.parametrize("per_out_ch", [True, False])
160-
@pytest.mark.parametrize("bias", [True, False])
168+
@pytest.mark.parametrize("use_bias", [True, False])
161169
@pytest.mark.parametrize("device", CUDA_DEVICES)
162170
@pytest.mark.skipif(capability < 89,
163171
reason="FP8 is not supported on this GPU type.")
164172
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
165-
bias: bool, device: str):
166-
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
173+
use_bias: bool, device: str):
174+
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
167175
torch.bfloat16, device)
168176

169177

170178
@pytest.mark.parametrize("per_act_token", [True, False])
171179
@pytest.mark.parametrize("per_out_ch", [True, False])
172-
@pytest.mark.parametrize("bias", [True, False])
180+
@pytest.mark.parametrize("use_bias", [True, False])
173181
@pytest.mark.parametrize("device", CUDA_DEVICES)
174182
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
175-
bias: bool, device: str):
183+
use_bias: bool, device: str):
176184
cutlass_int8_gemm_helper(512,
177185
512,
178186
512,
179187
per_act_token,
180188
per_out_ch,
181-
bias,
189+
use_bias,
182190
out_dtype=torch.bfloat16,
183191
device=device)
184192

@@ -190,25 +198,26 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
190198
# kernel must handle any M thrown at it.
191199
@pytest.mark.parametrize("per_act_token", [True, False])
192200
@pytest.mark.parametrize("per_out_ch", [True, False])
193-
@pytest.mark.parametrize("bias", [True, False])
201+
@pytest.mark.parametrize("use_bias", [True, False])
194202
@pytest.mark.skipif(capability < 89,
195203
reason="FP8 is not supported on this GPU type.")
196204
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
197-
bias: bool):
205+
use_bias: bool):
198206
for nk in range(32, 128, 32):
199207
for m in range(1, 128):
200-
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
208+
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
209+
use_bias)
201210

202211

203212
@pytest.mark.parametrize("per_act_token", [True, False])
204213
@pytest.mark.parametrize("per_out_ch", [True, False])
205-
@pytest.mark.parametrize("bias", [True, False])
214+
@pytest.mark.parametrize("use_bias", [True, False])
206215
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
207-
bias: bool):
216+
use_bias: bool):
208217
for nk in range(32, 128, 32):
209218
for m in range(1, 128):
210219
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
211-
bias)
220+
use_bias)
212221

213222

214223
# Test working with a subset of A and B
@@ -229,9 +238,11 @@ def test_cutlass_subset():
229238
scale_a,
230239
scale_b,
231240
out_dtype=torch.bfloat16)
232-
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
233-
scale_b *
234-
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
241+
baseline = baseline_scaled_mm(a,
242+
b,
243+
scale_a,
244+
scale_b,
245+
out_dtype=torch.bfloat16)
235246

236247
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
237248

0 commit comments

Comments
 (0)