-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support W4A8 Marlin kernel #1113
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1113
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 2690ff4 with merge base 39f16f4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
can we do some comparisons between this and #880? |
thanks, looks pretty good, can you add the benchmark code in https://github.com/pytorch/ao/tree/main/benchmarks as well and what GPU are the kernels benchmarked on? |
@drisspg @jerryzh168 I am working on the benckmark and I will give the comparisons with #880. The kernel is benchmarked on A100-80G GPU and it can work on SM > 8.0. |
@jerryzh168 @drisspg @msaroufim
w4a8-cutlass is a great work. In comparison, we believe marlin_qqq_w4a8 can support weight per-group quantization in addition to weight per-channel quantization. However, marlin_qqq_w4a8 does have some limitations: it only supports symmetric quantization and the output dtype can only be In addition, we also provide the performance of
|
@jerryzh168 @msaroufim @drisspg I have resolved the conficts. Look forward to see your new advice. |
Hey @HandH1998 rekicking off the internal CI/CD here that failed Overall this works looks very good, will do a more thorough review |
It seems that you also applied formatting to a lot of files. This makes it pretty hard to review since all the changes get mixed together. I opened #1226 Would you mind opening a separate PR for some of the other files you touched? And let's first add the formatting and then we can merge in your changes and make it easier to review. |
@drisspg I have removed the extra formatting in this PR, which should now simplify the review process. |
I've added a benchmarking script to #880, that makes it possible to compare the performance between the two W4A8 kernels. As CUTLASS-based version doesn't support group quantization, at the moment it is only possible to make the comparison with Let me return the compliment by stating that this Marlin-based kernel is a great work too. In particular, for me it clearly shows where the CUTLASS-based kernel should be improved. |
Overall looking really good, would you mind reporting lib size increase from this PR? I plan to take another once over of the cuda code tomorrow and then once CI is good this should be good to go :) |
@drisspg @jerryzh168 Thanks for your reviews. I have resolved the most issues according to your advice. The increased lib size is about 5M. Let's go forward :) |
Lib increases from 3.5M -> 5.0M for provenance, I think this is acceptable |
@HandH1998 looks like there is 1 true failure on your PR would you mind fixing and then we can land? |
I will try to fix it soon. |
can you add a table similar to https://github.com/pytorch/ao/tree/main/torchao/quantization#sparse-marlin to README to show the performance? otherwise looks good to me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just need to fix lint
torchao/_models/llama/generate.py
Outdated
quantize_(model, int4_weight_only(group_size=groupsize)) | ||
if "marlin" in quantization: | ||
# NOTE(HandH1998): `marlin_qqq` should be put before `marlin` to avoid going to the wrong branch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self: this is the real code and it seems reasonable - rest is linting changes
I have added it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
@HandH1998 It's so coooooool! |
support Marlin W4A8 kernel
Summary
We inroduce a mixed precision GEMM kernel for INT4-Weight and INT8-Activation. We implemented the W4A8 GEMM based on Marlin GEMM. The kernel is designed to support our W4A8 quantization method QQQ. For more details on the kernel implementation, you can refer to our paper. The kernel demonstrates excellent performance and has been merged into the official vLLM project (see vllm-project/vllm#5218).
We hope the w4a8 GEMM can also provide a practical speedup for other W4A8 quantization methods in the community.
Additionally, since torchao is widely used in frameworks like SGLang, we can extend support for W4A8 once the kernel is integrated into torchao.
Performance
Here is the speedup over PyTorch FP16 GEMM (Calling CUTLASS) of all GEMMs under different numbers of input tokens. The weight matrix size is (N=8192, K=21760). You can reproduce the benchmark results using bench_w4a8.py in my repo.