-
Notifications
You must be signed in to change notification settings - Fork 26
/
rocwmma.hpp
387 lines (366 loc) · 16.5 KB
/
rocwmma.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
/*******************************************************************************
*
* MIT License
*
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef ROCWMMA_API_HPP
#define ROCWMMA_API_HPP
#include "internal/accessors.hpp"
#include "internal/io_traits.hpp"
#include "internal/pack_util.hpp"
#include "internal/types.hpp"
/**
* \mainpage
*
* ROCWMMA is a C++ library for facilitating GEMM, or GEMM-like 2D matrix multiplications
* leveraging AMD's GPU hardware through HIP.
* Specifically, the library enhances the portability of CUDA WMMA code to
* AMD's heterogeneous platform and provides an interface to use underlying
* hardware matrix multiplication.
* The ROCWMMA API exposes memory and MMA (Matrix Multiply Accumulate) functions
* that operate on blocks, or 'fragments' of data appropriately sized for
* warp (thread block) execution.
* ROCWMMA code is templated for componentization and for providing ability to
* make compile-time optimizations based on available meta-data.
* This library is an ongoing Work-In-Progress (WIP).
*
* **Supported Hardware**
* - CDNA architecture: gfx908, gfx90a, gfx940, gfx941, gfx942 (gfx9)
* - RDNA3 architecture: gfx1100, gfx1101, gfx1102 (gfx11)
*
* **Supported Wave Sizes**
* - Wave 32 (gfx11 only)
* - Wave 64 (gfx9 only)
*
* **Supported Datatypes (gfx9)**
* - Native Data Types
* - float = f32
* - double = f64 (*only on gfx90a, gfx940, gfx941 & gfx942)
* - _Float16 = f16
* - int8
*
* - Non-Native Data Types
* - h16 = __half
* - bf16 = bfloat16
*
* **Supported Datatypes (gfx11)**
* - Native Data Types
* - _Float16 = f16
* - int8
*
* - Non-Native Data Types
* - h16 = __half
* - bf16 = bfloat16
*
* **Supported Thread Block Sizes**
* Total wave count of 4
* TBlockX | TBlockY |
* :---------:|:---------:|
* WaveSize | 1 |
* WaveSize | 2 |
* WaveSize | 4 |
* WaveSize*2 | 1 |
* WaveSize*2 | 2 |
* WaveSize*4 | 1 |
*
* @note TBlockX must be a multiple of WaveSize
*
*
* **Supported Matrix Layouts**
*
* Matrix Layout(N = col major, T = row major)
*
* LayoutA | LayoutB | LayoutC | LayoutD |
* :-------:|:---------:|:---------:|:----------:|
* N | N | N | N |
* N | T | N | N |
* T | N | N | N |
* T | T | N | N |
* N | N | T | T |
* N | T | T | T |
* T | N | T | T |
* T | T | T | T |
*
* **Data Types <Ti / To / Tc> = <InputType / OutputType / ComputeType >**
* \n
* **MMA Block Size = <BlockM, BlockN, BlockK>**
* @note gfx11 only supports BlockM/N = 16
* \n
* Ti / To / Tc | BlockM | BlockN | BlockK
* :-------------------:|:-----------:|:-----------:|:-----------:|
* i8/i32/i32 | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* i8/i8/i32 | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* f16/f32/f32 | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* f16/f16/f32 | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* f16/f16/f16 | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* __half/f32/f32 | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* __half/__half/f32 | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* __half/__half/__half | 16 | 16 | Min:16,pow2 |
* ^ | 32 | 32 | Min:8, pow2 |
* bf16/f32/f32 | 16 | 16 | Min:8, pow2 |
* ^ | 32 | 32 | Min:4, pow2 |
* bf16/bf16/f32 | 16 | 16 | Min:8, pow2 |
* ^ | 32 | 32 | Min:4, pow2 |
* bf16/bf16/bf16 | 16 | 16 | Min:8, pow2 |
* ^ | 32 | 32 | Min:4, pow2 |
* f32/f32/f32 | 16 | 16 | Min:4, pow2 |
* ^ | 32 | 32 | Min:2, pow2 |
* f64/f64/f64 | 16 | 16 | Min:4, pow2 |
*
*
* \n
* \n
* **Fragment:**
*
* **fill_fragment**
*
* Broadcast a desired value to all elements in the fragment.
*
* \n
* **load_matrix_sync / store_matrix_sync**
*
* Loads data from memory according to Matrix Layout.
* Matrix A layout loads / stores matrix columns in the K direction
* (Matrix A = M x K, fragA = BlockM x BlockK)
* Matrix B layout loads / stores matrix rows in the K direction
* (Matrix B = K x N, fragB = BlockK x BlockN)
* Matrix C layout loads / stores matrix rows in the M direction
* (Matrix C = M x N, fragAcc = BlockM x BlockN)
*
* @note Fragments are stored in packed registers, however elements have no guaranteed order.
*
* \n
* **mma_sync**
*
* MMA is performed with fragment data. The outer product of Fragment A cols
* with Fragment B rows are added back into the accumulator fragment.
*
* **synchronize_workgroup**
* Synchronization point for all wavefronts in a workgroup.
*/
namespace rocwmma
{
/**
* \defgroup Rocwmma ROCWMMA Public API
*
* @brief ROCWMMA Fragment and its API function definitions.
* @{
*/
/*! \class fragment
* \brief Definition of MFMA Fragment
*
* @tparam MatrixT - fragment context
* @tparam BlockM/N/K - block dimensions
* @tparam DataT - data type
* @tparam DataLayout - in-memory layout as col_major or row_major
*
* PackedT - The type of the vector register holding packed element data
* UnpackedT - The type of the vector register holding unpacked element data
* IOTraits - Input/output traits specific to AMDGCN architecture
* AccessT - Unpacked data storage
* StorageT = Packed data storage required for MFMA
*
* @note Fragments are stored in packed registers, however elements have no guaranteed order.
*/
template <typename MatrixT,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
typename DataT,
typename DataLayout = void>
class __align__(4) fragment
{
public:
using IOTraits =
typename IOConfig<MatrixT, BlockM, BlockN, BlockK, DataT, DataLayout>::IOTraits;
struct Traits
{
private:
using PackedElementT = typename PackTraits<DataT>::PackedT;
using UnpackedElementT = typename PackTraits<DataT>::UnpackedT;
public:
using AccessT = VecT<UnpackedElementT, IOTraits::UnpackedSize>;
using StorageT = VecT<PackedElementT, IOTraits::PackedSize>;
constexpr static uint32_t Size = IOTraits::UnpackedSize;
static_assert(IOTraits::PackedVRegCount >= 1,
"Fragments must occupy at least one packed register");
static_assert(IOTraits::UnpackedSize % IOTraits::PackedSize == 0,
"Unable to pack fragment elements");
};
ROCWMMA_DEVICE fragment() = default;
ROCWMMA_DEVICE fragment(const fragment& other);
ROCWMMA_DEVICE fragment& operator=(const fragment& other);
// Accessors
ROCWMMA_DEVICE inline DataT& operator[](uint32_t index);
ROCWMMA_DEVICE inline DataT const& operator[](uint32_t index) const;
ROCWMMA_DEVICE inline typename Traits::StorageT& operator*();
ROCWMMA_DEVICE inline typename Traits::StorageT const& operator*() const;
// Traits
ROCWMMA_DEVICE constexpr static inline uint32_t height();
ROCWMMA_DEVICE constexpr static inline uint32_t width();
ROCWMMA_DEVICE constexpr static inline uint32_t blockDim();
ROCWMMA_DEVICE constexpr static inline uint32_t kDim();
ROCWMMA_DEVICE constexpr static inline uint32_t size();
// Compatibility with nvcuda::wmma
union
{
typename Traits::StorageT mStorage; // Packed
typename Traits::AccessT mAccess; // Unpacked
typename Traits::AccessT::Native_vec_ x; // Nuanced access
static_assert(sizeof(typename Traits::AccessT) == sizeof(typename Traits::StorageT),
"Storage type and access type should be views into the same raw data");
};
constexpr static uint32_t num_elements = Traits::Size;
using element_type = DataT;
};
//! Fills the entire fragment with the desired value.
/*!
\param frag Fragment of type MatrixT with its associated block sizes, data type and layout
\param value Value of type DataT.
\tparam Matrix fragment context
\tparam BlockM/N/K block dimensions
\tparam DataT data type
\tparam DataLayout in-memory layout as col_major or row_major
*/
template <typename MatrixT,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
typename DataT,
typename DataLayout>
ROCWMMA_DEVICE void
fill_fragment(fragment<MatrixT, BlockM, BlockN, BlockK, DataT, DataLayout>& frag,
DataT value);
//! Loads the entire fragment from the data pointer according to its matrix and data layouts. Data pointer may point to either local or global memory.
/*!
\param frag Fragment of type MatrixT with its associated block sizes, data type and layout
\param data Data pointer to global/local memory
\param ldm Leading dimension size
\tparam MatrixT fragment context
\tparam BlockM/N/K block dimensions
\tparam DataT data type
\tparam DataLayout in-memory layout as col_major or row_major
*/
template <typename MatrixT,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
typename DataT,
typename DataLayout>
ROCWMMA_DEVICE void
load_matrix_sync(fragment<MatrixT, BlockM, BlockN, BlockK, DataT, DataLayout>& frag,
const DataT* data,
uint32_t ldm);
//! Loads the entire fragment from the data pointer according to its matrix layout.Data pointer may point to either local or global memory. This overload provides a run-time ability to choose the data layout of the target fragment.
/*!
\param frag Fragment of type MatrixT with its associated block sizes, data type and layout
\param data Data pointer to global/local memory
\param ldm Leading dimension size
\param layout Matrix layout
\tparam MatrixT fragment context
\tparam BlockM/N/K block dimensions
\tparam DataT data type
\tparam DataLayout in-memory layout as col_major or row_major
*/
template <typename MatrixT, uint32_t BlockM, uint32_t BlockN, uint32_t BlockK, typename DataT>
ROCWMMA_DEVICE void load_matrix_sync(fragment<MatrixT, BlockM, BlockN, BlockK, DataT>& frag,
const DataT* data,
uint32_t ldm,
layout_t layout);
//! Stores the entire fragment to the data pointer according to its matrix and data layouts. Data pointer may point to either local or global memory.
/*!
\param frag Fragment of type MatrixT with its associated block sizes, data type and layout
\param data Data pointer to global/local memory
\param ldm Leading dimension size
\tparam MatrixT fragment context
\tparam BlockM/N/K block dimensions
\tparam DataT data type
\tparam DataLayout in-memory layout as col_major or row_major
*/
template <typename MatrixT,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
typename DataT,
typename DataLayout>
ROCWMMA_DEVICE void
store_matrix_sync(DataT* data,
fragment<MatrixT, BlockM, BlockN, BlockK, DataT, DataLayout> const& frag,
uint32_t ldm);
//! Stores the entire fragment to the data pointer according to its matrix layout. Data pointer may point to either local or global memory. This overload provides a run-time ability to choose the data layout of the target fragment.
/*!
\param frag Fragment of type MatrixT with its associated block sizes, data type and layout
\param data Data pointer to global/local memory
\param ldm Leading dimension size
\param layout Data layout
\tparam MatrixT fragment context
\tparam BlockM/N/K block dimensions
\tparam DataT data type
\tparam DataLayout in-memory layout as col_major or row_major
*/
template <typename MatrixT, uint32_t BlockM, uint32_t BlockN, uint32_t BlockK, typename DataT>
ROCWMMA_DEVICE void
store_matrix_sync(DataT* data,
fragment<MatrixT, BlockM, BlockN, BlockK, DataT> const& frag,
uint32_t ldm,
layout_t layout);
//! Performs the Multiply-Accumulate operation on the fragments A, B, C and D(D = A * B + C)
/*!
\param d Accumulator output D
\param a Input fragment A
\param b Input fragment B
\param c Input accumulator fragment C
\tparam BlockM/N/K block dimensions
\tparam InputT data type of input frags A and B
\tparam ComputeT data type of accumulator fragment C / D
\tparam LayoutA in-memory layout of frag A as col_major or row_major
\tparam LayoutB in-memory layout of frag B as col_major or row_major
\note Frag c = d is valid
*/
template <uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
typename InputT,
typename ComputeT,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename LayoutD>
ROCWMMA_DEVICE void
mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>& d,
fragment<matrix_a, BlockM, BlockN, BlockK, InputT, LayoutA> const& a,
fragment<matrix_b, BlockM, BlockN, BlockK, InputT, LayoutB> const& b,
fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutC> const& c);
//! Synchronization point for all wavefronts in a workgroup.
ROCWMMA_DEVICE void synchronize_workgroup();
/** @}*/
} // namespace rocwmma
#include "rocwmma_impl.hpp"
#endif // ROCWMMA_API_HPP