forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wgmma_sm90.cu
562 lines (472 loc) · 20.1 KB
/
wgmma_sm90.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/device_kernel.h"
using namespace cute;
template <class ElementA,
class ElementB,
class SmemLayoutA, // (M,K,P)
class SmemLayoutB> // (N,K,P)
struct SharedStorage
{
array_aligned<ElementA, cosize_v<SmemLayoutA>> smem_A;
array_aligned<ElementB, cosize_v<SmemLayoutB>> smem_B;
uint64_t tma_barrier[size<2>(SmemLayoutA{})];
uint64_t mma_barrier[size<2>(SmemLayoutA{})];
};
template <class ProblemShape, class CtaTiler,
class TA, class SmemLayoutA, class TmaA,
class TB, class SmemLayoutB, class TmaB,
class TC, class CStride, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, CUTLASS_GRID_CONSTANT TmaA const tma_a,
TB const* B, CUTLASS_GRID_CONSTANT TmaB const tma_b,
TC * C, CStride dC, TiledMma mma,
Alpha alpha, Beta beta)
{
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
static_assert(is_static<SmemLayoutA>::value);
static_assert(is_static<SmemLayoutB>::value);
CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutA{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutB{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutA{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutB{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
auto [M, N, K] = shape_MNK;
Tensor mA = tma_a.get_tma_tensor(make_shape(M,K)); // (M,K) TMA Tensor
Tensor mB = tma_b.get_tma_tensor(make_shape(N,K)); // (N,K) TMA Tensor
Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory tensors
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, SmemLayoutA, SmemLayoutB>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Partition the copying of A and B tiles
//
// TUTORIAL:
// These are TMA partitionings, which have a dedicated custom partitioner.
// The Int<0>, Layout<_1> indicates that the TMAs are not multicasted.
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
// The group_modes<0,2> transforms the (X,Y,Z)-shaped tensors into ((X,Y),Z)-shaped tensors
// with the understanding that the TMA is responsible for everything in mode-0.
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
//
auto [tAgA, tAsA] = tma_partition(tma_a, Int<0>{}, Layout<_1>{},
group_modes<0,2>(sA), group_modes<0,2>(gA)); // (TMA,k) and (TMA,PIPE)
auto [tBgB, tBsB] = tma_partition(tma_b, Int<0>{}, Layout<_1>{},
group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE)
// The TMA is responsible for copying everything in mode-0 of tAsA and tBsB
constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) +
CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB);
//
// PREFETCH
//
auto K_PIPE_MAX = size<1>(tAsA);
// Total count of tiles
int k_tile_count = size<1>(tAgA);
// Current tile index in gmem to read from
int k_tile = 0;
// Initialize Barriers
int warp_idx = cutlass::canonical_warp_idx_sync();
int lane_predicate = cute::elect_one_sync();
uint64_t* producer_mbar = smem.tma_barrier;
uint64_t* consumer_mbar = smem.mma_barrier;
using ProducerBarType = cutlass::arch::ClusterTransactionBarrier; // TMA
using ConsumerBarType = cutlass::arch::ClusterBarrier; // MMA
CUTE_UNROLL
for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe) {
if ((warp_idx == 0) && lane_predicate) {
ProducerBarType::init(&producer_mbar[pipe], 1);
ConsumerBarType::init(&consumer_mbar[pipe], 128);
}
}
// Ensure barrier init is complete on all CTAs
cluster_sync();
// Start async loads for all pipes
CUTE_UNROLL
for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe)
{
if ((warp_idx == 0) && lane_predicate)
{
// Set expected Tx Bytes after each reset / init
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
}
--k_tile_count;
++k_tile;
}
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL:
// The tCrA and tCrB are actually Tensors of MMA Descriptors constructed as views of SMEM.
// The MMA Descriptor generation is automatic via inspection and validation of the SMEM Layouts.
// Because the MMA reads directly from SMEM and the fragments are descriptors rather than registers,
// there is no need for copy(tCsA, tCrA) in the mainloop.
//
ThrMMA thr_mma = mma.get_thread_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate accumulators and clear them
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
clear(tCrC);
// Allocate "fragments"
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
//
// PIPELINED MAIN LOOP
//
// TUTORIAL:
// Rather than interleaving the stages and instructions like in SM70 and SM80,
// the SM90 mainloops rely on explicit producer-consumer synchronization
// on the purely async instructions TMA and MMA.
// More advanced pipeline and warp-specialization strategies are available in CUTLASS mainloops.
//
// A PipelineState is a circular pipe index [.index()] and a pipe phase [.phase()]
// that flips each cycle through K_PIPE_MAX.
auto write_state = cutlass::PipelineState<K_PIPE_MAX>(); // TMA writes
auto read_state = cutlass::PipelineState<K_PIPE_MAX>(); // MMA reads
CUTE_NO_UNROLL
while (k_tile_count > -K_PIPE_MAX)
{
// Wait for Producer to complete
int read_pipe = read_state.index();
ProducerBarType::wait(&producer_mbar[read_pipe], read_state.phase());
// MMAs to cover 1 K_TILE
warpgroup_arrive();
gemm(mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC); // (V,M) x (V,N) => (V,M,N)
warpgroup_commit_batch();
// Wait for all MMAs in a K_TILE to complete
warpgroup_wait<0>();
// Notify that consumption is done
ConsumerBarType::arrive(&consumer_mbar[read_pipe]);
++read_state;
if ((warp_idx == 0) && lane_predicate)
{
int pipe = write_state.index();
// Wait for Consumer to complete consumption
ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase());
// Set expected Tx Bytes after each reset / init
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
++write_state;
}
--k_tile_count;
++k_tile;
}
//
// Epilogue (unpredicated)
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for an NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int< 3>{}; // Pipeline
// Define the smem layouts (static)
auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
// Define the MMA
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});
// Define the TMAs
// Create Global memory tensors for TMA inspection
Tensor mA = make_tensor(A, make_shape(M,K), dA);
Tensor mB = make_tensor(B, make_shape(N,K), dB);
// Create TMA Atoms with the desired copy operation on the source and destination
Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK));
Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK));
//
// Setup and Launch
//
// Launch parameter setup
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(tiled_mma));
dim3 dimCluster(2, 1, 1);
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
round_up(size(ceil_div(n, bN)), dimCluster.y));
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
void const* kernel_ptr = reinterpret_cast<void const*>(
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(sA), decltype(tmaA),
TB, decltype(sB), decltype(tmaB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>);
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
// Kernel Launch
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
prob_shape, cta_tiler,
A, tmaA,
B, tmaB,
C, dC, tiled_mma,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
// Define the MMA
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::K,GMMA::Major::K>{});
// Define the TMAs
// Create Global memory tensors for TMA inspection
Tensor mA = make_tensor(A, make_shape(M,K), dA);
Tensor mB = make_tensor(B, make_shape(N,K), dB);
// Create TMA Atoms with the desired copy operation on the source and destination
Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK));
Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK));
//
// Setup and Launch
//
// Launch parameter setup
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(tiled_mma));
dim3 dimCluster(2, 1, 1);
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
round_up(size(ceil_div(n, bN)), dimCluster.y));
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
void const* kernel_ptr = reinterpret_cast<void const*>(
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(sA), decltype(tmaA),
TB, decltype(sB), decltype(tmaB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>);
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
// Kernel Launch
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
prob_shape, cta_tiler,
A, tmaA,
B, tmaB,
C, dC, tiled_mma,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major != 9) {
std::cout << "This example requires NVIDIA's Hopper Architecture GPU with compute capability 90a\n" << std::endl;
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
int m = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 256;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 1024;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = cute::half_t;
using TB = cute::half_t;
using TC = cute::half_t;
using TI = cute::half_t;
TI alpha = TI(1.0f);
TI beta = TI(0.0f);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
// Initialize the tensors
for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1));
for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1));
for (int j = 0; j < m*n; ++j) h_C[j] = TC(0);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
#else
std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}