Skip to content

Commit 67409d3

Browse files
author
Varun Sundar Rabindranath
committed
format
1 parent 2748e67 commit 67409d3

File tree

1 file changed

+98
-28
lines changed

1 file changed

+98
-28
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

+98-28
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
278278
KernelSchedule, EpilogueSchedule>;
279279
};
280280

281-
template <typename InType, typename OutType, int32_t M, bool IsSmallN>
281+
template <typename InType, typename OutType,
282+
template <typename, typename, typename> typename Epilogue, int32_t M,
283+
bool IsSmallN> // IsSmallN is true if N < 8192
282284
struct sm90_int8_config {
283285
static_assert(std::is_same<InType, int8_t>());
284286
using KernelSchedule =
@@ -287,12 +289,14 @@ struct sm90_int8_config {
287289
using TileShape = Shape<_128, _128, _128>;
288290
using ClusterShape = Shape<_2, _1, _1>;
289291
using Cutlass3xGemm =
290-
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
291-
EpilogueSchedule>;
292+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
293+
KernelSchedule, EpilogueSchedule>;
292294
};
293295

294-
template <typename InType, typename OutType, bool IsSmallN>
295-
struct sm90_int8_config<InType, OutType, 128, IsSmallN> {
296+
template <typename InType, typename OutType,
297+
template <typename, typename, typename> typename Epilogue,
298+
bool IsSmallN>
299+
struct sm90_int8_config<InType, OutType, Epilogue, 128, IsSmallN> {
296300
// Specialization for M in (64, 128] and any N
297301
static_assert(std::is_same<InType, int8_t>());
298302
using KernelSchedule =
@@ -301,47 +305,51 @@ struct sm90_int8_config<InType, OutType, 128, IsSmallN> {
301305
using TileShape = Shape<_64, _128, _128>;
302306
using ClusterShape = Shape<_2, _1, _1>;
303307
using Cutlass3xGemm =
304-
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
305-
EpilogueSchedule>;
308+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
309+
KernelSchedule, EpilogueSchedule>;
306310
};
307311

308-
template <typename InType, typename OutType, bool IsSmallN>
309-
struct sm90_int8_config<InType, OutType, 64, IsSmallN> {
312+
template <typename InType, typename OutType,
313+
template <typename, typename, typename> typename Epilogue,
314+
bool IsSmallN>
315+
struct sm90_int8_config<InType, OutType, Epilogue, 64, IsSmallN> {
310316
// Specialization for M in (32, 64] and any N
311317
static_assert(std::is_same<InType, int8_t>());
312318
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
313319
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
314320
using TileShape = Shape<_64, _64, _256>;
315321
using ClusterShape = Shape<_1, _1, _1>;
316322
using Cutlass3xGemm =
317-
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
318-
EpilogueSchedule>;
323+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
324+
KernelSchedule, EpilogueSchedule>;
319325
};
320326

321-
template <typename InType, typename OutType>
322-
struct sm90_int8_config<InType, OutType, 32, false> {
327+
template <typename InType, typename OutType,
328+
template <typename, typename, typename> typename Epilogue>
329+
struct sm90_int8_config<InType, OutType, Epilogue, 32, false> {
323330
// Specialization for M in [1, 32] and N >= 8192
324331
static_assert(std::is_same<InType, int8_t>());
325332
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
326333
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
327334
using TileShape = Shape<_64, _128, _256>;
328335
using ClusterShape = Shape<_1, _4, _1>;
329336
using Cutlass3xGemm =
330-
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
331-
EpilogueSchedule>;
337+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
338+
KernelSchedule, EpilogueSchedule>;
332339
};
333340

334-
template <typename InType, typename OutType>
335-
struct sm90_int8_config<InType, OutType, 32, true> {
341+
template <typename InType, typename OutType,
342+
template <typename, typename, typename> typename Epilogue>
343+
struct sm90_int8_config<InType, OutType, Epilogue, 32, true> {
336344
// Specialization for M in [1, 32] and N < 8192
337345
static_assert(std::is_same<InType, int8_t>());
338346
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
339347
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
340348
using TileShape = Shape<_64, _64, _256>;
341349
using ClusterShape = Shape<_1, _8, _1>;
342350
using Cutlass3xGemm =
343-
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
344-
EpilogueSchedule>;
351+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
352+
KernelSchedule, EpilogueSchedule>;
345353
};
346354

347355
} // namespace
@@ -357,9 +365,9 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
357365
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
358366

359367
static const int32_t MDimDontCare = 0;
360-
361368
using Cutlass3xGemmDefault =
362-
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
369+
typename sm90_fp8_config<InType, OutType, Epilogue,
370+
MDimDontCare>::Cutlass3xGemm;
363371
using Cutlass3xGemmM64 =
364372
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
365373
using Cutlass3xGemmM128 =
@@ -384,6 +392,70 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
384392
}
385393
}
386394

395+
template <typename InType, typename OutType,
396+
template <typename, typename, typename> typename Epilogue,
397+
typename... EpilogueArgs>
398+
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
399+
torch::Tensor const& b,
400+
EpilogueArgs&&... args) {
401+
static_assert(std::is_same<InType, int8_t>());
402+
TORCH_CHECK(a.dtype() == torch::kInt8);
403+
TORCH_CHECK(b.dtype() == torch::kInt8);
404+
405+
static const int32_t MDimDontCare = 0;
406+
static const bool NDimDontCare = false;
407+
408+
// Same config for Large N and Small N
409+
using Cutlass3xGemmDefault =
410+
typename sm90_int8_config<InType, OutType, Epilogue, MDimDontCare,
411+
NDimDontCare>::Cutlass3xGemm;
412+
// Same config for Large N and Small N
413+
using Cutlass3xGemmM128 =
414+
typename sm90_int8_config<InType, OutType, Epilogue, 128,
415+
NDimDontCare>::Cutlass3xGemm;
416+
// Same config for Large N and Small N
417+
using Cutlass3xGemmM64 =
418+
typename sm90_int8_config<InType, OutType, Epilogue, 64,
419+
NDimDontCare>::Cutlass3xGemm;
420+
// Different configs for Large N and Small N
421+
using Cutlass3xGemmM32LargeN =
422+
typename sm90_int8_config<InType, OutType, Epilogue, 32,
423+
false>::Cutlass3xGemm;
424+
using Cutlass3xGemmM32SmallN =
425+
typename sm90_int8_config<InType, OutType, Epilogue, 32,
426+
true>::Cutlass3xGemm;
427+
428+
uint32_t const n = a.size(1);
429+
bool const is_small_n = n < 8192;
430+
431+
uint32_t const m = a.size(0);
432+
uint32_t const mp2 =
433+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
434+
435+
if (mp2 <= 32) {
436+
// m in [1, 32]
437+
if (is_small_n) {
438+
return cutlass_gemm_caller<Cutlass3xGemmM32SmallN>(
439+
out, a, b, std::forward<EpilogueArgs>(args)...);
440+
} else {
441+
return cutlass_gemm_caller<Cutlass3xGemmM32LargeN>(
442+
out, a, b, std::forward<EpilogueArgs>(args)...);
443+
}
444+
} else if (mp2 <= 64) {
445+
// m in (32, 64]
446+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
447+
out, a, b, std::forward<EpilogueArgs>(args)...);
448+
} else if (mp2 <= 128) {
449+
// m in (64, 128]
450+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
451+
out, a, b, std::forward<EpilogueArgs>(args)...);
452+
} else {
453+
// m in (128, inf)
454+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
455+
out, a, b, std::forward<EpilogueArgs>(args)...);
456+
}
457+
}
458+
387459
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
388460
torch::Tensor const& b,
389461
torch::Tensor const& a_scales,
@@ -395,15 +467,13 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
395467
TORCH_CHECK(b.dtype() == torch::kInt8);
396468

397469
if (out.dtype() == torch::kBFloat16) {
398-
return cutlass_gemm_caller<cutlass_3x_gemm<
399-
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
400-
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
470+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
471+
ScaledEpilogue>(
472+
out, a, b, a_scales, b_scales);
401473
} else {
402474
TORCH_CHECK(out.dtype() == torch::kFloat16);
403-
404-
return cutlass_gemm_caller<
405-
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
406-
ClusterShape, KernelSchedule, EpilogueSchedule>>(
475+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t,
476+
ScaledEpilogue>(
407477
out, a, b, a_scales, b_scales);
408478
}
409479
} else {

0 commit comments

Comments
 (0)