Skip to content

Commit

Permalink
Add test for choose_qparams for tinygemm ops
Browse files Browse the repository at this point in the history
Summary:
This is in preparation for replacing tinygemm q/dq ops with the unified quant primitive ops

Test Plan:
python test/quantization/test_quant_primitives.py -k test_tinygemm_get_groupwise_affine_qparams

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 7, 2024
1 parent b34d1ac commit 251d09b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
44 changes: 41 additions & 3 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
Expand Down Expand Up @@ -57,7 +58,7 @@ def test_get_group_qparams_symmetric(self):

# assert that scales are identical
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)
torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0)

def test_choose_qparams_group_sym(self):
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
Expand Down Expand Up @@ -88,7 +89,7 @@ def test_choose_qparams_token_asym(self):
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_asym(self):
Expand Down Expand Up @@ -257,7 +258,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)
torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02)

def test_choose_qparams_tensor_asym_eps(self):
input = torch.zeros(10, 10)
Expand All @@ -279,5 +280,42 @@ def test_get_group_qparams_symmetric_memory(self):
after_choose_qparams_mem_use = torch.cuda.memory_allocated()
self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use)

def test_tinygemm_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)

mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
quant_max = 2**n_bit - 1
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
scale, zero_point = \
choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
is_zero_exactly_representable=False
)

def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 2 ** (n_bit - 1)
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)

self.assertTrue(torch.equal(scale, scale_ref))
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.1, atol=torch.max(scale)*0.5)


if __name__ == "__main__":
unittest.main()
12 changes: 10 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
is_zero_exactly_representable = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand All @@ -264,6 +265,9 @@ def choose_qparams_affine(
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
scale_dtype (torch.dtype): dtype for scale Tensor
zero_point_dtype (torch.dtype): dtype for zero_point Tensor
is_zero_exactly_representable (bool): flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero_padding, like convolution
it's less important for ops that doesn't have zero_padding in the op itself, like linear.
Output:
Tuple of scales and zero_points Tensor with requested dtype
Expand All @@ -283,8 +287,12 @@ def choose_qparams_affine(
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)

min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
if is_zero_exactly_representable:
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
else:
min_val_neg = min_val
max_val_pos = max_val

if mapping_type == MappingType.SYMMETRIC:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
Expand Down

0 comments on commit 251d09b

Please sign in to comment.