-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* SpinQuant using R2 matrices * Move Hadamard functions and matrices to separate file * Add R4 rotation * Reformat * Do not wrap Linear layers but use nn.Sequential Wrapping the Linear layers might mess with the quantization of the linear layers, so it's probably better to keep the linear layers the same and insert new layers alongside them * Add test * Fix test and do small reformat of Hadamard code * Fuse Layernorm params into linear layers This is done for pre-norm LLMs like LLaMa to make them scale-invariant (see footnote 3 in the paper). However, in the current implementation it seems to hurt performance when quantization is used. * Add R1 rotation * Add option to load pretrained R1/R2 matrices * Move Spinquant from `torchao/quantization` to `torchao/prototype/spinquant` * Move Hadamard matrices to a separate file * Move test * Minor changes * Reformat * Only enable R4 as default setting Random R1 and R2 matrices are showing worse results than just using R4, so the latter seems to be a better default option (at least for now). * Add __init__.py to spinquant folder * Do not fail if fast_hadamard_transform is not present
- Loading branch information
1 parent
76b6e36
commit 590f8fb
Showing
8 changed files
with
99,838 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
import torch | ||
from torchao._models.llama.model import Transformer | ||
from torchao.prototype.spinquant import apply_spinquant | ||
|
||
|
||
def _init_model(name="7B", device="cpu", precision=torch.bfloat16): | ||
model = Transformer.from_name(name) | ||
model.to(device=device, dtype=precision) | ||
return model.eval() | ||
|
||
|
||
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
|
||
|
||
@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) | ||
def test_spinquant_no_quantization(device): | ||
model = _init_model(device=device) | ||
seq_len = 16 | ||
batch_size = 1 | ||
is_training = False | ||
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) | ||
input_pos = None if is_training else torch.arange(seq_len).to(device) | ||
with torch.device(device): | ||
model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len, training=is_training) | ||
|
||
with torch.no_grad(): | ||
out = model(input_ids, input_pos) | ||
apply_spinquant(model) | ||
out_spinquant = model(input_ids, input_pos) | ||
|
||
# Output should be the same without quantization (the rotations cancel out) | ||
# TODO: not sure if these atol/rtol are excessively large (it fails for smaller values) | ||
torch.testing.assert_close(out, out_spinquant, atol=5e-2, rtol=1e-2) | ||
|
||
|
||
# TODO: test GPTQ compatability? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SpinQuant | ||
|
||
Re-implementation of SpinQuant based on the official code implementation (https://github.com/facebookresearch/SpinQuant). | ||
|
||
## Usage | ||
|
||
Using this implementation with CUDA requires installing the Fast Hadamard Transform CUDA package, which can be done as follows: | ||
|
||
```shell | ||
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .spinquant import apply_spinquant |
Oops, something went wrong.