-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Spin Quant in TorchAO #579
Comments
I'd love to work on this @HDCharles |
One question about this @HDCharles. The SpinQuant repo has a dependency on the CUDA fast Hadamard transform package for doing the actual Hadamard transform. Would it be acceptable to include this dependency in torch/ao? |
merging a custom cuda kernel is fine and you can check Although as a baseline I'd much prefer we see if we can match the performance using Granted I also recall @swolchok working on a fast CPU fast hadamard transform so maybe he has some ideas on where the compiler was not doing its job |
@tobiasvanderwerff thanks for the PR, we have some internal discussion on how to implement this as well, so we can work together on this |
we have other tools like auto round or the llama eval stuff that require an external library and just don't work without the installed package which i think is fine. thats probably a good place to start and if this gets a lot of usage or we want to take this out of prototype we can figure out whether to add a dependency or add the kernel in at that time. |
@jerryzh168 I'd be happy to work together on this. For now, I'll be implementing the rotation matrices one by one, and document the results in the PR. Let me know if this works for you or if you prefer a different approach. |
this sounds good, please go ahead, internally we are interestd in QAT as well, we can discuss later how this fits when QAT is introduced later |
Background:
The spin quant paper introduces a method of improving quantization by adding additional rotation matrices to the model weights that improve quantization performance.
While spin-quant is a fairly sophisticated technique, some independent pieces could be implemented to modularly to get incremental improvements on a smaller scale.
(see image)
https://imgur.com/jU60Iqs
In the above image each rotation in both the a and b parts of the figure could be independently implemented to improve quantization accuracy in the model. These rotations are
Rotations which can be fully absorbed by weight matrices of the linear ops and don’t introduce additional ops:
R2
Rotations which need a constant number of additional ops per model:
R1
Rotations which require additional ops per block:
R3, R4
While the second set and third set of rotations requires adding additional ops to the model, the R2 rotation would require only a small change to the model weights and no additional ops.
Task:
Start by implementing the R2 rotation with a random hadamard matrix (the paper indicates they perform fairly well) and demonstrate the improved quantization accuracy for int8 dynamic/weight-only and int4 weight-only quantization. Ideally we'd like to see improved eval performance in eval compared to the non spin-quant version. Code would ideally go into a new file in torchao/quantization/spin_quant.py.
Adding additional rotations (and the necessary additional ops) or a rotation optimization procedure for R2 as used in Spin Quant can follow after.
The text was updated successfully, but these errors were encountered: