forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BYOC] CUTLASS integration (apache#9261)
* byoc cutlass * add cmake and fix build * test worked but accuracy is bad * fixed argument printing properly * moving files * moving contents of cutlass_profiler into python/tvm/contrib/cutlass * run black * remove irrelavant codegen code * clang format * tried replacing sm 75 with 80, didn't help improve accuracy * remove irrelavant code from generator * tried dense + bias fusion but generated cu file does not compile * dense + bias worked after adding Leyuan's patch, bias + relu worked too * tried adding sm80 generator but accuracy is still off * remove GemmUniversal generator * cleanup partition and build * moved partition, profile and build function out of test * turned out the result match's TVM non-cutlass result. Numpy fp16 matmul is busted? * clean up test * LinearCombination can be reused for bias only epilogue * remove unsupported epilogues like gelu * removing deadcode * unify gemm templates for with or without beta scaling * supported gelu but accuracy is slightly off * gelu test passed with relaxed rtol * cleanup * remove unused stuff from library.py * move profiler template into its own file * removed gemm_profiler.py * move contents of compile_engine.py into gen_gemm.py * rename to profiler_template.cu to avoid CI issue * cleaning up trying to pass pylint * add missing asf header * run black * fixing many pylint issues except wildcard import * fixed wildcard warning * add missing CUTLASS.cmake file, restore gemm_profiler.py * pylint * minor fix * add license * start filling in TODO doc * rename GemmProfiler to GemmProfilerEmitter * more renaming and doc * add doc to the main compile API * refactored generator * run black * black fix * finish doc TODO * add test for 32 bit accum * fixed kernel generator to correctly handle fp32 accum * revise build-related API * add option to profile only one kernel * add option to enable parallel compilation * clean up gen_gemm * doc update * profile_cutlass_kernels -> tune_cutlass_kernels Co-authored-by: leyuan.wang <leyuan.wang@bytedance.com> Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
- Loading branch information
Showing
18 changed files
with
1,864 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
if(USE_CUTLASS) | ||
file(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc) | ||
list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC}) | ||
|
||
message(STATUS "Build with CUTLASS") | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
* Redistributions of source code must retain the above copyright | ||
notice, this list of conditions and the following disclaimer. | ||
* Redistributions in binary form must reproduce the above copyright | ||
notice, this list of conditions and the following disclaimer in the | ||
documentation and/or other materials provided with the distribution. | ||
* Neither the name of the NVIDIA CORPORATION nor the | ||
names of its contributors may be used to endorse or promote products | ||
derived from this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""BYOC support for CUTLASS.""" | ||
from .build import tune_cutlass_kernels, build_cutlass_kernels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name | ||
"""Driver for partitioning and building a Relay module for CUTLASS offload.""" | ||
import tvm | ||
from tvm import runtime, relay | ||
from .gen_gemm import CutlassGemmProfiler | ||
|
||
|
||
class GemmAnnotator(tvm.relay.ExprVisitor): | ||
"""Annotates partitioned functions with shape and dtype information.""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.signature = {} | ||
|
||
def visit_call(self, call): | ||
op = call.op | ||
if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: | ||
self.signature["op_type"] = op.attrs["Composite"] | ||
for i, arg in enumerate(op.params): | ||
self.signature["arg%d_shape" % i] = arg.checked_type.shape | ||
self.signature["arg%d_dtype" % i] = arg.checked_type.dtype | ||
self.signature["ret_shape"] = op.ret_type.shape | ||
self.signature["ret_dtype"] = op.ret_type.dtype | ||
|
||
|
||
def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): | ||
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which | ||
kernels to emit. | ||
Parameters | ||
---------- | ||
mod : IRModule | ||
The Relay module with cutlass partitions. | ||
sm : int | ||
An integer specifying the compute capability. For example, 75 for Turing and | ||
80 or 86 for Ampere. | ||
profile_all : bool | ||
Whether or not profile all candidate kernels, or stop profiling after | ||
the first applicable kernel is found. | ||
use_multiprocessing : bool | ||
Whether or not compile profiler executables for different kernels in parallel. | ||
tmp_dir : string, optional | ||
A temporary directory where intermediate compiled artifacts will be stored. | ||
Returns | ||
------- | ||
mod : IRModule | ||
The updated module annotated with cutlass profiling information. | ||
num_cutlass_partition : int | ||
The number of partitioned functions created for CUTLASS. | ||
""" | ||
cutlass_profiler = CutlassGemmProfiler(sm, "../../../3rdparty/cutlass", tmp_dir) | ||
num_cutlass_partition = 0 | ||
for var in mod.get_global_vars(): | ||
fun_name = var.name_hint | ||
func = mod[fun_name] | ||
annotator = GemmAnnotator() | ||
if "cutlass" in fun_name: | ||
num_cutlass_partition += 1 | ||
annotator.visit(func) | ||
# call cutlass profiler to find best settings, update attr | ||
new_attrs = {} | ||
new_attrs.update(annotator.signature) | ||
for key in func.attrs.keys(): | ||
new_attrs[key] = func.attrs[key] | ||
# call profiler | ||
arg0_shape = new_attrs["arg0_shape"] | ||
arg1_shape = new_attrs["arg1_shape"] | ||
MM = arg0_shape[0] | ||
KK = arg0_shape[1] | ||
NN = arg1_shape[0] | ||
out = cutlass_profiler.profile( | ||
MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing | ||
) | ||
if new_attrs["op_type"] == "cutlass.dense": | ||
new_attrs["cutlass_op_def"] = out["opdef"] | ||
elif new_attrs["op_type"] == "cutlass.dense_bias": | ||
new_attrs["cutlass_op_def"] = out["opdef_bias"] | ||
elif new_attrs["op_type"] == "cutlass.dense_bias_relu": | ||
new_attrs["cutlass_op_def"] = out["opdef_bias_relu"] | ||
elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]: | ||
new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"] | ||
else: | ||
raise ValueError("%s pattern is not implemented." % new_attrs["op_type"]) | ||
new_attrs["cutlass_op_name"] = out["name"] | ||
|
||
print("The best kernel is " + new_attrs["cutlass_op_name"]) | ||
if new_attrs["cutlass_op_name"].find("_tn_align") > 0: | ||
new_attrs["lda"] = "K" | ||
new_attrs["ldb"] = "K" | ||
new_attrs["ldc"] = "N" | ||
elif new_attrs["cutlass_op_name"].find("_nt_align") > 0: | ||
new_attrs["lda"] = "M" | ||
new_attrs["ldb"] = "N" | ||
new_attrs["ldc"] = "N" | ||
else: | ||
raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"]) | ||
new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) | ||
new_func = relay.Function( | ||
func.params, | ||
func.body, | ||
ret_type=func.ret_type, | ||
type_params=func.type_params, | ||
attrs=new_attrs, | ||
) | ||
mod.update_func(var, new_func) | ||
|
||
return mod, num_cutlass_partition | ||
|
||
|
||
def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"): | ||
"""Compile CUTLASS kernels in lib and return the runtime module ready to run. | ||
Parameters | ||
---------- | ||
lib : GraphExecutorFactoryModule | ||
The output from relay.build containing compiled host code and non-cutlass kernels. | ||
sm : int | ||
An integer specifying the compute capability. For example, 75 for Turing and | ||
80 or 86 for Ampere. | ||
tmp_dir : string, optional | ||
A temporary directory where intermediate compiled artifacts will be stored. | ||
lib_path : string, optional | ||
The path to a shared library which will be generated as the result of the build process | ||
Returns | ||
------- | ||
updated_lib : runtime.Module | ||
The updated module with compiled cutlass kernels. | ||
""" | ||
cutlass_path = "../../../3rdparty/cutlass/include" | ||
cutlass_util_path = "../../../3rdparty/cutlass/tools/util/include" | ||
|
||
kwargs = {} | ||
kwargs["cc"] = "nvcc" | ||
kwargs["options"] = [ | ||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", | ||
"-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm), | ||
"-Xcompiler=-fPIC", | ||
"-Xcompiler=-Wconversion", | ||
"-Xcompiler=-fno-strict-aliasing", | ||
"-O3", | ||
"-std=c++14", | ||
"-I" + cutlass_path, | ||
"-I" + cutlass_util_path, | ||
] | ||
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) | ||
return runtime.load_module(lib_path) |
Oops, something went wrong.