Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 7d85753

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[Kernel] Update Cutlass int8 kernel configs for SM90 (vllm-project#5514)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent df3ae01 commit 7d85753

File tree

1 file changed

+143
-22
lines changed

1 file changed

+143
-22
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

+143-22
Original file line numberDiff line numberDiff line change
@@ -234,38 +234,39 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
234234
}
235235

236236
template <typename InType, typename OutType,
237-
template <typename, typename, typename> typename Epilogue, int32_t M>
238-
struct sm90_fp8_config {
237+
template <typename, typename, typename> typename Epilogue>
238+
struct sm90_fp8_config_default {
239+
// M in (128, inf)
239240
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
240241
using KernelSchedule =
241242
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
242243
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
243244
using TileShape = Shape<_128, _128, _128>;
244245
using ClusterShape = Shape<_2, _1, _1>;
245-
246246
using Cutlass3xGemm =
247247
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
248248
KernelSchedule, EpilogueSchedule>;
249249
};
250250

251251
template <typename InType, typename OutType,
252252
template <typename, typename, typename> typename Epilogue>
253-
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
253+
struct sm90_fp8_config_M128 {
254+
// M in (64, 128]
254255
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
255256
using KernelSchedule =
256257
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
257258
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
258259
using TileShape = Shape<_64, _128, _128>;
259260
using ClusterShape = Shape<_2, _1, _1>;
260-
261261
using Cutlass3xGemm =
262262
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
263263
KernelSchedule, EpilogueSchedule>;
264264
};
265265

266266
template <typename InType, typename OutType,
267267
template <typename, typename, typename> typename Epilogue>
268-
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
268+
struct sm90_fp8_config_M64 {
269+
// M in [1, 64]
269270
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
270271
using KernelSchedule =
271272
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
@@ -278,6 +279,78 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
278279
KernelSchedule, EpilogueSchedule>;
279280
};
280281

282+
template <typename InType, typename OutType,
283+
template <typename, typename, typename> typename Epilogue>
284+
struct sm90_int8_config_default {
285+
// For M > 128 and any N
286+
static_assert(std::is_same<InType, int8_t>());
287+
using KernelSchedule =
288+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
289+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
290+
using TileShape = Shape<_128, _128, _128>;
291+
using ClusterShape = Shape<_2, _1, _1>;
292+
using Cutlass3xGemm =
293+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
294+
KernelSchedule, EpilogueSchedule>;
295+
};
296+
297+
template <typename InType, typename OutType,
298+
template <typename, typename, typename> typename Epilogue>
299+
struct sm90_int8_config_M128 {
300+
// For M in (64, 128] and any N
301+
static_assert(std::is_same<InType, int8_t>());
302+
using KernelSchedule =
303+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
304+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
305+
using TileShape = Shape<_64, _128, _128>;
306+
using ClusterShape = Shape<_2, _1, _1>;
307+
using Cutlass3xGemm =
308+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
309+
KernelSchedule, EpilogueSchedule>;
310+
};
311+
312+
template <typename InType, typename OutType,
313+
template <typename, typename, typename> typename Epilogue>
314+
struct sm90_int8_config_M64 {
315+
// For M in (32, 64] and any N
316+
static_assert(std::is_same<InType, int8_t>());
317+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
318+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
319+
using TileShape = Shape<_64, _64, _256>;
320+
using ClusterShape = Shape<_1, _1, _1>;
321+
using Cutlass3xGemm =
322+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
323+
KernelSchedule, EpilogueSchedule>;
324+
};
325+
326+
template <typename InType, typename OutType,
327+
template <typename, typename, typename> typename Epilogue>
328+
struct sm90_int8_config_M32_NBig {
329+
// For M in [1, 32] and N >= 8192
330+
static_assert(std::is_same<InType, int8_t>());
331+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
332+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
333+
using TileShape = Shape<_64, _128, _256>;
334+
using ClusterShape = Shape<_1, _4, _1>;
335+
using Cutlass3xGemm =
336+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
337+
KernelSchedule, EpilogueSchedule>;
338+
};
339+
340+
template <typename InType, typename OutType,
341+
template <typename, typename, typename> typename Epilogue>
342+
struct sm90_int8_config_M32_NSmall {
343+
// For M in [1, 32] and N < 8192
344+
static_assert(std::is_same<InType, int8_t>());
345+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
346+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
347+
using TileShape = Shape<_64, _64, _256>;
348+
using ClusterShape = Shape<_1, _8, _1>;
349+
using Cutlass3xGemm =
350+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
351+
KernelSchedule, EpilogueSchedule>;
352+
};
353+
281354
} // namespace
282355

283356
template <typename InType, typename OutType,
@@ -291,11 +364,12 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
291364
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
292365

293366
using Cutlass3xGemmDefault =
294-
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
367+
typename sm90_fp8_config_default<InType, OutType,
368+
Epilogue>::Cutlass3xGemm;
295369
using Cutlass3xGemmM64 =
296-
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
370+
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
297371
using Cutlass3xGemmM128 =
298-
typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
372+
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
299373

300374
uint32_t const m = a.size(0);
301375
uint32_t const mp2 =
@@ -316,6 +390,61 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
316390
}
317391
}
318392

393+
template <typename InType, typename OutType,
394+
template <typename, typename, typename> typename Epilogue,
395+
typename... EpilogueArgs>
396+
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
397+
torch::Tensor const& b,
398+
EpilogueArgs&&... args) {
399+
static_assert(std::is_same<InType, int8_t>());
400+
TORCH_CHECK(a.dtype() == torch::kInt8);
401+
TORCH_CHECK(b.dtype() == torch::kInt8);
402+
403+
using Cutlass3xGemmDefault =
404+
typename sm90_int8_config_default<InType, OutType,
405+
Epilogue>::Cutlass3xGemm;
406+
using Cutlass3xGemmM128 =
407+
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
408+
using Cutlass3xGemmM64 =
409+
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
410+
using Cutlass3xGemmM32NBig =
411+
typename sm90_int8_config_M32_NBig<InType, OutType,
412+
Epilogue>::Cutlass3xGemm;
413+
using Cutlass3xGemmM32NSmall =
414+
typename sm90_int8_config_M32_NSmall<InType, OutType,
415+
Epilogue>::Cutlass3xGemm;
416+
417+
uint32_t const n = out.size(1);
418+
bool const is_small_n = n < 8192;
419+
420+
uint32_t const m = a.size(0);
421+
uint32_t const mp2 =
422+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
423+
424+
if (mp2 <= 32) {
425+
// m in [1, 32]
426+
if (is_small_n) {
427+
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
428+
out, a, b, std::forward<EpilogueArgs>(args)...);
429+
} else {
430+
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
431+
out, a, b, std::forward<EpilogueArgs>(args)...);
432+
}
433+
} else if (mp2 <= 64) {
434+
// m in (32, 64]
435+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
436+
out, a, b, std::forward<EpilogueArgs>(args)...);
437+
} else if (mp2 <= 128) {
438+
// m in (64, 128]
439+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
440+
out, a, b, std::forward<EpilogueArgs>(args)...);
441+
} else {
442+
// m in (128, inf)
443+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
444+
out, a, b, std::forward<EpilogueArgs>(args)...);
445+
}
446+
}
447+
319448
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
320449
torch::Tensor const& b,
321450
torch::Tensor const& a_scales,
@@ -326,22 +455,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
326455
if (a.dtype() == torch::kInt8) {
327456
TORCH_CHECK(b.dtype() == torch::kInt8);
328457

329-
using TileShape = Shape<_128, _128, _128>;
330-
using ClusterShape = Shape<_1, _2, _1>;
331-
using KernelSchedule =
332-
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
333-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
334-
335458
if (out.dtype() == torch::kBFloat16) {
336-
return cutlass_gemm_caller<cutlass_3x_gemm<
337-
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
338-
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
459+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
460+
ScaledEpilogue>(
461+
out, a, b, a_scales, b_scales);
339462
} else {
340463
TORCH_CHECK(out.dtype() == torch::kFloat16);
341-
342-
return cutlass_gemm_caller<
343-
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
344-
ClusterShape, KernelSchedule, EpilogueSchedule>>(
464+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t,
465+
ScaledEpilogue>(
345466
out, a, b, a_scales, b_scales);
346467
}
347468
} else {

0 commit comments

Comments
 (0)