Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP16Act-FP6Weight Linear #223

Merged
merged 44 commits into from
May 14, 2024
Merged

Add FP16Act-FP6Weight Linear #223

merged 44 commits into from
May 14, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented May 7, 2024

Closes #208

References:

TODO:

benchmarks/benchmark_fp6.py results - 4070 Ti SUPER, PyTorch 2.3, CUDA 12.1

m k n fp6_latency (ms) fp16_latency (ms) speedup (d/s) correct
1 10240 8192 0.10459 0.272524 2.60565 1
1 8192 8192 0.0259702 0.218884 8.42828 1
1 57344 8192 0.590541 1.49265 2.5276 1
1 8192 28672 0.286375 0.773923 2.70248 1
2 10240 8192 0.105072 0.272984 2.59807 1
2 8192 8192 0.0283768 0.221068 7.79045 1
2 57344 8192 0.588764 1.50597 2.55785 1
2 8192 28672 0.289398 0.749859 2.5911 1
4 10240 8192 0.106225 0.273336 2.57318 1
4 8192 8192 0.0352987 0.221425 6.27289 1
4 57344 8192 0.590647 1.51019 2.55685 1
4 8192 28672 0.294598 0.752468 2.55422 1
8 10240 8192 0.108011 0.27399 2.53669 1
8 8192 8192 0.0501209 0.22264 4.44206 1
8 57344 8192 0.588333 1.51822 2.58054 1
8 8192 28672 0.303168 0.759707 2.50589 1
16 10240 8192 0.112524 0.298142 2.64959 1
16 8192 8192 0.0690163 0.222848 3.22893 1
16 57344 8192 0.624313 1.53145 2.45302 1
16 8192 28672 0.319848 0.762582 2.38421 1
64 10240 8192 0.175586 0.287482 1.63727 1
64 8192 8192 0.126386 0.248988 1.97006 1
64 57344 8192 0.879407 1.58489 1.80222 1
64 8192 28672 0.482216 0.794717 1.64805 1
128 10240 8192 0.295514 0.305409 1.03349 1
128 8192 8192 0.226178 0.243452 1.07638 1
128 57344 8192 1.48975 1.64592 1.10483 1
128 8192 28672 0.850137 0.897592 1.05582 1
256 10240 8192 0.579169 0.528245 0.912073 1
256 8192 8192 0.454445 0.394432 0.867942 1
256 57344 8192 2.93721 2.77987 0.946432 1
256 8192 28672 1.66438 1.41421 0.84969 1
512 10240 8192 1.10103 0.984327 0.894003 1
512 8192 8192 1.03181 0.770269 0.746521 1
512 57344 8192 5.86696 5.47663 0.933469 1
512 8192 28672 3.40694 2.76967 0.812948 1
1024 10240 8192 2.08286 1.95155 0.936956 1
1024 8192 8192 1.81341 1.5645 0.862739 1
1024 57344 8192 11.6214 10.8689 0.935248 1
1024 8192 28672 6.27002 5.5432 0.88408 1
2048 10240 8192 4.17314 3.91984 0.939303 1
2048 8192 8192 3.34931 3.15769 0.942786 1
2048 57344 8192 23.4409 21.4201 0.913792 1
2048 8192 28672 11.4675 10.7142 0.934307 1
4096 10240 8192 8.37251 7.69253 0.918785 1
4096 8192 8192 6.71261 6.15112 0.916353 1
4096 57344 8192 46.804 42.3869 0.905626 1
4096 8192 28672 23.3502 21.1444 0.905533 1

benchmarks/benchmark_fp6.py results - 4090. Courtesy to @Iron-Bound

m k n fp6_latency (ms) fp16_latency (ms) speedup (d/s) correct
1 10240 8192 0.0249717 0.177861 7.12252 1
1 8192 8192 0.0185284 0.142879 7.71136 1
1 57344 8192 0.374375 0.989105 2.64202 1
1 8192 28672 0.192924 0.492164 2.55107 1
2 10240 8192 0.0251179 0.178913 7.12294 1
2 8192 8192 0.0185992 0.149051 8.01384 1
2 57344 8192 0.375186 0.983217 2.62061 1
2 8192 28672 0.194807 0.508954 2.61261 1
4 10240 8192 0.0251999 0.179157 7.10943 1
4 8192 8192 0.0187983 0.149632 7.95988 1
4 57344 8192 0.376361 0.983849 2.61411 1
4 8192 28672 0.197904 0.510254 2.57829 1
8 10240 8192 0.0257653 0.179633 6.97189 1
8 8192 8192 0.0195129 0.149917 7.683 1
8 57344 8192 0.378805 0.984245 2.59829 1
8 8192 28672 0.202335 0.513004 2.53542 1
16 10240 8192 0.0363614 0.180923 4.97567 1
16 8192 8192 0.0264505 0.150544 5.69153 1
16 57344 8192 0.383548 0.985312 2.56894 1
16 8192 28672 0.212429 0.515403 2.42624 1
64 10240 8192 0.121198 0.208701 1.72198 1
64 8192 8192 0.0838365 0.17458 2.08239 1
64 57344 8192 0.469056 1.07661 2.29527 1
64 8192 28672 0.308592 0.561562 1.81975 1
128 10240 8192 0.183175 0.207593 1.13331 1
128 8192 8192 0.157488 0.170052 1.07978 1
128 57344 8192 1.03939 1.11073 1.06863 1
128 8192 28672 0.490236 0.56882 1.1603 1
256 10240 8192 0.302165 0.272471 0.901729 1
256 8192 8192 0.240262 0.226496 0.942705 1
256 57344 8192 1.55872 1.4771 0.947637 1
256 8192 28672 0.961353 0.819177 0.852109 1
512 10240 8192 0.578808 0.525709 0.908262 1
512 8192 8192 0.561836 0.420924 0.749193 1
512 57344 8192 3.11518 2.83765 0.910911 1
512 8192 28672 1.91902 1.53363 0.799173 1
1024 10240 8192 1.11179 1.04132 0.936611 1
1024 8192 8192 0.966707 0.835691 0.864472 1
1024 57344 8192 6.17944 5.68417 0.919852 1
1024 8192 28672 3.48025 2.82627 0.812087 1
2048 10240 8192 2.22177 2.08143 0.936833 1
2048 8192 8192 1.78033 1.68593 0.946973 1
2048 57344 8192 12.5079 11.5538 0.923727 1
2048 8192 28672 6.2242 5.6579 0.909016 1
4096 10240 8192 4.46022 4.02399 0.902194 1
4096 8192 8192 3.56057 3.22559 0.905919 1
4096 57344 8192 24.7088 22.9535 0.928962 1
4096 8192 28672 12.4933 11.5338 0.9232 1

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 7, 2024
@msaroufim msaroufim self-requested a review May 7, 2024 17:18
Copy link

pytorch-bot bot commented May 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/223

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit a8b4dd3 with merge base ad12663 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Iron-Bound
Copy link

Hey, I was looking at the fp6 code as well,
got stuck on TORCH_LIBRARY_IMPL so good stuff solving that 👍🏼

I was going to bring up was the code has a few utility/kernel files which is be reusable for implementing the next quant type.
What do people think about have a folder for this now or later?

@msaroufim
Copy link
Member

I'd opt for generalizing things in a future PR but will @gau-nernst decide what makes sense for them. @Iron-Bound which future work were you hoping to build on top?

@Iron-Bound
Copy link

Iron-Bound commented May 8, 2024

@msaroufim Could hack on CFloat8_1_4_3 and CFloat8_1_5_2 if people think its valuable?

@msaroufim
Copy link
Member

I haven't fllowed our float8 work closely but have you gotten the chance to take a look at https://github.com/pytorch-labs/float8_experimental

Granted I would like an API that looks like to(torch.float6/8) and that's one of the benefits of using tensor subclasses in this repo

@gau-nernst
Copy link
Collaborator Author

I was going to bring up was the code has a few utility/kernel files which is be reusable for implementing the next quant type. What do people think about have a folder for this now or later?

I will leave it for a future PR to refactor. I don't understand much of the parts that involved in the kernel, so I won't be touching them and leave them as is.

Regarding float dtype. The actual FP6 used in FP6_LLM is E3M2, without nan/inf. Two pointers

  1. It will be good to signal E3M2 somehow in the code, since obviously FP6 is non-standard. Also, do we also need to signal whether nan/inf are represented?
  2. FP6_LLM re-arrange the fp6 weight layout to optimize data access (see weight_matrix_prepacking()). This "non-standard" layout may make generalized float tensor subclass difficult, since how should the users know what is the underlying layout? (perhaps we need to keep track of them somehow?)

Also, another interesting thing to work on is to replicate qtorch.quant.float_quantize() from https://github.com/Tiiiger/QPyTorch.

fp6_test.py Outdated
@@ -0,0 +1,98 @@
# from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the relevant files to either benchmark or test folder

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so I think to merge this what we can do

  1. Move the relevant benchmark and test files to either the benchmark or test repo
  2. In CI either do the numerics check on an op level or at a macro eval level (ideally the first for now)
  3. We can worry about the subclass and torch.compile stuff in a future PR
  4. Make sure the acknowledgements to original repos are crystal clear everywhere
  5. Make the speedup clear in the PR description over the fp16/bf16 baselines

@gau-nernst gau-nernst marked this pull request as ready for review May 10, 2024 16:11
@msaroufim
Copy link
Member

H100 benchmarks

Screenshot 2024-05-10 at 9 42 27 AM

@msaroufim
Copy link
Member

Ok I think we're ready to merge this, last thing is add limitations in README for small batch sizes here usyd-fsalab/fp6_llm#8 and explain that this should be used to speed up autoregressive decoding

And for the next PR let's start to do evals with an end to end model, I'm hoping we can leverage this PR for that #189

@msaroufim msaroufim self-requested a review May 14, 2024 14:39
@msaroufim msaroufim merged commit 7734f79 into pytorch:main May 14, 2024
12 of 13 checks passed
@gau-nernst gau-nernst deleted the fp6 branch May 14, 2024 16:08
lancerts pushed a commit to lancerts/ao that referenced this pull request May 17, 2024
* add files from fp6_llm

* try to port weight packing first

* rename

* rename fp6 weight packing

* add fp16act_fp6weight_linear

* fix function def

* delete duplicate file

* move weight quant file

* rename

* add pytorch interface for fp6 weight dequant

* add fake_fp6 to fp6

* move weight_quant to csrc/cuda due to cuda_fp16.h dependency

* add fake_fp6_to_fp6 test

* add test for fp16act_fp6weight_linear

* add test for fp6_weight_dequant

* Fp6WeightOnlyQuantizedLinearWeight (not working yet)

* skip some tests, since the functions are not built w/o CUDA

* add the original test

* implement transpose and clone so that F.linear will work

* remove print

* remove dequantize

* add notes and some rename

* typo

* small cleanup

* improve tensor subclass and add test (which is failing for torch-compile)

* add note

* add note

* add qtorch as dev requirement

* update error message

* add __repr__ and fix transposed issue

* add fp6 perplexity test

* rename variables

* remove subclass

* add correctness test

* remove unwanted changes

* add apache 2.0 notice

* add benchmark script

* add note about FP6 kernel

* relax tolerance

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
@gau-nernst gau-nernst mentioned this pull request May 21, 2024
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* add files from fp6_llm

* try to port weight packing first

* rename

* rename fp6 weight packing

* add fp16act_fp6weight_linear

* fix function def

* delete duplicate file

* move weight quant file

* rename

* add pytorch interface for fp6 weight dequant

* add fake_fp6 to fp6

* move weight_quant to csrc/cuda due to cuda_fp16.h dependency

* add fake_fp6_to_fp6 test

* add test for fp16act_fp6weight_linear

* add test for fp6_weight_dequant

* Fp6WeightOnlyQuantizedLinearWeight (not working yet)

* skip some tests, since the functions are not built w/o CUDA

* add the original test

* implement transpose and clone so that F.linear will work

* remove print

* remove dequantize

* add notes and some rename

* typo

* small cleanup

* improve tensor subclass and add test (which is failing for torch-compile)

* add note

* add note

* add qtorch as dev requirement

* update error message

* add __repr__ and fix transposed issue

* add fp6 perplexity test

* rename variables

* remove subclass

* add correctness test

* remove unwanted changes

* add apache 2.0 notice

* add benchmark script

* add note about FP6 kernel

* relax tolerance

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FP6 dtype!
4 participants