From 0b7184931c7873a3015ad31ad54397db7afc4dde Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 11 Oct 2024 14:01:40 -0700 Subject: [PATCH] Add generic fake quantized linear for QAT Summary: This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w ghstack-source-id: e6f4e1096a98e2ee725b8b51a6cfcfcf545c4ca9 Pull Request resolved: https://github.com/pytorch/ao/pull/1020 --- test/quantization/test_qat.py | 334 +++++++++++++++--- torchao/quantization/README.md | 5 +- torchao/quantization/__init__.py | 5 +- torchao/quantization/granularity.py | 24 +- torchao/quantization/prototype/qat/api.py | 215 ++++++++++- .../prototype/qat/fake_quantizer.py | 121 +++++++ torchao/quantization/prototype/qat/linear.py | 332 +++++++++-------- torchao/quantization/prototype/qat/utils.py | 10 +- torchao/quantization/quant_primitives.py | 61 +++- 9 files changed, 896 insertions(+), 211 deletions(-) create mode 100644 torchao/quantization/prototype/qat/fake_quantizer.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f0c5601ab7..1f45e84d74 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -11,17 +11,32 @@ import unittest import torch +import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torchao.dtypes import ( TensorCoreTiledLayout, ) +from torchao.quantization.granularity import ( + PerAxis, + PerGroup, + PerRow, + PerToken, +) from torchao.quantization.prototype.qat.api import ( ComposableQATQuantizer, + FakeQuantizeConfig, +) +from torchao.quantization.prototype.qat.fake_quantizer import ( + FakeQuantizer, +) +from torchao.quantization.prototype.qat.linear import ( + FakeQuantizedLinear, ) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, + _get_qmin_qmax, _GenericFakeQuantize, ) from torchao.quantization.quant_api import ( @@ -31,6 +46,7 @@ from torchao.quantization.quant_primitives import ( fake_quantize_affine, MappingType, + TorchAODType, ZeroPointDomain, ) from torchao.quantization.unified import ( @@ -92,15 +108,10 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - def _get_qmin_qmax(self, n_bit: int): - qmin = -(2 ** (n_bit - 1)) - qmax = 2 ** (n_bit - 1) - 1 - return (qmin, qmax) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_channel_group(self): n_bit = 4 - (qmin, qmax) = self._get_qmin_qmax(n_bit) + (qmin, qmax) = _get_qmin_qmax(n_bit) group_size = 128 torch.manual_seed(self.SEED) @@ -126,7 +137,7 @@ def test_fake_quantize_per_channel_group(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_token(self): - (qmin, qmax) = self._get_qmin_qmax(8) + (qmin, qmax) = _get_qmin_qmax(8) torch.manual_seed(self.SEED) x = torch.randn(100, 256).requires_grad_() @@ -165,11 +176,11 @@ def _set_ptq_weight( Int4WeightOnlyQATLinear, ) n_bit = 4 - (qmin, qmax) = self._get_qmin_qmax(n_bit) + (qmin, qmax) = _get_qmin_qmax(n_bit) + group_size = qat_linear.weight_fake_quantizer.config.group_size if isinstance(ptq_linear, Int8DynActInt4WeightLinear): assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear) fp32_weight = qat_linear.weight - group_size = qat_linear.groupsize (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, @@ -180,7 +191,7 @@ def _set_ptq_weight( elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - qat_linear.weight, n_bit, qat_linear.groupsize, + qat_linear.weight, n_bit, group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( q_weight.to("cuda"), qat_linear.inner_k_tiles, @@ -218,31 +229,36 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) - subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - subclass_model = subclass_quantizer.prepare(m) - module_swap_model = module_swap_quantizer.prepare(m2) + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) # Compare model values torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) # Convert QAT model and compare model values - subclass_model = subclass_quantizer.convert(subclass_model) - module_swap_model = module_swap_quantizer.convert(module_swap_model) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): @@ -275,9 +291,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self.assertFalse(qat_model.linear1._fake_quant_enabled) - self.assertFalse(qat_model.linear2._fake_quant_enabled) - self.assertFalse(qat_model.sub.linear._fake_quant_enabled) + self.assertFalse(qat_model.linear1.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.linear1.weight_fake_quantizer.enabled) + self.assertFalse(qat_model.linear2.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.linear2.weight_fake_quantizer.enabled) + self.assertFalse(qat_model.sub.linear.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.sub.linear.weight_fake_quantizer.enabled) # Disabled fake quant is just a normal linear m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) @@ -292,9 +311,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - self.assertTrue(qat_model.linear1._fake_quant_enabled) - self.assertTrue(qat_model.linear2._fake_quant_enabled) - self.assertTrue(qat_model.sub.linear._fake_quant_enabled) + self.assertTrue(qat_model.linear1.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.linear1.weight_fake_quantizer.enabled) + self.assertTrue(qat_model.linear2.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.linear2.weight_fake_quantizer.enabled) + self.assertTrue(qat_model.sub.linear.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.sub.linear.weight_fake_quantizer.enabled) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) @@ -407,7 +429,7 @@ def test_qat_generic_fake_quantize(self): the numerics of existing fake quantize ops in Pytorch in both the forward and the backward passes. """ - (qmin, qmax) = self._get_qmin_qmax(4) + (qmin, qmax) = _get_qmin_qmax(4) py_input = torch.randn(16, 64).float().requires_grad_() py_s = torch.randn(16).float() py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) @@ -521,7 +543,7 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 inner_k_tiles = 8 @@ -530,29 +552,34 @@ def test_qat_4w_quantizer(self): torch.manual_seed(self.SEED) m = M().to(device).to(dtype) m2 = copy.deepcopy(m) - subclass_quantizer = Int4WeightOnlyQATQuantizer( + qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - module_swap_quantizer = Int4WeightOnlyQATQuantizer( + ptq_quantizer = Int4WeightOnlyQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - subclass_model = subclass_quantizer.prepare(m) - module_swap_model = module_swap_quantizer.prepare(m2) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) # Compare model values torch.manual_seed(self.SEED) x = [i.to(device).to(dtype) for i in m.example_inputs()] x2 = copy.deepcopy(x) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + self._assert_close_4w(qat_out, ptq_out) # Convert QAT model and compare model values - subclass_model = subclass_quantizer.convert(subclass_model) - module_swap_model = module_swap_quantizer.convert(module_swap_model) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) class _MyQATQuantizer(TwoStepQuantizer): """ @@ -603,5 +630,226 @@ def test_qat_4w_embedding(self): converted = quantizer.convert(model) converted_out = converted(*x) + def test_fake_quantize_config_granularity(self): + """ + Test initialization and property setting of `FakeQuantizeConfig`'s granularity. + """ + # per token + per_token_config1 = FakeQuantizeConfig(torch.int8, PerToken()) + per_token_config2 = FakeQuantizeConfig(torch.int8, "per_token") + self.assertIsInstance(per_token_config1.granularity, PerToken) + self.assertIsInstance(per_token_config2.granularity, PerToken) + + # per channel + per_channel_config1 = FakeQuantizeConfig(torch.int8, PerAxis(0)) + per_channel_config2 = FakeQuantizeConfig(torch.int8, "per_channel") + self.assertIsInstance(per_channel_config1.granularity, PerAxis) + self.assertIsInstance(per_channel_config2.granularity, PerAxis) + self.assertEqual(per_channel_config1.granularity.axis, 0) + self.assertEqual(per_channel_config2.granularity.axis, 0) + + # per group + per_group_config1 = FakeQuantizeConfig(torch.int8, PerGroup(32)) + per_group_config2 = FakeQuantizeConfig(torch.int8, "per_group", group_size=32) + per_group_config3 = FakeQuantizeConfig(torch.int8, group_size=32) + self.assertIsInstance(per_group_config1.granularity, PerGroup) + self.assertIsInstance(per_group_config2.granularity, PerGroup) + self.assertIsInstance(per_group_config3.granularity, PerGroup) + self.assertEqual(per_group_config1.group_size, 32) + self.assertEqual(per_group_config2.group_size, 32) + self.assertEqual(per_group_config3.group_size, 32) + + # set `group_size` after initialization + per_token_config1.group_size = 64 + per_channel_config1.group_size = 64 + per_group_config1.group_size = 64 + self.assertIsInstance(per_token_config1.granularity, PerGroup) + self.assertIsInstance(per_channel_config1.granularity, PerGroup) + self.assertIsInstance(per_group_config1.granularity, PerGroup) + self.assertEqual(per_token_config1.granularity.group_size, 64) + self.assertEqual(per_channel_config1.granularity.group_size, 64) + self.assertEqual(per_group_config1.granularity.group_size, 64) + + def test_fake_quantize_config_granularity_error_cases(self): + """ + Test incorrect settings of `FakeQuantizeConfig`'s granularity. + """ + # no granularity provided + with self.assertRaisesRegex(ValueError, "`granularity` or `group_size` must be set"): + FakeQuantizeConfig(torch.int8) + + # group_size with conflicting granularity + msg = "`group_size` conflicts with granularity" + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.int8, PerToken(), group_size=32) + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.int8, PerGroup(64), group_size=32) + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.int8, "per_token", group_size=32) + + # 'per_group' but no group_size + msg = "Granularity was 'per_group' but no `group_size` was set" + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.int8, "per_group") + + # not supported + with self.assertRaisesRegex(ValueError, "not supported"): + FakeQuantizeConfig(torch.int8, PerRow()) + with self.assertRaisesRegex(ValueError, "Only axis=0 is supported"): + FakeQuantizeConfig(torch.int8, PerAxis(1)) + with self.assertRaisesRegex(ValueError, "Unexpected granularity"): + FakeQuantizeConfig(torch.int8, "blah") + with self.assertRaisesRegex(ValueError, "unexpected type"): + FakeQuantizeConfig(torch.int8, 1234) + + def test_fake_quantize_config_mapping_type(self): + """ + Test initialization and property setting of `FakeQuantizeConfig`'s mapping type. + """ + # symmetric + symmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token") + symmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=True) + symmetric_config3 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC) + self.assertEqual(symmetric_config1.mapping_type, MappingType.SYMMETRIC) + self.assertEqual(symmetric_config2.mapping_type, MappingType.SYMMETRIC) + self.assertEqual(symmetric_config3.mapping_type, MappingType.SYMMETRIC) + self.assertTrue(symmetric_config1.is_symmetric) + self.assertTrue(symmetric_config2.is_symmetric) + self.assertTrue(symmetric_config3.is_symmetric) + + # asymmetric + asymmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + asymmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.ASYMMETRIC) + self.assertEqual(asymmetric_config1.mapping_type, MappingType.ASYMMETRIC) + self.assertEqual(asymmetric_config2.mapping_type, MappingType.ASYMMETRIC) + self.assertFalse(asymmetric_config1.is_symmetric) + self.assertFalse(asymmetric_config2.is_symmetric) + + # set `is_symmetric` after initialization + asymmetric_config1.is_symmetric = True + self.assertEqual(asymmetric_config1.mapping_type, MappingType.SYMMETRIC) + self.assertTrue(asymmetric_config1.is_symmetric) + + # bad config1: both mapping_type and is_symmetric are set + msg = "Cannot set both `mapping_type` and `is_symmetric`" + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False) + + # bad config2: not supported + with self.assertRaisesRegex(ValueError, "not supported"): + FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR) + + def test_fake_quantize_config_dtype(self): + """ + Test that unsupported dtypes are caught in `FakeQuantizeConfig`. + """ + msg = "Unsupported dtype" + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.int16, "per_token") + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.int32, "per_token") + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.bfloat16, "per_token") + with self.assertRaisesRegex(ValueError, msg): + FakeQuantizeConfig(torch.float32, "per_token") + # OK + FakeQuantizeConfig(torch.uint1, "per_token") + FakeQuantizeConfig(torch.uint2, "per_token") + FakeQuantizeConfig(torch.uint3, "per_token") + FakeQuantizeConfig(torch.uint4, "per_token") + FakeQuantizeConfig(torch.uint5, "per_token") + FakeQuantizeConfig(torch.uint6, "per_token") + FakeQuantizeConfig(torch.uint7, "per_token") + FakeQuantizeConfig(torch.uint8, "per_token") + FakeQuantizeConfig(TorchAODType.INT1, "per_token") + FakeQuantizeConfig(TorchAODType.INT2, "per_token") + FakeQuantizeConfig(TorchAODType.INT3, "per_token") + FakeQuantizeConfig(TorchAODType.INT4, "per_token") + FakeQuantizeConfig(TorchAODType.INT5, "per_token") + FakeQuantizeConfig(TorchAODType.INT6, "per_token") + FakeQuantizeConfig(TorchAODType.INT7, "per_token") + FakeQuantizeConfig(torch.int8, "per_token") + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantized_linear_8da4w(self): + """ + Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. + """ + group_size = 128 + torch.manual_seed(self.SEED) + fq_linear = FakeQuantizedLinear( + 256, + 688, + bias=False, + activation_config=FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False), + weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), + ) + + def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant. + """ + # activations + (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) + (qmin, qmax) = _get_qmin_qmax(8) + x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax) + + # weights + (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) + zp = zp.to(torch.int32) + (qmin, qmax) = _get_qmin_qmax(4) + w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + return F.linear(x_fq, w_fq) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + fq_out = fq_linear(x) + baseline_out = linear_forward_8da4w(x2, fq_linear.weight) + torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantized_linear_4w(self): + """ + Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. + """ + group_size = 128 + weight_config = FakeQuantizeConfig( + dtype=torch.uint4, + group_size=group_size, + is_symmetric=False, + zero_point_domain=ZeroPointDomain.FLOAT, + ) + torch.manual_seed(self.SEED) + fq_linear = FakeQuantizedLinear( + 256, + 688, + bias=False, + activation_config=None, + weight_config=weight_config, + ) + + def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Baseline for int4 weight only fake quantization that simulates the tinygemm kernel. + """ + (qmin, qmax) = _get_qmin_qmax(4, symmetric=False) + (s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32) + zp = zp.to(torch.int32) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT, + ) + return F.linear(x, w_fq) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + fq_out = fq_linear(x) + baseline_out = linear_forward_4w(x2, fq_linear.weight) + torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 9d7b049470..f420f0f857 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -136,8 +136,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ -from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight -from torchao.quantization.quant_api import PerTensor +from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) ``` @@ -321,7 +320,7 @@ This API works today but has not been extensively tested and benchmarked yet. Ha ```python # for torch 2.5+ -from torchao.quantization.quant_api import quantize_, PerRow, float8_dynamic_activation_float8_weight +from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow())) ``` diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 31757e7ee6..295741c9fe 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -12,11 +12,12 @@ from .weight_only import * # noqa: F403 from .unified import * from .autoquant import * -from .linear_activation_quantized_tensor import ( # noqat: F403 +from .granularity import * +from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .linear_activation_scale import ( # noqat: F403 +from .linear_activation_scale import ( to_weight_tensor_with_linear_activation_scale_metadata, ) diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 5251c7865e..a4d1d0a2de 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -22,7 +22,7 @@ class PerTensor(Granularity): """ Represents per-tensor granularity in quantization. - This granularity type calcualtes the quantization parameters + This granularity type calculates the quantization parameters based off the entire tensor. """ pass @@ -32,26 +32,24 @@ class PerAxis(Granularity): """ Represents per-axis granularity in quantization. - This granularity type calcualtes different quantization parameters + This granularity type calculates different quantization parameters along a specified axis of the tensor. For example if the input tensor is shape [8, 16] and axis=0, then the quantization parameters are calculated for each row of the tensor. Giving a total of 8 quantization parameters. - Attributes: axis (int): The axis along which reduction is performed. """ axis: int @dataclass(frozen=True) - class PerGroup(Granularity): """ Represents per-channel group granularity in quantization. - This granularity type calcualtes different quantization parameters + This granularity type calculates different quantization parameters for each group of elements. For example if the input tensor is shape [8, 16], and the group size is 4, then @@ -74,3 +72,19 @@ class PerRow(Granularity): is quantized with a block_size of (1, weight.shape[1]). """ pass + +class PerToken(Granularity): + """ + Represents per-token granularity in quantization. + + This granularity type calculates a different set of quantization parameters + for each token, which is represented as the last dimension of the tensor. + + For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens + with 4 elements each, and we will calculate 6 sets of quantization parameters, + one for each token. + + If the input tensor has only two dimensions, e.g. [8, 16], then this is + equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. + """ + pass diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 93717271bb..60d45c8b17 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,11 +4,224 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional, Union import torch +from torchao.quantization.granularity import ( + Granularity, + PerAxis, + PerGroup, + PerToken, +) from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.quant_primitives import ( + _SUB_BYTE_INT_BOUNDS, + _SUB_BYTE_UINT_BOUNDS, + MappingType, + TorchAODType, + ZeroPointDomain, +) + + +@dataclass +class FakeQuantizeConfig: + """ + Config for how to fake quantize weights or activations. + + args: + dtype: dtype to simulate during fake quantization, e.g. torch.int8. + For PyTorch versions older than 2.6, you may use `TorchAODType` to represent + torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. + granularity: granularity of scales and zero points, e.g. PerGroup(32). + We also support the following strings: + 1) 'per_token': equivalent to PerToken() + 2) 'per_channel': equivalent to PerAxis(0) + 3) 'per_group': equivalent to PerGroup(group_size), must be combined + with separate `group_size` kwarg, Alternatively, just set the + `group_size` kwarg and leave this field empty. + mapping_type: whether to use symmetric (default) or asymmetric quantization + Alternatively, set `is_symmetric` (bool) and leave this field empty. + scale_precision: scale dtype (default torch.fp32) + zero_point_precision: zero point dtype (default torch.int32) + zero_point_domain: whether zero point is in integer (default) or float domain + is_dynamic: whether to use dynamic (defualt) or static scale and zero points + range_learning: whether to learn scale and zero points during training (coming soon) + + kwargs (optional): + group_size: size of each group in per group fake quantization, + can be set instead of `granularity` + is_symmetric: whether to use symmetric or asymmetric quantization, + can be set instead of `mapping_type` + + Example usage:: + + # Per token asymmetric quantization + FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) + + # Per channel symmetric quantization + FakeQuantizeConfig(torch.int4, "per_channel") + FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) + FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) + + # Per group symmetric quantization + FakeQuantizeConfig(torch.int4, group_size=32) + FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) + FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) + FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) + """ + dtype: Union[torch.dtype, TorchAODType] + granularity: Granularity + mapping_type: MappingType + scale_precision: torch.dtype + zero_point_precision: torch.dtype + zero_point_domain: ZeroPointDomain + is_dynamic: bool = True + range_learning: bool = False + + def __init__( + self, + dtype: Union[torch.dtype, TorchAODType], + granularity: Union[Granularity, str, None] = None, + mapping_type: Optional[MappingType] = None, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + is_dynamic: bool = True, + range_learning: bool = False, + *, + group_size: Optional[int] = None, + is_symmetric: Optional[bool] = None, + ): + self.dtype = dtype + self.granularity = self._get_granularity(granularity, group_size) + self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) + self.scale_precision = scale_precision + self.zero_point_precision = zero_point_precision + self.zero_point_domain = zero_point_domain + self.is_dynamic = is_dynamic + self.range_learning = range_learning + + # Validate dtype + all_dtypes = [torch.int8, torch.uint8] + all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) + all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) + if dtype not in all_dtypes: + raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)) + + def _get_granularity( + self, + granularity: Union[Granularity, str, None], + group_size: Optional[int], + ) -> Granularity: + """ + Parse the `Granularity` represented in the args. + + Granularity can be specified in one of three ways: + 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) + 2) str: one of 'per_token', 'per_channel', and 'per_group' + 3) None: `group_size` must be set instead, represents per group granularity + """ + # If group_size is set, then granularity must be either "per_group" or None + if group_size is not None and granularity != "per_group" and granularity is not None: + raise ValueError("`group_size` conflicts with granularity '%s'" % granularity) + + # Case 1: Granularity object + if isinstance(granularity, Granularity): + if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): + raise ValueError("Granularity '%s' is not supported" % granularity) + if isinstance(granularity, PerAxis) and granularity.axis != 0: + raise ValueError("Only axis=0 is supported for PerAxis granularity") + return granularity + + # Case 2: str granularity + if granularity == "per_token": + return PerToken() + elif granularity == "per_channel": + return PerAxis(axis=0) + elif granularity == "per_group": + if group_size is None: + raise ValueError("Granularity was 'per_group' but no `group_size` was set") + return PerGroup(group_size) + elif isinstance(granularity, str): + raise ValueError( + "Unexpected granularity: '%s', must be one of %s" % + (granularity, ["per_token", "per_channel", "per_group"]) + ) + + # Case 3: None granularity + group_size was specified + if granularity is not None: + raise ValueError( + "Granularity '%s' has unexpected type %s" % (granularity, type(granularity)) + ) + if group_size is None: + raise ValueError("At least one of `granularity` or `group_size` must be set") + return PerGroup(group_size) + + def _get_mapping_type( + self, + mapping_type: Optional[MappingType], + is_symmetric: Optional[bool], + ) -> MappingType: + """ + Parse the `MappingType` represented in the args. + + Mapping type can be specified in one of two ways: + 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC + 2): is_symmetric bool + """ + if mapping_type is not None and is_symmetric is not None: + raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") + + # Case 0: Default to symmetric + if mapping_type is None and is_symmetric is None: + return MappingType.SYMMETRIC + + # Case 1: MappingType object + if mapping_type is not None: + if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: + raise ValueError("MappingType '%s' is not supported" % mapping_type) + return mapping_type + + # Case 2: is_symmetric flag + assert is_symmetric is not None + if is_symmetric: + return MappingType.SYMMETRIC + else: + return MappingType.ASYMMETRIC + + @property + def group_size(self) -> int: + """ + If this is per group granularity, return the group size. + Otherwise, throw an error. + """ + if isinstance(self.granularity, PerGroup): + return self.granularity.group_size + else: + raise ValueError("`group_size` is undefined for %s granularity" % self.granularity) + + @property + def is_symmetric(self) -> bool: + """ + Return True if mapping type is symmetric, else False (asymmetric). + """ + return self.mapping_type == MappingType.SYMMETRIC + + def __setattr__(self, name: str, value: Any): + """ + Support setting `group_size` and `is_symmetric`. + """ + if name == "group_size": + super().__setattr__("granularity", PerGroup(value)) + elif name == "is_symmetric": + mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC + super().__setattr__("mapping_type", mapping_type) + else: + super().__setattr__(name, value) class ComposableQATQuantizer(TwoStepQuantizer): diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/prototype/qat/fake_quantizer.py new file mode 100644 index 0000000000..eb42dcf047 --- /dev/null +++ b/torchao/quantization/prototype/qat/fake_quantizer.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch + +from torchao.quantization.granularity import ( + PerAxis, + PerGroup, + PerToken, +) +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, + _DTYPE_TO_QVALUE_BOUNDS, +) +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + get_groupwise_affine_qparams, +) +from .api import ( + FakeQuantizeConfig, +) +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +class FakeQuantizer(torch.nn.Module): + """ + Generic module for applying fake quantization to a tensor, as specified in the config. + """ + def __init__(self, config: FakeQuantizeConfig): + super().__init__() + self.config = config + self.enabled = True + self.scale: Optional[torch.Tensor] = None + self.zero_point: Optional[torch.Tensor] = None + + # TODO: support range learinng + if self.config.range_learning: + raise NotImplementedError("Range learning is not supported yet") + + def forward(self, x: torch.Tensor): + """ + Apply fake quantization to the tensor based on the bit-width, + granularity, symmetry, and other properties specified in the config. + """ + if not self.enabled: + return x + + if isinstance(self.config.granularity, PerToken): + return self._per_token_forward(x) + elif isinstance(self.config.granularity, (PerAxis, PerGroup)): + return self._per_channel_or_group_forward(x) + else: + raise ValueError("Unknown granularity '%s'" % self.config.granularity) + + def _per_token_forward(self, x: torch.Tensor): + """ + Perform per token fake quantization on the tensor. + """ + if self.config.is_symmetric: + raise NotImplementedError("Symmetric per token is not supported yet") + if self._should_compute_qparams(): + (self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric( + x, self.config.scale_precision, self.config.zero_point_precision, + ) + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] + return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) + + def _per_channel_or_group_forward(self, x: torch.Tensor): + """ + Perform per channel or per group fake quantization on the tensor. + We express per channel using per group where the group size is the size + of the last dimension of the tensor. + """ + granularity = self.config.granularity + scale_precision = self.config.scale_precision + zero_point_precision = self.config.zero_point_precision + zero_point_domain = self.config.zero_point_domain + is_symmetric = self.config.is_symmetric + + # get group size + if isinstance(granularity, PerAxis): + assert granularity.axis == 0 + group_size = x.size()[-1] + elif isinstance(granularity, PerGroup): + group_size = granularity.group_size + else: + raise ValueError("Unexpected granularity '%s'" % granularity) + + # get scales and zero points + if self._should_compute_qparams(): + bit_width = _DTYPE_TO_BIT_WIDTH[self.config.dtype] + if is_symmetric: + (self.scale, self.zero_point) = get_group_qparams_symmetric( + x, bit_width, group_size, scale_precision, + ) + else: + (self.scale, self.zero_point) = get_groupwise_affine_qparams( + x, bit_width, group_size, scale_precision, + ) + self.zero_point = self.zero_point.to(zero_point_precision) + + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] + return _fake_quantize_per_channel_group( + x, self.scale, self.zero_point, qmin, qmax, group_size, zero_point_domain, + ) + + def _should_compute_qparams(self) -> bool: + """ + Return whether we need to compute new scales and zero points. + """ + return self.config.is_dynamic or self.scale is None or self.zero_point is None diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py index 07276ba84c..ef1714808c 100644 --- a/torchao/quantization/prototype/qat/linear.py +++ b/torchao/quantization/prototype/qat/linear.py @@ -18,9 +18,14 @@ Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) -from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.quantization.quant_primitives import ( + TorchAODType, + ZeroPointDomain, +) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric +from .api import FakeQuantizeConfig +from .fake_quantizer import FakeQuantizer from .utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, @@ -29,6 +34,79 @@ ) +class FakeQuantizedLinear(torch.nn.Linear): + """ + General linear layer with fake quantized weights and activations. + + Specific target dtypes, granularity, schemes etc. are specified + through separate configs for weights and activations. + + Example usage:: + + activation_config = FakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=False, + ) + weight_config = FakeQuantizeConfig( + dtype=torch.int4, + group_size=8, + is_symmetric=True, + ) + fq_linear = FakeQuantizedLinear( + 16, 32, False, activation_config, weight_config, + ) + fq_linear(torch.randn(16)) + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + activation_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + *args, + **kwargs, + ) + if bias: + raise NotImplementedError("bias not supported yet") + + # initialize activation fake quantizer + if activation_config is not None: + self.activation_fake_quantizer = FakeQuantizer(activation_config) + else: + self.activation_fake_quantizer = None + + # initialize weight fake quantizer + if weight_config is not None: + group_size = weight_config.group_size + if group_size is not None and in_features % group_size != 0: + raise ValueError( + "in_features (%s) % group_size (%s) must be == 0" % + (in_features, group_size) + ) + self.weight_fake_quantizer = FakeQuantizer(weight_config) + else: + self.weight_fake_quantizer = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.activation_fake_quantizer is not None: + x = self.activation_fake_quantizer(x) + if self.weight_fake_quantizer is not None: + w = self.weight_fake_quantizer(self.weight) + else: + w = self.weight + return F.linear(x, w) + + # ========================================================= # | Linear int8 dynamic activations + int4 weight QAT | # ========================================================= @@ -77,42 +155,42 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) + self._convert_qat_linear_8da4w(model) return model - -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - - -class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + def _convert_qat_linear_8da4w(self, module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + config = child.weight_fake_quantizer.config + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=config.group_size, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = _get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, config.group_size) + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, config.group_size, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + self._convert_qat_linear_8da4w(child) + + +class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int8 dynamic per token fake quantized activations with int4 fake quantized grouped per channel weights. @@ -133,63 +211,39 @@ def __init__( precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: + activation_config = FakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=False, + is_dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) + weight_config = FakeQuantizeConfig( + dtype=TorchAODType.INT4, + group_size=groupsize, + is_symmetric=True, + is_dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) super().__init__( in_features, out_features, bias, + activation_config, + weight_config, device=device, dtype=precision, ) - assert ( - in_features % groupsize == 0 - ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" - assert not bias, "require bias=False" - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - # TODO: make this configurable? - self.zero_points_precision = torch.int32 - self._fake_quant_enabled = True def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - # activations: int8 dynamic asymmetric quant - if self._fake_quant_enabled: - (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( - x, self.scales_precision, self.zero_points_precision, - ) - (act_qmin, act_qmax) = _get_qmin_qmax(8) - x_fq = _fake_quantize_per_token( - x, act_scales, act_zp, act_qmin, act_qmax, - ) - else: - x_fq = x - - # weights: int4 grouped per channel symmetric quant - if self._fake_quant_enabled: - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 4, self.groupsize, self.scales_precision, - ) - # TODO: pass zp dtype to `get_group_qparams_symmetric` instead - weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.groupsize, - ) - else: - w_fq = self.weight - return F.linear(x_fq, w_fq) - def enable_8da4w_fake_quant(mod: torch.nn.Module): """ @@ -257,46 +311,45 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_4w(model) + self._convert_qat_linear_4w(model) return model - -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - - -class Int4WeightOnlyQATLinear(torch.nn.Linear): + def _convert_qat_linear_4w(self, module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + inner_k_tiles = child.inner_k_tiles + config = child.weight_fake_quantizer.config + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=config.group_size, + inner_k_tiles=inner_k_tiles, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, config.group_size, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + self._convert_qat_linear_4w(child) + + +class Int4WeightOnlyQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int4 fake quantized grouped per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, @@ -319,47 +372,36 @@ def __init__( precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.inner_k_tiles = inner_k_tiles + weight_config = FakeQuantizeConfig( + dtype=torch.uint4, + group_size=groupsize, + is_symmetric=False, + is_dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + zero_point_domain=ZeroPointDomain.FLOAT, + ) super().__init__( in_features, out_features, bias, + activation_config=None, + weight_config=weight_config, device=device, dtype=precision, ) - assert not bias, "require bias=False" - assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" - if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): - raise ValueError("Padding for QAT 4w is not supported yet") - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.precision = precision - self.scales_precision = scales_precision - self._fake_quant_enabled = True def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - n_bit = 4 - qmin = 0 - qmax = 2 ** n_bit - 1 - scales, zero_points = get_groupwise_affine_qparams( - self.weight, n_bit, self.groupsize, self.scales_precision, - ) - w_fq = _fake_quantize_per_channel_group( - self.weight, - scales, - zero_points, - qmin, - qmax, - self.groupsize, - ZeroPointDomain.FLOAT, - ) - return F.linear(x, w_fq) - def enable_4w_fake_quant(mod: torch.nn.Module): """ diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 354475e655..8f2dd9d13f 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -181,7 +181,11 @@ def _choose_qparams_per_token_asymmetric( return scale.to(scales_precision), zero_point.to(zero_points_precision) -def _get_qmin_qmax(n_bit: int): - qmin = -(2 ** (n_bit - 1)) - qmax = 2 ** (n_bit - 1) - 1 +def _get_qmin_qmax(n_bit: int, symmetric: bool=True): + if symmetric: + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + else: + qmin = 0 + qmax = 2 ** n_bit - 1 return (qmin, qmax) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ea3a9d54c5..0a1227b23a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -64,6 +64,19 @@ class ZeroPointDomain(Enum): INT = auto() FLOAT = auto() +class TorchAODType(Enum): + """ + Placeholder for dtypes that do not exist in PyTorch core yet. + """ + # torch.int1 to torch.int7 will be added to PyTorch 2.6 + # These will remain here for BC with older PyTorch versions + INT1 = auto() + INT2 = auto() + INT3 = auto() + INT4 = auto() + INT5 = auto() + INT6 = auto() + INT7 = auto() if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) @@ -79,16 +92,17 @@ class ZeroPointDomain(Enum): Map from dtype to the bound value of integers TODO: maybe can replace this with call to torch.iinfo """ -_DTYPE_TO_QVALUE_BOUNDS: Dict[torch.dtype, Tuple[int, int]] = { +_DTYPE_TO_QVALUE_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = { torch.uint8: (0, 255), torch.int8: (-128, 127), torch.int16: (-(2**15), 2**15 - 1), torch.int32: (-(2**31), 2**31 - 1), } -_SUB_BYTE_DTYPE_BOUNDS: Dict[torch.dtype, Tuple[int, int]] = {} +_SUB_BYTE_UINT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {} +_SUB_BYTE_INT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {} if TORCH_VERSION_AT_LEAST_2_3: - _SUB_BYTE_DTYPE_BOUNDS = { + _SUB_BYTE_UINT_BOUNDS = { torch.uint1: (0, 2**1-1), torch.uint2: (0, 2**2-1), torch.uint3: (0, 2**3-1), @@ -97,10 +111,39 @@ class ZeroPointDomain(Enum): torch.uint6: (0, 2**6-1), torch.uint7: (0, 2**7-1), } - _DTYPE_TO_QVALUE_BOUNDS.update( - _SUB_BYTE_DTYPE_BOUNDS - ) - + _SUB_BYTE_INT_BOUNDS = { + TorchAODType.INT1: (-(2**0), 2**0 - 1), + TorchAODType.INT2: (-(2**1), 2**1 - 1), + TorchAODType.INT3: (-(2**2), 2**2 - 1), + TorchAODType.INT4: (-(2**3), 2**3 - 1), + TorchAODType.INT5: (-(2**4), 2**4 - 1), + TorchAODType.INT6: (-(2**5), 2**5 - 1), + TorchAODType.INT7: (-(2**6), 2**6 - 1), + } + _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) + _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) + +_DTYPE_TO_BIT_WIDTH: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, + torch.uint8: 8, + TorchAODType.INT1: 1, + TorchAODType.INT2: 2, + TorchAODType.INT3: 3, + TorchAODType.INT4: 4, + TorchAODType.INT5: 5, + TorchAODType.INT6: 6, + TorchAODType.INT7: 7, + torch.int8: 8, + torch.int16: 16, + torch.int32: 32, +} +assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys() _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -251,7 +294,7 @@ def _quantize_affine( quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) # workaround for uintx dtypes, since we don't have native Uintx dtype connected with # torch.uintx dtypes yet - if output_dtype in _SUB_BYTE_DTYPE_BOUNDS: + if output_dtype in _SUB_BYTE_UINT_BOUNDS: output_dtype = torch.uint8 return _quantize_affine_no_dtype_cast( input, @@ -377,7 +420,7 @@ def _dequantize_affine( """op definition that has compatible signatures with custom op library """ # TODO: validate scale/zero_point dimensions are compatible with block_size - if input_dtype not in _SUB_BYTE_DTYPE_BOUNDS: + if input_dtype not in _SUB_BYTE_UINT_BOUNDS: assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)