Skip to content

Commit

Permalink
Add test for choose_qparams for tinygemm ops
Browse files Browse the repository at this point in the history
Summary:
This is in preparation for replacing tinygemm q/dq ops with the unified quant primitive ops

Test Plan:
python test/quantization/test_quant_primitives.py -k test_tinygemm_get_groupwise_affine_qparams

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 8, 2024
1 parent b34d1ac commit 40d5a75
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
46 changes: 42 additions & 4 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
Expand Down Expand Up @@ -56,8 +57,8 @@ def test_get_group_qparams_symmetric(self):
scale_obs = scale_obs.reshape(weight.shape[0], -1)

# assert that scales are identical
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16)
torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0)

def test_choose_qparams_group_sym(self):
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
Expand Down Expand Up @@ -88,7 +89,7 @@ def test_choose_qparams_token_asym(self):
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_asym(self):
Expand Down Expand Up @@ -257,7 +258,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)
torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02)

def test_choose_qparams_tensor_asym_eps(self):
input = torch.zeros(10, 10)
Expand All @@ -279,5 +280,42 @@ def test_get_group_qparams_symmetric_memory(self):
after_choose_qparams_mem_use = torch.cuda.memory_allocated()
self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use)

def test_tinygemm_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)

mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
quant_max = 2**n_bit - 1
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
scale, zero_point = \
choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
_is_zero_exactly_representable=False,
)

def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 2 ** (n_bit - 1)
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)

self.assertTrue(torch.equal(scale, scale_ref))
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03)


if __name__ == "__main__":
unittest.main()
26 changes: 20 additions & 6 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,24 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
_is_zero_exactly_representable = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input (torch.Tensor): fp32, bf16, fp16 input Tensor
mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
target_dtype (torch.dtype): dtype for target quantized Tensor
quant_min (Optional[int]): minimum quantized value for target quantized Tensor
quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
scale_dtype (torch.dtype): dtype for scale Tensor
zero_point_dtype (torch.dtype): dtype for zero_point Tensor
_is_zero_exactly_representable (bool): a private flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero_padding, like convolution
it's less important for ops that doesn't have zero_padding in the op itself, like linear. if we don't need
zero to be exactly representable, we'll not do rounding and clamping for zero_point
Output:
Tuple of scales and zero_points Tensor with requested dtype
Expand All @@ -283,17 +288,26 @@ def choose_qparams_affine(
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)

min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
if _is_zero_exactly_representable:
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
else:
min_val_neg = min_val
max_val_pos = max_val

if mapping_type == MappingType.SYMMETRIC:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
assert _is_zero_exactly_representable, "non-representable zero path is not implemented for symmetric quantization"
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
if _is_zero_exactly_representable:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
zero_point = quant_min - min_val_neg / scale


if eps is None:
eps = torch.finfo(input.dtype).eps
Expand Down

0 comments on commit 40d5a75

Please sign in to comment.