Skip to content

Commit 553ae02

Browse files
masahiylc
authored andcommitted
[CUTLASS] Support more kernels: int8, tf32, and 3xtf32 (apache#9899)
* add int8 type in library * wip * adding test and plumbing data and weight dtype * adding 3xtf32 support and refactor tile description enum * add 3xtf32 test * update gemm generator too * int8 test worked * 3xtf32 also works * int8 and 3xtf32 gemm works * clean up test * support int8 in sm75 * refined int8 alignment constraints * black * support 3xtf32 in default kernel * remove log * refine dtype check * support tf32 * leave TODO for alignment modification on int8 kernels * tf32 test working * fix default kernel for tf32 * workaround for compilation failure * lint
1 parent 3a6b88b commit 553ae02

File tree

8 files changed

+445
-116
lines changed

8 files changed

+445
-116
lines changed

python/tvm/contrib/cutlass/build.py

+83-8
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,25 @@ def visit_call(self, call):
9494

9595

9696
def select_gemm_kernel(
97-
cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
97+
cutlass_profiler,
98+
op_type,
99+
MM,
100+
KK,
101+
NN,
102+
out_dtype,
103+
arg0_dtype,
104+
arg1_dtype,
105+
use_3xtf32,
106+
batched,
107+
profile_all,
108+
use_multiprocessing,
98109
):
99110
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
100111
workloads."""
101112
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
102-
out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched)
113+
out = cutlass_profiler.get_default(
114+
op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32, batched=batched
115+
)
103116
name, cutlass_op_def = out["name"], out["opdef"]
104117
logger.info("Picked the default kernel %s", name)
105118
else:
@@ -109,6 +122,9 @@ def select_gemm_kernel(
109122
NN,
110123
KK,
111124
out_dtype,
125+
arg0_dtype,
126+
arg1_dtype,
127+
use_3xtf32,
112128
batched=batched,
113129
profile_all=profile_all,
114130
use_multiprocessing=use_multiprocessing,
@@ -122,15 +138,35 @@ def select_gemm_kernel(
122138

123139

124140
def handle_batch_matmul(
125-
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
141+
cutlass_profiler,
142+
op_type,
143+
arg0_shape,
144+
arg1_shape,
145+
out_dtype,
146+
arg0_dtype,
147+
arg1_dtype,
148+
use_3xtf32,
149+
profile_all,
150+
use_multiprocessing,
126151
):
127152
"""Profile and select a kernel for batch_matmul op workload."""
128153
MM = arg0_shape[1]
129154
KK = arg0_shape[2]
130155
NN = arg1_shape[1]
131156

132157
name, cutlass_op_def = select_gemm_kernel(
133-
cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing
158+
cutlass_profiler,
159+
op_type,
160+
MM,
161+
KK,
162+
NN,
163+
out_dtype,
164+
arg0_dtype,
165+
arg1_dtype,
166+
use_3xtf32,
167+
True,
168+
profile_all,
169+
use_multiprocessing,
134170
)
135171

136172
return {
@@ -147,15 +183,35 @@ def handle_batch_matmul(
147183

148184

149185
def handle_dense(
150-
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
186+
cutlass_profiler,
187+
op_type,
188+
arg0_shape,
189+
arg1_shape,
190+
out_dtype,
191+
arg0_dtype,
192+
arg1_dtype,
193+
use_3xtf32,
194+
profile_all,
195+
use_multiprocessing,
151196
):
152197
"""Profile and select a kernel for dense op workload."""
153198
MM = arg0_shape[0]
154199
KK = arg0_shape[1]
155200
NN = arg1_shape[0]
156201

157202
name, cutlass_op_def = select_gemm_kernel(
158-
cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing
203+
cutlass_profiler,
204+
op_type,
205+
MM,
206+
KK,
207+
NN,
208+
out_dtype,
209+
arg0_dtype,
210+
arg1_dtype,
211+
use_3xtf32,
212+
False,
213+
profile_all,
214+
use_multiprocessing,
159215
)
160216

161217
assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now."
@@ -178,12 +234,15 @@ def handle_conv2d(
178234
strides,
179235
dilation,
180236
out_dtype,
237+
data_dtype,
238+
weight_dtype,
239+
use_3xtf32,
181240
profile_all,
182241
use_multiprocessing,
183242
):
184243
"""Profile and select a kernel for conv2d op workload."""
185244
if any(isinstance(s, tvm.tir.Any) for s in d_shape):
186-
out = cutlass_profiler.get_default(op_type, out_dtype)
245+
out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32)
187246
name, cutlass_op_def = out["name"], out["opdef"]
188247
logger.info("Picked the default kernel %s", name)
189248
else:
@@ -195,6 +254,9 @@ def handle_conv2d(
195254
strides,
196255
dilation,
197256
out_dtype,
257+
data_dtype,
258+
weight_dtype,
259+
use_3xtf32,
198260
profile_all=profile_all,
199261
use_multiprocessing=use_multiprocessing,
200262
)
@@ -209,7 +271,9 @@ def handle_conv2d(
209271
}
210272

211273

212-
def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
274+
def tune_cutlass_kernels(
275+
mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"
276+
):
213277
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
214278
kernels to emit.
215279
@@ -258,6 +322,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
258322
new_attrs.update(func.attrs)
259323
arg0_shape = new_attrs["arg0_shape"]
260324
arg1_shape = new_attrs["arg1_shape"]
325+
arg0_dtype = new_attrs["arg0_dtype"]
326+
arg1_dtype = new_attrs["arg1_dtype"]
261327

262328
if "conv2d" in op_type:
263329
new_attrs["padding"] = annotator.op_attrs.padding
@@ -273,6 +339,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
273339
annotator.op_attrs.strides,
274340
annotator.op_attrs.dilation,
275341
out_dtype,
342+
arg0_dtype,
343+
arg1_dtype,
344+
use_3xtf32,
276345
profile_all,
277346
use_multiprocessing,
278347
)
@@ -285,6 +354,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
285354
arg0_shape,
286355
arg1_shape,
287356
out_dtype,
357+
arg0_dtype,
358+
arg1_dtype,
359+
use_3xtf32,
288360
profile_all,
289361
use_multiprocessing,
290362
)
@@ -297,6 +369,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
297369
arg0_shape,
298370
arg1_shape,
299371
out_dtype,
372+
arg0_dtype,
373+
arg1_dtype,
374+
use_3xtf32,
300375
profile_all,
301376
use_multiprocessing,
302377
)

python/tvm/contrib/cutlass/gen_conv2d.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,13 @@ def __init__(self, sm, cutlass_path, binary_path):
153153
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
154154
self.cache = {}
155155

156-
def get_default(self, op_type, out_dtype):
157-
gemm_profile_result = self.gemm_profiler.get_default(op_type, out_dtype)
156+
def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
157+
"""Return the default kernel for the requested architecture.
158+
For now, the default kernel was picked arbitrary.
159+
"""
160+
gemm_profile_result = self.gemm_profiler.get_default(
161+
op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32
162+
)
158163
tile_description = gemm_profile_result["tile_description"]
159164
alignment = gemm_profile_result["alignment"]
160165
data_type = gemm_profile_result["data_type"]
@@ -165,9 +170,10 @@ def get_default(self, op_type, out_dtype):
165170

166171
def check_align(self, op_name, C, K):
167172
"""Filter out kernels that cannot be supported."""
168-
aligns = re.findall(r"align[1|2|4|8]", op_name)
169-
assert len(aligns) == 1
170-
align = int(aligns[0][-1])
173+
match = re.match(".*_align([1-9]+)", op_name)
174+
assert match is not None and len(match.groups()) == 1
175+
# The same alignment is used for all axes
176+
align = int(match.groups()[0])
171177
return all([dim % align == 0 for dim in [C, K]])
172178

173179
def select_op(
@@ -178,6 +184,9 @@ def select_op(
178184
stride,
179185
dilation,
180186
out_dtype,
187+
data_dtype,
188+
weight_dtype,
189+
use_3xtf32,
181190
profile_all=True,
182191
use_multiprocessing=False,
183192
):
@@ -207,9 +216,9 @@ def select_op(
207216
return self.cache[workload]
208217

209218
ops = GENERATOR_FUNC_TABLE[self.sm](
210-
out_dtype,
211-
op_creator=enumerate_conv2d_operators,
219+
out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32
212220
)
221+
213222
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))
214223

215224
if profile_all:
@@ -240,6 +249,9 @@ def profile(
240249
stride,
241250
dilation,
242251
out_dtype,
252+
data_dtype,
253+
weight_dtype,
254+
use_3xtf32=True,
243255
profile_all=True,
244256
use_multiprocessing=False,
245257
):
@@ -254,6 +266,9 @@ def profile(
254266
stride,
255267
dilation,
256268
out_dtype,
269+
data_dtype,
270+
weight_dtype,
271+
use_3xtf32,
257272
profile_all=profile_all,
258273
use_multiprocessing=use_multiprocessing,
259274
)

python/tvm/contrib/cutlass/gen_gemm.py

+53-13
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,18 @@ def enumerate_gemm_operators(
125125
# TODO(masahi): A sensible way to pick reasonable default kernels
126126
DEFAULT_KERNELS = {
127127
75: {
128-
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
129-
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
128+
("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
129+
("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
130130
},
131131
# align1 variants do not seem to be available for sm80
132132
80: {
133-
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
134-
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
133+
("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
134+
("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
135+
# two kernels for tf32 and 3xtf32
136+
("float32", "float32"): (
137+
"cutlass_tensorop_s1688gemm_128x64_32x3_tn_align1",
138+
"cutlass_tensorop_s1688gemm_64x64_16x3_tn_align1",
139+
),
135140
},
136141
}
137142

@@ -147,21 +152,31 @@ def __init__(self, sm, cutlass_path, binary_path):
147152

148153
def check_align(self, op_name, M, N, K):
149154
"""Filter out kernels that cannot be supported."""
150-
aligns = re.findall(r"align[1|2|4|8]", op_name)
151-
assert len(aligns) == 1
155+
match = re.match(".*_align([1-9]+)", op_name)
156+
assert match is not None and len(match.groups()) == 1
152157
# The same alignment is used for all axes
153-
align = int(aligns[0][-1])
158+
align = int(match.groups()[0])
154159
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
155160
# See https://github.com/NVIDIA/cutlass/issues/362.
156161
# When the above issue is resolved, we can remove the alignment check on M below.
157162
return all([dim % align == 0 for dim in [M, N, K]])
158163

159-
def get_default(self, op_type, out_dtype, batched=False):
164+
def get_default(
165+
self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False
166+
):
160167
"""Return the default kernel for the requested architecture.
161168
For now, the default kernel was picked arbitrary.
162169
"""
163-
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=enumerate_gemm_operators)
164-
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
170+
ops = GENERATOR_FUNC_TABLE[self.sm](
171+
out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32
172+
)
173+
default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]
174+
175+
if arg0_dtype == "float32":
176+
default_kernel_name = (
177+
default_kernel_name[0] if not use_3xtf32 else default_kernel_name[1]
178+
)
179+
165180
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
166181
assert len(filtered) == 1
167182
op = filtered[0]
@@ -176,7 +191,18 @@ def get_default(self, op_type, out_dtype, batched=False):
176191
op.update({"name": name, "opdef": opdef})
177192
return op
178193

179-
def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False):
194+
def select_op(
195+
self,
196+
M,
197+
N,
198+
K,
199+
out_dtype,
200+
arg0_dtype,
201+
arg1_dtype,
202+
use_3xtf32,
203+
profile_all=True,
204+
use_multiprocessing=False,
205+
):
180206
"""
181207
Profile and select the best kernel from candidate kernels.
182208
See the documentation for the profile method below.
@@ -187,7 +213,10 @@ def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=Fa
187213

188214
ops = GENERATOR_FUNC_TABLE[self.sm](
189215
out_dtype,
190-
op_creator=enumerate_gemm_operators,
216+
arg0_dtype,
217+
arg1_dtype,
218+
enumerate_gemm_operators,
219+
use_3xtf32=use_3xtf32,
191220
)
192221
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))
193222

@@ -212,6 +241,9 @@ def profile(
212241
N,
213242
K,
214243
out_dtype,
244+
arg0_dtype,
245+
arg1_dtype,
246+
use_3xtf32=True,
215247
profile_all=True,
216248
use_multiprocessing=False,
217249
batched=False,
@@ -221,7 +253,15 @@ def profile(
221253
If use_multiprocessing is True, compile all profiler executables in parallel.
222254
"""
223255
op = self.select_op(
224-
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
256+
M,
257+
N,
258+
K,
259+
out_dtype,
260+
arg0_dtype,
261+
arg1_dtype,
262+
use_3xtf32,
263+
profile_all=profile_all,
264+
use_multiprocessing=use_multiprocessing,
225265
)
226266

227267
name, opdef = create_gemm_operator_with_epilogue(

0 commit comments

Comments
 (0)