diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 6e26256e96..5f8680a509 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -10,6 +10,7 @@ import torch from torchao.quantization.quant_primitives import ( fake_quantize_affine, + fake_quantize_affine_cachemask, quantize_affine, dequantize_affine, choose_qparams_affine, @@ -523,5 +524,28 @@ def test_fake_quantize_affine(self): fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) torch.testing.assert_close(dequantized, fake_quantized) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantize_affine_cachemask(self): + input = torch.randn(10, 10) + + mapping_type = MappingType.SYMMETRIC + block_size = list(input.shape) + for i in range(len(block_size) - 1): + block_size[i] = 1 + dtype = torch.int8 + eps = 1e-5 + quant_min = -127 + quant_max = 127 + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + + quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) + (fake_quantized, mask) = fake_quantize_affine_cachemask( + input, block_size, scale, zero_point, dtype, quant_min, quant_max, + ) + expected_mask = torch.full(input.shape, True) + torch.testing.assert_close(dequantized, fake_quantized) + torch.testing.assert_close(expected_mask, mask) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 8f860917f4..9745a26d1a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -24,6 +24,7 @@ "quantize_affine", "dequantize_affine", "fake_quantize_affine", + "fake_quantize_affine_cachemask", ] class MappingType(Enum): @@ -411,6 +412,76 @@ def fake_quantize_affine( value during quantization default is ZeroPointDomain.INT """ + (_, fq) = _do_fake_quantize_affine( + input, + block_size, + scale, + zero_point, + quant_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return fq + + +def fake_quantize_affine_cachemask( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + General fake quantize op for quantization-aware training (QAT). + This is equivalent to calling `quantize_affine` + `dequantize_affine` + but without the dtype casts. + + Note: Compared to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`, + this consumes more memory and returns an additional outlier mask for + intermediate quantized values. + + Args: + Same as :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`. + + Returns: + A 2-tuple of ( + final fake quantized values, + outlier mask for intermediate quantized values + ) + + """ + (q, dq) = _do_fake_quantize_affine( + input, + block_size, + scale, + zero_point, + quant_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + mask = torch.logical_and((q >= quant_min), (q <= quant_max)) + return (dq, mask) + + +def _do_fake_quantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Helper function for `fake_quantize_affine` that returns both the + intermediate quantized values and the final dequantized values. + """ input_dtype = input.dtype quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max) q = _quantize_affine_no_dtype_cast( @@ -432,7 +503,7 @@ def fake_quantize_affine( zero_point_domain.name, output_dtype=input_dtype, ) - return dq + return (q, dq) def choose_qparams_affine(