From 9b1120b66ad91dfde1c30e7c49daa61098ca3503 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Sun, 8 Sep 2024 17:10:49 -0700 Subject: [PATCH] Add symmetric quantization with no clipping error in the tensor subclass based API --- test/quantization/test_quant_api.py | 25 +++++++++++++++++++++++++ torchao/quantization/quant_api.py | 7 +++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e61ce7b4fb..b37665db19 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -526,6 +526,31 @@ def test_quantized_tensor_subclass_8da4w(self): self.assertTrue(torch.equal(res, ref)) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + def test_quantized_tensor_subclass_8da4w_no_clipping_err(self): + group_size = 32 + m = ToyLinearModel().eval() + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs() + quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size, mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR)) + + assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) + assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + + quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR) + m_copy = quantizer.quantize(m_copy) + assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) + assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + self.assertTrue(torch.equal(res, ref)) + # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int4(self): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 364463b5ad..24e6f2c3c6 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -473,14 +473,13 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: target_dtype = torch.int8 return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype) -def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32): +def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32, mapping_type=MappingType.SYMMETRIC): """This is defined here instead of local function to support serialization """ if weight.shape[-1] % group_size != 0: return weight # weight settings - mapping_type = MappingType.SYMMETRIC block_size = (1, group_size) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -494,7 +493,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32): weight = to_linear_activation_quantized(weight, input_quant_func) return weight -def int8_dynamic_activation_int4_weight(group_size=32): +def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.SYMMETRIC): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear This is used to produce a model for executorch backend, but currently executorch did not support lowering for the quantized model from this flow yet @@ -503,7 +502,7 @@ def int8_dynamic_activation_int4_weight(group_size=32): `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained """ - return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size) + return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type) def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False):