v0.4.0
v0.4.0
Highlights
We are excited to announce the 0.4 release of torchao! This release adds support for KV cache quantization, quantization aware training (QAT), low bit optimizer support, composing quantization and sparsity, and more!
KV cache quantization (#532)
We've added support for KV cache quantization, showing a peak memory reduction from 19.7 -> 19.2 GB on Llama3-8B at an 8192 context length. We plan to investigate Llama3.1 next.
Quantization-Aware Training (QAT) (#383, #555)
We now support two QAT schemes for linear layers: Int8 per token dynamic activations + int4 per group weights, and int4 per group weights (using the efficient tinygemm int4 kernel after training). Users can access this feature by transforming their models before and after training using the appropriate quantizer, for example:
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)
# Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)
Initial evaluation results indicate that QAT in torchao can recover up to 96% of quantized accuracy degradation on hellaswag and up to 68% of quantized perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the README and this blog post.
Composing quantization and sparsity (#457, #473)
We've added support for composing int8 dynamic quantization with 2:4 sparsity, using the quantize_
API. We also added SAM benchmarks that show a 7% speedup over standalone sparsity / int8 dynamic quantization here.
from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
Community Contributions
low-bit optimizer support (#478, #463, #482, #484, #538)
@gau-nernst added implementations for 4-bit, 8-bit, and FP8 Adam with FSDP2/FSDP support. Our API is a drop-in replacement for torch.optim.Adam
and can be used as follows:
from torchao.prototype.low_bit_optim import Adam8bit, Adam4bit, AdamFp8
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
model = ...
optim = Adam8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
For more information about low bit optimizer support please refer to our README.
Improvements to 4-bit quantization (#517, #552, #544, #479 )
@bdhirsh @jeromeku @yanbing-j @manuelcandales @larryliu0820 added torch.compile support for NF4 Tensor, custom CUDA int4 tinygemm unpacking ops, and several bugfixes to torchao
BC breaking
quantize
has been renamed toquantize_
#467
# for torchao 0.4
from torchao.quantization import quantize_, int8_weight_only
quantize_(model, int8_weight_only())
# for torchao 0.3
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())
apply_sparse_semi_structured
has been deprecated in favor ofsparsify_
which matches thequantize_
API #473
# for torchao 0.4
from torchao.sparsity import _sparsify, semi_sparse_weight
sparsify_(model, semi_sparse_weight())
# for torchao 0.3
from torchao.sparsity import apply_sparse_semi_structured
apply_sparse_semi_structured(model)
Deprecations
New Features
- Added kv_cache quantization #532
- Migrated float8_experimental to
torchao.float8
, enabling float8 training support #551 #529 - Added FP5 E2M2 #399
- Added 4-bit, 8-bit, and FP8 ADAM support #478 #463 #482
- Added FSDP2 support for low-bit optimizers #484
- [prototype] mixed-precision quantization and eval framework #531
- Added int4 weight-only QAT support #555, #383
- Added custom CUDA
tinygemm
unpacking ops #415
Improvements
- Composing quantization and sparsity now uses the unified AQT Layout #498
- Added default inductor config settings #423
- Better dtype and device handling for
Int8DynActInt4WeightQuantizer
andInt4WeightOnlyQuantizer
#475 #479 - Enable
model.to
for int4/int8 weight only quantized models #486 #522 - Added more logging to
TensorCoreTiledAQTLayout
#520 - Added general
fake_quantize_affine op
with mask support #492 #500 - QAT now uses the shared
fake_quantize_affine
primitive #527 - Improve FSDP support for low-bit optimizers #538
- Custom op and inductor decomp registration now uses a decorator #434
- Updated torch version to no longer require
unwrap_tensor_subclass
#595
Bug fixes
- Fixed import for
TORCH_VERSION_AFTER_*
#433 - Fixed crash when PYTORCH_VERSION is not defined #455
- Added
torch.compile
support forNF4Tensor
#544 - Added fbcode check to fix torchtune in Genie #480
- Fixed
int4pack_mm
error #517 - Fixed cuda device check #536
- Weight shuffling now runs on CPU for int4 quantization due to a MPS memory issue #552
- Scale and input now are the same dtype for int8 weight only quantization #534
- Fixed FP6-LLM API #595
Performance
- Added
segment-anything-fast
benchmarks for composed quantization + sparsity #457 - Updated low-bit Adam benchmark #481
Docs
- Updated README.md #583 #438 #445 #460
- Updated installation instructions #447 #459
- Added more docs for int4_weight_only API #469
- Added developer guide notebook #588
- Added optimized model serialization/deserialization doc #524 #525
- Added new float8 feature tracker #557
- Added static quantization tutorial for calibration-based techniques #487
Devs
- Fix numpy version in CI #537
- trymerge now uploads merge records to s3 #448
- Updated python version to 3.9 #488
torchao
no long depends ontorch
#449benchmark_model
now accepts args and kwargs and supportscpu
andmps
backends #586 #406- Add git version suffix to package name #547
- Added validations to torchao #453 #454
- Parallel test support with pytest-xdist #518
Quantizer
now useslogging
instead ofprint
#472
Not user facing
- Refactored
_replace_linear_8da4w
#451 - Remove unused code from AQT implementation #476 #440 #441 #471
- Improved error message for lm_eval script #444
- Updated HF_TOKEN env variable #427
- Fixed typo in Quant-LLM in #450
- Add a test for map_location="cpu" in #497
- Removed sparse test collection warning #489
- Refactored layout implementation #491
- Refactored
LinearActQuantizedTensor
#542
New Contributors
- @qingquansong made their first contribution in #433
- @Hanxian97 made their first contribution in #451
- @larryliu0820 made their first contribution in #472
- @SLR722 made their first contribution in #480
- @jainapurva made their first contribution in #406
- @bdhirsh made their first contribution in #544
- @yanbing-j made their first contribution in #517
- @manuelcandales made their first contribution in #552
- @Valentine233 made their first contribution in #534
Full Changelog: v0.3.1-rc1...v0.4.0-rc1
We were able to close about 60% of tasks for 0.4.0, which will now spill over into upcoming releases. We will post a list for 0.5.0 next, which we aim to release at the end of August 2024. We want to follow a monthly release cadence until further notice.