From 8c72b21e51321412c84884332668e7282971ab67 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 24 Sep 2024 15:24:18 -0700 Subject: [PATCH] Add composable QAT quantizer Summary: This is a utility for users who wish to apply multiple QAT quantizers to their models. In the near future, we expect to add an embedding QAT quantizer that composes with the existing linear QAT quantizers. Test Plan: python test/quantization/test_qat.py -k test_composable_qat_quantizer --- test/quantization/test_qat.py | 42 +++++++++++++++++ .../quantization/prototype/qat/__init__.py | 2 + torchao/quantization/prototype/qat/api.py | 46 +++++++++++++++++-- torchao/quantization/unified.py | 6 +-- 4 files changed, 89 insertions(+), 7 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 397283a59b..457e3a060f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -15,6 +15,9 @@ from torchao.dtypes import ( TensorCoreTiledLayoutType, ) +from torchao.quantization.prototype.qat.api import ( + ComposableQATQuantizer, +) from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) @@ -34,6 +37,9 @@ MappingType, ZeroPointDomain, ) +from torchao.quantization.unified import ( + TwoStepQuantizer, +) from torchao.quantization.utils import ( get_group_qparams_symmetric, get_groupwise_affine_qparams, @@ -626,6 +632,42 @@ def test_qat_4w_quantizer_module_swap(self): module_swap_out = module_swap_model(*x2) torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + class _MyQATQuantizer(TwoStepQuantizer): + """ + Dummy quantizer that attaches a certain value to each nn.Linear's + `_temp_quantizer_values` attribute. + """ + ATTR_NAME = "_temp_quantizer_values" + + def __init__(self, value: str): + self.value = value + + def _attach_value(self, module: torch.nn.Module): + if isinstance(module, torch.nn.Linear): + if not hasattr(module, self.ATTR_NAME): + setattr(module, self.ATTR_NAME, []) + getattr(module, self.ATTR_NAME).append(self.value) + + def prepare(self, model: torch.nn.Module): + model.apply(self._attach_value) + return model + + def convert(self, model: torch.nn.Module): + model.apply(self._attach_value) + return model + + def test_composable_qat_quantizer(self): + quantizer1 = self._MyQATQuantizer("quantizer1") + quantizer2 = self._MyQATQuantizer("quantizer2") + composable_quantizer = ComposableQATQuantizer([quantizer1, quantizer2]) + model = M() + model = composable_quantizer.prepare(model) + self.assertTrue(hasattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)) + values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) + self.assertEqual(values_list, ["quantizer1", "quantizer2"]) + composable_quantizer.convert(model) + values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) + self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"]) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index c16b3ead44..9f8dd74e44 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -5,6 +5,7 @@ enable_8da4w_fake_quant, int4_weight_only_fake_quantize, int8_dynamic_activation_int4_weight_fake_quantize, + ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) @@ -20,6 +21,7 @@ "enable_8da4w_fake_quant", "int4_weight_only_fake_quantize", "int8_dynamic_activation_int4_weight_fake_quantize", + "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", "Int8DynActInt4WeightQATLinear", diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 2f3368ff1c..e1c5221e1e 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional +from typing import Any, List, Optional import torch import torch.nn.functional as F @@ -34,6 +34,44 @@ ) +class ComposableQATQuantizer(TwoStepQuantizer): + """ + Composable quantizer that users can use to apply multiple QAT quantizers easily. + Quantizers will be applied in the order they are specified in the constructor. + + Note: the quantizers provided must apply to different modules in the model, + e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined. + + Example usage:: + + my_quantizer = ComposableQATQuantizer([ + QATQuantizer1(), + QATQuantizer2(), + QATQuantizer3(), + ]) + model = my_quantizer.prepare(model) + train(model) + model = my_quantizer.convert(model) + """ + + def __init__(self, quantizers: List[TwoStepQuantizer]): + self.quantizers = quantizers + + def prepare( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.prepare(model) + return model + + def convert( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + for quantizer in self.quantizers: + model = quantizer.convert(model) + return model + + # ================= # | 8da4w QAT | # ================= @@ -44,7 +82,8 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): int4 per group weight symmetric fake quantization to linear. Please see :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. - Example usage: + Example usage:: + from torchao.quantization import quantize_ quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) """ @@ -151,7 +190,8 @@ def int4_weight_only_fake_quantize(group_size=128): Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. Please see :func:`~torchao.quantization.int4_weight_only` for more details. - Example usage: + Example usage:: + from torchao.quantization import quantize_ quantize_(model, int4_weight_only_fake_quantize(group_size=32)) """ diff --git a/torchao/quantization/unified.py b/torchao/quantization/unified.py index 7da915dec7..1bd62b8979 100644 --- a/torchao/quantization/unified.py +++ b/torchao/quantization/unified.py @@ -1,5 +1,5 @@ import torch -from typing import Any +from typing import Any, List from abc import ABC, abstractmethod """ @@ -17,7 +17,6 @@ class Quantizer(ABC): def quantize( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass @@ -27,11 +26,10 @@ class TwoStepQuantizer: def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass + @abstractmethod def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: - pass