-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[reland] Move QAT out of prototype (#1152)
Move QAT out of prototype Summary: Move QAT out of prototype so we can provide stronger BC guarantees moving forward. **(Future) BC-breaking notes** Note: This commit itself doesn't break BC yet. A future PR will do that. The following is just to save this BC breaking note somewhere. Before: ``` from torchao.quantization.prototype.qat import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.prototype.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.prototype.qat.fake_quantizer import ( FakeQuantizer, ) ``` After: ``` from torchao.quantization.qat import ( ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, ) from torchao.quantization.qat.linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) ``` Test Plan: python test/quantization/test_qat.py ghstack-source-id: add9dcac61e45f3b4ddeed07c300cc78ee3fd23c Pull Request resolved: #1091
- Loading branch information
1 parent
629aee1
commit cf573bd
Showing
20 changed files
with
1,703 additions
and
1,604 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,125 +1,3 @@ | ||
# Quantization-Aware Training (QAT) | ||
|
||
Quantization-Aware Training (QAT) refers to applying fake quantization during the | ||
training or fine-tuning process, such that the final quantized model will exhibit | ||
higher accuracies and perplexities. Fake quantization refers to rounding the float | ||
values to quantized values without actually casting them to dtypes with lower | ||
bit-widths, in contrast to post-training quantization (PTQ), which does cast the | ||
quantized values to lower bit-width dtypes, e.g.: | ||
|
||
``` | ||
# PTQ: x_q is quantized and cast to int8 | ||
# scale and zero point (zp) refer to parameters used to quantize x_float | ||
# qmin and qmax refer to the range of quantized values | ||
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8) | ||
# QAT: x_fq is still in float | ||
# Fake quantize simulates the numerics of quantize + dequantize | ||
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax) | ||
x_fq = (x_fq - zp) * scale | ||
``` | ||
|
||
## API | ||
|
||
torchao currently supports two QAT schemes for linear layers: | ||
- int8 per token dynamic activations + int4 per group weights | ||
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) | ||
|
||
QAT typically involves applying a transformation to your model before and after training. | ||
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts | ||
fake quantize operations into linear layers, and (2) convert transforms the fake quantize | ||
operations to actual quantize and dequantize operations after training, thereby producing | ||
a quantized model (dequantize operations are typically fused with linear after lowering). | ||
Between these two steps, training can proceed exactly as before. | ||
|
||
![qat](images/qat_diagram.png) | ||
|
||
To use QAT in torchao, apply the prepare step using the appropriate Quantizer before | ||
training, then apply the convert step after training for inference or generation. | ||
For example, on a single GPU: | ||
|
||
```python | ||
import torch | ||
from torchtune.models.llama3 import llama3 | ||
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer | ||
|
||
# Smaller version of llama3 to fit in a single GPU | ||
model = llama3( | ||
vocab_size=4096, | ||
num_layers=16, | ||
num_heads=16, | ||
num_kv_heads=4, | ||
embed_dim=2048, | ||
max_seq_len=2048, | ||
).cuda() | ||
|
||
# Quantizer for int8 dynamic per token activations + | ||
# int4 grouped per channel weights, only for linear layers | ||
qat_quantizer = Int8DynActInt4WeightQATQuantizer() | ||
|
||
# Insert "fake quantize" operations into linear layers. | ||
# These operations simulate quantization numerics during | ||
# training without performing any dtype casting | ||
model = qat_quantizer.prepare(model) | ||
|
||
# Standard training loop | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) | ||
loss_fn = torch.nn.CrossEntropyLoss() | ||
for i in range(10): | ||
example = torch.randint(0, 4096, (2, 16)).cuda() | ||
target = torch.randn((2, 16, 4096)).cuda() | ||
output = model(example) | ||
loss = loss_fn(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
# Convert fake quantize to actual quantize operations | ||
# The quantized model has the exact same structure as the | ||
# quantized model produced in the corresponding PTQ flow | ||
# through `Int8DynActInt4WeightQuantizer` | ||
model = qat_quantizer.convert(model) | ||
|
||
# inference or generate | ||
``` | ||
|
||
Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune) | ||
and apply quantized-aware fine-tuning as follows: | ||
|
||
``` | ||
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full | ||
``` | ||
|
||
For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). | ||
|
||
|
||
## Evaluation Results | ||
|
||
Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT | ||
integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | ||
on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset) | ||
for 5000 steps using a group size of 256 for the weights. Note that extensive | ||
hyperparameter tuning may further improve these results. | ||
|
||
Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5: | ||
|
||
| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) | | ||
| ---------------- | ------ | ------ | ------ | ------ | ------ | | ||
| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 | | ||
| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 | | ||
| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 | | ||
| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 | | ||
| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 | | ||
|
||
Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the | ||
quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097). | ||
|
||
| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) | | ||
| ---------------- | -------- | ------- | ------ | ------ | ------ | | ||
| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 | | ||
| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 | | ||
| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 | | ||
| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 | | ||
| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 | | ||
|
||
For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training). | ||
Note: QAT has been moved to torchao/quantization/qat. | ||
This is a legacy folder only for backward compatibility | ||
and will be removed in the near future. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.