diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 90fd8f8bf0..6a99ace160 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -10,6 +10,7 @@ import torch from torchao.quantization.quant_primitives import ( get_group_qparams_symmetric, + get_groupwise_affine_qparams, quantize_affine, dequantize_affine, choose_qparams_affine, @@ -56,8 +57,8 @@ def test_get_group_qparams_symmetric(self): scale_obs = scale_obs.reshape(weight.shape[0], -1) # assert that scales are identical - (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize) - torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0) + (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16) + torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0) def test_choose_qparams_group_sym(self): """Note: groupwise asymmetric quant is using a different way of computing zero_points, so @@ -88,7 +89,7 @@ def test_choose_qparams_token_asym(self): scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() - torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3) + torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3) self.assertTrue(torch.equal(zero_point, zp_ref)) def test_choose_qparams_tensor_asym(self): @@ -257,7 +258,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): quantized = quantize_affine(input, block_size, scale, zero_point, dtype) dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float - torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02) + torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02) def test_choose_qparams_tensor_asym_eps(self): input = torch.zeros(10, 10) @@ -279,5 +280,42 @@ def test_get_group_qparams_symmetric_memory(self): after_choose_qparams_mem_use = torch.cuda.memory_allocated() self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use) + def test_tinygemm_get_groupwise_affine_qparams(self): + input = torch.randn(10, 256) + n_bit = 4 + scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) + + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (1, 128) + quant_min = 0 + quant_max = 2**n_bit - 1 + eps = 1e-6 + scale_dtype = torch.bfloat16 + zero_point_dtype = torch.bfloat16 + scale, zero_point = \ + choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + _is_zero_exactly_representable=False, + ) + + def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point): + return (quant_min - zero_point + mid_point) * scale + + mid_point = 2 ** (n_bit - 1) + zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point) + + self.assertTrue(torch.equal(scale, scale_ref)) + torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index f59144becd..7f0b496fb7 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -251,19 +251,24 @@ def choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, + _is_zero_exactly_representable = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input (torch.Tensor): fp32, bf16, fp16 input Tensor mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric - block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization target_dtype (torch.dtype): dtype for target quantized Tensor quant_min (Optional[int]): minimum quantized value for target quantized Tensor quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype scale_dtype (torch.dtype): dtype for scale Tensor zero_point_dtype (torch.dtype): dtype for zero_point Tensor + _is_zero_exactly_representable (bool): a private flag to indicate whether we need zero to be exactly + representable or not, this is typically required for ops that needs zero_padding, like convolution + it's less important for ops that doesn't have zero_padding in the op itself, like linear. if we don't need + zero to be exactly representable, we'll not do rounding and clamping for zero_point Output: Tuple of scales and zero_points Tensor with requested dtype @@ -283,17 +288,26 @@ def choose_qparams_affine( min_val = torch.amin(input, dim=reduction_dims, keepdim=False) max_val = torch.amax(input, dim=reduction_dims, keepdim=False) - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + if _is_zero_exactly_representable: + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + else: + min_val_neg = min_val + max_val_pos = max_val if mapping_type == MappingType.SYMMETRIC: max_val_pos = torch.max(-min_val_neg, max_val_pos) scale = max_val_pos / (float(quant_max - quant_min) / 2) + assert _is_zero_exactly_representable, "non-representable zero path is not implemented for symmetric quantization" zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2)) else: scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) - zero_point = quant_min - torch.round(min_val_neg / scale) - zero_point = torch.clamp(zero_point, quant_min, quant_max) + if _is_zero_exactly_representable: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + zero_point = quant_min - min_val_neg / scale + if eps is None: eps = torch.finfo(input.dtype).eps