Skip to content

Commit

Permalink
SpinQuant (#983)
Browse files Browse the repository at this point in the history
* SpinQuant using R2 matrices

* Move Hadamard functions and matrices to separate file

* Add R4 rotation

* Reformat

* Do not wrap Linear layers but use nn.Sequential

Wrapping the Linear layers might mess with the quantization of the linear layers, so it's probably better to keep the linear layers the same and insert new layers alongside them

* Add test

* Fix test and do small reformat of Hadamard code

* Fuse Layernorm params into linear layers

This is done for pre-norm LLMs like LLaMa to make them scale-invariant (see footnote 3 in the paper). However, in the current implementation it seems to hurt performance when quantization is used.

* Add R1 rotation

* Add option to load pretrained R1/R2 matrices

* Move Spinquant from `torchao/quantization` to `torchao/prototype/spinquant`

* Move Hadamard matrices to a separate file

* Move test

* Minor changes

* Reformat

* Only enable R4 as default setting

Random R1 and R2 matrices are showing worse results than just using R4, so the latter seems to be a better default option (at least for now).

* Add __init__.py to spinquant folder

* Do not fail if fast_hadamard_transform is not present
  • Loading branch information
tobiasvanderwerff authored Oct 10, 2024
1 parent 76b6e36 commit 590f8fb
Show file tree
Hide file tree
Showing 8 changed files with 99,838 additions and 1 deletion.
37 changes: 37 additions & 0 deletions test/prototype/test_spinquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch
from torchao._models.llama.model import Transformer
from torchao.prototype.spinquant import apply_spinquant


def _init_model(name="7B", device="cpu", precision=torch.bfloat16):
model = Transformer.from_name(name)
model.to(device=device, dtype=precision)
return model.eval()


_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
def test_spinquant_no_quantization(device):
model = _init_model(device=device)
seq_len = 16
batch_size = 1
is_training = False
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
input_pos = None if is_training else torch.arange(seq_len).to(device)
with torch.device(device):
model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len, training=is_training)

with torch.no_grad():
out = model(input_ids, input_pos)
apply_spinquant(model)
out_spinquant = model(input_ids, input_pos)

# Output should be the same without quantization (the rotations cancel out)
# TODO: not sure if these atol/rtol are excessively large (it fails for smaller values)
torch.testing.assert_close(out, out_spinquant, atol=5e-2, rtol=1e-2)


# TODO: test GPTQ compatability?
5 changes: 4 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tokenizer import get_tokenizer
import time
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.prototype.spinquant import apply_spinquant

def run_evaluation(
checkpoint_path: Path,
Expand Down Expand Up @@ -69,6 +70,8 @@ def run_evaluation(
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

if quantization:
if "spinquant" in quantization:
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -229,7 +232,7 @@ def run_evaluation(
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>, "
"float8wo, float8dq, float8saq"
),
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
- [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers).
- [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406)

#### Roadmap

Expand Down
11 changes: 11 additions & 0 deletions torchao/prototype/spinquant/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SpinQuant

Re-implementation of SpinQuant based on the official code implementation (https://github.com/facebookresearch/SpinQuant).

## Usage

Using this implementation with CUDA requires installing the Fast Hadamard Transform CUDA package, which can be done as follows:

```shell
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
```
1 change: 1 addition & 0 deletions torchao/prototype/spinquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .spinquant import apply_spinquant
Loading

0 comments on commit 590f8fb

Please sign in to comment.