@@ -234,38 +234,39 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
234
234
}
235
235
236
236
template <typename InType, typename OutType,
237
- template <typename , typename , typename > typename Epilogue, int32_t M>
238
- struct sm90_fp8_config {
237
+ template <typename , typename , typename > typename Epilogue>
238
+ struct sm90_fp8_config_default {
239
+ // M in (128, inf)
239
240
static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
240
241
using KernelSchedule =
241
242
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
242
243
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
243
244
using TileShape = Shape<_128, _128, _128>;
244
245
using ClusterShape = Shape<_2, _1, _1>;
245
-
246
246
using Cutlass3xGemm =
247
247
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
248
248
KernelSchedule, EpilogueSchedule>;
249
249
};
250
250
251
251
template <typename InType, typename OutType,
252
252
template <typename , typename , typename > typename Epilogue>
253
- struct sm90_fp8_config <InType, OutType, Epilogue, 128 > {
253
+ struct sm90_fp8_config_M128 {
254
+ // M in (64, 128]
254
255
static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
255
256
using KernelSchedule =
256
257
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
257
258
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
258
259
using TileShape = Shape<_64, _128, _128>;
259
260
using ClusterShape = Shape<_2, _1, _1>;
260
-
261
261
using Cutlass3xGemm =
262
262
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
263
263
KernelSchedule, EpilogueSchedule>;
264
264
};
265
265
266
266
template <typename InType, typename OutType,
267
267
template <typename , typename , typename > typename Epilogue>
268
- struct sm90_fp8_config <InType, OutType, Epilogue, 64 > {
268
+ struct sm90_fp8_config_M64 {
269
+ // M in [1, 64]
269
270
static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
270
271
using KernelSchedule =
271
272
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
@@ -278,6 +279,78 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
278
279
KernelSchedule, EpilogueSchedule>;
279
280
};
280
281
282
+ template <typename InType, typename OutType,
283
+ template <typename , typename , typename > typename Epilogue>
284
+ struct sm90_int8_config_default {
285
+ // For M > 128 and any N
286
+ static_assert (std::is_same<InType, int8_t >());
287
+ using KernelSchedule =
288
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
289
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
290
+ using TileShape = Shape<_128, _128, _128>;
291
+ using ClusterShape = Shape<_2, _1, _1>;
292
+ using Cutlass3xGemm =
293
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
294
+ KernelSchedule, EpilogueSchedule>;
295
+ };
296
+
297
+ template <typename InType, typename OutType,
298
+ template <typename , typename , typename > typename Epilogue>
299
+ struct sm90_int8_config_M128 {
300
+ // For M in (64, 128] and any N
301
+ static_assert (std::is_same<InType, int8_t >());
302
+ using KernelSchedule =
303
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
304
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
305
+ using TileShape = Shape<_64, _128, _128>;
306
+ using ClusterShape = Shape<_2, _1, _1>;
307
+ using Cutlass3xGemm =
308
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
309
+ KernelSchedule, EpilogueSchedule>;
310
+ };
311
+
312
+ template <typename InType, typename OutType,
313
+ template <typename , typename , typename > typename Epilogue>
314
+ struct sm90_int8_config_M64 {
315
+ // For M in (32, 64] and any N
316
+ static_assert (std::is_same<InType, int8_t >());
317
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
318
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
319
+ using TileShape = Shape<_64, _64, _256>;
320
+ using ClusterShape = Shape<_1, _1, _1>;
321
+ using Cutlass3xGemm =
322
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
323
+ KernelSchedule, EpilogueSchedule>;
324
+ };
325
+
326
+ template <typename InType, typename OutType,
327
+ template <typename , typename , typename > typename Epilogue>
328
+ struct sm90_int8_config_M32_NBig {
329
+ // For M in [1, 32] and N >= 8192
330
+ static_assert (std::is_same<InType, int8_t >());
331
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
332
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
333
+ using TileShape = Shape<_64, _128, _256>;
334
+ using ClusterShape = Shape<_1, _4, _1>;
335
+ using Cutlass3xGemm =
336
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
337
+ KernelSchedule, EpilogueSchedule>;
338
+ };
339
+
340
+ template <typename InType, typename OutType,
341
+ template <typename , typename , typename > typename Epilogue>
342
+ struct sm90_int8_config_M32_NSmall {
343
+ // For M in [1, 32] and N < 8192
344
+ static_assert (std::is_same<InType, int8_t >());
345
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
346
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
347
+ using TileShape = Shape<_64, _64, _256>;
348
+ using ClusterShape = Shape<_1, _8, _1>;
349
+ using Cutlass3xGemm =
350
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
351
+ KernelSchedule, EpilogueSchedule>;
352
+ };
353
+
281
354
} // namespace
282
355
283
356
template <typename InType, typename OutType,
@@ -291,11 +364,12 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
291
364
TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
292
365
293
366
using Cutlass3xGemmDefault =
294
- typename sm90_fp8_config<InType, OutType, Epilogue, 0 >::Cutlass3xGemm;
367
+ typename sm90_fp8_config_default<InType, OutType,
368
+ Epilogue>::Cutlass3xGemm;
295
369
using Cutlass3xGemmM64 =
296
- typename sm90_fp8_config <InType, OutType, Epilogue, 64 >::Cutlass3xGemm;
370
+ typename sm90_fp8_config_M64 <InType, OutType, Epilogue>::Cutlass3xGemm;
297
371
using Cutlass3xGemmM128 =
298
- typename sm90_fp8_config <InType, OutType, Epilogue, 128 >::Cutlass3xGemm;
372
+ typename sm90_fp8_config_M128 <InType, OutType, Epilogue>::Cutlass3xGemm;
299
373
300
374
uint32_t const m = a.size (0 );
301
375
uint32_t const mp2 =
@@ -316,6 +390,61 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
316
390
}
317
391
}
318
392
393
+ template <typename InType, typename OutType,
394
+ template <typename , typename , typename > typename Epilogue,
395
+ typename ... EpilogueArgs>
396
+ void cutlass_gemm_sm90_int8_dispatch (torch::Tensor& out, torch::Tensor const & a,
397
+ torch::Tensor const & b,
398
+ EpilogueArgs&&... args) {
399
+ static_assert (std::is_same<InType, int8_t >());
400
+ TORCH_CHECK (a.dtype () == torch::kInt8 );
401
+ TORCH_CHECK (b.dtype () == torch::kInt8 );
402
+
403
+ using Cutlass3xGemmDefault =
404
+ typename sm90_int8_config_default<InType, OutType,
405
+ Epilogue>::Cutlass3xGemm;
406
+ using Cutlass3xGemmM128 =
407
+ typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
408
+ using Cutlass3xGemmM64 =
409
+ typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
410
+ using Cutlass3xGemmM32NBig =
411
+ typename sm90_int8_config_M32_NBig<InType, OutType,
412
+ Epilogue>::Cutlass3xGemm;
413
+ using Cutlass3xGemmM32NSmall =
414
+ typename sm90_int8_config_M32_NSmall<InType, OutType,
415
+ Epilogue>::Cutlass3xGemm;
416
+
417
+ uint32_t const n = out.size (1 );
418
+ bool const is_small_n = n < 8192 ;
419
+
420
+ uint32_t const m = a.size (0 );
421
+ uint32_t const mp2 =
422
+ std::max (static_cast <uint32_t >(32 ), next_pow_2 (m)); // next power of 2
423
+
424
+ if (mp2 <= 32 ) {
425
+ // m in [1, 32]
426
+ if (is_small_n) {
427
+ return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
428
+ out, a, b, std::forward<EpilogueArgs>(args)...);
429
+ } else {
430
+ return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
431
+ out, a, b, std::forward<EpilogueArgs>(args)...);
432
+ }
433
+ } else if (mp2 <= 64 ) {
434
+ // m in (32, 64]
435
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
436
+ out, a, b, std::forward<EpilogueArgs>(args)...);
437
+ } else if (mp2 <= 128 ) {
438
+ // m in (64, 128]
439
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
440
+ out, a, b, std::forward<EpilogueArgs>(args)...);
441
+ } else {
442
+ // m in (128, inf)
443
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
444
+ out, a, b, std::forward<EpilogueArgs>(args)...);
445
+ }
446
+ }
447
+
319
448
void cutlass_scaled_mm_sm90 (torch::Tensor& out, torch::Tensor const & a,
320
449
torch::Tensor const & b,
321
450
torch::Tensor const & a_scales,
@@ -326,22 +455,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
326
455
if (a.dtype () == torch::kInt8 ) {
327
456
TORCH_CHECK (b.dtype () == torch::kInt8 );
328
457
329
- using TileShape = Shape<_128, _128, _128>;
330
- using ClusterShape = Shape<_1, _2, _1>;
331
- using KernelSchedule =
332
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
333
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
334
-
335
458
if (out.dtype () == torch::kBFloat16 ) {
336
- return cutlass_gemm_caller<cutlass_3x_gemm<
337
- int8_t , cutlass:: bfloat16_t , ScaledEpilogue, TileShape, ClusterShape,
338
- KernelSchedule, EpilogueSchedule>>( out, a, b, a_scales, b_scales);
459
+ return cutlass_gemm_sm90_int8_dispatch< int8_t , cutlass:: bfloat16_t ,
460
+ ScaledEpilogue>(
461
+ out, a, b, a_scales, b_scales);
339
462
} else {
340
463
TORCH_CHECK (out.dtype () == torch::kFloat16 );
341
-
342
- return cutlass_gemm_caller<
343
- cutlass_3x_gemm<int8_t , cutlass::half_t , ScaledEpilogue, TileShape,
344
- ClusterShape, KernelSchedule, EpilogueSchedule>>(
464
+ return cutlass_gemm_sm90_int8_dispatch<int8_t , cutlass::half_t ,
465
+ ScaledEpilogue>(
345
466
out, a, b, a_scales, b_scales);
346
467
}
347
468
} else {
0 commit comments