@@ -278,7 +278,9 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
278
278
KernelSchedule, EpilogueSchedule>;
279
279
};
280
280
281
- template <typename InType, typename OutType, int32_t M, bool IsSmallN>
281
+ template <typename InType, typename OutType,
282
+ template <typename , typename , typename > typename Epilogue, int32_t M,
283
+ bool IsSmallN> // IsSmallN is true if N < 8192
282
284
struct sm90_int8_config {
283
285
static_assert (std::is_same<InType, int8_t >());
284
286
using KernelSchedule =
@@ -287,12 +289,14 @@ struct sm90_int8_config {
287
289
using TileShape = Shape<_128, _128, _128>;
288
290
using ClusterShape = Shape<_2, _1, _1>;
289
291
using Cutlass3xGemm =
290
- cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule ,
291
- EpilogueSchedule>;
292
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape ,
293
+ KernelSchedule, EpilogueSchedule>;
292
294
};
293
295
294
- template <typename InType, typename OutType, bool IsSmallN>
295
- struct sm90_int8_config <InType, OutType, 128 , IsSmallN> {
296
+ template <typename InType, typename OutType,
297
+ template <typename , typename , typename > typename Epilogue,
298
+ bool IsSmallN>
299
+ struct sm90_int8_config <InType, OutType, Epilogue, 128 , IsSmallN> {
296
300
// Specialization for M in (64, 128] and any N
297
301
static_assert (std::is_same<InType, int8_t >());
298
302
using KernelSchedule =
@@ -301,47 +305,51 @@ struct sm90_int8_config<InType, OutType, 128, IsSmallN> {
301
305
using TileShape = Shape<_64, _128, _128>;
302
306
using ClusterShape = Shape<_2, _1, _1>;
303
307
using Cutlass3xGemm =
304
- cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule ,
305
- EpilogueSchedule>;
308
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape ,
309
+ KernelSchedule, EpilogueSchedule>;
306
310
};
307
311
308
- template <typename InType, typename OutType, bool IsSmallN>
309
- struct sm90_int8_config <InType, OutType, 64 , IsSmallN> {
312
+ template <typename InType, typename OutType,
313
+ template <typename , typename , typename > typename Epilogue,
314
+ bool IsSmallN>
315
+ struct sm90_int8_config <InType, OutType, Epilogue, 64 , IsSmallN> {
310
316
// Specialization for M in (32, 64] and any N
311
317
static_assert (std::is_same<InType, int8_t >());
312
318
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
313
319
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
314
320
using TileShape = Shape<_64, _64, _256>;
315
321
using ClusterShape = Shape<_1, _1, _1>;
316
322
using Cutlass3xGemm =
317
- cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule ,
318
- EpilogueSchedule>;
323
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape ,
324
+ KernelSchedule, EpilogueSchedule>;
319
325
};
320
326
321
- template <typename InType, typename OutType>
322
- struct sm90_int8_config <InType, OutType, 32 , false > {
327
+ template <typename InType, typename OutType,
328
+ template <typename , typename , typename > typename Epilogue>
329
+ struct sm90_int8_config <InType, OutType, Epilogue, 32 , false > {
323
330
// Specialization for M in [1, 32] and N >= 8192
324
331
static_assert (std::is_same<InType, int8_t >());
325
332
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
326
333
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
327
334
using TileShape = Shape<_64, _128, _256>;
328
335
using ClusterShape = Shape<_1, _4, _1>;
329
336
using Cutlass3xGemm =
330
- cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule ,
331
- EpilogueSchedule>;
337
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape ,
338
+ KernelSchedule, EpilogueSchedule>;
332
339
};
333
340
334
- template <typename InType, typename OutType>
335
- struct sm90_int8_config <InType, OutType, 32 , true > {
341
+ template <typename InType, typename OutType,
342
+ template <typename , typename , typename > typename Epilogue>
343
+ struct sm90_int8_config <InType, OutType, Epilogue, 32 , true > {
336
344
// Specialization for M in [1, 32] and N < 8192
337
345
static_assert (std::is_same<InType, int8_t >());
338
346
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
339
347
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
340
348
using TileShape = Shape<_64, _64, _256>;
341
349
using ClusterShape = Shape<_1, _8, _1>;
342
350
using Cutlass3xGemm =
343
- cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule ,
344
- EpilogueSchedule>;
351
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape ,
352
+ KernelSchedule, EpilogueSchedule>;
345
353
};
346
354
347
355
} // namespace
@@ -357,9 +365,9 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
357
365
TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
358
366
359
367
static const int32_t MDimDontCare = 0 ;
360
-
361
368
using Cutlass3xGemmDefault =
362
- typename sm90_fp8_config<InType, OutType, Epilogue, 0 >::Cutlass3xGemm;
369
+ typename sm90_fp8_config<InType, OutType, Epilogue,
370
+ MDimDontCare>::Cutlass3xGemm;
363
371
using Cutlass3xGemmM64 =
364
372
typename sm90_fp8_config<InType, OutType, Epilogue, 64 >::Cutlass3xGemm;
365
373
using Cutlass3xGemmM128 =
@@ -384,6 +392,70 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
384
392
}
385
393
}
386
394
395
+ template <typename InType, typename OutType,
396
+ template <typename , typename , typename > typename Epilogue,
397
+ typename ... EpilogueArgs>
398
+ void cutlass_gemm_sm90_int8_dispatch (torch::Tensor& out, torch::Tensor const & a,
399
+ torch::Tensor const & b,
400
+ EpilogueArgs&&... args) {
401
+ static_assert (std::is_same<InType, int8_t >());
402
+ TORCH_CHECK (a.dtype () == torch::kInt8 );
403
+ TORCH_CHECK (b.dtype () == torch::kInt8 );
404
+
405
+ static const int32_t MDimDontCare = 0 ;
406
+ static const bool NDimDontCare = false ;
407
+
408
+ // Same config for Large N and Small N
409
+ using Cutlass3xGemmDefault =
410
+ typename sm90_int8_config<InType, OutType, Epilogue, MDimDontCare,
411
+ NDimDontCare>::Cutlass3xGemm;
412
+ // Same config for Large N and Small N
413
+ using Cutlass3xGemmM128 =
414
+ typename sm90_int8_config<InType, OutType, Epilogue, 128 ,
415
+ NDimDontCare>::Cutlass3xGemm;
416
+ // Same config for Large N and Small N
417
+ using Cutlass3xGemmM64 =
418
+ typename sm90_int8_config<InType, OutType, Epilogue, 64 ,
419
+ NDimDontCare>::Cutlass3xGemm;
420
+ // Different configs for Large N and Small N
421
+ using Cutlass3xGemmM32LargeN =
422
+ typename sm90_int8_config<InType, OutType, Epilogue, 32 ,
423
+ false >::Cutlass3xGemm;
424
+ using Cutlass3xGemmM32SmallN =
425
+ typename sm90_int8_config<InType, OutType, Epilogue, 32 ,
426
+ true >::Cutlass3xGemm;
427
+
428
+ uint32_t const n = a.size (1 );
429
+ bool const is_small_n = n < 8192 ;
430
+
431
+ uint32_t const m = a.size (0 );
432
+ uint32_t const mp2 =
433
+ std::max (static_cast <uint32_t >(32 ), next_pow_2 (m)); // next power of 2
434
+
435
+ if (mp2 <= 32 ) {
436
+ // m in [1, 32]
437
+ if (is_small_n) {
438
+ return cutlass_gemm_caller<Cutlass3xGemmM32SmallN>(
439
+ out, a, b, std::forward<EpilogueArgs>(args)...);
440
+ } else {
441
+ return cutlass_gemm_caller<Cutlass3xGemmM32LargeN>(
442
+ out, a, b, std::forward<EpilogueArgs>(args)...);
443
+ }
444
+ } else if (mp2 <= 64 ) {
445
+ // m in (32, 64]
446
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
447
+ out, a, b, std::forward<EpilogueArgs>(args)...);
448
+ } else if (mp2 <= 128 ) {
449
+ // m in (64, 128]
450
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
451
+ out, a, b, std::forward<EpilogueArgs>(args)...);
452
+ } else {
453
+ // m in (128, inf)
454
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
455
+ out, a, b, std::forward<EpilogueArgs>(args)...);
456
+ }
457
+ }
458
+
387
459
void cutlass_scaled_mm_sm90 (torch::Tensor& out, torch::Tensor const & a,
388
460
torch::Tensor const & b,
389
461
torch::Tensor const & a_scales,
@@ -395,15 +467,13 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
395
467
TORCH_CHECK (b.dtype () == torch::kInt8 );
396
468
397
469
if (out.dtype () == torch::kBFloat16 ) {
398
- return cutlass_gemm_caller<cutlass_3x_gemm<
399
- int8_t , cutlass:: bfloat16_t , ScaledEpilogue, TileShape, ClusterShape,
400
- KernelSchedule, EpilogueSchedule>>( out, a, b, a_scales, b_scales);
470
+ return cutlass_gemm_sm90_int8_dispatch< int8_t , cutlass:: bfloat16_t ,
471
+ ScaledEpilogue>(
472
+ out, a, b, a_scales, b_scales);
401
473
} else {
402
474
TORCH_CHECK (out.dtype () == torch::kFloat16 );
403
-
404
- return cutlass_gemm_caller<
405
- cutlass_3x_gemm<int8_t , cutlass::half_t , ScaledEpilogue, TileShape,
406
- ClusterShape, KernelSchedule, EpilogueSchedule>>(
475
+ return cutlass_gemm_sm90_int8_dispatch<int8_t , cutlass::half_t ,
476
+ ScaledEpilogue>(
407
477
out, a, b, a_scales, b_scales);
408
478
}
409
479
} else {
0 commit comments