diff --git a/README.md b/README.md index 71fb25fa24..ba48dbf451 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to * Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) ```python -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer +from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer qat_quantizer = Int8DynActInt4WeightQATQuantizer() diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 9186977240..29f833c9ab 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -22,22 +22,22 @@ PerRow, PerToken, ) -from torchao.quantization.prototype.qat.api import ( +from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.prototype.qat.fake_quantizer import ( +from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) -from torchao.quantization.prototype.qat.embedding import ( +from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) -from torchao.quantization.prototype.qat.linear import ( +from torchao.quantization.qat.linear import ( FakeQuantizedLinear, Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear ) -from torchao.quantization.prototype.qat.utils import ( +from torchao.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, @@ -181,7 +181,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat.linear import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -213,7 +213,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear + from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -238,7 +238,7 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 @@ -272,7 +272,7 @@ def test_qat_8da4w_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer with torch.device("meta"): m = M() @@ -287,7 +287,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, enable_8da4w_fake_quant, @@ -346,7 +346,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, ) @@ -428,7 +428,7 @@ def _test_qat_quantized_gradients(self, quantizer): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) @@ -518,7 +518,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -545,14 +545,14 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 @@ -630,7 +630,7 @@ def test_composable_qat_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_embedding(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer + from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer model = M2() x = model.example_inputs() out = model(*x) @@ -937,6 +937,59 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_qat_prototype_bc(self): + """ + Just to make sure we can import all the old prototype paths. + We will remove this test in the near future when we actually break BC. + """ + from torchao.quantization.prototype.qat import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + ComposableQATQuantizer, + Int8DynActInt4WeightQATLinear, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, + ) + from torchao.quantization.prototype.qat._module_swap_api import ( + disable_4w_fake_quant_module_swap, + enable_4w_fake_quant_module_swap, + disable_8da4w_fake_quant_module_swap, + enable_8da4w_fake_quant_module_swap, + Int4WeightOnlyQATQuantizerModuleSwap, + Int8DynActInt4WeightQATQuantizerModuleSwap, + ) + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + to_affine_fake_quantized, + ) + from torchao.quantization.prototype.qat.api import ( + ComposableQATQuantizer, + FakeQuantizeConfig, + ) + from torchao.quantization.prototype.qat.embedding import ( + FakeQuantizedEmbedding, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyEmbedding, + Int4WeightOnlyQATEmbedding, + ) + from torchao.quantization.prototype.qat.fake_quantizer import ( + FakeQuantizer, + ) + from torchao.quantization.prototype.qat.linear import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + FakeQuantizedLinear, + Int4WeightOnlyQATLinear, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, + ) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/torchao/quantization/prototype/qat/README.md b/torchao/quantization/prototype/qat/README.md index 2869322297..dbce4d48e1 100644 --- a/torchao/quantization/prototype/qat/README.md +++ b/torchao/quantization/prototype/qat/README.md @@ -1,125 +1,3 @@ -# Quantization-Aware Training (QAT) - -Quantization-Aware Training (QAT) refers to applying fake quantization during the -training or fine-tuning process, such that the final quantized model will exhibit -higher accuracies and perplexities. Fake quantization refers to rounding the float -values to quantized values without actually casting them to dtypes with lower -bit-widths, in contrast to post-training quantization (PTQ), which does cast the -quantized values to lower bit-width dtypes, e.g.: - -``` -# PTQ: x_q is quantized and cast to int8 -# scale and zero point (zp) refer to parameters used to quantize x_float -# qmin and qmax refer to the range of quantized values -x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8) - -# QAT: x_fq is still in float -# Fake quantize simulates the numerics of quantize + dequantize -x_fq = (x_float / scale + zp).round().clamp(qmin, qmax) -x_fq = (x_fq - zp) * scale -``` - -## API - -torchao currently supports two QAT schemes for linear layers: -- int8 per token dynamic activations + int4 per group weights -- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) - -QAT typically involves applying a transformation to your model before and after training. -In torchao, these are represented as the prepare and convert steps: (1) prepare inserts -fake quantize operations into linear layers, and (2) convert transforms the fake quantize -operations to actual quantize and dequantize operations after training, thereby producing -a quantized model (dequantize operations are typically fused with linear after lowering). -Between these two steps, training can proceed exactly as before. - -![qat](images/qat_diagram.png) - -To use QAT in torchao, apply the prepare step using the appropriate Quantizer before -training, then apply the convert step after training for inference or generation. -For example, on a single GPU: - -```python -import torch -from torchtune.models.llama3 import llama3 -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - -# Smaller version of llama3 to fit in a single GPU -model = llama3( - vocab_size=4096, - num_layers=16, - num_heads=16, - num_kv_heads=4, - embed_dim=2048, - max_seq_len=2048, -).cuda() - -# Quantizer for int8 dynamic per token activations + -# int4 grouped per channel weights, only for linear layers -qat_quantizer = Int8DynActInt4WeightQATQuantizer() - -# Insert "fake quantize" operations into linear layers. -# These operations simulate quantization numerics during -# training without performing any dtype casting -model = qat_quantizer.prepare(model) - -# Standard training loop -optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) -loss_fn = torch.nn.CrossEntropyLoss() -for i in range(10): - example = torch.randint(0, 4096, (2, 16)).cuda() - target = torch.randn((2, 16, 4096)).cuda() - output = model(example) - loss = loss_fn(output, target) - loss.backward() - optimizer.step() - optimizer.zero_grad() - -# Convert fake quantize to actual quantize operations -# The quantized model has the exact same structure as the -# quantized model produced in the corresponding PTQ flow -# through `Int8DynActInt4WeightQuantizer` -model = qat_quantizer.convert(model) - -# inference or generate -``` - -Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune) -and apply quantized-aware fine-tuning as follows: - -``` -tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full -``` - -For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). - - -## Evaluation Results - -Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT -integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) -on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset) -for 5000 steps using a group size of 256 for the weights. Note that extensive -hyperparameter tuning may further improve these results. - -Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5: - -| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) | -| ---------------- | ------ | ------ | ------ | ------ | ------ | -| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 | -| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 | -| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 | -| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 | -| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 | - -Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the -quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097). - -| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) | -| ---------------- | -------- | ------- | ------ | ------ | ------ | -| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 | -| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 | -| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 | -| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 | -| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 | - -For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training). +Note: QAT has been moved to torchao/quantization/qat. +This is a legacy folder only for backward compatibility +and will be removed in the near future. diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index 09ea6e708d..f6c4b1c8ce 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -1,17 +1,15 @@ -from .api import ( +from torchao.quantization.qat import ( ComposableQATQuantizer, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, ) -from .linear import ( +from torchao.quantization.qat.linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, - Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATLinear, - Int8DynActInt4WeightQATQuantizer, -) -from .embedding import ( - Int4WeightOnlyEmbeddingQATQuantizer, ) __all__ = [ diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py index 0b44974f21..d6aa4a3ae5 100644 --- a/torchao/quantization/prototype/qat/_module_swap_api.py +++ b/torchao/quantization/prototype/qat/_module_swap_api.py @@ -1,7 +1,7 @@ # For backward compatibility only # These will be removed in the future -from .linear import ( +from torchao.quantization.qat.linear import ( Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap, Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap, enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap, diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py index 9d7cdccfc3..ff09ca842e 100644 --- a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py @@ -1,328 +1,4 @@ -import torch -import torch.utils._pytree as pytree -from typing import Callable, Optional, Tuple -from torchao.quantization.quant_primitives import ( - _get_and_check_qmin_qmax, - choose_qparams_affine, - fake_quantize_affine, - ZeroPointDomain, - MappingType, +from torchao.quantization.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + to_affine_fake_quantized, ) -from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor -from .utils import ( - _GenericFakeQuantize, - _UnwrapAffineFakeQuantizedTensor, -) - -aten = torch.ops.aten - - -class _ToAffineFakeQuantized(torch.autograd.Function): - """ - Differentiable constructor for `AffineFakeQuantizedTensor`, - needed for input activation fake quantization. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - original_tensor: torch.Tensor, - mapping_type: MappingType, - block_size: Tuple[int, ...], - target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: bool = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - ) -> "AffineFakeQuantizedTensor": - def apply_fake_quant_fn(t: torch.Tensor): - assert isinstance(t, AffineFakeQuantizedTensor) - qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) - scale, zero_point = choose_qparams_affine( - t.original_tensor, - mapping_type, - block_size, - target_dtype, - qmin, - qmax, - eps, - scale_dtype, - zero_point_dtype, - preserve_zero, - zero_point_domain, - ) - fq = _GenericFakeQuantize.apply( - t, - block_size, - scale, - zero_point, - qmin, - qmax, - zero_point_domain, - ) - return fq - return AffineFakeQuantizedTensor( - original_tensor, - apply_fake_quant_fn, - fake_quant_enabled=True, - ) - - @staticmethod - def backward(ctx, gy): - return gy, None, None, None, None, None, None, None, None, None, None - - -class AffineFakeQuantizedTensor(TorchAOBaseTensor): - """ - Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor - with an affine transformation: - quantized_tensor = float_tensor / scale + zero_point - - Fake quantization refers to performing the quantization math without actually casting the floating point - tensor into lower bit-width dtypes. It is commonly used for quantization-aware training (QAT). - - The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, - regardless of the internal representation's type or orientation. - - fields: - original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later - apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values - """ - - @staticmethod - def __new__( - cls, - original_tensor: torch.Tensor, - apply_fake_quant_fn: Callable, - fake_quant_enabled: bool = True, - **kwargs, - ): - kwargs.setdefault("dtype", original_tensor.dtype) - kwargs.setdefault("device", original_tensor.device) - kwargs.setdefault("requires_grad", original_tensor.requires_grad) - return torch.Tensor._make_wrapper_subclass( - cls, - original_tensor.shape, - **kwargs, - ) - - def __init__( - self, - original_tensor: torch.Tensor, - apply_fake_quant_fn: Callable, - fake_quant_enabled: bool = True, - **kwargs - ): - self.original_tensor = original_tensor - self.apply_fake_quant_fn = apply_fake_quant_fn - self.fake_quant_enabled = fake_quant_enabled - - def __tensor_flatten__(self): - return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride, - ): - original_tensor = tensor_data_dict["original_tensor"] - (apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes - return cls( - original_tensor, - apply_fake_quant_fn, - fake_quant_enabled, - ) - - @classmethod - def from_float( - cls, - original_input: torch.Tensor, - mapping_type: MappingType, - block_size: Tuple[int, ...], - target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: bool = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - ): - return _ToAffineFakeQuantized.apply( - original_input, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - preserve_zero, - zero_point_domain, - ) - - def get_value(self) -> torch.Tensor: - if self.fake_quant_enabled: - return self.apply_fake_quant_fn(self) - else: - return _UnwrapAffineFakeQuantizedTensor.apply(self) - - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - "requires_grad": self.requires_grad, - } - return kwargs - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - # not supported yet - kwargs.pop("memory_format") - return self.__class__( - self.original_tensor.to(device), - self.apply_fake_quant_fn, - self.fake_quant_enabled, - **kwargs, - ) - - def _apply_fn_to_data(self, fn: Callable): - """ - Create a new `AffineFakeQuantizedTensor` with `fn` applied to the - original tensor, to be called within __torch_dispatch__. - """ - return self._create_new(fn(self.original_tensor)) - - def _create_new(self, new_value: torch.Tensor): - """ - Create a new `AffineFakeQuantizedTensor` with a new value, - to be called within __torch_dispatch__. - - Note: `requires_grad` must be False here because tensors created - in `__torch_dispatch__` cannot produce gradients, since autograd - will try to attach autograd metadata to these tensors when we exit - `__torch_dispatch__`, but if these tensors already have metadata - attached then autograd will throw an error. - """ - return self.__class__( - new_value, - self.apply_fake_quant_fn, - self.fake_quant_enabled, - requires_grad=False, - ) - -implements = AffineFakeQuantizedTensor.implements - - -@implements(torch.nn.functional.linear) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - if isinstance(input_tensor, AffineFakeQuantizedTensor): - input_tensor = input_tensor.get_value() - if isinstance(weight_tensor, AffineFakeQuantizedTensor): - weight_tensor = weight_tensor.get_value() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - -@implements(aten.mm.default) -def _(func, types, args, kwargs): - bias = None - input_tensor = args[0] - weight_tensor = args[1] - if isinstance(input_tensor, AffineFakeQuantizedTensor): - input_tensor = input_tensor.get_value() - if isinstance(weight_tensor, AffineFakeQuantizedTensor): - weight_tensor = weight_tensor.get_value() - return func(input_tensor, weight_tensor) - - -@implements(aten.addmm.default) -def _(func, types, args, kwargs): - bias = args[0] - input_tensor = args[1] - weight_tensor = args[2] - if isinstance(input_tensor, AffineFakeQuantizedTensor): - input_tensor = input_tensor.get_value() - if isinstance(weight_tensor, AffineFakeQuantizedTensor): - weight_tensor = weight_tensor.get_value() - return func(bias, input_tensor, weight_tensor) - - -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach), - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone), - ) - - -@implements(aten.t.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.t), - ) - - -@implements([ - aten.add.Tensor, - aten.add_.Tensor, - aten.mul_.Tensor, - aten.copy_.default, -]) -def _(func, types, args, kwargs): - assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}" - new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args) - first_afq_tensor = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1] - new_value = func(*new_args, **kwargs) - out = first_afq_tensor._create_new(new_value) - return return_and_correct_aliasing(func, args, kwargs, out) - - -# Needed by FSDP: - -@implements(aten.empty_like.default) -def _(func, types, args, kwargs): - out = torch.empty_like(args[0].original_tensor, **kwargs) - return return_and_correct_aliasing(func, args, kwargs, out) - - -@implements(aten.split.Tensor) -def _(func, types, args, kwargs): - new_values = torch.split(args[0].original_tensor, *args[1:], **kwargs) - - def make_new_tensor(value): - out = args[0]._create_new(value) - return return_and_correct_aliasing(func, args, kwargs, out) - - return list(map(make_new_tensor, new_values)) - - -@implements(aten.new_zeros.default) -def _(func, types, args, kwargs): - out = args[0].original_tensor.new_zeros(*args[1:], **kwargs) - return return_and_correct_aliasing(func, args, kwargs, out) - - -to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 60d45c8b17..e3e18e9e30 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -1,262 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass -from enum import Enum -from typing import Any, List, Optional, Union - -import torch - -from torchao.quantization.granularity import ( - Granularity, - PerAxis, - PerGroup, - PerToken, +from torchao.quantization.qat.api import ( + ComposableQATQuantizer, + FakeQuantizeConfig, ) -from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.quant_primitives import ( - _SUB_BYTE_INT_BOUNDS, - _SUB_BYTE_UINT_BOUNDS, - MappingType, - TorchAODType, - ZeroPointDomain, -) - - -@dataclass -class FakeQuantizeConfig: - """ - Config for how to fake quantize weights or activations. - - args: - dtype: dtype to simulate during fake quantization, e.g. torch.int8. - For PyTorch versions older than 2.6, you may use `TorchAODType` to represent - torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. - granularity: granularity of scales and zero points, e.g. PerGroup(32). - We also support the following strings: - 1) 'per_token': equivalent to PerToken() - 2) 'per_channel': equivalent to PerAxis(0) - 3) 'per_group': equivalent to PerGroup(group_size), must be combined - with separate `group_size` kwarg, Alternatively, just set the - `group_size` kwarg and leave this field empty. - mapping_type: whether to use symmetric (default) or asymmetric quantization - Alternatively, set `is_symmetric` (bool) and leave this field empty. - scale_precision: scale dtype (default torch.fp32) - zero_point_precision: zero point dtype (default torch.int32) - zero_point_domain: whether zero point is in integer (default) or float domain - is_dynamic: whether to use dynamic (defualt) or static scale and zero points - range_learning: whether to learn scale and zero points during training (coming soon) - - kwargs (optional): - group_size: size of each group in per group fake quantization, - can be set instead of `granularity` - is_symmetric: whether to use symmetric or asymmetric quantization, - can be set instead of `mapping_type` - - Example usage:: - - # Per token asymmetric quantization - FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) - - # Per channel symmetric quantization - FakeQuantizeConfig(torch.int4, "per_channel") - FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) - - # Per group symmetric quantization - FakeQuantizeConfig(torch.int4, group_size=32) - FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) - """ - dtype: Union[torch.dtype, TorchAODType] - granularity: Granularity - mapping_type: MappingType - scale_precision: torch.dtype - zero_point_precision: torch.dtype - zero_point_domain: ZeroPointDomain - is_dynamic: bool = True - range_learning: bool = False - - def __init__( - self, - dtype: Union[torch.dtype, TorchAODType], - granularity: Union[Granularity, str, None] = None, - mapping_type: Optional[MappingType] = None, - scale_precision: torch.dtype = torch.float32, - zero_point_precision: torch.dtype = torch.int32, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - is_dynamic: bool = True, - range_learning: bool = False, - *, - group_size: Optional[int] = None, - is_symmetric: Optional[bool] = None, - ): - self.dtype = dtype - self.granularity = self._get_granularity(granularity, group_size) - self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) - self.scale_precision = scale_precision - self.zero_point_precision = zero_point_precision - self.zero_point_domain = zero_point_domain - self.is_dynamic = is_dynamic - self.range_learning = range_learning - - # Validate dtype - all_dtypes = [torch.int8, torch.uint8] - all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) - all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) - if dtype not in all_dtypes: - raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)) - - def _get_granularity( - self, - granularity: Union[Granularity, str, None], - group_size: Optional[int], - ) -> Granularity: - """ - Parse the `Granularity` represented in the args. - - Granularity can be specified in one of three ways: - 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) - 2) str: one of 'per_token', 'per_channel', and 'per_group' - 3) None: `group_size` must be set instead, represents per group granularity - """ - # If group_size is set, then granularity must be either "per_group" or None - if group_size is not None and granularity != "per_group" and granularity is not None: - raise ValueError("`group_size` conflicts with granularity '%s'" % granularity) - - # Case 1: Granularity object - if isinstance(granularity, Granularity): - if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): - raise ValueError("Granularity '%s' is not supported" % granularity) - if isinstance(granularity, PerAxis) and granularity.axis != 0: - raise ValueError("Only axis=0 is supported for PerAxis granularity") - return granularity - - # Case 2: str granularity - if granularity == "per_token": - return PerToken() - elif granularity == "per_channel": - return PerAxis(axis=0) - elif granularity == "per_group": - if group_size is None: - raise ValueError("Granularity was 'per_group' but no `group_size` was set") - return PerGroup(group_size) - elif isinstance(granularity, str): - raise ValueError( - "Unexpected granularity: '%s', must be one of %s" % - (granularity, ["per_token", "per_channel", "per_group"]) - ) - - # Case 3: None granularity + group_size was specified - if granularity is not None: - raise ValueError( - "Granularity '%s' has unexpected type %s" % (granularity, type(granularity)) - ) - if group_size is None: - raise ValueError("At least one of `granularity` or `group_size` must be set") - return PerGroup(group_size) - - def _get_mapping_type( - self, - mapping_type: Optional[MappingType], - is_symmetric: Optional[bool], - ) -> MappingType: - """ - Parse the `MappingType` represented in the args. - - Mapping type can be specified in one of two ways: - 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC - 2): is_symmetric bool - """ - if mapping_type is not None and is_symmetric is not None: - raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") - - # Case 0: Default to symmetric - if mapping_type is None and is_symmetric is None: - return MappingType.SYMMETRIC - - # Case 1: MappingType object - if mapping_type is not None: - if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: - raise ValueError("MappingType '%s' is not supported" % mapping_type) - return mapping_type - - # Case 2: is_symmetric flag - assert is_symmetric is not None - if is_symmetric: - return MappingType.SYMMETRIC - else: - return MappingType.ASYMMETRIC - - @property - def group_size(self) -> int: - """ - If this is per group granularity, return the group size. - Otherwise, throw an error. - """ - if isinstance(self.granularity, PerGroup): - return self.granularity.group_size - else: - raise ValueError("`group_size` is undefined for %s granularity" % self.granularity) - - @property - def is_symmetric(self) -> bool: - """ - Return True if mapping type is symmetric, else False (asymmetric). - """ - return self.mapping_type == MappingType.SYMMETRIC - - def __setattr__(self, name: str, value: Any): - """ - Support setting `group_size` and `is_symmetric`. - """ - if name == "group_size": - super().__setattr__("granularity", PerGroup(value)) - elif name == "is_symmetric": - mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC - super().__setattr__("mapping_type", mapping_type) - else: - super().__setattr__(name, value) - - -class ComposableQATQuantizer(TwoStepQuantizer): - """ - Composable quantizer that users can use to apply multiple QAT quantizers easily. - Quantizers will be applied in the order they are specified in the constructor. - - Note: the quantizers provided must apply to different modules in the model, - e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined. - - Example usage:: - - my_quantizer = ComposableQATQuantizer([ - QATQuantizer1(), - QATQuantizer2(), - QATQuantizer3(), - ]) - model = my_quantizer.prepare(model) - train(model) - model = my_quantizer.convert(model) - """ - - def __init__(self, quantizers: List[TwoStepQuantizer]): - self.quantizers = quantizers - - def prepare( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - for quantizer in self.quantizers: - model = quantizer.prepare(model) - return model - - def convert( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - for quantizer in self.quantizers: - model = quantizer.convert(model) - return model diff --git a/torchao/quantization/prototype/qat/embedding.py b/torchao/quantization/prototype/qat/embedding.py index 1f471fa490..9168291202 100644 --- a/torchao/quantization/prototype/qat/embedding.py +++ b/torchao/quantization/prototype/qat/embedding.py @@ -1,325 +1,6 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Optional - -import torch -import torch.nn.functional as F - -from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import get_group_qparams_symmetric -from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, +from torchao.quantization.qat.embedding import ( + FakeQuantizedEmbedding, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyEmbedding, + Int4WeightOnlyQATEmbedding, ) -from torchao.quantization.quant_primitives import TorchAODType -from .api import FakeQuantizeConfig -from .fake_quantizer import FakeQuantizer -from .utils import ( - _fake_quantize_per_channel_group, - _get_qmin_qmax, -) - - -class FakeQuantizedEmbedding(torch.nn.Embedding): - """ - General embedding layer with fake quantized weights. - - Specific target dtypes, granularity, schemes etc. are specified - through separate configs for weights and activations. - - Example usage:: - - weight_config = FakeQuantizeConfig( - dtype=torch.int4, - group_size=8, - symmetric=True, - ) - fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config) - fq_embedding(torch.LongTensor([3])) - """ - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False, - weight_config: Optional[FakeQuantizeConfig] = None, - *args, - **kwargs, - ) -> None: - super().__init__( - num_embeddings, - embedding_dim, - padding_idx, - max_norm, - norm_type, - scale_grad_by_freq, - sparse, - *args, - **kwargs, - ) - if weight_config is not None: - self.weight_fake_quantizer = FakeQuantizer(weight_config) - else: - self.weight_fake_quantizer = None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.weight_fake_quantizer is not None: - w = self.weight_fake_quantizer(self.weight) - else: - w = self.weight - return F.embedding( - x, w, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse, - ) - - -# ====================================== -# | Embedding int4 weight-only QAT | -# ====================================== - -class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where embedding layers have - int4 fake quantized grouped per channel weights. - """ - - def __init__( - self, - group_size: int = 256, - scale_precision: torch.dtype = torch.float32, - zero_point_precision: torch.dtype = torch.int32, - ) -> None: - super().__init__() - self.bit_width = 4 - self.group_size: int = group_size - self.scale_precision: torch.dtype = scale_precision - self.zero_point_precision: torch.dtype = zero_point_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - """ - Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`. - """ - def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool: - return isinstance(child, torch.nn.Embedding) - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - new_embedding = Int4WeightOnlyQATEmbedding( - # nn.Embedding args - num_embeddings=child.num_embeddings, - embedding_dim=child.embedding_dim, - padding_idx=child.padding_idx, - max_norm=child.max_norm, - norm_type=child.norm_type, - scale_grad_by_freq=child.scale_grad_by_freq, - sparse=child.sparse, - # quantization args - group_size=self.group_size, - scale_precision=self.scale_precision, - zero_point_precision=self.zero_point_precision, - device=child.weight.device, - ) - # In distributed training, the model may be instantiated - # on the meta device, in which case there is no need to - # copy the weights, and doing so will result in an error - if child.weight.device != torch.device("meta"): - new_embedding.weight = child.weight - return new_embedding - - _replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - """ - Swap all `Int4WeightOnlyQATEmbedding` modules with `Int4WeightOnlyEmbedding`. - """ - self._convert_helper(model) - return model - - def _convert_helper(self, module: torch.nn.Module): - """ - Helper function to recursively swap `Int4WeightOnlyQATEmbedding` - modules with `Int4WeightOnlyEmbedding` - """ - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATEmbedding): - group_size = child.weight_fake_quantizer.config.group_size - scale_precision = child.weight_fake_quantizer.config.scale_precision - zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision - quantized_embedding = Int4WeightOnlyEmbedding( - # nn.Embedding args - num_embeddings=child.num_embeddings, - embedding_dim=child.embedding_dim, - padding_idx=child.padding_idx, - max_norm=child.max_norm, - norm_type=child.norm_type, - scale_grad_by_freq=child.scale_grad_by_freq, - sparse=child.sparse, - # quantization args - group_size=group_size, - scale_precision=scale_precision, - zero_point_precision=zero_point_precision, - device=child.weight.device, - ) - setattr(module, name, quantized_embedding) - - # Load weights and qparams into quantized embedding - (qmin, qmax) = _get_qmin_qmax(self.bit_width) - (s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size) - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, group_size, - ) - quantized_embedding.weight = q_weight - quantized_embedding.scales = s - quantized_embedding.zeros = zp - else: - self._convert_helper(child) - - -class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding): - """ - This module implements a embedding layer with int4 fake quantized - grouped per channel weights. - - args: - group_size: the number of elements in each quantized group for weights - scale_precision: precision of per group scales - zero_point_precision: precision of per group zero points - """ - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False, - group_size: int = 32, - scale_precision: torch.dtype = torch.float32, - zero_point_precision: torch.dtype = torch.int32, - *args, - **kwargs, - ): - weight_config = FakeQuantizeConfig( - dtype=TorchAODType.INT4, - group_size=group_size, - is_symmetric=True, - is_dynamic=True, - scale_precision=scale_precision, - zero_point_precision=zero_point_precision, - ) - super().__init__( - num_embeddings, - embedding_dim, - padding_idx, - max_norm, - norm_type, - scale_grad_by_freq, - sparse, - weight_config, - *args, - **kwargs, - ) - - def enable_fake_quant(self, enabled: bool = True): - self.weight_fake_quantizer.enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - -class Int4WeightOnlyEmbedding(torch.nn.Module): - """ - This module implements a embedding layer with int4 quantized - grouped per channel weights. - """ - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False, - group_size: int = 32, - scale_precision: torch.dtype = torch.float32, - zero_point_precision: torch.dtype = torch.int32, - device: torch.device = None, - ): - super().__init__() - - # nn.Embedding args - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.max_norm = max_norm - self.norm_type = norm_type - self.scale_grad_by_freq = scale_grad_by_freq - self.sparse = sparse - - # quantization args - self.bit_width = 4 - self.group_size = group_size - self.scale_precision = scale_precision - self.zero_point_precision = zero_point_precision - - # currently storing unpacked int8 weights - self.register_buffer( - "weight", - torch.empty((num_embeddings, embedding_dim), dtype=torch.int8, device=device), - ) - self.register_buffer( - "scale", - torch.empty( - (num_embeddings, embedding_dim // group_size), - dtype=scale_precision, - device=device, - ), - ) - self.register_buffer( - "zero_point", - torch.empty( - (num_embeddings, embedding_dim // group_size), - dtype=zero_point_precision, - device=device, - ), - ) - - def forward(self, x): - from torchao._executorch_ops import _quantized_decomposed_dequantize_per_channel_group_wrapper - qmin, qmax = _get_qmin_qmax(self.bit_width) - w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper( - self.weight, - self.scale, - self.zero_point, - qmin, - qmax, - torch.int8, - self.group_size, - x.dtype, - ) - return F.embedding( - x, w_dq, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse, - ) diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/prototype/qat/fake_quantizer.py index eb42dcf047..584a46573b 100644 --- a/torchao/quantization/prototype/qat/fake_quantizer.py +++ b/torchao/quantization/prototype/qat/fake_quantizer.py @@ -1,121 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional, Tuple - -import torch - -from torchao.quantization.granularity import ( - PerAxis, - PerGroup, - PerToken, +from torchao.quantization.qat.fake_quantizer import ( + FakeQuantizer, ) -from torchao.quantization.quant_primitives import ( - _DTYPE_TO_BIT_WIDTH, - _DTYPE_TO_QVALUE_BOUNDS, -) -from torchao.quantization.utils import ( - get_group_qparams_symmetric, - get_groupwise_affine_qparams, -) -from .api import ( - FakeQuantizeConfig, -) -from .utils import ( - _choose_qparams_per_token_asymmetric, - _fake_quantize_per_channel_group, - _fake_quantize_per_token, - _get_qmin_qmax, -) - - -class FakeQuantizer(torch.nn.Module): - """ - Generic module for applying fake quantization to a tensor, as specified in the config. - """ - def __init__(self, config: FakeQuantizeConfig): - super().__init__() - self.config = config - self.enabled = True - self.scale: Optional[torch.Tensor] = None - self.zero_point: Optional[torch.Tensor] = None - - # TODO: support range learinng - if self.config.range_learning: - raise NotImplementedError("Range learning is not supported yet") - - def forward(self, x: torch.Tensor): - """ - Apply fake quantization to the tensor based on the bit-width, - granularity, symmetry, and other properties specified in the config. - """ - if not self.enabled: - return x - - if isinstance(self.config.granularity, PerToken): - return self._per_token_forward(x) - elif isinstance(self.config.granularity, (PerAxis, PerGroup)): - return self._per_channel_or_group_forward(x) - else: - raise ValueError("Unknown granularity '%s'" % self.config.granularity) - - def _per_token_forward(self, x: torch.Tensor): - """ - Perform per token fake quantization on the tensor. - """ - if self.config.is_symmetric: - raise NotImplementedError("Symmetric per token is not supported yet") - if self._should_compute_qparams(): - (self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric( - x, self.config.scale_precision, self.config.zero_point_precision, - ) - qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] - return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) - - def _per_channel_or_group_forward(self, x: torch.Tensor): - """ - Perform per channel or per group fake quantization on the tensor. - We express per channel using per group where the group size is the size - of the last dimension of the tensor. - """ - granularity = self.config.granularity - scale_precision = self.config.scale_precision - zero_point_precision = self.config.zero_point_precision - zero_point_domain = self.config.zero_point_domain - is_symmetric = self.config.is_symmetric - - # get group size - if isinstance(granularity, PerAxis): - assert granularity.axis == 0 - group_size = x.size()[-1] - elif isinstance(granularity, PerGroup): - group_size = granularity.group_size - else: - raise ValueError("Unexpected granularity '%s'" % granularity) - - # get scales and zero points - if self._should_compute_qparams(): - bit_width = _DTYPE_TO_BIT_WIDTH[self.config.dtype] - if is_symmetric: - (self.scale, self.zero_point) = get_group_qparams_symmetric( - x, bit_width, group_size, scale_precision, - ) - else: - (self.scale, self.zero_point) = get_groupwise_affine_qparams( - x, bit_width, group_size, scale_precision, - ) - self.zero_point = self.zero_point.to(zero_point_precision) - - qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] - return _fake_quantize_per_channel_group( - x, self.scale, self.zero_point, qmin, qmax, group_size, zero_point_domain, - ) - - def _should_compute_qparams(self) -> bool: - """ - Return whether we need to compute new scales and zero points. - """ - return self.config.is_dynamic or self.scale is None or self.zero_point is None diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py index ef1714808c..465b1c1e51 100644 --- a/torchao/quantization/prototype/qat/linear.py +++ b/torchao/quantization/prototype/qat/linear.py @@ -1,419 +1,11 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Optional - -import torch -import torch.nn.functional as F - -from torchao.quantization.GPTQ import ( - _check_linear_int4_k, - _replace_linear_int4, - _replace_linear_8da4w, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor, - Int8DynActInt4WeightLinear, - WeightOnlyInt4Linear, +from torchao.quantization.qat.linear import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + FakeQuantizedLinear, + Int4WeightOnlyQATLinear, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, ) -from torchao.quantization.quant_primitives import ( - TorchAODType, - ZeroPointDomain, -) -from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import get_group_qparams_symmetric -from .api import FakeQuantizeConfig -from .fake_quantizer import FakeQuantizer -from .utils import ( - _choose_qparams_per_token_asymmetric, - _fake_quantize_per_channel_group, - _fake_quantize_per_token, - _get_qmin_qmax, -) - - -class FakeQuantizedLinear(torch.nn.Linear): - """ - General linear layer with fake quantized weights and activations. - - Specific target dtypes, granularity, schemes etc. are specified - through separate configs for weights and activations. - - Example usage:: - - activation_config = FakeQuantizeConfig( - dtype=torch.int8, - granularity="per_token", - is_symmetric=False, - ) - weight_config = FakeQuantizeConfig( - dtype=torch.int4, - group_size=8, - is_symmetric=True, - ) - fq_linear = FakeQuantizedLinear( - 16, 32, False, activation_config, weight_config, - ) - fq_linear(torch.randn(16)) - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, - *args, - **kwargs, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - *args, - **kwargs, - ) - if bias: - raise NotImplementedError("bias not supported yet") - - # initialize activation fake quantizer - if activation_config is not None: - self.activation_fake_quantizer = FakeQuantizer(activation_config) - else: - self.activation_fake_quantizer = None - - # initialize weight fake quantizer - if weight_config is not None: - group_size = weight_config.group_size - if group_size is not None and in_features % group_size != 0: - raise ValueError( - "in_features (%s) % group_size (%s) must be == 0" % - (in_features, group_size) - ) - self.weight_fake_quantizer = FakeQuantizer(weight_config) - else: - self.weight_fake_quantizer = None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.activation_fake_quantizer is not None: - x = self.activation_fake_quantizer(x) - if self.weight_fake_quantizer is not None: - w = self.weight_fake_quantizer(self.weight) - else: - w = self.weight - return F.linear(x, w) - - -# ========================================================= -# | Linear int8 dynamic activations + int4 weight QAT | -# ========================================================= - - -class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have int8 - dynamic per token fake quantized activations and int4 fake quantized - grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - padding_allowed: bool = False, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__() - self.groupsize: int = groupsize - self.padding_allowed: bool = padding_allowed - self.precision: torch.dtype = precision - self.scales_precision: torch.dtype = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_8da4w( - model, - self.groupsize, - self.padding_allowed, - self.precision, - self.scales_precision, - Int8DynActInt4WeightQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - self._convert_qat_linear_8da4w(model) - return model - - def _convert_qat_linear_8da4w(self, module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - config = child.weight_fake_quantizer.config - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=config.group_size, - precision=child.weight.dtype, - scales_precision=config.scale_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, config.group_size) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, config.group_size, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - self._convert_qat_linear_8da4w(child) - - -class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): - """ - This module implements a linear layer with int8 dynamic per token fake - quantized activations with int4 fake quantized grouped per channel weights. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - activation_config = FakeQuantizeConfig( - dtype=torch.int8, - granularity="per_token", - is_symmetric=False, - is_dynamic=True, - scale_precision=scales_precision, - zero_point_precision=scales_precision, - ) - weight_config = FakeQuantizeConfig( - dtype=TorchAODType.INT4, - group_size=groupsize, - is_symmetric=True, - is_dynamic=True, - scale_precision=scales_precision, - zero_point_precision=scales_precision, - ) - super().__init__( - in_features, - out_features, - bias, - activation_config, - weight_config, - device=device, - dtype=precision, - ) - - def enable_fake_quant(self, enabled: bool = True): - self.activation_fake_quantizer.enabled = enabled - self.weight_fake_quantizer.enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - -def enable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.enable_fake_quant() - - -def disable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.disable_fake_quant() - - -# =================================== -# | Linear int4 weight-only QAT | -# =================================== - - -class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have - int4 fake quantized grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - inner_k_tiles: Optional[int] = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__() - assert inner_k_tiles in [2, 4, 8] - assert groupsize in [32, 64, 128, 256] - self.inner_k_tiles = inner_k_tiles - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_int4( - model, - self.groupsize, - self.inner_k_tiles, - padding_allowed=True, - precision=self.precision, - scales_precision=self.scales_precision, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - self._convert_qat_linear_4w(model) - return model - - def _convert_qat_linear_4w(self, module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - inner_k_tiles = child.inner_k_tiles - config = child.weight_fake_quantizer.config - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=config.group_size, - inner_k_tiles=inner_k_tiles, - precision=child.weight.dtype, - scales_precision=config.scale_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, config.group_size, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - self._convert_qat_linear_4w(child) - - -class Int4WeightOnlyQATLinear(FakeQuantizedLinear): - """ - This module implements a linear layer with int4 fake quantized grouped - per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, - which uses the efficient int4 tinygemm kernel. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - inner_k_tiles: int = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" - if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): - raise ValueError("Padding for QAT 4w is not supported yet") - self.inner_k_tiles = inner_k_tiles - weight_config = FakeQuantizeConfig( - dtype=torch.uint4, - group_size=groupsize, - is_symmetric=False, - is_dynamic=True, - scale_precision=scales_precision, - zero_point_precision=scales_precision, - zero_point_domain=ZeroPointDomain.FLOAT, - ) - super().__init__( - in_features, - out_features, - bias, - activation_config=None, - weight_config=weight_config, - device=device, - dtype=precision, - ) - - def enable_fake_quant(self, enabled: bool = True): - self.activation_fake_quantizer.enabled = enabled - self.weight_fake_quantizer.enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - -def enable_4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.enable_fake_quant() - - -def disable_4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.disable_fake_quant() diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md new file mode 100644 index 0000000000..6ecccd2b18 --- /dev/null +++ b/torchao/quantization/qat/README.md @@ -0,0 +1,125 @@ +# Quantization-Aware Training (QAT) + +Quantization-Aware Training (QAT) refers to applying fake quantization during the +training or fine-tuning process, such that the final quantized model will exhibit +higher accuracies and perplexities. Fake quantization refers to rounding the float +values to quantized values without actually casting them to dtypes with lower +bit-widths, in contrast to post-training quantization (PTQ), which does cast the +quantized values to lower bit-width dtypes, e.g.: + +``` +# PTQ: x_q is quantized and cast to int8 +# scale and zero point (zp) refer to parameters used to quantize x_float +# qmin and qmax refer to the range of quantized values +x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8) + +# QAT: x_fq is still in float +# Fake quantize simulates the numerics of quantize + dequantize +x_fq = (x_float / scale + zp).round().clamp(qmin, qmax) +x_fq = (x_fq - zp) * scale +``` + +## API + +torchao currently supports two QAT schemes for linear layers: +- int8 per token dynamic activations + int4 per group weights +- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) + +QAT typically involves applying a transformation to your model before and after training. +In torchao, these are represented as the prepare and convert steps: (1) prepare inserts +fake quantize operations into linear layers, and (2) convert transforms the fake quantize +operations to actual quantize and dequantize operations after training, thereby producing +a quantized model (dequantize operations are typically fused with linear after lowering). +Between these two steps, training can proceed exactly as before. + +![qat](images/qat_diagram.png) + +To use QAT in torchao, apply the prepare step using the appropriate Quantizer before +training, then apply the convert step after training for inference or generation. +For example, on a single GPU: + +```python +import torch +from torchtune.models.llama3 import llama3 +from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + +# Smaller version of llama3 to fit in a single GPU +model = llama3( + vocab_size=4096, + num_layers=16, + num_heads=16, + num_kv_heads=4, + embed_dim=2048, + max_seq_len=2048, +).cuda() + +# Quantizer for int8 dynamic per token activations + +# int4 grouped per channel weights, only for linear layers +qat_quantizer = Int8DynActInt4WeightQATQuantizer() + +# Insert "fake quantize" operations into linear layers. +# These operations simulate quantization numerics during +# training without performing any dtype casting +model = qat_quantizer.prepare(model) + +# Standard training loop +optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) +loss_fn = torch.nn.CrossEntropyLoss() +for i in range(10): + example = torch.randint(0, 4096, (2, 16)).cuda() + target = torch.randn((2, 16, 4096)).cuda() + output = model(example) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + +# Convert fake quantize to actual quantize operations +# The quantized model has the exact same structure as the +# quantized model produced in the corresponding PTQ flow +# through `Int8DynActInt4WeightQuantizer` +model = qat_quantizer.convert(model) + +# inference or generate +``` + +Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune) +and apply quantized-aware fine-tuning as follows: + +``` +tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full +``` + +For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). + + +## Evaluation Results + +Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT +integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset) +for 5000 steps using a group size of 256 for the weights. Note that extensive +hyperparameter tuning may further improve these results. + +Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5: + +| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) | +| ---------------- | ------ | ------ | ------ | ------ | ------ | +| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 | +| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 | +| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 | +| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 | +| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 | + +Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the +quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097). + +| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) | +| ---------------- | -------- | ------- | ------ | ------ | ------ | +| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 | +| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 | +| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 | +| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 | +| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 | + +For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training). diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py new file mode 100644 index 0000000000..09ef10af67 --- /dev/null +++ b/torchao/quantization/qat/__init__.py @@ -0,0 +1,17 @@ +from .api import ( + ComposableQATQuantizer, +) +from .linear import ( + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, +) +from .embedding import ( + Int4WeightOnlyEmbeddingQATQuantizer, +) + +__all__ = [ + "ComposableQATQuantizer", + "Int4WeightOnlyQATQuantizer", + "Int4WeightOnlyEmbeddingQATQuantizer" + "Int8DynActInt4WeightQATQuantizer", +] diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py new file mode 100644 index 0000000000..9d7cdccfc3 --- /dev/null +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -0,0 +1,328 @@ +import torch +import torch.utils._pytree as pytree +from typing import Callable, Optional, Tuple +from torchao.quantization.quant_primitives import ( + _get_and_check_qmin_qmax, + choose_qparams_affine, + fake_quantize_affine, + ZeroPointDomain, + MappingType, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.utils import TorchAOBaseTensor +from .utils import ( + _GenericFakeQuantize, + _UnwrapAffineFakeQuantizedTensor, +) + +aten = torch.ops.aten + + +class _ToAffineFakeQuantized(torch.autograd.Function): + """ + Differentiable constructor for `AffineFakeQuantizedTensor`, + needed for input activation fake quantization. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + original_tensor: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ) -> "AffineFakeQuantizedTensor": + def apply_fake_quant_fn(t: torch.Tensor): + assert isinstance(t, AffineFakeQuantizedTensor) + qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + scale, zero_point = choose_qparams_affine( + t.original_tensor, + mapping_type, + block_size, + target_dtype, + qmin, + qmax, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + fq = _GenericFakeQuantize.apply( + t, + block_size, + scale, + zero_point, + qmin, + qmax, + zero_point_domain, + ) + return fq + return AffineFakeQuantizedTensor( + original_tensor, + apply_fake_quant_fn, + fake_quant_enabled=True, + ) + + @staticmethod + def backward(ctx, gy): + return gy, None, None, None, None, None, None, None, None, None, None + + +class AffineFakeQuantizedTensor(TorchAOBaseTensor): + """ + Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor + with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point + + Fake quantization refers to performing the quantization math without actually casting the floating point + tensor into lower bit-width dtypes. It is commonly used for quantization-aware training (QAT). + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + + fields: + original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later + apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values + """ + + @staticmethod + def __new__( + cls, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, + **kwargs, + ): + kwargs.setdefault("dtype", original_tensor.dtype) + kwargs.setdefault("device", original_tensor.device) + kwargs.setdefault("requires_grad", original_tensor.requires_grad) + return torch.Tensor._make_wrapper_subclass( + cls, + original_tensor.shape, + **kwargs, + ) + + def __init__( + self, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, + **kwargs + ): + self.original_tensor = original_tensor + self.apply_fake_quant_fn = apply_fake_quant_fn + self.fake_quant_enabled = fake_quant_enabled + + def __tensor_flatten__(self): + return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride, + ): + original_tensor = tensor_data_dict["original_tensor"] + (apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes + return cls( + original_tensor, + apply_fake_quant_fn, + fake_quant_enabled, + ) + + @classmethod + def from_float( + cls, + original_input: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ): + return _ToAffineFakeQuantized.apply( + original_input, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + + def get_value(self) -> torch.Tensor: + if self.fake_quant_enabled: + return self.apply_fake_quant_fn(self) + else: + return _UnwrapAffineFakeQuantizedTensor.apply(self) + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + "requires_grad": self.requires_grad, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + # not supported yet + kwargs.pop("memory_format") + return self.__class__( + self.original_tensor.to(device), + self.apply_fake_quant_fn, + self.fake_quant_enabled, + **kwargs, + ) + + def _apply_fn_to_data(self, fn: Callable): + """ + Create a new `AffineFakeQuantizedTensor` with `fn` applied to the + original tensor, to be called within __torch_dispatch__. + """ + return self._create_new(fn(self.original_tensor)) + + def _create_new(self, new_value: torch.Tensor): + """ + Create a new `AffineFakeQuantizedTensor` with a new value, + to be called within __torch_dispatch__. + + Note: `requires_grad` must be False here because tensors created + in `__torch_dispatch__` cannot produce gradients, since autograd + will try to attach autograd metadata to these tensors when we exit + `__torch_dispatch__`, but if these tensors already have metadata + attached then autograd will throw an error. + """ + return self.__class__( + new_value, + self.apply_fake_quant_fn, + self.fake_quant_enabled, + requires_grad=False, + ) + +implements = AffineFakeQuantizedTensor.implements + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(aten.mm.default) +def _(func, types, args, kwargs): + bias = None + input_tensor = args[0] + weight_tensor = args[1] + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return func(input_tensor, weight_tensor) + + +@implements(aten.addmm.default) +def _(func, types, args, kwargs): + bias = args[0] + input_tensor = args[1] + weight_tensor = args[2] + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return func(bias, input_tensor, weight_tensor) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach), + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone), + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t), + ) + + +@implements([ + aten.add.Tensor, + aten.add_.Tensor, + aten.mul_.Tensor, + aten.copy_.default, +]) +def _(func, types, args, kwargs): + assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}" + new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args) + first_afq_tensor = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1] + new_value = func(*new_args, **kwargs) + out = first_afq_tensor._create_new(new_value) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# Needed by FSDP: + +@implements(aten.empty_like.default) +def _(func, types, args, kwargs): + out = torch.empty_like(args[0].original_tensor, **kwargs) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@implements(aten.split.Tensor) +def _(func, types, args, kwargs): + new_values = torch.split(args[0].original_tensor, *args[1:], **kwargs) + + def make_new_tensor(value): + out = args[0]._create_new(value) + return return_and_correct_aliasing(func, args, kwargs, out) + + return list(map(make_new_tensor, new_values)) + + +@implements(aten.new_zeros.default) +def _(func, types, args, kwargs): + out = args[0].original_tensor.new_zeros(*args[1:], **kwargs) + return return_and_correct_aliasing(func, args, kwargs, out) + + +to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py new file mode 100644 index 0000000000..60d45c8b17 --- /dev/null +++ b/torchao/quantization/qat/api.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional, Union + +import torch + +from torchao.quantization.granularity import ( + Granularity, + PerAxis, + PerGroup, + PerToken, +) +from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.quant_primitives import ( + _SUB_BYTE_INT_BOUNDS, + _SUB_BYTE_UINT_BOUNDS, + MappingType, + TorchAODType, + ZeroPointDomain, +) + + +@dataclass +class FakeQuantizeConfig: + """ + Config for how to fake quantize weights or activations. + + args: + dtype: dtype to simulate during fake quantization, e.g. torch.int8. + For PyTorch versions older than 2.6, you may use `TorchAODType` to represent + torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. + granularity: granularity of scales and zero points, e.g. PerGroup(32). + We also support the following strings: + 1) 'per_token': equivalent to PerToken() + 2) 'per_channel': equivalent to PerAxis(0) + 3) 'per_group': equivalent to PerGroup(group_size), must be combined + with separate `group_size` kwarg, Alternatively, just set the + `group_size` kwarg and leave this field empty. + mapping_type: whether to use symmetric (default) or asymmetric quantization + Alternatively, set `is_symmetric` (bool) and leave this field empty. + scale_precision: scale dtype (default torch.fp32) + zero_point_precision: zero point dtype (default torch.int32) + zero_point_domain: whether zero point is in integer (default) or float domain + is_dynamic: whether to use dynamic (defualt) or static scale and zero points + range_learning: whether to learn scale and zero points during training (coming soon) + + kwargs (optional): + group_size: size of each group in per group fake quantization, + can be set instead of `granularity` + is_symmetric: whether to use symmetric or asymmetric quantization, + can be set instead of `mapping_type` + + Example usage:: + + # Per token asymmetric quantization + FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) + + # Per channel symmetric quantization + FakeQuantizeConfig(torch.int4, "per_channel") + FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) + FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) + + # Per group symmetric quantization + FakeQuantizeConfig(torch.int4, group_size=32) + FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) + FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) + FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) + """ + dtype: Union[torch.dtype, TorchAODType] + granularity: Granularity + mapping_type: MappingType + scale_precision: torch.dtype + zero_point_precision: torch.dtype + zero_point_domain: ZeroPointDomain + is_dynamic: bool = True + range_learning: bool = False + + def __init__( + self, + dtype: Union[torch.dtype, TorchAODType], + granularity: Union[Granularity, str, None] = None, + mapping_type: Optional[MappingType] = None, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + is_dynamic: bool = True, + range_learning: bool = False, + *, + group_size: Optional[int] = None, + is_symmetric: Optional[bool] = None, + ): + self.dtype = dtype + self.granularity = self._get_granularity(granularity, group_size) + self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) + self.scale_precision = scale_precision + self.zero_point_precision = zero_point_precision + self.zero_point_domain = zero_point_domain + self.is_dynamic = is_dynamic + self.range_learning = range_learning + + # Validate dtype + all_dtypes = [torch.int8, torch.uint8] + all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) + all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) + if dtype not in all_dtypes: + raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)) + + def _get_granularity( + self, + granularity: Union[Granularity, str, None], + group_size: Optional[int], + ) -> Granularity: + """ + Parse the `Granularity` represented in the args. + + Granularity can be specified in one of three ways: + 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) + 2) str: one of 'per_token', 'per_channel', and 'per_group' + 3) None: `group_size` must be set instead, represents per group granularity + """ + # If group_size is set, then granularity must be either "per_group" or None + if group_size is not None and granularity != "per_group" and granularity is not None: + raise ValueError("`group_size` conflicts with granularity '%s'" % granularity) + + # Case 1: Granularity object + if isinstance(granularity, Granularity): + if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): + raise ValueError("Granularity '%s' is not supported" % granularity) + if isinstance(granularity, PerAxis) and granularity.axis != 0: + raise ValueError("Only axis=0 is supported for PerAxis granularity") + return granularity + + # Case 2: str granularity + if granularity == "per_token": + return PerToken() + elif granularity == "per_channel": + return PerAxis(axis=0) + elif granularity == "per_group": + if group_size is None: + raise ValueError("Granularity was 'per_group' but no `group_size` was set") + return PerGroup(group_size) + elif isinstance(granularity, str): + raise ValueError( + "Unexpected granularity: '%s', must be one of %s" % + (granularity, ["per_token", "per_channel", "per_group"]) + ) + + # Case 3: None granularity + group_size was specified + if granularity is not None: + raise ValueError( + "Granularity '%s' has unexpected type %s" % (granularity, type(granularity)) + ) + if group_size is None: + raise ValueError("At least one of `granularity` or `group_size` must be set") + return PerGroup(group_size) + + def _get_mapping_type( + self, + mapping_type: Optional[MappingType], + is_symmetric: Optional[bool], + ) -> MappingType: + """ + Parse the `MappingType` represented in the args. + + Mapping type can be specified in one of two ways: + 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC + 2): is_symmetric bool + """ + if mapping_type is not None and is_symmetric is not None: + raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") + + # Case 0: Default to symmetric + if mapping_type is None and is_symmetric is None: + return MappingType.SYMMETRIC + + # Case 1: MappingType object + if mapping_type is not None: + if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: + raise ValueError("MappingType '%s' is not supported" % mapping_type) + return mapping_type + + # Case 2: is_symmetric flag + assert is_symmetric is not None + if is_symmetric: + return MappingType.SYMMETRIC + else: + return MappingType.ASYMMETRIC + + @property + def group_size(self) -> int: + """ + If this is per group granularity, return the group size. + Otherwise, throw an error. + """ + if isinstance(self.granularity, PerGroup): + return self.granularity.group_size + else: + raise ValueError("`group_size` is undefined for %s granularity" % self.granularity) + + @property + def is_symmetric(self) -> bool: + """ + Return True if mapping type is symmetric, else False (asymmetric). + """ + return self.mapping_type == MappingType.SYMMETRIC + + def __setattr__(self, name: str, value: Any): + """ + Support setting `group_size` and `is_symmetric`. + """ + if name == "group_size": + super().__setattr__("granularity", PerGroup(value)) + elif name == "is_symmetric": + mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC + super().__setattr__("mapping_type", mapping_type) + else: + super().__setattr__(name, value) + + +class ComposableQATQuantizer(TwoStepQuantizer): + """ + Composable quantizer that users can use to apply multiple QAT quantizers easily. + Quantizers will be applied in the order they are specified in the constructor. + + Note: the quantizers provided must apply to different modules in the model, + e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined. + + Example usage:: + + my_quantizer = ComposableQATQuantizer([ + QATQuantizer1(), + QATQuantizer2(), + QATQuantizer3(), + ]) + model = my_quantizer.prepare(model) + train(model) + model = my_quantizer.convert(model) + """ + + def __init__(self, quantizers: List[TwoStepQuantizer]): + self.quantizers = quantizers + + def prepare( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.prepare(model) + return model + + def convert( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.convert(model) + return model diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py new file mode 100644 index 0000000000..1f471fa490 --- /dev/null +++ b/torchao/quantization/qat/embedding.py @@ -0,0 +1,325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, +) +from torchao.quantization.quant_primitives import TorchAODType +from .api import FakeQuantizeConfig +from .fake_quantizer import FakeQuantizer +from .utils import ( + _fake_quantize_per_channel_group, + _get_qmin_qmax, +) + + +class FakeQuantizedEmbedding(torch.nn.Embedding): + """ + General embedding layer with fake quantized weights. + + Specific target dtypes, granularity, schemes etc. are specified + through separate configs for weights and activations. + + Example usage:: + + weight_config = FakeQuantizeConfig( + dtype=torch.int4, + group_size=8, + symmetric=True, + ) + fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config) + fq_embedding(torch.LongTensor([3])) + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + weight_config: Optional[FakeQuantizeConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + *args, + **kwargs, + ) + if weight_config is not None: + self.weight_fake_quantizer = FakeQuantizer(weight_config) + else: + self.weight_fake_quantizer = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.weight_fake_quantizer is not None: + w = self.weight_fake_quantizer(self.weight) + else: + w = self.weight + return F.embedding( + x, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse, + ) + + +# ====================================== +# | Embedding int4 weight-only QAT | +# ====================================== + +class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where embedding layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + group_size: int = 256, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + ) -> None: + super().__init__() + self.bit_width = 4 + self.group_size: int = group_size + self.scale_precision: torch.dtype = scale_precision + self.zero_point_precision: torch.dtype = zero_point_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + """ + Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`. + """ + def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, torch.nn.Embedding) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_embedding = Int4WeightOnlyQATEmbedding( + # nn.Embedding args + num_embeddings=child.num_embeddings, + embedding_dim=child.embedding_dim, + padding_idx=child.padding_idx, + max_norm=child.max_norm, + norm_type=child.norm_type, + scale_grad_by_freq=child.scale_grad_by_freq, + sparse=child.sparse, + # quantization args + group_size=self.group_size, + scale_precision=self.scale_precision, + zero_point_precision=self.zero_point_precision, + device=child.weight.device, + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if child.weight.device != torch.device("meta"): + new_embedding.weight = child.weight + return new_embedding + + _replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + """ + Swap all `Int4WeightOnlyQATEmbedding` modules with `Int4WeightOnlyEmbedding`. + """ + self._convert_helper(model) + return model + + def _convert_helper(self, module: torch.nn.Module): + """ + Helper function to recursively swap `Int4WeightOnlyQATEmbedding` + modules with `Int4WeightOnlyEmbedding` + """ + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATEmbedding): + group_size = child.weight_fake_quantizer.config.group_size + scale_precision = child.weight_fake_quantizer.config.scale_precision + zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision + quantized_embedding = Int4WeightOnlyEmbedding( + # nn.Embedding args + num_embeddings=child.num_embeddings, + embedding_dim=child.embedding_dim, + padding_idx=child.padding_idx, + max_norm=child.max_norm, + norm_type=child.norm_type, + scale_grad_by_freq=child.scale_grad_by_freq, + sparse=child.sparse, + # quantization args + group_size=group_size, + scale_precision=scale_precision, + zero_point_precision=zero_point_precision, + device=child.weight.device, + ) + setattr(module, name, quantized_embedding) + + # Load weights and qparams into quantized embedding + (qmin, qmax) = _get_qmin_qmax(self.bit_width) + (s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size) + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, group_size, + ) + quantized_embedding.weight = q_weight + quantized_embedding.scales = s + quantized_embedding.zeros = zp + else: + self._convert_helper(child) + + +class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding): + """ + This module implements a embedding layer with int4 fake quantized + grouped per channel weights. + + args: + group_size: the number of elements in each quantized group for weights + scale_precision: precision of per group scales + zero_point_precision: precision of per group zero points + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + group_size: int = 32, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + *args, + **kwargs, + ): + weight_config = FakeQuantizeConfig( + dtype=TorchAODType.INT4, + group_size=group_size, + is_symmetric=True, + is_dynamic=True, + scale_precision=scale_precision, + zero_point_precision=zero_point_precision, + ) + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + weight_config, + *args, + **kwargs, + ) + + def enable_fake_quant(self, enabled: bool = True): + self.weight_fake_quantizer.enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + +class Int4WeightOnlyEmbedding(torch.nn.Module): + """ + This module implements a embedding layer with int4 quantized + grouped per channel weights. + """ + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + group_size: int = 32, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + device: torch.device = None, + ): + super().__init__() + + # nn.Embedding args + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + + # quantization args + self.bit_width = 4 + self.group_size = group_size + self.scale_precision = scale_precision + self.zero_point_precision = zero_point_precision + + # currently storing unpacked int8 weights + self.register_buffer( + "weight", + torch.empty((num_embeddings, embedding_dim), dtype=torch.int8, device=device), + ) + self.register_buffer( + "scale", + torch.empty( + (num_embeddings, embedding_dim // group_size), + dtype=scale_precision, + device=device, + ), + ) + self.register_buffer( + "zero_point", + torch.empty( + (num_embeddings, embedding_dim // group_size), + dtype=zero_point_precision, + device=device, + ), + ) + + def forward(self, x): + from torchao._executorch_ops import _quantized_decomposed_dequantize_per_channel_group_wrapper + qmin, qmax = _get_qmin_qmax(self.bit_width) + w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper( + self.weight, + self.scale, + self.zero_point, + qmin, + qmax, + torch.int8, + self.group_size, + x.dtype, + ) + return F.embedding( + x, w_dq, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse, + ) diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py new file mode 100644 index 0000000000..eb42dcf047 --- /dev/null +++ b/torchao/quantization/qat/fake_quantizer.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch + +from torchao.quantization.granularity import ( + PerAxis, + PerGroup, + PerToken, +) +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, + _DTYPE_TO_QVALUE_BOUNDS, +) +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + get_groupwise_affine_qparams, +) +from .api import ( + FakeQuantizeConfig, +) +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +class FakeQuantizer(torch.nn.Module): + """ + Generic module for applying fake quantization to a tensor, as specified in the config. + """ + def __init__(self, config: FakeQuantizeConfig): + super().__init__() + self.config = config + self.enabled = True + self.scale: Optional[torch.Tensor] = None + self.zero_point: Optional[torch.Tensor] = None + + # TODO: support range learinng + if self.config.range_learning: + raise NotImplementedError("Range learning is not supported yet") + + def forward(self, x: torch.Tensor): + """ + Apply fake quantization to the tensor based on the bit-width, + granularity, symmetry, and other properties specified in the config. + """ + if not self.enabled: + return x + + if isinstance(self.config.granularity, PerToken): + return self._per_token_forward(x) + elif isinstance(self.config.granularity, (PerAxis, PerGroup)): + return self._per_channel_or_group_forward(x) + else: + raise ValueError("Unknown granularity '%s'" % self.config.granularity) + + def _per_token_forward(self, x: torch.Tensor): + """ + Perform per token fake quantization on the tensor. + """ + if self.config.is_symmetric: + raise NotImplementedError("Symmetric per token is not supported yet") + if self._should_compute_qparams(): + (self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric( + x, self.config.scale_precision, self.config.zero_point_precision, + ) + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] + return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) + + def _per_channel_or_group_forward(self, x: torch.Tensor): + """ + Perform per channel or per group fake quantization on the tensor. + We express per channel using per group where the group size is the size + of the last dimension of the tensor. + """ + granularity = self.config.granularity + scale_precision = self.config.scale_precision + zero_point_precision = self.config.zero_point_precision + zero_point_domain = self.config.zero_point_domain + is_symmetric = self.config.is_symmetric + + # get group size + if isinstance(granularity, PerAxis): + assert granularity.axis == 0 + group_size = x.size()[-1] + elif isinstance(granularity, PerGroup): + group_size = granularity.group_size + else: + raise ValueError("Unexpected granularity '%s'" % granularity) + + # get scales and zero points + if self._should_compute_qparams(): + bit_width = _DTYPE_TO_BIT_WIDTH[self.config.dtype] + if is_symmetric: + (self.scale, self.zero_point) = get_group_qparams_symmetric( + x, bit_width, group_size, scale_precision, + ) + else: + (self.scale, self.zero_point) = get_groupwise_affine_qparams( + x, bit_width, group_size, scale_precision, + ) + self.zero_point = self.zero_point.to(zero_point_precision) + + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] + return _fake_quantize_per_channel_group( + x, self.scale, self.zero_point, qmin, qmax, group_size, zero_point_domain, + ) + + def _should_compute_qparams(self) -> bool: + """ + Return whether we need to compute new scales and zero points. + """ + return self.config.is_dynamic or self.scale is None or self.zero_point is None diff --git a/torchao/quantization/prototype/qat/images/qat_diagram.png b/torchao/quantization/qat/images/qat_diagram.png similarity index 100% rename from torchao/quantization/prototype/qat/images/qat_diagram.png rename to torchao/quantization/qat/images/qat_diagram.png diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py new file mode 100644 index 0000000000..ef1714808c --- /dev/null +++ b/torchao/quantization/qat/linear.py @@ -0,0 +1,419 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from torchao.quantization.GPTQ import ( + _check_linear_int4_k, + _replace_linear_int4, + _replace_linear_8da4w, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, + Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, +) +from torchao.quantization.quant_primitives import ( + TorchAODType, + ZeroPointDomain, +) +from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.utils import get_group_qparams_symmetric +from .api import FakeQuantizeConfig +from .fake_quantizer import FakeQuantizer +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +class FakeQuantizedLinear(torch.nn.Linear): + """ + General linear layer with fake quantized weights and activations. + + Specific target dtypes, granularity, schemes etc. are specified + through separate configs for weights and activations. + + Example usage:: + + activation_config = FakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=False, + ) + weight_config = FakeQuantizeConfig( + dtype=torch.int4, + group_size=8, + is_symmetric=True, + ) + fq_linear = FakeQuantizedLinear( + 16, 32, False, activation_config, weight_config, + ) + fq_linear(torch.randn(16)) + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + activation_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + *args, + **kwargs, + ) + if bias: + raise NotImplementedError("bias not supported yet") + + # initialize activation fake quantizer + if activation_config is not None: + self.activation_fake_quantizer = FakeQuantizer(activation_config) + else: + self.activation_fake_quantizer = None + + # initialize weight fake quantizer + if weight_config is not None: + group_size = weight_config.group_size + if group_size is not None and in_features % group_size != 0: + raise ValueError( + "in_features (%s) % group_size (%s) must be == 0" % + (in_features, group_size) + ) + self.weight_fake_quantizer = FakeQuantizer(weight_config) + else: + self.weight_fake_quantizer = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.activation_fake_quantizer is not None: + x = self.activation_fake_quantizer(x) + if self.weight_fake_quantizer is not None: + w = self.weight_fake_quantizer(self.weight) + else: + w = self.weight + return F.linear(x, w) + + +# ========================================================= +# | Linear int8 dynamic activations + int4 weight QAT | +# ========================================================= + + +class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have int8 + dynamic per token fake quantized activations and int4 fake quantized + grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.scales_precision: torch.dtype = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.scales_precision, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + self._convert_qat_linear_8da4w(model) + return model + + def _convert_qat_linear_8da4w(self, module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + config = child.weight_fake_quantizer.config + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=config.group_size, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = _get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, config.group_size) + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, config.group_size, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + self._convert_qat_linear_8da4w(child) + + +class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): + """ + This module implements a linear layer with int8 dynamic per token fake + quantized activations with int4 fake quantized grouped per channel weights. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + activation_config = FakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=False, + is_dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) + weight_config = FakeQuantizeConfig( + dtype=TorchAODType.INT4, + group_size=groupsize, + is_symmetric=True, + is_dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) + super().__init__( + in_features, + out_features, + bias, + activation_config, + weight_config, + device=device, + dtype=precision, + ) + + def enable_fake_quant(self, enabled: bool = True): + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + +def enable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.enable_fake_quant() + + +def disable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.disable_fake_quant() + + +# =================================== +# | Linear int4 weight-only QAT | +# =================================== + + +class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + inner_k_tiles: Optional[int] = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + self.inner_k_tiles = inner_k_tiles + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + padding_allowed=True, + precision=self.precision, + scales_precision=self.scales_precision, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + self._convert_qat_linear_4w(model) + return model + + def _convert_qat_linear_4w(self, module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + inner_k_tiles = child.inner_k_tiles + config = child.weight_fake_quantizer.config + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=config.group_size, + inner_k_tiles=inner_k_tiles, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, config.group_size, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + self._convert_qat_linear_4w(child) + + +class Int4WeightOnlyQATLinear(FakeQuantizedLinear): + """ + This module implements a linear layer with int4 fake quantized grouped + per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, + which uses the efficient int4 tinygemm kernel. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.inner_k_tiles = inner_k_tiles + weight_config = FakeQuantizeConfig( + dtype=torch.uint4, + group_size=groupsize, + is_symmetric=False, + is_dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + zero_point_domain=ZeroPointDomain.FLOAT, + ) + super().__init__( + in_features, + out_features, + bias, + activation_config=None, + weight_config=weight_config, + device=device, + dtype=precision, + ) + + def enable_fake_quant(self, enabled: bool = True): + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + +def enable_4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.enable_fake_quant() + + +def disable_4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.disable_fake_quant() diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/qat/utils.py similarity index 97% rename from torchao/quantization/prototype/qat/utils.py rename to torchao/quantization/qat/utils.py index 8f2dd9d13f..e2234a2556 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -46,7 +46,7 @@ def forward( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) @@ -88,7 +88,7 @@ def forward( input: torch.Tensor, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) assert isinstance(input, AffineFakeQuantizedTensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e8e7261fbd..9895ea084d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -221,7 +221,7 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, )