-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SmoothQuant using tensor subclassing (#1030)
* SmoothQuant using tensor subclassing * Update UT * Add SmoothQuant example * Remove duplicate implementation of int_scaled_matmul for CPU * Update example.py * Remove unused code * Implement with LinearActivationQuantizedTensor * Fix load/save * Fix device mismatch in observer * Fix fp16 overflow issue in int_scaled_matmul * Add linear_activation_scale_quantized.py for torch.compile * Quantize act/wei to 7 bit on old CPU platforms * Fix device mismatch * Fix UT failures * Fix UT * Don't use torch._int_mm for CPU now because it may overflow * Remove reduce_range * Refine code * Remove torch.compile from example * Add torch.compile in example * Debug CI failures * Debug CI failures (1) * Debug CI failures (2) * Debug CI failures (3) * Work with torch.compile * Update torchao/kernel/intmm.py * Update readme.md * Update readme.md * Debug CI failures (4) * Reimplement with nested tensor subclassing * Test torch.compile only with PyTorch >= 2.5 * Debug CI failures (5) * Debug CI failures (6) * Debug CI failures (7) * Use MovingAvg observer for activation; Update UT and readme * Revert changes to test_spinquant.py; refine readme * Debug CI failures (8) * Debug CI failures (9) * Fix CI failure * Refactor SmoothQuantObserver * Rename readme.md -> README.md * Rename insert_smooth_quant_observer -> insert_smooth_quant_observer_ to indicate inplace * Fix device mismatch in observer * Fall back to conventional quantization if alpha is None * Update README.md to provide more benchmark data; fix CI * Fix CI failures * Add a comment in affine_quantized_tensor.py
- Loading branch information
1 parent
75d0693
commit 629aee1
Showing
8 changed files
with
797 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from copy import deepcopy | ||
import pytest | ||
import torch | ||
import tempfile | ||
from torchao.quantization import quantize_ | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_2, | ||
TORCH_VERSION_AT_LEAST_2_4, | ||
TORCH_VERSION_AT_LEAST_2_5, | ||
) | ||
from torchao.quantization.utils import ( | ||
dynamically_quantize_per_channel, | ||
dequantize_per_channel, | ||
) | ||
from torchao.prototype.smoothquant import ( | ||
insert_smooth_quant_observer_, | ||
smooth_quant, | ||
SmoothQuantObservedLinear, | ||
save_smooth_quant_recipe, | ||
load_smooth_quant_recipe | ||
) | ||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=512, n=256, k=128): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(m, n, bias=False) | ||
self.linear2 = torch.nn.Linear(n, k, bias=False) | ||
self.linear3 = torch.nn.Linear(k, 1, bias=False) | ||
|
||
def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"): | ||
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)] | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
x = self.linear3(x) | ||
return x | ||
|
||
|
||
bias_list = [True, False] | ||
alpha_list = [None, 0.5, 0.75] | ||
quant_mode_list = ["static", "dynamic"] | ||
devices = ["cpu"] | ||
if torch.cuda.is_available(): | ||
devices.append("cuda") | ||
idtypes = (torch.float, torch.bfloat16, torch.half) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# This test case will trigger recompilation many times, so set a large cache_size_limit here | ||
torch._dynamo.config.cache_size_limit = 128 | ||
|
||
@pytest.mark.parametrize("bias", bias_list) | ||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
def test_compute(bias, alpha, quant_mode, device, idtype): | ||
class Linear(torch.nn.Module): | ||
def __init__(self, bias: bool): | ||
super().__init__() | ||
self.fc = torch.nn.Linear(32, 32, bias) | ||
self.fc.weight.data = torch.randn_like(self.fc.weight.data) | ||
|
||
def forward(self, x): | ||
return self.fc(x) | ||
|
||
m = Linear(bias).eval().to(idtype).to(device) | ||
m_ref = deepcopy(m) | ||
data = torch.randn(2, 32, dtype=idtype, device=device) | ||
|
||
# calibrate | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
m(data) | ||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, smooth_quant(), is_observed_linear) | ||
with torch.inference_mode(): | ||
if TORCH_VERSION_AT_LEAST_2_5: | ||
m = torch.compile(m, fullgraph=True) | ||
out = m(data) | ||
|
||
# reference | ||
weight = m_ref.fc.weight.data.float() | ||
b = m_ref.fc.bias if bias else None | ||
x_abs_max_per_ic = torch.abs(data).max(dim=0).values | ||
w_abs_max_per_ic = torch.abs(weight).max(dim=0).values | ||
smoothing_factor = 1 if alpha is None else ( | ||
torch.pow(x_abs_max_per_ic, alpha) / torch.pow( | ||
w_abs_max_per_ic, 1 - alpha) | ||
) | ||
act = data / smoothing_factor | ||
wei = weight * smoothing_factor | ||
qw, w_scales, w_zps = dynamically_quantize_per_channel( | ||
wei, -127, 127, torch.int8 | ||
) | ||
fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) | ||
if quant_mode == "static": | ||
# activation is quantized per-tensor | ||
act_min, act_max = torch.aminmax(act.float()) | ||
max_val_pos = torch.max(-act_min, act_max) | ||
act_scale = max_val_pos / 127.0 | ||
fq_act = torch.quantize_per_tensor( | ||
act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 | ||
).dequantize().to(idtype) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
else: | ||
# activation is quantized per-row (batch * sequence_length) | ||
qx, x_scales, x_zps = dynamically_quantize_per_channel( | ||
act.float(), -127, 127, torch.int8 | ||
) | ||
fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) | ||
out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) | ||
|
||
# BFloat16 and Float16 have larger errors | ||
atol = 0.1 if idtype == torch.float else ( | ||
0.2 if idtype == torch.half else 0.3 | ||
) | ||
assert torch.allclose(out, out_ref.to(idtype), atol=atol) | ||
|
||
|
||
@pytest.mark.parametrize("alpha", alpha_list) | ||
@pytest.mark.parametrize("quant_mode", quant_mode_list) | ||
@pytest.mark.parametrize("device", devices) | ||
@pytest.mark.parametrize("idtype", idtypes) | ||
def test_save_load_recipe(alpha, quant_mode, device, idtype): | ||
dataset_size = 20 | ||
l1, l2, l3 = 512, 256, 128 | ||
original_dtype = idtype | ||
n_calib_examples = 10 | ||
sequence_length = 5 | ||
|
||
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) | ||
m_save_load = deepcopy(m) | ||
|
||
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device) | ||
calibration_data = dataset[:n_calib_examples] | ||
|
||
# calibrate | ||
insert_smooth_quant_observer_(m, alpha, quant_mode) | ||
insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) | ||
|
||
for example in calibration_data: | ||
m(example.to(device)) | ||
m_save_load(example.to(device)) | ||
|
||
with tempfile.NamedTemporaryFile() as fp: | ||
save_path = fp.name | ||
save_smooth_quant_recipe(m_save_load, save_path) | ||
load_smooth_quant_recipe(m_save_load, save_path) | ||
|
||
# quantize | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
quantize_(m, smooth_quant(), is_observed_linear) | ||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# earlier versions are not compatible | ||
m = torch.compile(m, fullgraph=True) | ||
m_save_load = torch.compile(m_save_load, fullgraph=True) | ||
out_list = [m(data.squeeze(0)) for data in dataset] | ||
out = torch.cat(out_list) | ||
save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] | ||
save_load_out = torch.cat(save_load_out_list) | ||
|
||
assert out is not None | ||
assert save_load_out is not None | ||
assert torch.allclose(out, save_load_out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# SmothQuant quantization | ||
This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438). | ||
|
||
In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized. | ||
|
||
## Quick start | ||
Run the example code with | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> | ||
# An example | ||
python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic | ||
``` | ||
To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance. | ||
```bash | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --compile | ||
``` | ||
To save a quantized model for reuse, specify `--model-save-path` | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-save-path ./quantized_model.pt | ||
``` | ||
And load it by `--model-load-path` | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-load-path ./quantized_model.pt | ||
``` | ||
|
||
|
||
## Usage of API | ||
The following APIs are provided: | ||
- insert_smooth_quant_observer_ | ||
- smooth_quant | ||
- save_smooth_quant_recipe (advanced) | ||
- load_smooth_quant_recipe (advanced) | ||
|
||
`insert_smooth_quant_observer_` inserts observers into the model to be quantized. For example: | ||
```python | ||
insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") | ||
``` | ||
After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe. | ||
|
||
`smooth_quant` applies SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example: | ||
```python | ||
from torchao.prototype.smoothquant import SmoothQuantObservedLinear | ||
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) | ||
torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear) | ||
``` | ||
`is_observed_linear` is a filter so that we only quantize observed linear layers. | ||
|
||
(Advanced) `save_smooth_quant_recipe` and `load_smooth_quant_recipe` saves or loads a recipe for a model. | ||
|
||
A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layers. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed. | ||
|
||
To save a recipe, users should insert observers and run calibration first. For example, | ||
```python | ||
insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") | ||
for data in dataset_for_calibration: | ||
model(data) | ||
save_smooth_quant_recipe(model, "./smooth_quant_recipe.json") | ||
``` | ||
To load a recipe, users should insert observers first. For example, | ||
```python | ||
insert_smooth_quant_observer_(model) | ||
load_smooth_quant_recipe(model, "./smooth_quant_recipe.json") | ||
``` | ||
|
||
## Benchmark | ||
Running the example with `torch.compile` on a NVIDIA A10G GPU. | ||
### meta-llama/Llama-2-7b-hf | ||
Perplexity | ||
| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | | ||
|-|-|-|-|-| | ||
| Dynamic | 8.1872 | 7.4257 | 7.2518 | 7.5509 | | ||
| Static | 43.8051 | 11.2984 | 7.5791 | 19.5050 | | ||
|
||
Note*: Conventional quantization without SmoothQuant | ||
|
||
### meta-llama/Meta-Llama-3-8B | ||
Perplexity | ||
| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | | ||
|-|-|-|-|-| | ||
| Dynamic | 21.2475 | 8.8288 | 9.6514 | 8.3574 | | ||
| Static | 301.7118 | 18.0617 | 10.8343 | 278.9819 | | ||
|
||
Note*: Conventional quantization without SmoothQuant | ||
|
||
### Test method | ||
**Commands** | ||
```bash | ||
# dynamic quant | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m <model_id> --device=cuda --quant-mode=dynamic --compile | ||
# static quant | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m <model_id> --device=cuda --quant-mode=static --compile | ||
``` | ||
Use `--alpha` to specify the alpha parameter. Add `--disable-smooth-quant` to run quantization without SmoothQuant. | ||
|
||
**Environment** | ||
- AWS g5.12xlarge instance | ||
- torch==2.6.0.dev20241017+cu124 | ||
- python==3.12.6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .api import ( | ||
insert_smooth_quant_observer_, | ||
smooth_quant, | ||
save_smooth_quant_recipe, | ||
load_smooth_quant_recipe, | ||
) | ||
from .core import SmoothQuantObservedLinear |
Oops, something went wrong.