Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support W4A8 Marlin kernel #1113

Merged
merged 1 commit into from
Nov 14, 2024
Merged

support W4A8 Marlin kernel #1113

merged 1 commit into from
Nov 14, 2024

Conversation

HandH1998
Copy link
Contributor

Summary

We inroduce a mixed precision GEMM kernel for INT4-Weight and INT8-Activation. We implemented the W4A8 GEMM based on Marlin GEMM. The kernel is designed to support our W4A8 quantization method QQQ. For more details on the kernel implementation, you can refer to our paper. The kernel demonstrates excellent performance and has been merged into the official vLLM project (see vllm-project/vllm#5218).

We hope the w4a8 GEMM can also provide a practical speedup for other W4A8 quantization methods in the community.
Additionally, since torchao is widely used in frameworks like SGLang, we can extend support for W4A8 once the kernel is integrated into torchao.

Performance

Here is the speedup over PyTorch FP16 GEMM (Calling CUTLASS) of all GEMMs under different numbers of input tokens. The weight matrix size is (N=8192, K=21760). You can reproduce the benchmark results using bench_w4a8.py in my repo.
gemm_performance

Copy link

pytorch-bot bot commented Oct 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1113

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 2690ff4 with merge base 39f16f4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 18, 2024
@drisspg
Copy link
Contributor

drisspg commented Oct 18, 2024

can we do some comparisons between this and #880?

@jerryzh168 jerryzh168 requested a review from msaroufim October 21, 2024 18:34
@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 21, 2024

thanks, looks pretty good, can you add the benchmark code in https://github.com/pytorch/ao/tree/main/benchmarks as well

and what GPU are the kernels benchmarked on?

@HandH1998
Copy link
Contributor Author

@drisspg @jerryzh168 I am working on the benckmark and I will give the comparisons with #880. The kernel is benchmarked on A100-80G GPU and it can work on SM > 8.0.

@HandH1998
Copy link
Contributor Author

HandH1998 commented Oct 25, 2024

@jerryzh168 @drisspg @msaroufim
I have made the following modifications (code modifications refer to #621 and #880):

  1. Added benchmark code for marlin_qqq_w4a8 GEMM inbenchmarks/benchmark_marlin_qqq.py
  2. Summarized the main differences between marlin_qqq_w4a8 GEMM and marlin_w4a16 GEMM intorchao/quantization/marlin_qqq/README.md
  3. Supported marlin_qqq in torchao/quantization/quant_api.py
  4. Added some unit tests in test/quantization/test_marlin_qqq.py

w4a8-cutlass is a great work. In comparison, we believe marlin_qqq_w4a8 can support weight per-group quantization in addition to weight per-channel quantization. However, marlin_qqq_w4a8 does have some limitations: it only supports symmetric quantization and the output dtype can only be torch.float16.

In addition, we also provide the performance of torchao/_models/llama/generate.py here. -g128 means weight per-group quantization and the group size is 128.

-q parameter Precison Average tokens/sec Average Bandwidth in GB/s Peak Memory Usage in GB Model Size in GB
--compile fp16 112.45 1486.00 13.93 13.21
-q marlin_qqq --compile w4a8 197.45 653.50 4.79 3.31
-q marlin_qqq --compile w4a8-g128 187.62 640.32 4.82 3.41

@HandH1998
Copy link
Contributor Author

@jerryzh168 @msaroufim @drisspg I have resolved the conficts. Look forward to see your new advice.

@drisspg
Copy link
Contributor

drisspg commented Nov 5, 2024

Hey @HandH1998 rekicking off the internal CI/CD here that failed

Overall this works looks very good, will do a more thorough review

@drisspg
Copy link
Contributor

drisspg commented Nov 5, 2024

It seems that you also applied formatting to a lot of files. This makes it pretty hard to review since all the changes get mixed together.

I opened #1226

Would you mind opening a separate PR for some of the other files you touched? And let's first add the formatting and then we can merge in your changes and make it easier to review.

@HandH1998
Copy link
Contributor Author

@drisspg I have removed the extra formatting in this PR, which should now simplify the review process.

@alexsamardzic
Copy link
Collaborator

I've added a benchmarking script to #880, that makes it possible to compare the performance between the two W4A8 kernels. As CUTLASS-based version doesn't support group quantization, at the moment it is only possible to make the comparison with group_size=-1 in the Marlin-based version. Marlin-based version performs clearly better for input sizes less than 256, while CUTLASS-based version is faster for input sizes of 256 and higher. Consequently, Marlin-based version performs better on Llama generator too, with tokens/sec about 25% higher than CUTLASS-based version (note that in both cases I ran the generator as follows python generate.py --compile --precision torch.float16 -q ...). Please note that the comparison is not completely apples-to-apples, as besides group quantization support there are other small differences between kernels, but still it seems this is pretty much current state of affairs regarding the performance.

Let me return the compliment by stating that this Marlin-based kernel is a great work too. In particular, for me it clearly shows where the CUTLASS-based kernel should be improved.

@drisspg
Copy link
Contributor

drisspg commented Nov 7, 2024

Overall looking really good, would you mind reporting lib size increase from this PR? I plan to take another once over of the cuda code tomorrow and then once CI is good this should be good to go :)

@HandH1998
Copy link
Contributor Author

HandH1998 commented Nov 7, 2024

@drisspg @jerryzh168 Thanks for your reviews. I have resolved the most issues according to your advice. The increased lib size is about 5M. Let's go forward :)

@drisspg drisspg added inference enhancement New feature or request labels Nov 7, 2024
@drisspg
Copy link
Contributor

drisspg commented Nov 7, 2024

Lib increases from 3.5M -> 5.0M for provenance, I think this is acceptable

@drisspg
Copy link
Contributor

drisspg commented Nov 8, 2024

@HandH1998 looks like there is 1 true failure on your PR would you mind fixing and then we can land?

@HandH1998
Copy link
Contributor Author

@HandH1998 looks like there is 1 true failure on your PR would you mind fixing and then we can land?

I will try to fix it soon.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Nov 13, 2024

can you add a table similar to https://github.com/pytorch/ao/tree/main/torchao/quantization#sparse-marlin to README to show the performance? otherwise looks good to me

@msaroufim msaroufim added the topic: new feature Use this tag if this PR adds a new feature label Nov 13, 2024
@msaroufim msaroufim self-requested a review November 13, 2024 02:13
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just need to fix lint

quantize_(model, int4_weight_only(group_size=groupsize))
if "marlin" in quantization:
# NOTE(HandH1998): `marlin_qqq` should be put before `marlin` to avoid going to the wrong branch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: this is the real code and it seems reasonable - rest is linting changes

@HandH1998
Copy link
Contributor Author

can you add a table similar to https://github.com/pytorch/ao/tree/main/torchao/quantization#sparse-marlin to README to show the performance? otherwise looks good to me

I have added it.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@jerryzh168
Copy link
Contributor

@msaroufim msaroufim merged commit 06e69f6 into pytorch:main Nov 14, 2024
18 checks passed
@zhyncs
Copy link

zhyncs commented Nov 14, 2024

@HandH1998 It's so coooooool!

sunjiweiswift pushed a commit to sunjiweiswift/ao that referenced this pull request Nov 25, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request inference topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants