-
Notifications
You must be signed in to change notification settings - Fork 960
/
mma_traits.hpp
228 lines (196 loc) · 8.54 KB
/
mma_traits.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/mma.hpp>
#include <cute/tensor.hpp>
namespace cute
{
namespace detail {
template <class X, class = void>
struct supports_output_scaling { static constexpr bool value = false; };
template <class X>
struct supports_output_scaling<X, void_t<decltype(declval<X>().accumulate_)>> { static constexpr bool value = true; };
} // end namespace detail
/**
* concept MMA_Traits
* {
* using ValTypeD = // Logical A-value type
* using ValTypeA = // Logical B-value type
* using ValTypeB = // Logical C-value type
* using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD)
*
* using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA)
* using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB)
* using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC)
*
* using Shape_MNK = // Logical MxNxK shape of the MMA
*
* using ThrID = // Logical thread id (tid) -> tidx
*
* using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord
* using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord
* using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord
* };
*/
template <class MMAOperation, class... MMAOpArgs>
struct MMA_Traits
{
static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation.");
};
template <class D, class A, class B, class C>
struct MMA_Traits<UniversalFMA<D,A,B,C>>
{
using ValTypeD = D;
using ValTypeA = A;
using ValTypeB = B;
using ValTypeC = C;
// Logical shape of the MMA
using Shape_MNK = Shape<_1,_1,_1>;
// Logical thread id (tid) -> tidx
using ThrID = Layout<_1>;
// (Logical thread id (tid), Logical value id (vid)) -> coord
// (tid,vid) -> (m,k)
using ALayout = Layout<Shape<_1,_1>>;
// (tid,vid) -> (n,k)
using BLayout = Layout<Shape<_1,_1>>;
// (tid,vid) -> (m,n)
using CLayout = Layout<Shape<_1,_1>>;
};
//
// Generic mma_unpack for any MMA_Traits
//
template <class MMA_Op, class... MMA_Args,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr
void
mma_unpack(MMA_Traits<MMA_Op, MMA_Args...> const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_rmem<TD>::value, "Expected registers in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected registers in MMA_Atom::call");
static_assert(is_rmem<TC>::value, "Expected registers in MMA_Atom::call");
// Register value types from the MMA_Operation register arrays
using RegTypeD = typename remove_extent<typename MMA_Op::DRegisters>::type;
using RegTypeA = typename remove_extent<typename MMA_Op::ARegisters>::type;
using RegTypeB = typename remove_extent<typename MMA_Op::BRegisters>::type;
using RegTypeC = typename remove_extent<typename MMA_Op::CRegisters>::type;
using MMATraits = MMA_Traits<MMA_Op, MMA_Args...>;
[[maybe_unused]] constexpr int RegNumD = extent<typename MMA_Op::DRegisters>::value;
constexpr int RegNumA = extent<typename MMA_Op::ARegisters>::value;
constexpr int RegNumB = extent<typename MMA_Op::BRegisters>::value;
constexpr int RegNumC = extent<typename MMA_Op::CRegisters>::value;
Tensor rA = recast<RegTypeA>(A);
Tensor rB = recast<RegTypeB>(B);
CUTE_STATIC_ASSERT_V(size(rA) == Int<RegNumA>{});
CUTE_STATIC_ASSERT_V(size(rB) == Int<RegNumB>{});
if constexpr (is_same<RegTypeD, void>::value)
{
static_assert(is_same<typename TD::value_type, typename TC::value_type>::value, "GMMA C and D value_type must match.");
static_assert(is_same<DLayout, CLayout>::value, "GMMA C and D layouts must match.");
// assert((void*)&C == (void*)&D);
Tensor rC = recast<RegTypeC>(D); // NOTE: D and C are same, so use mutable D
//CUTE_STATIC_ASSERT_V(size(rC) == Int<RegNumC>{});
if constexpr (detail::supports_output_scaling<MMATraits>::value) {
detail::explode(MMA_Op::fma,
rA, make_int_sequence<RegNumA>{},
rB, make_int_sequence<RegNumB>{},
rC, make_int_sequence<RegNumC>{},
&(traits.accumulate_), seq<0>{});
}
else {
detail::explode(MMA_Op::fma,
rA, make_int_sequence<RegNumA>{},
rB, make_int_sequence<RegNumB>{},
rC, make_int_sequence<RegNumC>{});
}
}
else {
Tensor rD = recast<RegTypeD>(D);
Tensor rC = recast<RegTypeC>(C);
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumD>{});
CUTE_STATIC_ASSERT_V(size(rC) == Int<RegNumC>{});
if constexpr (detail::supports_output_scaling<MMATraits>::value) {
detail::explode(MMA_Op::fma,
rD, make_int_sequence<RegNumD>{},
rA, make_int_sequence<RegNumA>{},
rB, make_int_sequence<RegNumB>{},
rC, make_int_sequence<RegNumC>{},
&(traits.accumulate_), seq<0>{});
}
else {
detail::explode(MMA_Op::fma,
rD, make_int_sequence<RegNumD>{},
rA, make_int_sequence<RegNumA>{},
rB, make_int_sequence<RegNumB>{},
rC, make_int_sequence<RegNumC>{});
}
}
}
//
// Accept mutable temporaries
//
template <class MMA_Op, class... MMA_Args,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr
void
mma_unpack(MMA_Traits<MMA_Op, MMA_Args...> const& traits,
Tensor<TD, DLayout> && D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
mma_unpack(traits, D, A, B, C);
}
namespace detail {
template <class X, class = void>
struct FrgTypeA_or_Default { using type = typename X::ValTypeA; };
template <class X>
struct FrgTypeA_or_Default<X,void_t<typename X::FrgTypeA>> { using type = typename X::FrgTypeA; };
template <class X, class = void>
struct FrgTypeB_or_Default { using type = typename X::ValTypeB; };
template <class X>
struct FrgTypeB_or_Default<X,void_t<typename X::FrgTypeB>> { using type = typename X::FrgTypeB; };
template <class X, class = void>
struct FrgTypeC_or_Default { using type = typename X::ValTypeC; };
template <class X>
struct FrgTypeC_or_Default<X,void_t<typename X::FrgTypeC>> { using type = typename X::FrgTypeC; };
} // end namespace detail
} // namespace cute