Skip to content

Commit

Permalink
introduce gen_tensor_op.py
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 13, 2021
1 parent 37bb918 commit ca1ae27
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 192 deletions.
221 changes: 29 additions & 192 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
# under the License.
# pylint: disable=invalid-name
"""GEMM kernel generator and profiler for CUTLASS."""
import logging
import os
import re
import tempfile
import subprocess
import multiprocessing
from .gemm_operation import GemmOperation, EmitGemmInstance
from .gemm_profiler import GemmProfilerEmitter
from gen_tensor_op import (
ProfilerEngine,
generate_sm75_tensor_op_1688,
generate_sm80_tensor_op_16816,
)
from .library import (
EpilogueFunctor,
SwizzlingFunctor,
Expand All @@ -37,10 +41,8 @@
TileDescription,
)

logger = logging.getLogger("cutlass")


def create_gemm_operator(
def _create_gemm_operator(
layouts,
tile_descriptions,
data_type,
Expand Down Expand Up @@ -132,141 +134,32 @@ def create_gemm_operator(
return ret


def generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions, batched=False
):
"""Common kernel generator to be used by archtecture specific generators."""
ops = []
layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
]
for math_inst in math_instructions:
tile_descriptions = get_tile_descriptions(math_inst)
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]

out = create_gemm_operator(
layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
def create_gemm_operator(batched):
def op_creator(
layouts,
tile_descriptions,
data_type,
alignment_constraints,
swizzling_functor=SwizzlingFunctor.Identity8,
):
return _create_gemm_operator(
layouts,
tile_descriptions,
data_type,
alignment_constraints,
swizzling_functor,
batched=batched,
)

ops.extend(out)

return ops


def generate_sm75_tensor_op_1688(out_dtype, batched=False):
"""Generate GEMM kernels for Turing."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
"float32": [
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
],
"float16": [
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
],
}[out_dtype]

alignment_constraints = [8, 4, 2, 1]

def get_tile_descriptions(math_inst):
min_cc = 75
max_cc = 1024
return [
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc),
]

return generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions, batched
)


def generate_sm80_tensor_op_16816(out_dtype, batched=False):
"""Generate GEMM kernels for Ampere."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
"float32": [
MathInstruction(
[16, 8, 16],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
],
"float16": [
MathInstruction(
[16, 8, 16],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
],
}[out_dtype]

alignment_constraints = [8, 4, 2]

def get_tile_descriptions(math_inst):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
return [
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
]

return generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions, batched
)
return op_creator


GENERATOR_FUNC_TABLE = {
75: generate_sm75_tensor_op_1688,
80: generate_sm80_tensor_op_16816,
}


# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
Expand All @@ -280,66 +173,6 @@ def get_tile_descriptions(math_inst):
}


class ProfilerEngine:
"""Compile and run a given profiler executable."""

def __init__(self, cuda_arch, cutlass_path, binary_prefix):
self.cuda_arch = cuda_arch
self.binary_prefix = binary_prefix
self.cutlass = cutlass_path
self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format(
cutlass=cutlass_path
)
self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format(
arch=cuda_arch
)
self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"
self.cmd = "nvcc {cflags} {src} -o {output}"

def _compile(self, op):
os.makedirs(self.binary_prefix, exist_ok=True)
opath = os.path.join(self.binary_prefix, op["name"])
if os.path.exists(opath):
return
fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu")
fi.write(op["src"])
fi.close()
cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
os.system(cmd)
os.unlink(fi.name)

def compile_all(self, ops, use_multiprocessing=False):
"""Compile all profiler executables."""
if use_multiprocessing:
pool = multiprocessing.Pool(multiprocessing.cpu_count())
pool.map(self._compile, ops)
else:
for op in ops:
self._compile(op)

def evaluate(self, op, args):
"""Run the profiler executable corresponding to op_name with args."""
op_name = op["name"]
opath = os.path.join(self.binary_prefix, op_name)
if not os.path.exists(opath):
self._compile(op)
cmd = [opath]
if args is not None:
cmd.append(str(args[0]))
cmd.append(str(args[1]))
cmd.append(str(args[2]))
if len(args) > 3:
cmd.append(str(args[3]))
try:
sp = subprocess.run(cmd, capture_output=True, check=True)
rt = float(sp.stdout)
logger.info("%s, %f", op_name, rt)
except subprocess.CalledProcessError:
rt = -1
return rt


class CutlassGemmProfiler:
"""Profile all candidate kernels and select the best one."""

Expand All @@ -362,7 +195,9 @@ def get_default(self, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, op_creator=create_gemm_operator(batched)
)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
assert len(filtered) == 1
Expand All @@ -378,7 +213,9 @@ def profile(
if (M, N, K) in self.cache:
return self.cache[(M, N, K)]

ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, op_creator=create_gemm_operator(batched)
)
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))

for op in ops:
Expand Down
Loading

0 comments on commit ca1ae27

Please sign in to comment.