-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BYOC] CUTLASS integration #9261
Conversation
Thanks @Laurawly, I'll work on this as my top priority! |
Really cool work @Laurawly! |
One question for @comaniac: Currently, this integration is implemented with C source backend. But one big use case for cutlass BYOC is for dynamic workloads, for which we cannot do well at the moment. I think C source backend is not compatible with dynamic inputs, is that right? On the other hand, with json codegen/runtime, I can see a way to support dynamic inputs. However, since cutlass is a template library, the only way it could work with the json codegen/runtime is to instantiate all possible templates that we care about ahead of time, and build them together with |
IMHO, CUTLASS doesn't naturally benefit dynamic workloads due to the exact reason you mentioned. We internally use CUTLASS for training and it works well because we JIT kernels with known shapes in runtime. In the case of CUTLASS with BYOC in TVM for inference, my impression is we could leverage high performance kernel templates while 1) keeping the binary self-contained, 2) fusing ops, and 3) having lightweight tuning (e.g., ~10 trials similar to CUDNN). On the other hand, dynamic workloads are still challenging, and hopefully our ongoing efforts of dynamic kernel tuning and generation could be landed soon to make it happen. |
hmm, are you concerned about slow performance due to lack of tuning, or the integration problem that I brought up? I wouldn't worry about the former because cublas is fast without tuning. So I expect cutlass to perform equally well. Since our cublas offload supports dynamic inputs, not supporting dynamic inputs for cutlass would be a big bummer imo. So I want to discuss the integration problem first, and investigate performance issues later. |
I think even C source code should be able to handle dynamic shapes as well since it only expects tensors at the runtime, or I might forget something here. But in general, json format is more recommended as it is easier to debug and maintain. It is also more friendly in handling constants. |
Exactly, I've looked at the code a bit and it seems currently we generate an API that takes tvm/src/relay/backend/contrib/codegen_c/codegen_c.h Lines 176 to 185 in 3cb838d
I'm glad to find at least one path that supports all use cases. For the json codegen/runtime path, I'm not sure how to integrate cutlass's JIT codegen + compile approach with it. |
Ah sorry I didn't make it clear. The interface of C source codegen does deal with dynamic workloads, because it takes the raw pointer which could be any size in run time. What I meant was how to generate CUTLASS kernels that are able to perform well with all shapes. In @Laurawly's post, they generate lots of kernels to cover possible shapes, which result in 7GB binary. I assume they also generate a run time dispatching logic (also in the generated C source code) to determine which kernel should be used given the known shape in run time. Obviously, the binary size will definitely be an issue for this solution. For JSON codegen/runtime, it would be similar to TensorRT: We simply dump a JSON graph in codegen without doing anything else. Meanwhile, we have a custom runtime that JITs/catches CUTLASS kernels based on known shapes. This results in a much smaller binary, but the first execution (or an execution with new shapes) may take several seconds or even a minute to JIT all kernels. |
At this moment, CUTLASS does not provide heuristics to choose which kernel to use based on the runtime information. The user of CUTLASS needs to make the decision to pick kernels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
python/tvm/contrib/cutlass/build.py
Outdated
self.signature["ret_dtype"] = op.ret_type.dtype | ||
|
||
|
||
def profile_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering whether we should use "profile" here. My impression to "profile" is more like an analysis rather than transformation. For example, profile_executor
sounds getting a report of detail execution latencies. Maybe "tune_cutlass_kernels" (similar to AutoTVM/Ansor) or "search_cutlass_kernels" (similar to CuDNN) might be better? Would like to hear to others as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I'll change to tune_cutlass_kernels
. profile
is what cutlass uses in their repo (e.g. cutlass_profiler
), but yeah, might be confusing to TVM users.
My only concern with tune
is, people might get a wrong impression that tune_cutlass_kernels
is an optional thing, since in Ansor / AutoTVM tuning is optional. We can add "default kernels" later to avoid having tuning as a hard requirement. We need default kernels for dynamic workload anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair concern and I agree with your plan.
Thanks @Laurawly @comaniac @junrushao1994 @zhiics @hwu36 this is merged! |
This is so cool. Thank you everyone. |
* byoc cutlass * add cmake and fix build * test worked but accuracy is bad * fixed argument printing properly * moving files * moving contents of cutlass_profiler into python/tvm/contrib/cutlass * run black * remove irrelavant codegen code * clang format * tried replacing sm 75 with 80, didn't help improve accuracy * remove irrelavant code from generator * tried dense + bias fusion but generated cu file does not compile * dense + bias worked after adding Leyuan's patch, bias + relu worked too * tried adding sm80 generator but accuracy is still off * remove GemmUniversal generator * cleanup partition and build * moved partition, profile and build function out of test * turned out the result match's TVM non-cutlass result. Numpy fp16 matmul is busted? * clean up test * LinearCombination can be reused for bias only epilogue * remove unsupported epilogues like gelu * removing deadcode * unify gemm templates for with or without beta scaling * supported gelu but accuracy is slightly off * gelu test passed with relaxed rtol * cleanup * remove unused stuff from library.py * move profiler template into its own file * removed gemm_profiler.py * move contents of compile_engine.py into gen_gemm.py * rename to profiler_template.cu to avoid CI issue * cleaning up trying to pass pylint * add missing asf header * run black * fixing many pylint issues except wildcard import * fixed wildcard warning * add missing CUTLASS.cmake file, restore gemm_profiler.py * pylint * minor fix * add license * start filling in TODO doc * rename GemmProfiler to GemmProfilerEmitter * more renaming and doc * add doc to the main compile API * refactored generator * run black * black fix * finish doc TODO * add test for 32 bit accum * fixed kernel generator to correctly handle fp32 accum * revise build-related API * add option to profile only one kernel * add option to enable parallel compilation * clean up gen_gemm * doc update * profile_cutlass_kernels -> tune_cutlass_kernels Co-authored-by: leyuan.wang <leyuan.wang@bytedance.com> Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
* byoc cutlass * add cmake and fix build * test worked but accuracy is bad * fixed argument printing properly * moving files * moving contents of cutlass_profiler into python/tvm/contrib/cutlass * run black * remove irrelavant codegen code * clang format * tried replacing sm 75 with 80, didn't help improve accuracy * remove irrelavant code from generator * tried dense + bias fusion but generated cu file does not compile * dense + bias worked after adding Leyuan's patch, bias + relu worked too * tried adding sm80 generator but accuracy is still off * remove GemmUniversal generator * cleanup partition and build * moved partition, profile and build function out of test * turned out the result match's TVM non-cutlass result. Numpy fp16 matmul is busted? * clean up test * LinearCombination can be reused for bias only epilogue * remove unsupported epilogues like gelu * removing deadcode * unify gemm templates for with or without beta scaling * supported gelu but accuracy is slightly off * gelu test passed with relaxed rtol * cleanup * remove unused stuff from library.py * move profiler template into its own file * removed gemm_profiler.py * move contents of compile_engine.py into gen_gemm.py * rename to profiler_template.cu to avoid CI issue * cleaning up trying to pass pylint * add missing asf header * run black * fixing many pylint issues except wildcard import * fixed wildcard warning * add missing CUTLASS.cmake file, restore gemm_profiler.py * pylint * minor fix * add license * start filling in TODO doc * rename GemmProfiler to GemmProfilerEmitter * more renaming and doc * add doc to the main compile API * refactored generator * run black * black fix * finish doc TODO * add test for 32 bit accum * fixed kernel generator to correctly handle fp32 accum * revise build-related API * add option to profile only one kernel * add option to enable parallel compilation * clean up gen_gemm * doc update * profile_cutlass_kernels -> tune_cutlass_kernels Co-authored-by: leyuan.wang <leyuan.wang@bytedance.com> Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
As discussed in RFC: https://discuss.tvm.apache.org/t/rfc-byoc-nvidia-cutlass-integration/9147, this PR is a CUTLASS integration of GEMM to TVM. It also includes a profiler to search for best params in CUTLASS.
@masahi Please take over this PR, Thanks!