diff --git a/README.md b/README.md
index 71fb25fa24..ba48dbf451 100644
--- a/README.md
+++ b/README.md
@@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
```python
-from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
+from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py
index 9186977240..29f833c9ab 100644
--- a/test/quantization/test_qat.py
+++ b/test/quantization/test_qat.py
@@ -22,22 +22,22 @@
PerRow,
PerToken,
)
-from torchao.quantization.prototype.qat.api import (
+from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
)
-from torchao.quantization.prototype.qat.fake_quantizer import (
+from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
-from torchao.quantization.prototype.qat.embedding import (
+from torchao.quantization.qat.embedding import (
FakeQuantizedEmbedding,
)
-from torchao.quantization.prototype.qat.linear import (
+from torchao.quantization.qat.linear import (
FakeQuantizedLinear,
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear
)
-from torchao.quantization.prototype.qat.utils import (
+from torchao.quantization.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
@@ -181,7 +181,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
- from torchao.quantization.prototype.qat.linear import (
+ from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
@@ -213,7 +213,7 @@ def _set_ptq_weight(
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
- from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
+ from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
group_size = 128
@@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
- from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
+ from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
group_size = 16
@@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
- from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
+ from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
with torch.device("meta"):
m = M()
@@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
"""
- from torchao.quantization.prototype.qat import (
+ from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
@@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
"""
- from torchao.quantization.prototype.qat import (
+ from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
)
@@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_gradients(self):
- from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
+ from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
self._test_qat_quantized_gradients(quantizer)
@@ -518,7 +518,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
- from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
+ from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
group_size = 128
@@ -545,14 +545,14 @@ def test_qat_4w_linear(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_gradients(self):
- from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
+ from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
- from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
+ from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
group_size = 32
@@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_embedding(self):
- from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
+ from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
model = M2()
x = model.example_inputs()
out = model(*x)
@@ -937,6 +937,59 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
+ @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
+ def test_qat_prototype_bc(self):
+ """
+ Just to make sure we can import all the old prototype paths.
+ We will remove this test in the near future when we actually break BC.
+ """
+ from torchao.quantization.prototype.qat import (
+ disable_4w_fake_quant,
+ disable_8da4w_fake_quant,
+ enable_4w_fake_quant,
+ enable_8da4w_fake_quant,
+ ComposableQATQuantizer,
+ Int8DynActInt4WeightQATLinear,
+ Int4WeightOnlyEmbeddingQATQuantizer,
+ Int4WeightOnlyQATQuantizer,
+ Int8DynActInt4WeightQATQuantizer,
+ )
+ from torchao.quantization.prototype.qat._module_swap_api import (
+ disable_4w_fake_quant_module_swap,
+ enable_4w_fake_quant_module_swap,
+ disable_8da4w_fake_quant_module_swap,
+ enable_8da4w_fake_quant_module_swap,
+ Int4WeightOnlyQATQuantizerModuleSwap,
+ Int8DynActInt4WeightQATQuantizerModuleSwap,
+ )
+ from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
+ AffineFakeQuantizedTensor,
+ to_affine_fake_quantized,
+ )
+ from torchao.quantization.prototype.qat.api import (
+ ComposableQATQuantizer,
+ FakeQuantizeConfig,
+ )
+ from torchao.quantization.prototype.qat.embedding import (
+ FakeQuantizedEmbedding,
+ Int4WeightOnlyEmbeddingQATQuantizer,
+ Int4WeightOnlyEmbedding,
+ Int4WeightOnlyQATEmbedding,
+ )
+ from torchao.quantization.prototype.qat.fake_quantizer import (
+ FakeQuantizer,
+ )
+ from torchao.quantization.prototype.qat.linear import (
+ disable_4w_fake_quant,
+ disable_8da4w_fake_quant,
+ enable_4w_fake_quant,
+ enable_8da4w_fake_quant,
+ FakeQuantizedLinear,
+ Int4WeightOnlyQATLinear,
+ Int4WeightOnlyQATQuantizer,
+ Int8DynActInt4WeightQATLinear,
+ Int8DynActInt4WeightQATQuantizer,
+ )
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
diff --git a/torchao/quantization/prototype/qat/README.md b/torchao/quantization/prototype/qat/README.md
index 2869322297..dbce4d48e1 100644
--- a/torchao/quantization/prototype/qat/README.md
+++ b/torchao/quantization/prototype/qat/README.md
@@ -1,125 +1,3 @@
-# Quantization-Aware Training (QAT)
-
-Quantization-Aware Training (QAT) refers to applying fake quantization during the
-training or fine-tuning process, such that the final quantized model will exhibit
-higher accuracies and perplexities. Fake quantization refers to rounding the float
-values to quantized values without actually casting them to dtypes with lower
-bit-widths, in contrast to post-training quantization (PTQ), which does cast the
-quantized values to lower bit-width dtypes, e.g.:
-
-```
-# PTQ: x_q is quantized and cast to int8
-# scale and zero point (zp) refer to parameters used to quantize x_float
-# qmin and qmax refer to the range of quantized values
-x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
-
-# QAT: x_fq is still in float
-# Fake quantize simulates the numerics of quantize + dequantize
-x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
-x_fq = (x_fq - zp) * scale
-```
-
-## API
-
-torchao currently supports two QAT schemes for linear layers:
-- int8 per token dynamic activations + int4 per group weights
-- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
-
-QAT typically involves applying a transformation to your model before and after training.
-In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
-fake quantize operations into linear layers, and (2) convert transforms the fake quantize
-operations to actual quantize and dequantize operations after training, thereby producing
-a quantized model (dequantize operations are typically fused with linear after lowering).
-Between these two steps, training can proceed exactly as before.
-
-![qat](images/qat_diagram.png)
-
-To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
-training, then apply the convert step after training for inference or generation.
-For example, on a single GPU:
-
-```python
-import torch
-from torchtune.models.llama3 import llama3
-from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
-
-# Smaller version of llama3 to fit in a single GPU
-model = llama3(
- vocab_size=4096,
- num_layers=16,
- num_heads=16,
- num_kv_heads=4,
- embed_dim=2048,
- max_seq_len=2048,
-).cuda()
-
-# Quantizer for int8 dynamic per token activations +
-# int4 grouped per channel weights, only for linear layers
-qat_quantizer = Int8DynActInt4WeightQATQuantizer()
-
-# Insert "fake quantize" operations into linear layers.
-# These operations simulate quantization numerics during
-# training without performing any dtype casting
-model = qat_quantizer.prepare(model)
-
-# Standard training loop
-optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
-loss_fn = torch.nn.CrossEntropyLoss()
-for i in range(10):
- example = torch.randint(0, 4096, (2, 16)).cuda()
- target = torch.randn((2, 16, 4096)).cuda()
- output = model(example)
- loss = loss_fn(output, target)
- loss.backward()
- optimizer.step()
- optimizer.zero_grad()
-
-# Convert fake quantize to actual quantize operations
-# The quantized model has the exact same structure as the
-# quantized model produced in the corresponding PTQ flow
-# through `Int8DynActInt4WeightQuantizer`
-model = qat_quantizer.convert(model)
-
-# inference or generate
-```
-
-Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
-and apply quantized-aware fine-tuning as follows:
-
-```
-tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
-```
-
-For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
-
-
-## Evaluation Results
-
-Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT
-integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
-on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset)
-for 5000 steps using a group size of 256 for the weights. Note that extensive
-hyperparameter tuning may further improve these results.
-
-Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5:
-
-| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) |
-| ---------------- | ------ | ------ | ------ | ------ | ------ |
-| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 |
-| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 |
-| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 |
-| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 |
-| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 |
-
-Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the
-quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097).
-
-| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) |
-| ---------------- | -------- | ------- | ------ | ------ | ------ |
-| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 |
-| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 |
-| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 |
-| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 |
-| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 |
-
-For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training).
+Note: QAT has been moved to torchao/quantization/qat.
+This is a legacy folder only for backward compatibility
+and will be removed in the near future.
diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py
index 09ea6e708d..f6c4b1c8ce 100644
--- a/torchao/quantization/prototype/qat/__init__.py
+++ b/torchao/quantization/prototype/qat/__init__.py
@@ -1,17 +1,15 @@
-from .api import (
+from torchao.quantization.qat import (
ComposableQATQuantizer,
+ Int4WeightOnlyEmbeddingQATQuantizer,
+ Int4WeightOnlyQATQuantizer,
+ Int8DynActInt4WeightQATQuantizer,
)
-from .linear import (
+from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
- Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATLinear,
- Int8DynActInt4WeightQATQuantizer,
-)
-from .embedding import (
- Int4WeightOnlyEmbeddingQATQuantizer,
)
__all__ = [
diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py
index 0b44974f21..d6aa4a3ae5 100644
--- a/torchao/quantization/prototype/qat/_module_swap_api.py
+++ b/torchao/quantization/prototype/qat/_module_swap_api.py
@@ -1,7 +1,7 @@
# For backward compatibility only
# These will be removed in the future
-from .linear import (
+from torchao.quantization.qat.linear import (
Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap,
Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap,
enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap,
diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
index 9d7cdccfc3..ff09ca842e 100644
--- a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
+++ b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
@@ -1,328 +1,4 @@
-import torch
-import torch.utils._pytree as pytree
-from typing import Callable, Optional, Tuple
-from torchao.quantization.quant_primitives import (
- _get_and_check_qmin_qmax,
- choose_qparams_affine,
- fake_quantize_affine,
- ZeroPointDomain,
- MappingType,
+from torchao.quantization.qat.affine_fake_quantized_tensor import (
+ AffineFakeQuantizedTensor,
+ to_affine_fake_quantized,
)
-from torch.utils._python_dispatch import return_and_correct_aliasing
-from torchao.utils import TorchAOBaseTensor
-from .utils import (
- _GenericFakeQuantize,
- _UnwrapAffineFakeQuantizedTensor,
-)
-
-aten = torch.ops.aten
-
-
-class _ToAffineFakeQuantized(torch.autograd.Function):
- """
- Differentiable constructor for `AffineFakeQuantizedTensor`,
- needed for input activation fake quantization.
- """
-
- @staticmethod
- def forward(
- ctx: torch.autograd.function.FunctionCtx,
- original_tensor: torch.Tensor,
- mapping_type: MappingType,
- block_size: Tuple[int, ...],
- target_dtype: torch.dtype,
- quant_min: Optional[int] = None,
- quant_max: Optional[int] = None,
- eps: Optional[float] = None,
- scale_dtype: Optional[torch.dtype] = None,
- zero_point_dtype: Optional[torch.dtype] = None,
- preserve_zero: bool = True,
- zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
- ) -> "AffineFakeQuantizedTensor":
- def apply_fake_quant_fn(t: torch.Tensor):
- assert isinstance(t, AffineFakeQuantizedTensor)
- qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
- scale, zero_point = choose_qparams_affine(
- t.original_tensor,
- mapping_type,
- block_size,
- target_dtype,
- qmin,
- qmax,
- eps,
- scale_dtype,
- zero_point_dtype,
- preserve_zero,
- zero_point_domain,
- )
- fq = _GenericFakeQuantize.apply(
- t,
- block_size,
- scale,
- zero_point,
- qmin,
- qmax,
- zero_point_domain,
- )
- return fq
- return AffineFakeQuantizedTensor(
- original_tensor,
- apply_fake_quant_fn,
- fake_quant_enabled=True,
- )
-
- @staticmethod
- def backward(ctx, gy):
- return gy, None, None, None, None, None, None, None, None, None, None
-
-
-class AffineFakeQuantizedTensor(TorchAOBaseTensor):
- """
- Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor
- with an affine transformation:
- quantized_tensor = float_tensor / scale + zero_point
-
- Fake quantization refers to performing the quantization math without actually casting the floating point
- tensor into lower bit-width dtypes. It is commonly used for quantization-aware training (QAT).
-
- The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
- regardless of the internal representation's type or orientation.
-
- fields:
- original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later
- apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values
- """
-
- @staticmethod
- def __new__(
- cls,
- original_tensor: torch.Tensor,
- apply_fake_quant_fn: Callable,
- fake_quant_enabled: bool = True,
- **kwargs,
- ):
- kwargs.setdefault("dtype", original_tensor.dtype)
- kwargs.setdefault("device", original_tensor.device)
- kwargs.setdefault("requires_grad", original_tensor.requires_grad)
- return torch.Tensor._make_wrapper_subclass(
- cls,
- original_tensor.shape,
- **kwargs,
- )
-
- def __init__(
- self,
- original_tensor: torch.Tensor,
- apply_fake_quant_fn: Callable,
- fake_quant_enabled: bool = True,
- **kwargs
- ):
- self.original_tensor = original_tensor
- self.apply_fake_quant_fn = apply_fake_quant_fn
- self.fake_quant_enabled = fake_quant_enabled
-
- def __tensor_flatten__(self):
- return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled]
-
- @classmethod
- def __tensor_unflatten__(
- cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride,
- ):
- original_tensor = tensor_data_dict["original_tensor"]
- (apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes
- return cls(
- original_tensor,
- apply_fake_quant_fn,
- fake_quant_enabled,
- )
-
- @classmethod
- def from_float(
- cls,
- original_input: torch.Tensor,
- mapping_type: MappingType,
- block_size: Tuple[int, ...],
- target_dtype: torch.dtype,
- quant_min: Optional[int] = None,
- quant_max: Optional[int] = None,
- eps: Optional[float] = None,
- scale_dtype: Optional[torch.dtype] = None,
- zero_point_dtype: Optional[torch.dtype] = None,
- preserve_zero: bool = True,
- zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
- ):
- return _ToAffineFakeQuantized.apply(
- original_input,
- mapping_type,
- block_size,
- target_dtype,
- quant_min,
- quant_max,
- eps,
- scale_dtype,
- zero_point_dtype,
- preserve_zero,
- zero_point_domain,
- )
-
- def get_value(self) -> torch.Tensor:
- if self.fake_quant_enabled:
- return self.apply_fake_quant_fn(self)
- else:
- return _UnwrapAffineFakeQuantizedTensor.apply(self)
-
- def _get_to_kwargs(self, *args, **kwargs):
- device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
- device = self.device if device is None else device
- dtype = self.dtype if dtype is None else dtype
- memory_format = (
- memory_format if memory_format is not None else torch.preserve_format
- )
- kwargs = {
- "device": device,
- "dtype": dtype,
- "memory_format": memory_format,
- "requires_grad": self.requires_grad,
- }
- return kwargs
-
- def to(self, *args, **kwargs):
- kwargs = self._get_to_kwargs(*args, **kwargs)
- device = kwargs.pop("device")
- # not supported yet
- kwargs.pop("memory_format")
- return self.__class__(
- self.original_tensor.to(device),
- self.apply_fake_quant_fn,
- self.fake_quant_enabled,
- **kwargs,
- )
-
- def _apply_fn_to_data(self, fn: Callable):
- """
- Create a new `AffineFakeQuantizedTensor` with `fn` applied to the
- original tensor, to be called within __torch_dispatch__.
- """
- return self._create_new(fn(self.original_tensor))
-
- def _create_new(self, new_value: torch.Tensor):
- """
- Create a new `AffineFakeQuantizedTensor` with a new value,
- to be called within __torch_dispatch__.
-
- Note: `requires_grad` must be False here because tensors created
- in `__torch_dispatch__` cannot produce gradients, since autograd
- will try to attach autograd metadata to these tensors when we exit
- `__torch_dispatch__`, but if these tensors already have metadata
- attached then autograd will throw an error.
- """
- return self.__class__(
- new_value,
- self.apply_fake_quant_fn,
- self.fake_quant_enabled,
- requires_grad=False,
- )
-
-implements = AffineFakeQuantizedTensor.implements
-
-
-@implements(torch.nn.functional.linear)
-def _(func, types, args, kwargs):
- input_tensor, weight_tensor, bias = (
- args[0],
- args[1],
- args[2] if len(args) > 2 else None,
- )
- if isinstance(input_tensor, AffineFakeQuantizedTensor):
- input_tensor = input_tensor.get_value()
- if isinstance(weight_tensor, AffineFakeQuantizedTensor):
- weight_tensor = weight_tensor.get_value()
- return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
-
-
-@implements(aten.mm.default)
-def _(func, types, args, kwargs):
- bias = None
- input_tensor = args[0]
- weight_tensor = args[1]
- if isinstance(input_tensor, AffineFakeQuantizedTensor):
- input_tensor = input_tensor.get_value()
- if isinstance(weight_tensor, AffineFakeQuantizedTensor):
- weight_tensor = weight_tensor.get_value()
- return func(input_tensor, weight_tensor)
-
-
-@implements(aten.addmm.default)
-def _(func, types, args, kwargs):
- bias = args[0]
- input_tensor = args[1]
- weight_tensor = args[2]
- if isinstance(input_tensor, AffineFakeQuantizedTensor):
- input_tensor = input_tensor.get_value()
- if isinstance(weight_tensor, AffineFakeQuantizedTensor):
- weight_tensor = weight_tensor.get_value()
- return func(bias, input_tensor, weight_tensor)
-
-
-@implements(aten.detach.default)
-def _(func, types, args, kwargs):
- return return_and_correct_aliasing(
- func, args, kwargs, args[0]._apply_fn_to_data(torch.detach),
- )
-
-
-@implements(aten.clone.default)
-def _(func, types, args, kwargs):
- return return_and_correct_aliasing(
- func, args, kwargs, args[0]._apply_fn_to_data(torch.clone),
- )
-
-
-@implements(aten.t.default)
-def _(func, types, args, kwargs):
- return return_and_correct_aliasing(
- func, args, kwargs, args[0]._apply_fn_to_data(torch.t),
- )
-
-
-@implements([
- aten.add.Tensor,
- aten.add_.Tensor,
- aten.mul_.Tensor,
- aten.copy_.default,
-])
-def _(func, types, args, kwargs):
- assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}"
- new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args)
- first_afq_tensor = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1]
- new_value = func(*new_args, **kwargs)
- out = first_afq_tensor._create_new(new_value)
- return return_and_correct_aliasing(func, args, kwargs, out)
-
-
-# Needed by FSDP:
-
-@implements(aten.empty_like.default)
-def _(func, types, args, kwargs):
- out = torch.empty_like(args[0].original_tensor, **kwargs)
- return return_and_correct_aliasing(func, args, kwargs, out)
-
-
-@implements(aten.split.Tensor)
-def _(func, types, args, kwargs):
- new_values = torch.split(args[0].original_tensor, *args[1:], **kwargs)
-
- def make_new_tensor(value):
- out = args[0]._create_new(value)
- return return_and_correct_aliasing(func, args, kwargs, out)
-
- return list(map(make_new_tensor, new_values))
-
-
-@implements(aten.new_zeros.default)
-def _(func, types, args, kwargs):
- out = args[0].original_tensor.new_zeros(*args[1:], **kwargs)
- return return_and_correct_aliasing(func, args, kwargs, out)
-
-
-to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float
diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py
index 60d45c8b17..e3e18e9e30 100644
--- a/torchao/quantization/prototype/qat/api.py
+++ b/torchao/quantization/prototype/qat/api.py
@@ -1,262 +1,4 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from dataclasses import dataclass
-from enum import Enum
-from typing import Any, List, Optional, Union
-
-import torch
-
-from torchao.quantization.granularity import (
- Granularity,
- PerAxis,
- PerGroup,
- PerToken,
+from torchao.quantization.qat.api import (
+ ComposableQATQuantizer,
+ FakeQuantizeConfig,
)
-from torchao.quantization.unified import TwoStepQuantizer
-from torchao.quantization.quant_primitives import (
- _SUB_BYTE_INT_BOUNDS,
- _SUB_BYTE_UINT_BOUNDS,
- MappingType,
- TorchAODType,
- ZeroPointDomain,
-)
-
-
-@dataclass
-class FakeQuantizeConfig:
- """
- Config for how to fake quantize weights or activations.
-
- args:
- dtype: dtype to simulate during fake quantization, e.g. torch.int8.
- For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
- torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
- granularity: granularity of scales and zero points, e.g. PerGroup(32).
- We also support the following strings:
- 1) 'per_token': equivalent to PerToken()
- 2) 'per_channel': equivalent to PerAxis(0)
- 3) 'per_group': equivalent to PerGroup(group_size), must be combined
- with separate `group_size` kwarg, Alternatively, just set the
- `group_size` kwarg and leave this field empty.
- mapping_type: whether to use symmetric (default) or asymmetric quantization
- Alternatively, set `is_symmetric` (bool) and leave this field empty.
- scale_precision: scale dtype (default torch.fp32)
- zero_point_precision: zero point dtype (default torch.int32)
- zero_point_domain: whether zero point is in integer (default) or float domain
- is_dynamic: whether to use dynamic (defualt) or static scale and zero points
- range_learning: whether to learn scale and zero points during training (coming soon)
-
- kwargs (optional):
- group_size: size of each group in per group fake quantization,
- can be set instead of `granularity`
- is_symmetric: whether to use symmetric or asymmetric quantization,
- can be set instead of `mapping_type`
-
- Example usage::
-
- # Per token asymmetric quantization
- FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
- FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC)
-
- # Per channel symmetric quantization
- FakeQuantizeConfig(torch.int4, "per_channel")
- FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True)
- FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC)
-
- # Per group symmetric quantization
- FakeQuantizeConfig(torch.int4, group_size=32)
- FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True)
- FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True)
- FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC)
- """
- dtype: Union[torch.dtype, TorchAODType]
- granularity: Granularity
- mapping_type: MappingType
- scale_precision: torch.dtype
- zero_point_precision: torch.dtype
- zero_point_domain: ZeroPointDomain
- is_dynamic: bool = True
- range_learning: bool = False
-
- def __init__(
- self,
- dtype: Union[torch.dtype, TorchAODType],
- granularity: Union[Granularity, str, None] = None,
- mapping_type: Optional[MappingType] = None,
- scale_precision: torch.dtype = torch.float32,
- zero_point_precision: torch.dtype = torch.int32,
- zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
- is_dynamic: bool = True,
- range_learning: bool = False,
- *,
- group_size: Optional[int] = None,
- is_symmetric: Optional[bool] = None,
- ):
- self.dtype = dtype
- self.granularity = self._get_granularity(granularity, group_size)
- self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
- self.scale_precision = scale_precision
- self.zero_point_precision = zero_point_precision
- self.zero_point_domain = zero_point_domain
- self.is_dynamic = is_dynamic
- self.range_learning = range_learning
-
- # Validate dtype
- all_dtypes = [torch.int8, torch.uint8]
- all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys()))
- all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys()))
- if dtype not in all_dtypes:
- raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes))
-
- def _get_granularity(
- self,
- granularity: Union[Granularity, str, None],
- group_size: Optional[int],
- ) -> Granularity:
- """
- Parse the `Granularity` represented in the args.
-
- Granularity can be specified in one of three ways:
- 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size)
- 2) str: one of 'per_token', 'per_channel', and 'per_group'
- 3) None: `group_size` must be set instead, represents per group granularity
- """
- # If group_size is set, then granularity must be either "per_group" or None
- if group_size is not None and granularity != "per_group" and granularity is not None:
- raise ValueError("`group_size` conflicts with granularity '%s'" % granularity)
-
- # Case 1: Granularity object
- if isinstance(granularity, Granularity):
- if not isinstance(granularity, (PerToken, PerAxis, PerGroup)):
- raise ValueError("Granularity '%s' is not supported" % granularity)
- if isinstance(granularity, PerAxis) and granularity.axis != 0:
- raise ValueError("Only axis=0 is supported for PerAxis granularity")
- return granularity
-
- # Case 2: str granularity
- if granularity == "per_token":
- return PerToken()
- elif granularity == "per_channel":
- return PerAxis(axis=0)
- elif granularity == "per_group":
- if group_size is None:
- raise ValueError("Granularity was 'per_group' but no `group_size` was set")
- return PerGroup(group_size)
- elif isinstance(granularity, str):
- raise ValueError(
- "Unexpected granularity: '%s', must be one of %s" %
- (granularity, ["per_token", "per_channel", "per_group"])
- )
-
- # Case 3: None granularity + group_size was specified
- if granularity is not None:
- raise ValueError(
- "Granularity '%s' has unexpected type %s" % (granularity, type(granularity))
- )
- if group_size is None:
- raise ValueError("At least one of `granularity` or `group_size` must be set")
- return PerGroup(group_size)
-
- def _get_mapping_type(
- self,
- mapping_type: Optional[MappingType],
- is_symmetric: Optional[bool],
- ) -> MappingType:
- """
- Parse the `MappingType` represented in the args.
-
- Mapping type can be specified in one of two ways:
- 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC
- 2): is_symmetric bool
- """
- if mapping_type is not None and is_symmetric is not None:
- raise ValueError("Cannot set both `mapping_type` and `is_symmetric`")
-
- # Case 0: Default to symmetric
- if mapping_type is None and is_symmetric is None:
- return MappingType.SYMMETRIC
-
- # Case 1: MappingType object
- if mapping_type is not None:
- if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]:
- raise ValueError("MappingType '%s' is not supported" % mapping_type)
- return mapping_type
-
- # Case 2: is_symmetric flag
- assert is_symmetric is not None
- if is_symmetric:
- return MappingType.SYMMETRIC
- else:
- return MappingType.ASYMMETRIC
-
- @property
- def group_size(self) -> int:
- """
- If this is per group granularity, return the group size.
- Otherwise, throw an error.
- """
- if isinstance(self.granularity, PerGroup):
- return self.granularity.group_size
- else:
- raise ValueError("`group_size` is undefined for %s granularity" % self.granularity)
-
- @property
- def is_symmetric(self) -> bool:
- """
- Return True if mapping type is symmetric, else False (asymmetric).
- """
- return self.mapping_type == MappingType.SYMMETRIC
-
- def __setattr__(self, name: str, value: Any):
- """
- Support setting `group_size` and `is_symmetric`.
- """
- if name == "group_size":
- super().__setattr__("granularity", PerGroup(value))
- elif name == "is_symmetric":
- mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC
- super().__setattr__("mapping_type", mapping_type)
- else:
- super().__setattr__(name, value)
-
-
-class ComposableQATQuantizer(TwoStepQuantizer):
- """
- Composable quantizer that users can use to apply multiple QAT quantizers easily.
- Quantizers will be applied in the order they are specified in the constructor.
-
- Note: the quantizers provided must apply to different modules in the model,
- e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined.
-
- Example usage::
-
- my_quantizer = ComposableQATQuantizer([
- QATQuantizer1(),
- QATQuantizer2(),
- QATQuantizer3(),
- ])
- model = my_quantizer.prepare(model)
- train(model)
- model = my_quantizer.convert(model)
- """
-
- def __init__(self, quantizers: List[TwoStepQuantizer]):
- self.quantizers = quantizers
-
- def prepare(
- self, model: torch.nn.Module, *args: Any, **kwargs: Any
- ) -> torch.nn.Module:
- for quantizer in self.quantizers:
- model = quantizer.prepare(model)
- return model
-
- def convert(
- self, model: torch.nn.Module, *args: Any, **kwargs: Any
- ) -> torch.nn.Module:
- for quantizer in self.quantizers:
- model = quantizer.convert(model)
- return model
diff --git a/torchao/quantization/prototype/qat/embedding.py b/torchao/quantization/prototype/qat/embedding.py
index 1f471fa490..9168291202 100644
--- a/torchao/quantization/prototype/qat/embedding.py
+++ b/torchao/quantization/prototype/qat/embedding.py
@@ -1,325 +1,6 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Any, Optional
-
-import torch
-import torch.nn.functional as F
-
-from torchao.quantization.unified import TwoStepQuantizer
-from torchao.quantization.utils import get_group_qparams_symmetric
-from torchao.quantization.quant_api import (
- _replace_with_custom_fn_if_matches_filter,
+from torchao.quantization.qat.embedding import (
+ FakeQuantizedEmbedding,
+ Int4WeightOnlyEmbeddingQATQuantizer,
+ Int4WeightOnlyEmbedding,
+ Int4WeightOnlyQATEmbedding,
)
-from torchao.quantization.quant_primitives import TorchAODType
-from .api import FakeQuantizeConfig
-from .fake_quantizer import FakeQuantizer
-from .utils import (
- _fake_quantize_per_channel_group,
- _get_qmin_qmax,
-)
-
-
-class FakeQuantizedEmbedding(torch.nn.Embedding):
- """
- General embedding layer with fake quantized weights.
-
- Specific target dtypes, granularity, schemes etc. are specified
- through separate configs for weights and activations.
-
- Example usage::
-
- weight_config = FakeQuantizeConfig(
- dtype=torch.int4,
- group_size=8,
- symmetric=True,
- )
- fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
- fq_embedding(torch.LongTensor([3]))
- """
-
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- weight_config: Optional[FakeQuantizeConfig] = None,
- *args,
- **kwargs,
- ) -> None:
- super().__init__(
- num_embeddings,
- embedding_dim,
- padding_idx,
- max_norm,
- norm_type,
- scale_grad_by_freq,
- sparse,
- *args,
- **kwargs,
- )
- if weight_config is not None:
- self.weight_fake_quantizer = FakeQuantizer(weight_config)
- else:
- self.weight_fake_quantizer = None
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.weight_fake_quantizer is not None:
- w = self.weight_fake_quantizer(self.weight)
- else:
- w = self.weight
- return F.embedding(
- x, w, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse,
- )
-
-
-# ======================================
-# | Embedding int4 weight-only QAT |
-# ======================================
-
-class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer):
- """
- Quantizer for performing QAT on a model, where embedding layers have
- int4 fake quantized grouped per channel weights.
- """
-
- def __init__(
- self,
- group_size: int = 256,
- scale_precision: torch.dtype = torch.float32,
- zero_point_precision: torch.dtype = torch.int32,
- ) -> None:
- super().__init__()
- self.bit_width = 4
- self.group_size: int = group_size
- self.scale_precision: torch.dtype = scale_precision
- self.zero_point_precision: torch.dtype = zero_point_precision
-
- def prepare(
- self,
- model: torch.nn.Module,
- *args: Any,
- **kwargs: Any
- ) -> torch.nn.Module:
- """
- Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`.
- """
- def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
- return isinstance(child, torch.nn.Embedding)
-
- def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
- new_embedding = Int4WeightOnlyQATEmbedding(
- # nn.Embedding args
- num_embeddings=child.num_embeddings,
- embedding_dim=child.embedding_dim,
- padding_idx=child.padding_idx,
- max_norm=child.max_norm,
- norm_type=child.norm_type,
- scale_grad_by_freq=child.scale_grad_by_freq,
- sparse=child.sparse,
- # quantization args
- group_size=self.group_size,
- scale_precision=self.scale_precision,
- zero_point_precision=self.zero_point_precision,
- device=child.weight.device,
- )
- # In distributed training, the model may be instantiated
- # on the meta device, in which case there is no need to
- # copy the weights, and doing so will result in an error
- if child.weight.device != torch.device("meta"):
- new_embedding.weight = child.weight
- return new_embedding
-
- _replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn)
- return model
-
- def convert(
- self,
- model: torch.nn.Module,
- *args: Any,
- **kwargs: Any
- ) -> torch.nn.Module:
- """
- Swap all `Int4WeightOnlyQATEmbedding` modules with `Int4WeightOnlyEmbedding`.
- """
- self._convert_helper(model)
- return model
-
- def _convert_helper(self, module: torch.nn.Module):
- """
- Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
- modules with `Int4WeightOnlyEmbedding`
- """
- from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
- for name, child in module.named_children():
- if isinstance(child, Int4WeightOnlyQATEmbedding):
- group_size = child.weight_fake_quantizer.config.group_size
- scale_precision = child.weight_fake_quantizer.config.scale_precision
- zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision
- quantized_embedding = Int4WeightOnlyEmbedding(
- # nn.Embedding args
- num_embeddings=child.num_embeddings,
- embedding_dim=child.embedding_dim,
- padding_idx=child.padding_idx,
- max_norm=child.max_norm,
- norm_type=child.norm_type,
- scale_grad_by_freq=child.scale_grad_by_freq,
- sparse=child.sparse,
- # quantization args
- group_size=group_size,
- scale_precision=scale_precision,
- zero_point_precision=zero_point_precision,
- device=child.weight.device,
- )
- setattr(module, name, quantized_embedding)
-
- # Load weights and qparams into quantized embedding
- (qmin, qmax) = _get_qmin_qmax(self.bit_width)
- (s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size)
- q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
- child.weight, s, zp, qmin, qmax, torch.int8, group_size,
- )
- quantized_embedding.weight = q_weight
- quantized_embedding.scales = s
- quantized_embedding.zeros = zp
- else:
- self._convert_helper(child)
-
-
-class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
- """
- This module implements a embedding layer with int4 fake quantized
- grouped per channel weights.
-
- args:
- group_size: the number of elements in each quantized group for weights
- scale_precision: precision of per group scales
- zero_point_precision: precision of per group zero points
- """
-
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- group_size: int = 32,
- scale_precision: torch.dtype = torch.float32,
- zero_point_precision: torch.dtype = torch.int32,
- *args,
- **kwargs,
- ):
- weight_config = FakeQuantizeConfig(
- dtype=TorchAODType.INT4,
- group_size=group_size,
- is_symmetric=True,
- is_dynamic=True,
- scale_precision=scale_precision,
- zero_point_precision=zero_point_precision,
- )
- super().__init__(
- num_embeddings,
- embedding_dim,
- padding_idx,
- max_norm,
- norm_type,
- scale_grad_by_freq,
- sparse,
- weight_config,
- *args,
- **kwargs,
- )
-
- def enable_fake_quant(self, enabled: bool = True):
- self.weight_fake_quantizer.enabled = enabled
-
- def disable_fake_quant(self):
- self.enable_fake_quant(False)
-
-
-class Int4WeightOnlyEmbedding(torch.nn.Module):
- """
- This module implements a embedding layer with int4 quantized
- grouped per channel weights.
- """
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- group_size: int = 32,
- scale_precision: torch.dtype = torch.float32,
- zero_point_precision: torch.dtype = torch.int32,
- device: torch.device = None,
- ):
- super().__init__()
-
- # nn.Embedding args
- self.num_embeddings = num_embeddings
- self.embedding_dim = embedding_dim
- self.padding_idx = padding_idx
- self.max_norm = max_norm
- self.norm_type = norm_type
- self.scale_grad_by_freq = scale_grad_by_freq
- self.sparse = sparse
-
- # quantization args
- self.bit_width = 4
- self.group_size = group_size
- self.scale_precision = scale_precision
- self.zero_point_precision = zero_point_precision
-
- # currently storing unpacked int8 weights
- self.register_buffer(
- "weight",
- torch.empty((num_embeddings, embedding_dim), dtype=torch.int8, device=device),
- )
- self.register_buffer(
- "scale",
- torch.empty(
- (num_embeddings, embedding_dim // group_size),
- dtype=scale_precision,
- device=device,
- ),
- )
- self.register_buffer(
- "zero_point",
- torch.empty(
- (num_embeddings, embedding_dim // group_size),
- dtype=zero_point_precision,
- device=device,
- ),
- )
-
- def forward(self, x):
- from torchao._executorch_ops import _quantized_decomposed_dequantize_per_channel_group_wrapper
- qmin, qmax = _get_qmin_qmax(self.bit_width)
- w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper(
- self.weight,
- self.scale,
- self.zero_point,
- qmin,
- qmax,
- torch.int8,
- self.group_size,
- x.dtype,
- )
- return F.embedding(
- x, w_dq, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse,
- )
diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/prototype/qat/fake_quantizer.py
index eb42dcf047..584a46573b 100644
--- a/torchao/quantization/prototype/qat/fake_quantizer.py
+++ b/torchao/quantization/prototype/qat/fake_quantizer.py
@@ -1,121 +1,3 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Optional, Tuple
-
-import torch
-
-from torchao.quantization.granularity import (
- PerAxis,
- PerGroup,
- PerToken,
+from torchao.quantization.qat.fake_quantizer import (
+ FakeQuantizer,
)
-from torchao.quantization.quant_primitives import (
- _DTYPE_TO_BIT_WIDTH,
- _DTYPE_TO_QVALUE_BOUNDS,
-)
-from torchao.quantization.utils import (
- get_group_qparams_symmetric,
- get_groupwise_affine_qparams,
-)
-from .api import (
- FakeQuantizeConfig,
-)
-from .utils import (
- _choose_qparams_per_token_asymmetric,
- _fake_quantize_per_channel_group,
- _fake_quantize_per_token,
- _get_qmin_qmax,
-)
-
-
-class FakeQuantizer(torch.nn.Module):
- """
- Generic module for applying fake quantization to a tensor, as specified in the config.
- """
- def __init__(self, config: FakeQuantizeConfig):
- super().__init__()
- self.config = config
- self.enabled = True
- self.scale: Optional[torch.Tensor] = None
- self.zero_point: Optional[torch.Tensor] = None
-
- # TODO: support range learinng
- if self.config.range_learning:
- raise NotImplementedError("Range learning is not supported yet")
-
- def forward(self, x: torch.Tensor):
- """
- Apply fake quantization to the tensor based on the bit-width,
- granularity, symmetry, and other properties specified in the config.
- """
- if not self.enabled:
- return x
-
- if isinstance(self.config.granularity, PerToken):
- return self._per_token_forward(x)
- elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
- return self._per_channel_or_group_forward(x)
- else:
- raise ValueError("Unknown granularity '%s'" % self.config.granularity)
-
- def _per_token_forward(self, x: torch.Tensor):
- """
- Perform per token fake quantization on the tensor.
- """
- if self.config.is_symmetric:
- raise NotImplementedError("Symmetric per token is not supported yet")
- if self._should_compute_qparams():
- (self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric(
- x, self.config.scale_precision, self.config.zero_point_precision,
- )
- qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
- return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
-
- def _per_channel_or_group_forward(self, x: torch.Tensor):
- """
- Perform per channel or per group fake quantization on the tensor.
- We express per channel using per group where the group size is the size
- of the last dimension of the tensor.
- """
- granularity = self.config.granularity
- scale_precision = self.config.scale_precision
- zero_point_precision = self.config.zero_point_precision
- zero_point_domain = self.config.zero_point_domain
- is_symmetric = self.config.is_symmetric
-
- # get group size
- if isinstance(granularity, PerAxis):
- assert granularity.axis == 0
- group_size = x.size()[-1]
- elif isinstance(granularity, PerGroup):
- group_size = granularity.group_size
- else:
- raise ValueError("Unexpected granularity '%s'" % granularity)
-
- # get scales and zero points
- if self._should_compute_qparams():
- bit_width = _DTYPE_TO_BIT_WIDTH[self.config.dtype]
- if is_symmetric:
- (self.scale, self.zero_point) = get_group_qparams_symmetric(
- x, bit_width, group_size, scale_precision,
- )
- else:
- (self.scale, self.zero_point) = get_groupwise_affine_qparams(
- x, bit_width, group_size, scale_precision,
- )
- self.zero_point = self.zero_point.to(zero_point_precision)
-
- qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
- return _fake_quantize_per_channel_group(
- x, self.scale, self.zero_point, qmin, qmax, group_size, zero_point_domain,
- )
-
- def _should_compute_qparams(self) -> bool:
- """
- Return whether we need to compute new scales and zero points.
- """
- return self.config.is_dynamic or self.scale is None or self.zero_point is None
diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py
index ef1714808c..465b1c1e51 100644
--- a/torchao/quantization/prototype/qat/linear.py
+++ b/torchao/quantization/prototype/qat/linear.py
@@ -1,419 +1,11 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Any, Optional
-
-import torch
-import torch.nn.functional as F
-
-from torchao.quantization.GPTQ import (
- _check_linear_int4_k,
- _replace_linear_int4,
- _replace_linear_8da4w,
- get_groupwise_affine_qparams,
- groupwise_affine_quantize_tensor,
- Int8DynActInt4WeightLinear,
- WeightOnlyInt4Linear,
+from torchao.quantization.qat.linear import (
+ disable_4w_fake_quant,
+ disable_8da4w_fake_quant,
+ enable_4w_fake_quant,
+ enable_8da4w_fake_quant,
+ FakeQuantizedLinear,
+ Int4WeightOnlyQATLinear,
+ Int4WeightOnlyQATQuantizer,
+ Int8DynActInt4WeightQATLinear,
+ Int8DynActInt4WeightQATQuantizer,
)
-from torchao.quantization.quant_primitives import (
- TorchAODType,
- ZeroPointDomain,
-)
-from torchao.quantization.unified import TwoStepQuantizer
-from torchao.quantization.utils import get_group_qparams_symmetric
-from .api import FakeQuantizeConfig
-from .fake_quantizer import FakeQuantizer
-from .utils import (
- _choose_qparams_per_token_asymmetric,
- _fake_quantize_per_channel_group,
- _fake_quantize_per_token,
- _get_qmin_qmax,
-)
-
-
-class FakeQuantizedLinear(torch.nn.Linear):
- """
- General linear layer with fake quantized weights and activations.
-
- Specific target dtypes, granularity, schemes etc. are specified
- through separate configs for weights and activations.
-
- Example usage::
-
- activation_config = FakeQuantizeConfig(
- dtype=torch.int8,
- granularity="per_token",
- is_symmetric=False,
- )
- weight_config = FakeQuantizeConfig(
- dtype=torch.int4,
- group_size=8,
- is_symmetric=True,
- )
- fq_linear = FakeQuantizedLinear(
- 16, 32, False, activation_config, weight_config,
- )
- fq_linear(torch.randn(16))
- """
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = False,
- activation_config: Optional[FakeQuantizeConfig] = None,
- weight_config: Optional[FakeQuantizeConfig] = None,
- *args,
- **kwargs,
- ) -> None:
- super().__init__(
- in_features,
- out_features,
- bias,
- *args,
- **kwargs,
- )
- if bias:
- raise NotImplementedError("bias not supported yet")
-
- # initialize activation fake quantizer
- if activation_config is not None:
- self.activation_fake_quantizer = FakeQuantizer(activation_config)
- else:
- self.activation_fake_quantizer = None
-
- # initialize weight fake quantizer
- if weight_config is not None:
- group_size = weight_config.group_size
- if group_size is not None and in_features % group_size != 0:
- raise ValueError(
- "in_features (%s) % group_size (%s) must be == 0" %
- (in_features, group_size)
- )
- self.weight_fake_quantizer = FakeQuantizer(weight_config)
- else:
- self.weight_fake_quantizer = None
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.activation_fake_quantizer is not None:
- x = self.activation_fake_quantizer(x)
- if self.weight_fake_quantizer is not None:
- w = self.weight_fake_quantizer(self.weight)
- else:
- w = self.weight
- return F.linear(x, w)
-
-
-# =========================================================
-# | Linear int8 dynamic activations + int4 weight QAT |
-# =========================================================
-
-
-class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
- """
- Quantizer for performing QAT on a model, where linear layers have int8
- dynamic per token fake quantized activations and int4 fake quantized
- grouped per channel weights.
- """
-
- def __init__(
- self,
- groupsize: int = 256,
- padding_allowed: bool = False,
- precision: torch.dtype = torch.float32,
- scales_precision: torch.dtype = torch.float32,
- ) -> None:
- super().__init__()
- self.groupsize: int = groupsize
- self.padding_allowed: bool = padding_allowed
- self.precision: torch.dtype = precision
- self.scales_precision: torch.dtype = scales_precision
-
- def prepare(
- self,
- model: torch.nn.Module,
- *args: Any,
- **kwargs: Any
- ) -> torch.nn.Module:
- _replace_linear_8da4w(
- model,
- self.groupsize,
- self.padding_allowed,
- self.precision,
- self.scales_precision,
- Int8DynActInt4WeightQATLinear,
- copy_weights=True,
- )
- return model
-
- def convert(
- self,
- model: torch.nn.Module,
- *args: Any,
- **kwargs: Any
- ) -> torch.nn.Module:
- self._convert_qat_linear_8da4w(model)
- return model
-
- def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
- """
- Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
- """
- for name, child in module.named_children():
- if isinstance(child, Int8DynActInt4WeightQATLinear):
- config = child.weight_fake_quantizer.config
- quantized_linear = Int8DynActInt4WeightLinear(
- child.in_features,
- child.out_features,
- bias=False,
- groupsize=config.group_size,
- precision=child.weight.dtype,
- scales_precision=config.scale_precision,
- )
- setattr(module, name, quantized_linear)
-
- # Load weights and qparams into quantized linear
- n_bit = 4
- (qmin, qmax) = _get_qmin_qmax(n_bit)
- (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, config.group_size)
- from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
- q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
- child.weight, s, zp, qmin, qmax, torch.int8, config.group_size,
- )
- quantized_linear.weight = q_weight
- quantized_linear.scales = s
- quantized_linear.zeros = zp
- else:
- self._convert_qat_linear_8da4w(child)
-
-
-class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear):
- """
- This module implements a linear layer with int8 dynamic per token fake
- quantized activations with int4 fake quantized grouped per channel weights.
-
- args:
- groupsize: the number of elements in each quantized group for weights
- precision: precision of weights
- scales_precision: precision of per group scales and zero points
- """
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = False,
- device: torch.device = None,
- groupsize: int = 256,
- precision: torch.dtype = torch.float32,
- scales_precision: torch.dtype = torch.float32,
- ) -> None:
- activation_config = FakeQuantizeConfig(
- dtype=torch.int8,
- granularity="per_token",
- is_symmetric=False,
- is_dynamic=True,
- scale_precision=scales_precision,
- zero_point_precision=scales_precision,
- )
- weight_config = FakeQuantizeConfig(
- dtype=TorchAODType.INT4,
- group_size=groupsize,
- is_symmetric=True,
- is_dynamic=True,
- scale_precision=scales_precision,
- zero_point_precision=scales_precision,
- )
- super().__init__(
- in_features,
- out_features,
- bias,
- activation_config,
- weight_config,
- device=device,
- dtype=precision,
- )
-
- def enable_fake_quant(self, enabled: bool = True):
- self.activation_fake_quantizer.enabled = enabled
- self.weight_fake_quantizer.enabled = enabled
-
- def disable_fake_quant(self):
- self.enable_fake_quant(False)
-
-
-def enable_8da4w_fake_quant(mod: torch.nn.Module):
- """
- Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
- """
- if isinstance(mod, Int8DynActInt4WeightQATLinear):
- mod.enable_fake_quant()
-
-
-def disable_8da4w_fake_quant(mod: torch.nn.Module):
- """
- Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
- """
- if isinstance(mod, Int8DynActInt4WeightQATLinear):
- mod.disable_fake_quant()
-
-
-# ===================================
-# | Linear int4 weight-only QAT |
-# ===================================
-
-
-class Int4WeightOnlyQATQuantizer(TwoStepQuantizer):
- """
- Quantizer for performing QAT on a model, where linear layers have
- int4 fake quantized grouped per channel weights.
- """
-
- def __init__(
- self,
- groupsize: int = 256,
- inner_k_tiles: Optional[int] = 8,
- precision: torch.dtype = torch.bfloat16,
- scales_precision: torch.dtype = torch.bfloat16,
- ) -> None:
- super().__init__()
- assert inner_k_tiles in [2, 4, 8]
- assert groupsize in [32, 64, 128, 256]
- self.inner_k_tiles = inner_k_tiles
- self.groupsize = groupsize
- self.precision = precision
- self.scales_precision = scales_precision
-
- def prepare(
- self,
- model: torch.nn.Module,
- *args: Any,
- **kwargs: Any
- ) -> torch.nn.Module:
- _replace_linear_int4(
- model,
- self.groupsize,
- self.inner_k_tiles,
- padding_allowed=True,
- precision=self.precision,
- scales_precision=self.scales_precision,
- linear_class=Int4WeightOnlyQATLinear,
- copy_weights=True,
- )
- return model
-
- def convert(
- self,
- model: torch.nn.Module,
- *args: Any,
- **kwargs: Any
- ) -> torch.nn.Module:
- self._convert_qat_linear_4w(model)
- return model
-
- def _convert_qat_linear_4w(self, module: torch.nn.Module):
- """
- Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`.
- """
- for name, child in module.named_children():
- if isinstance(child, Int4WeightOnlyQATLinear):
- in_features = child.in_features
- out_features = child.out_features
- inner_k_tiles = child.inner_k_tiles
- config = child.weight_fake_quantizer.config
- quantized_linear = WeightOnlyInt4Linear(
- in_features,
- out_features,
- bias=False,
- groupsize=config.group_size,
- inner_k_tiles=inner_k_tiles,
- precision=child.weight.dtype,
- scales_precision=config.scale_precision,
- )
- setattr(module, name, quantized_linear)
-
- # Load weights and qparams into quantized linear
- n_bit = 4
- (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
- child.weight, n_bit, config.group_size,
- )
- q_weight = torch.ops.aten._convert_weight_to_int4pack(
- q_weight.to(child.weight.device), child.inner_k_tiles,
- )
- quantized_linear.weight = q_weight
- quantized_linear.scales_and_zeros = scales_and_zeros
- else:
- self._convert_qat_linear_4w(child)
-
-
-class Int4WeightOnlyQATLinear(FakeQuantizedLinear):
- """
- This module implements a linear layer with int4 fake quantized grouped
- per channel weights, with forward numerics matching `WeightOnlyInt4Linear`,
- which uses the efficient int4 tinygemm kernel.
-
- args:
- groupsize: the number of elements in each quantized group for weights
- precision: precision of weights
- scales_precision: precision of per group scales and zero points
- """
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = False,
- device: torch.device = None,
- groupsize: int = 256,
- inner_k_tiles: int = 8,
- precision: torch.dtype = torch.bfloat16,
- scales_precision: torch.dtype = torch.bfloat16,
- ) -> None:
- assert scales_precision == torch.bfloat16, "only bf16 is supported for scales"
- if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles):
- raise ValueError("Padding for QAT 4w is not supported yet")
- self.inner_k_tiles = inner_k_tiles
- weight_config = FakeQuantizeConfig(
- dtype=torch.uint4,
- group_size=groupsize,
- is_symmetric=False,
- is_dynamic=True,
- scale_precision=scales_precision,
- zero_point_precision=scales_precision,
- zero_point_domain=ZeroPointDomain.FLOAT,
- )
- super().__init__(
- in_features,
- out_features,
- bias,
- activation_config=None,
- weight_config=weight_config,
- device=device,
- dtype=precision,
- )
-
- def enable_fake_quant(self, enabled: bool = True):
- self.activation_fake_quantizer.enabled = enabled
- self.weight_fake_quantizer.enabled = enabled
-
- def disable_fake_quant(self):
- self.enable_fake_quant(False)
-
-
-def enable_4w_fake_quant(mod: torch.nn.Module):
- """
- Enable fake quantization for `Int4WeightOnlyQATLinear`.
- """
- if isinstance(mod, Int4WeightOnlyQATLinear):
- mod.enable_fake_quant()
-
-
-def disable_4w_fake_quant(mod: torch.nn.Module):
- """
- Disable fake quantization for `Int4WeightOnlyQATLinear`.
- """
- if isinstance(mod, Int4WeightOnlyQATLinear):
- mod.disable_fake_quant()
diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md
new file mode 100644
index 0000000000..6ecccd2b18
--- /dev/null
+++ b/torchao/quantization/qat/README.md
@@ -0,0 +1,125 @@
+# Quantization-Aware Training (QAT)
+
+Quantization-Aware Training (QAT) refers to applying fake quantization during the
+training or fine-tuning process, such that the final quantized model will exhibit
+higher accuracies and perplexities. Fake quantization refers to rounding the float
+values to quantized values without actually casting them to dtypes with lower
+bit-widths, in contrast to post-training quantization (PTQ), which does cast the
+quantized values to lower bit-width dtypes, e.g.:
+
+```
+# PTQ: x_q is quantized and cast to int8
+# scale and zero point (zp) refer to parameters used to quantize x_float
+# qmin and qmax refer to the range of quantized values
+x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
+
+# QAT: x_fq is still in float
+# Fake quantize simulates the numerics of quantize + dequantize
+x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
+x_fq = (x_fq - zp) * scale
+```
+
+## API
+
+torchao currently supports two QAT schemes for linear layers:
+- int8 per token dynamic activations + int4 per group weights
+- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
+
+QAT typically involves applying a transformation to your model before and after training.
+In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
+fake quantize operations into linear layers, and (2) convert transforms the fake quantize
+operations to actual quantize and dequantize operations after training, thereby producing
+a quantized model (dequantize operations are typically fused with linear after lowering).
+Between these two steps, training can proceed exactly as before.
+
+![qat](images/qat_diagram.png)
+
+To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
+training, then apply the convert step after training for inference or generation.
+For example, on a single GPU:
+
+```python
+import torch
+from torchtune.models.llama3 import llama3
+from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
+
+# Smaller version of llama3 to fit in a single GPU
+model = llama3(
+ vocab_size=4096,
+ num_layers=16,
+ num_heads=16,
+ num_kv_heads=4,
+ embed_dim=2048,
+ max_seq_len=2048,
+).cuda()
+
+# Quantizer for int8 dynamic per token activations +
+# int4 grouped per channel weights, only for linear layers
+qat_quantizer = Int8DynActInt4WeightQATQuantizer()
+
+# Insert "fake quantize" operations into linear layers.
+# These operations simulate quantization numerics during
+# training without performing any dtype casting
+model = qat_quantizer.prepare(model)
+
+# Standard training loop
+optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
+loss_fn = torch.nn.CrossEntropyLoss()
+for i in range(10):
+ example = torch.randint(0, 4096, (2, 16)).cuda()
+ target = torch.randn((2, 16, 4096)).cuda()
+ output = model(example)
+ loss = loss_fn(output, target)
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+# Convert fake quantize to actual quantize operations
+# The quantized model has the exact same structure as the
+# quantized model produced in the corresponding PTQ flow
+# through `Int8DynActInt4WeightQuantizer`
+model = qat_quantizer.convert(model)
+
+# inference or generate
+```
+
+Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
+and apply quantized-aware fine-tuning as follows:
+
+```
+tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
+```
+
+For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
+
+
+## Evaluation Results
+
+Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT
+integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
+on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset)
+for 5000 steps using a group size of 256 for the weights. Note that extensive
+hyperparameter tuning may further improve these results.
+
+Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5:
+
+| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) |
+| ---------------- | ------ | ------ | ------ | ------ | ------ |
+| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 |
+| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 |
+| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 |
+| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 |
+| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 |
+
+Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the
+quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097).
+
+| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) |
+| ---------------- | -------- | ------- | ------ | ------ | ------ |
+| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 |
+| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 |
+| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 |
+| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 |
+| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 |
+
+For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training).
diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py
new file mode 100644
index 0000000000..09ef10af67
--- /dev/null
+++ b/torchao/quantization/qat/__init__.py
@@ -0,0 +1,17 @@
+from .api import (
+ ComposableQATQuantizer,
+)
+from .linear import (
+ Int4WeightOnlyQATQuantizer,
+ Int8DynActInt4WeightQATQuantizer,
+)
+from .embedding import (
+ Int4WeightOnlyEmbeddingQATQuantizer,
+)
+
+__all__ = [
+ "ComposableQATQuantizer",
+ "Int4WeightOnlyQATQuantizer",
+ "Int4WeightOnlyEmbeddingQATQuantizer"
+ "Int8DynActInt4WeightQATQuantizer",
+]
diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py
new file mode 100644
index 0000000000..9d7cdccfc3
--- /dev/null
+++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py
@@ -0,0 +1,328 @@
+import torch
+import torch.utils._pytree as pytree
+from typing import Callable, Optional, Tuple
+from torchao.quantization.quant_primitives import (
+ _get_and_check_qmin_qmax,
+ choose_qparams_affine,
+ fake_quantize_affine,
+ ZeroPointDomain,
+ MappingType,
+)
+from torch.utils._python_dispatch import return_and_correct_aliasing
+from torchao.utils import TorchAOBaseTensor
+from .utils import (
+ _GenericFakeQuantize,
+ _UnwrapAffineFakeQuantizedTensor,
+)
+
+aten = torch.ops.aten
+
+
+class _ToAffineFakeQuantized(torch.autograd.Function):
+ """
+ Differentiable constructor for `AffineFakeQuantizedTensor`,
+ needed for input activation fake quantization.
+ """
+
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ original_tensor: torch.Tensor,
+ mapping_type: MappingType,
+ block_size: Tuple[int, ...],
+ target_dtype: torch.dtype,
+ quant_min: Optional[int] = None,
+ quant_max: Optional[int] = None,
+ eps: Optional[float] = None,
+ scale_dtype: Optional[torch.dtype] = None,
+ zero_point_dtype: Optional[torch.dtype] = None,
+ preserve_zero: bool = True,
+ zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
+ ) -> "AffineFakeQuantizedTensor":
+ def apply_fake_quant_fn(t: torch.Tensor):
+ assert isinstance(t, AffineFakeQuantizedTensor)
+ qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
+ scale, zero_point = choose_qparams_affine(
+ t.original_tensor,
+ mapping_type,
+ block_size,
+ target_dtype,
+ qmin,
+ qmax,
+ eps,
+ scale_dtype,
+ zero_point_dtype,
+ preserve_zero,
+ zero_point_domain,
+ )
+ fq = _GenericFakeQuantize.apply(
+ t,
+ block_size,
+ scale,
+ zero_point,
+ qmin,
+ qmax,
+ zero_point_domain,
+ )
+ return fq
+ return AffineFakeQuantizedTensor(
+ original_tensor,
+ apply_fake_quant_fn,
+ fake_quant_enabled=True,
+ )
+
+ @staticmethod
+ def backward(ctx, gy):
+ return gy, None, None, None, None, None, None, None, None, None, None
+
+
+class AffineFakeQuantizedTensor(TorchAOBaseTensor):
+ """
+ Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor
+ with an affine transformation:
+ quantized_tensor = float_tensor / scale + zero_point
+
+ Fake quantization refers to performing the quantization math without actually casting the floating point
+ tensor into lower bit-width dtypes. It is commonly used for quantization-aware training (QAT).
+
+ The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
+ regardless of the internal representation's type or orientation.
+
+ fields:
+ original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later
+ apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values
+ """
+
+ @staticmethod
+ def __new__(
+ cls,
+ original_tensor: torch.Tensor,
+ apply_fake_quant_fn: Callable,
+ fake_quant_enabled: bool = True,
+ **kwargs,
+ ):
+ kwargs.setdefault("dtype", original_tensor.dtype)
+ kwargs.setdefault("device", original_tensor.device)
+ kwargs.setdefault("requires_grad", original_tensor.requires_grad)
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ original_tensor.shape,
+ **kwargs,
+ )
+
+ def __init__(
+ self,
+ original_tensor: torch.Tensor,
+ apply_fake_quant_fn: Callable,
+ fake_quant_enabled: bool = True,
+ **kwargs
+ ):
+ self.original_tensor = original_tensor
+ self.apply_fake_quant_fn = apply_fake_quant_fn
+ self.fake_quant_enabled = fake_quant_enabled
+
+ def __tensor_flatten__(self):
+ return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled]
+
+ @classmethod
+ def __tensor_unflatten__(
+ cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride,
+ ):
+ original_tensor = tensor_data_dict["original_tensor"]
+ (apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes
+ return cls(
+ original_tensor,
+ apply_fake_quant_fn,
+ fake_quant_enabled,
+ )
+
+ @classmethod
+ def from_float(
+ cls,
+ original_input: torch.Tensor,
+ mapping_type: MappingType,
+ block_size: Tuple[int, ...],
+ target_dtype: torch.dtype,
+ quant_min: Optional[int] = None,
+ quant_max: Optional[int] = None,
+ eps: Optional[float] = None,
+ scale_dtype: Optional[torch.dtype] = None,
+ zero_point_dtype: Optional[torch.dtype] = None,
+ preserve_zero: bool = True,
+ zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
+ ):
+ return _ToAffineFakeQuantized.apply(
+ original_input,
+ mapping_type,
+ block_size,
+ target_dtype,
+ quant_min,
+ quant_max,
+ eps,
+ scale_dtype,
+ zero_point_dtype,
+ preserve_zero,
+ zero_point_domain,
+ )
+
+ def get_value(self) -> torch.Tensor:
+ if self.fake_quant_enabled:
+ return self.apply_fake_quant_fn(self)
+ else:
+ return _UnwrapAffineFakeQuantizedTensor.apply(self)
+
+ def _get_to_kwargs(self, *args, **kwargs):
+ device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
+ device = self.device if device is None else device
+ dtype = self.dtype if dtype is None else dtype
+ memory_format = (
+ memory_format if memory_format is not None else torch.preserve_format
+ )
+ kwargs = {
+ "device": device,
+ "dtype": dtype,
+ "memory_format": memory_format,
+ "requires_grad": self.requires_grad,
+ }
+ return kwargs
+
+ def to(self, *args, **kwargs):
+ kwargs = self._get_to_kwargs(*args, **kwargs)
+ device = kwargs.pop("device")
+ # not supported yet
+ kwargs.pop("memory_format")
+ return self.__class__(
+ self.original_tensor.to(device),
+ self.apply_fake_quant_fn,
+ self.fake_quant_enabled,
+ **kwargs,
+ )
+
+ def _apply_fn_to_data(self, fn: Callable):
+ """
+ Create a new `AffineFakeQuantizedTensor` with `fn` applied to the
+ original tensor, to be called within __torch_dispatch__.
+ """
+ return self._create_new(fn(self.original_tensor))
+
+ def _create_new(self, new_value: torch.Tensor):
+ """
+ Create a new `AffineFakeQuantizedTensor` with a new value,
+ to be called within __torch_dispatch__.
+
+ Note: `requires_grad` must be False here because tensors created
+ in `__torch_dispatch__` cannot produce gradients, since autograd
+ will try to attach autograd metadata to these tensors when we exit
+ `__torch_dispatch__`, but if these tensors already have metadata
+ attached then autograd will throw an error.
+ """
+ return self.__class__(
+ new_value,
+ self.apply_fake_quant_fn,
+ self.fake_quant_enabled,
+ requires_grad=False,
+ )
+
+implements = AffineFakeQuantizedTensor.implements
+
+
+@implements(torch.nn.functional.linear)
+def _(func, types, args, kwargs):
+ input_tensor, weight_tensor, bias = (
+ args[0],
+ args[1],
+ args[2] if len(args) > 2 else None,
+ )
+ if isinstance(input_tensor, AffineFakeQuantizedTensor):
+ input_tensor = input_tensor.get_value()
+ if isinstance(weight_tensor, AffineFakeQuantizedTensor):
+ weight_tensor = weight_tensor.get_value()
+ return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
+
+
+@implements(aten.mm.default)
+def _(func, types, args, kwargs):
+ bias = None
+ input_tensor = args[0]
+ weight_tensor = args[1]
+ if isinstance(input_tensor, AffineFakeQuantizedTensor):
+ input_tensor = input_tensor.get_value()
+ if isinstance(weight_tensor, AffineFakeQuantizedTensor):
+ weight_tensor = weight_tensor.get_value()
+ return func(input_tensor, weight_tensor)
+
+
+@implements(aten.addmm.default)
+def _(func, types, args, kwargs):
+ bias = args[0]
+ input_tensor = args[1]
+ weight_tensor = args[2]
+ if isinstance(input_tensor, AffineFakeQuantizedTensor):
+ input_tensor = input_tensor.get_value()
+ if isinstance(weight_tensor, AffineFakeQuantizedTensor):
+ weight_tensor = weight_tensor.get_value()
+ return func(bias, input_tensor, weight_tensor)
+
+
+@implements(aten.detach.default)
+def _(func, types, args, kwargs):
+ return return_and_correct_aliasing(
+ func, args, kwargs, args[0]._apply_fn_to_data(torch.detach),
+ )
+
+
+@implements(aten.clone.default)
+def _(func, types, args, kwargs):
+ return return_and_correct_aliasing(
+ func, args, kwargs, args[0]._apply_fn_to_data(torch.clone),
+ )
+
+
+@implements(aten.t.default)
+def _(func, types, args, kwargs):
+ return return_and_correct_aliasing(
+ func, args, kwargs, args[0]._apply_fn_to_data(torch.t),
+ )
+
+
+@implements([
+ aten.add.Tensor,
+ aten.add_.Tensor,
+ aten.mul_.Tensor,
+ aten.copy_.default,
+])
+def _(func, types, args, kwargs):
+ assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}"
+ new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args)
+ first_afq_tensor = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1]
+ new_value = func(*new_args, **kwargs)
+ out = first_afq_tensor._create_new(new_value)
+ return return_and_correct_aliasing(func, args, kwargs, out)
+
+
+# Needed by FSDP:
+
+@implements(aten.empty_like.default)
+def _(func, types, args, kwargs):
+ out = torch.empty_like(args[0].original_tensor, **kwargs)
+ return return_and_correct_aliasing(func, args, kwargs, out)
+
+
+@implements(aten.split.Tensor)
+def _(func, types, args, kwargs):
+ new_values = torch.split(args[0].original_tensor, *args[1:], **kwargs)
+
+ def make_new_tensor(value):
+ out = args[0]._create_new(value)
+ return return_and_correct_aliasing(func, args, kwargs, out)
+
+ return list(map(make_new_tensor, new_values))
+
+
+@implements(aten.new_zeros.default)
+def _(func, types, args, kwargs):
+ out = args[0].original_tensor.new_zeros(*args[1:], **kwargs)
+ return return_and_correct_aliasing(func, args, kwargs, out)
+
+
+to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float
diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py
new file mode 100644
index 0000000000..60d45c8b17
--- /dev/null
+++ b/torchao/quantization/qat/api.py
@@ -0,0 +1,262 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, List, Optional, Union
+
+import torch
+
+from torchao.quantization.granularity import (
+ Granularity,
+ PerAxis,
+ PerGroup,
+ PerToken,
+)
+from torchao.quantization.unified import TwoStepQuantizer
+from torchao.quantization.quant_primitives import (
+ _SUB_BYTE_INT_BOUNDS,
+ _SUB_BYTE_UINT_BOUNDS,
+ MappingType,
+ TorchAODType,
+ ZeroPointDomain,
+)
+
+
+@dataclass
+class FakeQuantizeConfig:
+ """
+ Config for how to fake quantize weights or activations.
+
+ args:
+ dtype: dtype to simulate during fake quantization, e.g. torch.int8.
+ For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
+ torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
+ granularity: granularity of scales and zero points, e.g. PerGroup(32).
+ We also support the following strings:
+ 1) 'per_token': equivalent to PerToken()
+ 2) 'per_channel': equivalent to PerAxis(0)
+ 3) 'per_group': equivalent to PerGroup(group_size), must be combined
+ with separate `group_size` kwarg, Alternatively, just set the
+ `group_size` kwarg and leave this field empty.
+ mapping_type: whether to use symmetric (default) or asymmetric quantization
+ Alternatively, set `is_symmetric` (bool) and leave this field empty.
+ scale_precision: scale dtype (default torch.fp32)
+ zero_point_precision: zero point dtype (default torch.int32)
+ zero_point_domain: whether zero point is in integer (default) or float domain
+ is_dynamic: whether to use dynamic (defualt) or static scale and zero points
+ range_learning: whether to learn scale and zero points during training (coming soon)
+
+ kwargs (optional):
+ group_size: size of each group in per group fake quantization,
+ can be set instead of `granularity`
+ is_symmetric: whether to use symmetric or asymmetric quantization,
+ can be set instead of `mapping_type`
+
+ Example usage::
+
+ # Per token asymmetric quantization
+ FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
+ FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC)
+
+ # Per channel symmetric quantization
+ FakeQuantizeConfig(torch.int4, "per_channel")
+ FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True)
+ FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC)
+
+ # Per group symmetric quantization
+ FakeQuantizeConfig(torch.int4, group_size=32)
+ FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True)
+ FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True)
+ FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC)
+ """
+ dtype: Union[torch.dtype, TorchAODType]
+ granularity: Granularity
+ mapping_type: MappingType
+ scale_precision: torch.dtype
+ zero_point_precision: torch.dtype
+ zero_point_domain: ZeroPointDomain
+ is_dynamic: bool = True
+ range_learning: bool = False
+
+ def __init__(
+ self,
+ dtype: Union[torch.dtype, TorchAODType],
+ granularity: Union[Granularity, str, None] = None,
+ mapping_type: Optional[MappingType] = None,
+ scale_precision: torch.dtype = torch.float32,
+ zero_point_precision: torch.dtype = torch.int32,
+ zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
+ is_dynamic: bool = True,
+ range_learning: bool = False,
+ *,
+ group_size: Optional[int] = None,
+ is_symmetric: Optional[bool] = None,
+ ):
+ self.dtype = dtype
+ self.granularity = self._get_granularity(granularity, group_size)
+ self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
+ self.scale_precision = scale_precision
+ self.zero_point_precision = zero_point_precision
+ self.zero_point_domain = zero_point_domain
+ self.is_dynamic = is_dynamic
+ self.range_learning = range_learning
+
+ # Validate dtype
+ all_dtypes = [torch.int8, torch.uint8]
+ all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys()))
+ all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys()))
+ if dtype not in all_dtypes:
+ raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes))
+
+ def _get_granularity(
+ self,
+ granularity: Union[Granularity, str, None],
+ group_size: Optional[int],
+ ) -> Granularity:
+ """
+ Parse the `Granularity` represented in the args.
+
+ Granularity can be specified in one of three ways:
+ 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size)
+ 2) str: one of 'per_token', 'per_channel', and 'per_group'
+ 3) None: `group_size` must be set instead, represents per group granularity
+ """
+ # If group_size is set, then granularity must be either "per_group" or None
+ if group_size is not None and granularity != "per_group" and granularity is not None:
+ raise ValueError("`group_size` conflicts with granularity '%s'" % granularity)
+
+ # Case 1: Granularity object
+ if isinstance(granularity, Granularity):
+ if not isinstance(granularity, (PerToken, PerAxis, PerGroup)):
+ raise ValueError("Granularity '%s' is not supported" % granularity)
+ if isinstance(granularity, PerAxis) and granularity.axis != 0:
+ raise ValueError("Only axis=0 is supported for PerAxis granularity")
+ return granularity
+
+ # Case 2: str granularity
+ if granularity == "per_token":
+ return PerToken()
+ elif granularity == "per_channel":
+ return PerAxis(axis=0)
+ elif granularity == "per_group":
+ if group_size is None:
+ raise ValueError("Granularity was 'per_group' but no `group_size` was set")
+ return PerGroup(group_size)
+ elif isinstance(granularity, str):
+ raise ValueError(
+ "Unexpected granularity: '%s', must be one of %s" %
+ (granularity, ["per_token", "per_channel", "per_group"])
+ )
+
+ # Case 3: None granularity + group_size was specified
+ if granularity is not None:
+ raise ValueError(
+ "Granularity '%s' has unexpected type %s" % (granularity, type(granularity))
+ )
+ if group_size is None:
+ raise ValueError("At least one of `granularity` or `group_size` must be set")
+ return PerGroup(group_size)
+
+ def _get_mapping_type(
+ self,
+ mapping_type: Optional[MappingType],
+ is_symmetric: Optional[bool],
+ ) -> MappingType:
+ """
+ Parse the `MappingType` represented in the args.
+
+ Mapping type can be specified in one of two ways:
+ 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC
+ 2): is_symmetric bool
+ """
+ if mapping_type is not None and is_symmetric is not None:
+ raise ValueError("Cannot set both `mapping_type` and `is_symmetric`")
+
+ # Case 0: Default to symmetric
+ if mapping_type is None and is_symmetric is None:
+ return MappingType.SYMMETRIC
+
+ # Case 1: MappingType object
+ if mapping_type is not None:
+ if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]:
+ raise ValueError("MappingType '%s' is not supported" % mapping_type)
+ return mapping_type
+
+ # Case 2: is_symmetric flag
+ assert is_symmetric is not None
+ if is_symmetric:
+ return MappingType.SYMMETRIC
+ else:
+ return MappingType.ASYMMETRIC
+
+ @property
+ def group_size(self) -> int:
+ """
+ If this is per group granularity, return the group size.
+ Otherwise, throw an error.
+ """
+ if isinstance(self.granularity, PerGroup):
+ return self.granularity.group_size
+ else:
+ raise ValueError("`group_size` is undefined for %s granularity" % self.granularity)
+
+ @property
+ def is_symmetric(self) -> bool:
+ """
+ Return True if mapping type is symmetric, else False (asymmetric).
+ """
+ return self.mapping_type == MappingType.SYMMETRIC
+
+ def __setattr__(self, name: str, value: Any):
+ """
+ Support setting `group_size` and `is_symmetric`.
+ """
+ if name == "group_size":
+ super().__setattr__("granularity", PerGroup(value))
+ elif name == "is_symmetric":
+ mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC
+ super().__setattr__("mapping_type", mapping_type)
+ else:
+ super().__setattr__(name, value)
+
+
+class ComposableQATQuantizer(TwoStepQuantizer):
+ """
+ Composable quantizer that users can use to apply multiple QAT quantizers easily.
+ Quantizers will be applied in the order they are specified in the constructor.
+
+ Note: the quantizers provided must apply to different modules in the model,
+ e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined.
+
+ Example usage::
+
+ my_quantizer = ComposableQATQuantizer([
+ QATQuantizer1(),
+ QATQuantizer2(),
+ QATQuantizer3(),
+ ])
+ model = my_quantizer.prepare(model)
+ train(model)
+ model = my_quantizer.convert(model)
+ """
+
+ def __init__(self, quantizers: List[TwoStepQuantizer]):
+ self.quantizers = quantizers
+
+ def prepare(
+ self, model: torch.nn.Module, *args: Any, **kwargs: Any
+ ) -> torch.nn.Module:
+ for quantizer in self.quantizers:
+ model = quantizer.prepare(model)
+ return model
+
+ def convert(
+ self, model: torch.nn.Module, *args: Any, **kwargs: Any
+ ) -> torch.nn.Module:
+ for quantizer in self.quantizers:
+ model = quantizer.convert(model)
+ return model
diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py
new file mode 100644
index 0000000000..1f471fa490
--- /dev/null
+++ b/torchao/quantization/qat/embedding.py
@@ -0,0 +1,325 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+
+from torchao.quantization.unified import TwoStepQuantizer
+from torchao.quantization.utils import get_group_qparams_symmetric
+from torchao.quantization.quant_api import (
+ _replace_with_custom_fn_if_matches_filter,
+)
+from torchao.quantization.quant_primitives import TorchAODType
+from .api import FakeQuantizeConfig
+from .fake_quantizer import FakeQuantizer
+from .utils import (
+ _fake_quantize_per_channel_group,
+ _get_qmin_qmax,
+)
+
+
+class FakeQuantizedEmbedding(torch.nn.Embedding):
+ """
+ General embedding layer with fake quantized weights.
+
+ Specific target dtypes, granularity, schemes etc. are specified
+ through separate configs for weights and activations.
+
+ Example usage::
+
+ weight_config = FakeQuantizeConfig(
+ dtype=torch.int4,
+ group_size=8,
+ symmetric=True,
+ )
+ fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
+ fq_embedding(torch.LongTensor([3]))
+ """
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ weight_config: Optional[FakeQuantizeConfig] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ *args,
+ **kwargs,
+ )
+ if weight_config is not None:
+ self.weight_fake_quantizer = FakeQuantizer(weight_config)
+ else:
+ self.weight_fake_quantizer = None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.weight_fake_quantizer is not None:
+ w = self.weight_fake_quantizer(self.weight)
+ else:
+ w = self.weight
+ return F.embedding(
+ x, w, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse,
+ )
+
+
+# ======================================
+# | Embedding int4 weight-only QAT |
+# ======================================
+
+class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer):
+ """
+ Quantizer for performing QAT on a model, where embedding layers have
+ int4 fake quantized grouped per channel weights.
+ """
+
+ def __init__(
+ self,
+ group_size: int = 256,
+ scale_precision: torch.dtype = torch.float32,
+ zero_point_precision: torch.dtype = torch.int32,
+ ) -> None:
+ super().__init__()
+ self.bit_width = 4
+ self.group_size: int = group_size
+ self.scale_precision: torch.dtype = scale_precision
+ self.zero_point_precision: torch.dtype = zero_point_precision
+
+ def prepare(
+ self,
+ model: torch.nn.Module,
+ *args: Any,
+ **kwargs: Any
+ ) -> torch.nn.Module:
+ """
+ Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`.
+ """
+ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
+ return isinstance(child, torch.nn.Embedding)
+
+ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
+ new_embedding = Int4WeightOnlyQATEmbedding(
+ # nn.Embedding args
+ num_embeddings=child.num_embeddings,
+ embedding_dim=child.embedding_dim,
+ padding_idx=child.padding_idx,
+ max_norm=child.max_norm,
+ norm_type=child.norm_type,
+ scale_grad_by_freq=child.scale_grad_by_freq,
+ sparse=child.sparse,
+ # quantization args
+ group_size=self.group_size,
+ scale_precision=self.scale_precision,
+ zero_point_precision=self.zero_point_precision,
+ device=child.weight.device,
+ )
+ # In distributed training, the model may be instantiated
+ # on the meta device, in which case there is no need to
+ # copy the weights, and doing so will result in an error
+ if child.weight.device != torch.device("meta"):
+ new_embedding.weight = child.weight
+ return new_embedding
+
+ _replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn)
+ return model
+
+ def convert(
+ self,
+ model: torch.nn.Module,
+ *args: Any,
+ **kwargs: Any
+ ) -> torch.nn.Module:
+ """
+ Swap all `Int4WeightOnlyQATEmbedding` modules with `Int4WeightOnlyEmbedding`.
+ """
+ self._convert_helper(model)
+ return model
+
+ def _convert_helper(self, module: torch.nn.Module):
+ """
+ Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
+ modules with `Int4WeightOnlyEmbedding`
+ """
+ from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
+ for name, child in module.named_children():
+ if isinstance(child, Int4WeightOnlyQATEmbedding):
+ group_size = child.weight_fake_quantizer.config.group_size
+ scale_precision = child.weight_fake_quantizer.config.scale_precision
+ zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision
+ quantized_embedding = Int4WeightOnlyEmbedding(
+ # nn.Embedding args
+ num_embeddings=child.num_embeddings,
+ embedding_dim=child.embedding_dim,
+ padding_idx=child.padding_idx,
+ max_norm=child.max_norm,
+ norm_type=child.norm_type,
+ scale_grad_by_freq=child.scale_grad_by_freq,
+ sparse=child.sparse,
+ # quantization args
+ group_size=group_size,
+ scale_precision=scale_precision,
+ zero_point_precision=zero_point_precision,
+ device=child.weight.device,
+ )
+ setattr(module, name, quantized_embedding)
+
+ # Load weights and qparams into quantized embedding
+ (qmin, qmax) = _get_qmin_qmax(self.bit_width)
+ (s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size)
+ q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
+ child.weight, s, zp, qmin, qmax, torch.int8, group_size,
+ )
+ quantized_embedding.weight = q_weight
+ quantized_embedding.scales = s
+ quantized_embedding.zeros = zp
+ else:
+ self._convert_helper(child)
+
+
+class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
+ """
+ This module implements a embedding layer with int4 fake quantized
+ grouped per channel weights.
+
+ args:
+ group_size: the number of elements in each quantized group for weights
+ scale_precision: precision of per group scales
+ zero_point_precision: precision of per group zero points
+ """
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ group_size: int = 32,
+ scale_precision: torch.dtype = torch.float32,
+ zero_point_precision: torch.dtype = torch.int32,
+ *args,
+ **kwargs,
+ ):
+ weight_config = FakeQuantizeConfig(
+ dtype=TorchAODType.INT4,
+ group_size=group_size,
+ is_symmetric=True,
+ is_dynamic=True,
+ scale_precision=scale_precision,
+ zero_point_precision=zero_point_precision,
+ )
+ super().__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ weight_config,
+ *args,
+ **kwargs,
+ )
+
+ def enable_fake_quant(self, enabled: bool = True):
+ self.weight_fake_quantizer.enabled = enabled
+
+ def disable_fake_quant(self):
+ self.enable_fake_quant(False)
+
+
+class Int4WeightOnlyEmbedding(torch.nn.Module):
+ """
+ This module implements a embedding layer with int4 quantized
+ grouped per channel weights.
+ """
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ group_size: int = 32,
+ scale_precision: torch.dtype = torch.float32,
+ zero_point_precision: torch.dtype = torch.int32,
+ device: torch.device = None,
+ ):
+ super().__init__()
+
+ # nn.Embedding args
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ self.sparse = sparse
+
+ # quantization args
+ self.bit_width = 4
+ self.group_size = group_size
+ self.scale_precision = scale_precision
+ self.zero_point_precision = zero_point_precision
+
+ # currently storing unpacked int8 weights
+ self.register_buffer(
+ "weight",
+ torch.empty((num_embeddings, embedding_dim), dtype=torch.int8, device=device),
+ )
+ self.register_buffer(
+ "scale",
+ torch.empty(
+ (num_embeddings, embedding_dim // group_size),
+ dtype=scale_precision,
+ device=device,
+ ),
+ )
+ self.register_buffer(
+ "zero_point",
+ torch.empty(
+ (num_embeddings, embedding_dim // group_size),
+ dtype=zero_point_precision,
+ device=device,
+ ),
+ )
+
+ def forward(self, x):
+ from torchao._executorch_ops import _quantized_decomposed_dequantize_per_channel_group_wrapper
+ qmin, qmax = _get_qmin_qmax(self.bit_width)
+ w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper(
+ self.weight,
+ self.scale,
+ self.zero_point,
+ qmin,
+ qmax,
+ torch.int8,
+ self.group_size,
+ x.dtype,
+ )
+ return F.embedding(
+ x, w_dq, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse,
+ )
diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py
new file mode 100644
index 0000000000..eb42dcf047
--- /dev/null
+++ b/torchao/quantization/qat/fake_quantizer.py
@@ -0,0 +1,121 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple
+
+import torch
+
+from torchao.quantization.granularity import (
+ PerAxis,
+ PerGroup,
+ PerToken,
+)
+from torchao.quantization.quant_primitives import (
+ _DTYPE_TO_BIT_WIDTH,
+ _DTYPE_TO_QVALUE_BOUNDS,
+)
+from torchao.quantization.utils import (
+ get_group_qparams_symmetric,
+ get_groupwise_affine_qparams,
+)
+from .api import (
+ FakeQuantizeConfig,
+)
+from .utils import (
+ _choose_qparams_per_token_asymmetric,
+ _fake_quantize_per_channel_group,
+ _fake_quantize_per_token,
+ _get_qmin_qmax,
+)
+
+
+class FakeQuantizer(torch.nn.Module):
+ """
+ Generic module for applying fake quantization to a tensor, as specified in the config.
+ """
+ def __init__(self, config: FakeQuantizeConfig):
+ super().__init__()
+ self.config = config
+ self.enabled = True
+ self.scale: Optional[torch.Tensor] = None
+ self.zero_point: Optional[torch.Tensor] = None
+
+ # TODO: support range learinng
+ if self.config.range_learning:
+ raise NotImplementedError("Range learning is not supported yet")
+
+ def forward(self, x: torch.Tensor):
+ """
+ Apply fake quantization to the tensor based on the bit-width,
+ granularity, symmetry, and other properties specified in the config.
+ """
+ if not self.enabled:
+ return x
+
+ if isinstance(self.config.granularity, PerToken):
+ return self._per_token_forward(x)
+ elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
+ return self._per_channel_or_group_forward(x)
+ else:
+ raise ValueError("Unknown granularity '%s'" % self.config.granularity)
+
+ def _per_token_forward(self, x: torch.Tensor):
+ """
+ Perform per token fake quantization on the tensor.
+ """
+ if self.config.is_symmetric:
+ raise NotImplementedError("Symmetric per token is not supported yet")
+ if self._should_compute_qparams():
+ (self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric(
+ x, self.config.scale_precision, self.config.zero_point_precision,
+ )
+ qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
+ return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
+
+ def _per_channel_or_group_forward(self, x: torch.Tensor):
+ """
+ Perform per channel or per group fake quantization on the tensor.
+ We express per channel using per group where the group size is the size
+ of the last dimension of the tensor.
+ """
+ granularity = self.config.granularity
+ scale_precision = self.config.scale_precision
+ zero_point_precision = self.config.zero_point_precision
+ zero_point_domain = self.config.zero_point_domain
+ is_symmetric = self.config.is_symmetric
+
+ # get group size
+ if isinstance(granularity, PerAxis):
+ assert granularity.axis == 0
+ group_size = x.size()[-1]
+ elif isinstance(granularity, PerGroup):
+ group_size = granularity.group_size
+ else:
+ raise ValueError("Unexpected granularity '%s'" % granularity)
+
+ # get scales and zero points
+ if self._should_compute_qparams():
+ bit_width = _DTYPE_TO_BIT_WIDTH[self.config.dtype]
+ if is_symmetric:
+ (self.scale, self.zero_point) = get_group_qparams_symmetric(
+ x, bit_width, group_size, scale_precision,
+ )
+ else:
+ (self.scale, self.zero_point) = get_groupwise_affine_qparams(
+ x, bit_width, group_size, scale_precision,
+ )
+ self.zero_point = self.zero_point.to(zero_point_precision)
+
+ qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
+ return _fake_quantize_per_channel_group(
+ x, self.scale, self.zero_point, qmin, qmax, group_size, zero_point_domain,
+ )
+
+ def _should_compute_qparams(self) -> bool:
+ """
+ Return whether we need to compute new scales and zero points.
+ """
+ return self.config.is_dynamic or self.scale is None or self.zero_point is None
diff --git a/torchao/quantization/prototype/qat/images/qat_diagram.png b/torchao/quantization/qat/images/qat_diagram.png
similarity index 100%
rename from torchao/quantization/prototype/qat/images/qat_diagram.png
rename to torchao/quantization/qat/images/qat_diagram.png
diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py
new file mode 100644
index 0000000000..ef1714808c
--- /dev/null
+++ b/torchao/quantization/qat/linear.py
@@ -0,0 +1,419 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+
+from torchao.quantization.GPTQ import (
+ _check_linear_int4_k,
+ _replace_linear_int4,
+ _replace_linear_8da4w,
+ get_groupwise_affine_qparams,
+ groupwise_affine_quantize_tensor,
+ Int8DynActInt4WeightLinear,
+ WeightOnlyInt4Linear,
+)
+from torchao.quantization.quant_primitives import (
+ TorchAODType,
+ ZeroPointDomain,
+)
+from torchao.quantization.unified import TwoStepQuantizer
+from torchao.quantization.utils import get_group_qparams_symmetric
+from .api import FakeQuantizeConfig
+from .fake_quantizer import FakeQuantizer
+from .utils import (
+ _choose_qparams_per_token_asymmetric,
+ _fake_quantize_per_channel_group,
+ _fake_quantize_per_token,
+ _get_qmin_qmax,
+)
+
+
+class FakeQuantizedLinear(torch.nn.Linear):
+ """
+ General linear layer with fake quantized weights and activations.
+
+ Specific target dtypes, granularity, schemes etc. are specified
+ through separate configs for weights and activations.
+
+ Example usage::
+
+ activation_config = FakeQuantizeConfig(
+ dtype=torch.int8,
+ granularity="per_token",
+ is_symmetric=False,
+ )
+ weight_config = FakeQuantizeConfig(
+ dtype=torch.int4,
+ group_size=8,
+ is_symmetric=True,
+ )
+ fq_linear = FakeQuantizedLinear(
+ 16, 32, False, activation_config, weight_config,
+ )
+ fq_linear(torch.randn(16))
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = False,
+ activation_config: Optional[FakeQuantizeConfig] = None,
+ weight_config: Optional[FakeQuantizeConfig] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ in_features,
+ out_features,
+ bias,
+ *args,
+ **kwargs,
+ )
+ if bias:
+ raise NotImplementedError("bias not supported yet")
+
+ # initialize activation fake quantizer
+ if activation_config is not None:
+ self.activation_fake_quantizer = FakeQuantizer(activation_config)
+ else:
+ self.activation_fake_quantizer = None
+
+ # initialize weight fake quantizer
+ if weight_config is not None:
+ group_size = weight_config.group_size
+ if group_size is not None and in_features % group_size != 0:
+ raise ValueError(
+ "in_features (%s) % group_size (%s) must be == 0" %
+ (in_features, group_size)
+ )
+ self.weight_fake_quantizer = FakeQuantizer(weight_config)
+ else:
+ self.weight_fake_quantizer = None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.activation_fake_quantizer is not None:
+ x = self.activation_fake_quantizer(x)
+ if self.weight_fake_quantizer is not None:
+ w = self.weight_fake_quantizer(self.weight)
+ else:
+ w = self.weight
+ return F.linear(x, w)
+
+
+# =========================================================
+# | Linear int8 dynamic activations + int4 weight QAT |
+# =========================================================
+
+
+class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
+ """
+ Quantizer for performing QAT on a model, where linear layers have int8
+ dynamic per token fake quantized activations and int4 fake quantized
+ grouped per channel weights.
+ """
+
+ def __init__(
+ self,
+ groupsize: int = 256,
+ padding_allowed: bool = False,
+ precision: torch.dtype = torch.float32,
+ scales_precision: torch.dtype = torch.float32,
+ ) -> None:
+ super().__init__()
+ self.groupsize: int = groupsize
+ self.padding_allowed: bool = padding_allowed
+ self.precision: torch.dtype = precision
+ self.scales_precision: torch.dtype = scales_precision
+
+ def prepare(
+ self,
+ model: torch.nn.Module,
+ *args: Any,
+ **kwargs: Any
+ ) -> torch.nn.Module:
+ _replace_linear_8da4w(
+ model,
+ self.groupsize,
+ self.padding_allowed,
+ self.precision,
+ self.scales_precision,
+ Int8DynActInt4WeightQATLinear,
+ copy_weights=True,
+ )
+ return model
+
+ def convert(
+ self,
+ model: torch.nn.Module,
+ *args: Any,
+ **kwargs: Any
+ ) -> torch.nn.Module:
+ self._convert_qat_linear_8da4w(model)
+ return model
+
+ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
+ """
+ Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
+ """
+ for name, child in module.named_children():
+ if isinstance(child, Int8DynActInt4WeightQATLinear):
+ config = child.weight_fake_quantizer.config
+ quantized_linear = Int8DynActInt4WeightLinear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=config.group_size,
+ precision=child.weight.dtype,
+ scales_precision=config.scale_precision,
+ )
+ setattr(module, name, quantized_linear)
+
+ # Load weights and qparams into quantized linear
+ n_bit = 4
+ (qmin, qmax) = _get_qmin_qmax(n_bit)
+ (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, config.group_size)
+ from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
+ q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
+ child.weight, s, zp, qmin, qmax, torch.int8, config.group_size,
+ )
+ quantized_linear.weight = q_weight
+ quantized_linear.scales = s
+ quantized_linear.zeros = zp
+ else:
+ self._convert_qat_linear_8da4w(child)
+
+
+class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear):
+ """
+ This module implements a linear layer with int8 dynamic per token fake
+ quantized activations with int4 fake quantized grouped per channel weights.
+
+ args:
+ groupsize: the number of elements in each quantized group for weights
+ precision: precision of weights
+ scales_precision: precision of per group scales and zero points
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = False,
+ device: torch.device = None,
+ groupsize: int = 256,
+ precision: torch.dtype = torch.float32,
+ scales_precision: torch.dtype = torch.float32,
+ ) -> None:
+ activation_config = FakeQuantizeConfig(
+ dtype=torch.int8,
+ granularity="per_token",
+ is_symmetric=False,
+ is_dynamic=True,
+ scale_precision=scales_precision,
+ zero_point_precision=scales_precision,
+ )
+ weight_config = FakeQuantizeConfig(
+ dtype=TorchAODType.INT4,
+ group_size=groupsize,
+ is_symmetric=True,
+ is_dynamic=True,
+ scale_precision=scales_precision,
+ zero_point_precision=scales_precision,
+ )
+ super().__init__(
+ in_features,
+ out_features,
+ bias,
+ activation_config,
+ weight_config,
+ device=device,
+ dtype=precision,
+ )
+
+ def enable_fake_quant(self, enabled: bool = True):
+ self.activation_fake_quantizer.enabled = enabled
+ self.weight_fake_quantizer.enabled = enabled
+
+ def disable_fake_quant(self):
+ self.enable_fake_quant(False)
+
+
+def enable_8da4w_fake_quant(mod: torch.nn.Module):
+ """
+ Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
+ """
+ if isinstance(mod, Int8DynActInt4WeightQATLinear):
+ mod.enable_fake_quant()
+
+
+def disable_8da4w_fake_quant(mod: torch.nn.Module):
+ """
+ Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
+ """
+ if isinstance(mod, Int8DynActInt4WeightQATLinear):
+ mod.disable_fake_quant()
+
+
+# ===================================
+# | Linear int4 weight-only QAT |
+# ===================================
+
+
+class Int4WeightOnlyQATQuantizer(TwoStepQuantizer):
+ """
+ Quantizer for performing QAT on a model, where linear layers have
+ int4 fake quantized grouped per channel weights.
+ """
+
+ def __init__(
+ self,
+ groupsize: int = 256,
+ inner_k_tiles: Optional[int] = 8,
+ precision: torch.dtype = torch.bfloat16,
+ scales_precision: torch.dtype = torch.bfloat16,
+ ) -> None:
+ super().__init__()
+ assert inner_k_tiles in [2, 4, 8]
+ assert groupsize in [32, 64, 128, 256]
+ self.inner_k_tiles = inner_k_tiles
+ self.groupsize = groupsize
+ self.precision = precision
+ self.scales_precision = scales_precision
+
+ def prepare(
+ self,
+ model: torch.nn.Module,
+ *args: Any,
+ **kwargs: Any
+ ) -> torch.nn.Module:
+ _replace_linear_int4(
+ model,
+ self.groupsize,
+ self.inner_k_tiles,
+ padding_allowed=True,
+ precision=self.precision,
+ scales_precision=self.scales_precision,
+ linear_class=Int4WeightOnlyQATLinear,
+ copy_weights=True,
+ )
+ return model
+
+ def convert(
+ self,
+ model: torch.nn.Module,
+ *args: Any,
+ **kwargs: Any
+ ) -> torch.nn.Module:
+ self._convert_qat_linear_4w(model)
+ return model
+
+ def _convert_qat_linear_4w(self, module: torch.nn.Module):
+ """
+ Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`.
+ """
+ for name, child in module.named_children():
+ if isinstance(child, Int4WeightOnlyQATLinear):
+ in_features = child.in_features
+ out_features = child.out_features
+ inner_k_tiles = child.inner_k_tiles
+ config = child.weight_fake_quantizer.config
+ quantized_linear = WeightOnlyInt4Linear(
+ in_features,
+ out_features,
+ bias=False,
+ groupsize=config.group_size,
+ inner_k_tiles=inner_k_tiles,
+ precision=child.weight.dtype,
+ scales_precision=config.scale_precision,
+ )
+ setattr(module, name, quantized_linear)
+
+ # Load weights and qparams into quantized linear
+ n_bit = 4
+ (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
+ child.weight, n_bit, config.group_size,
+ )
+ q_weight = torch.ops.aten._convert_weight_to_int4pack(
+ q_weight.to(child.weight.device), child.inner_k_tiles,
+ )
+ quantized_linear.weight = q_weight
+ quantized_linear.scales_and_zeros = scales_and_zeros
+ else:
+ self._convert_qat_linear_4w(child)
+
+
+class Int4WeightOnlyQATLinear(FakeQuantizedLinear):
+ """
+ This module implements a linear layer with int4 fake quantized grouped
+ per channel weights, with forward numerics matching `WeightOnlyInt4Linear`,
+ which uses the efficient int4 tinygemm kernel.
+
+ args:
+ groupsize: the number of elements in each quantized group for weights
+ precision: precision of weights
+ scales_precision: precision of per group scales and zero points
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = False,
+ device: torch.device = None,
+ groupsize: int = 256,
+ inner_k_tiles: int = 8,
+ precision: torch.dtype = torch.bfloat16,
+ scales_precision: torch.dtype = torch.bfloat16,
+ ) -> None:
+ assert scales_precision == torch.bfloat16, "only bf16 is supported for scales"
+ if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles):
+ raise ValueError("Padding for QAT 4w is not supported yet")
+ self.inner_k_tiles = inner_k_tiles
+ weight_config = FakeQuantizeConfig(
+ dtype=torch.uint4,
+ group_size=groupsize,
+ is_symmetric=False,
+ is_dynamic=True,
+ scale_precision=scales_precision,
+ zero_point_precision=scales_precision,
+ zero_point_domain=ZeroPointDomain.FLOAT,
+ )
+ super().__init__(
+ in_features,
+ out_features,
+ bias,
+ activation_config=None,
+ weight_config=weight_config,
+ device=device,
+ dtype=precision,
+ )
+
+ def enable_fake_quant(self, enabled: bool = True):
+ self.activation_fake_quantizer.enabled = enabled
+ self.weight_fake_quantizer.enabled = enabled
+
+ def disable_fake_quant(self):
+ self.enable_fake_quant(False)
+
+
+def enable_4w_fake_quant(mod: torch.nn.Module):
+ """
+ Enable fake quantization for `Int4WeightOnlyQATLinear`.
+ """
+ if isinstance(mod, Int4WeightOnlyQATLinear):
+ mod.enable_fake_quant()
+
+
+def disable_4w_fake_quant(mod: torch.nn.Module):
+ """
+ Disable fake quantization for `Int4WeightOnlyQATLinear`.
+ """
+ if isinstance(mod, Int4WeightOnlyQATLinear):
+ mod.disable_fake_quant()
diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/qat/utils.py
similarity index 97%
rename from torchao/quantization/prototype/qat/utils.py
rename to torchao/quantization/qat/utils.py
index 8f2dd9d13f..e2234a2556 100644
--- a/torchao/quantization/prototype/qat/utils.py
+++ b/torchao/quantization/qat/utils.py
@@ -46,7 +46,7 @@ def forward(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
# avoid circular dependencies
- from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
+ from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
@@ -88,7 +88,7 @@ def forward(
input: torch.Tensor,
) -> torch.Tensor:
# avoid circular dependencies
- from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
+ from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
assert isinstance(input, AffineFakeQuantizedTensor)
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index e8e7261fbd..9895ea084d 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -221,7 +221,7 @@ def _replace_with_custom_fn_if_matches_filter(
def _is_linear(mod, *args):
# avoid circular dependencies
- from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
+ from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)