Skip to content

Commit

Permalink
[quant] Add per block quantization primitives
Browse files Browse the repository at this point in the history
Summary:
We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py

Test Plan:
python test/quantization/test_quant_primitives.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 24, 2024
1 parent 3124382 commit df719e3
Show file tree
Hide file tree
Showing 3 changed files with 439 additions and 4 deletions.
185 changes: 183 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
# This test takes a long time to run
import unittest
import torch
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
MappingType,
)

from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)

class TestQuantPrimitives(unittest.TestCase):
SEED = 123
Expand Down Expand Up @@ -46,5 +56,176 @@ def test_get_group_qparams_symmetric(self):
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)

def test_choose_qparams_group_sym(self):
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
we don't include it here. We may just replace it with per block quant
"""
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps)


quant_min = -128
quant_max = 127
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps)

quant_min = -128
quant_max = 127
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

group_size = 2
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group(
input, scale, zero_point, quant_min, quant_max, torch.int8, group_size
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group(
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32
)

self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 1)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
output_dtype = torch.float32
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel(
input, scale, zero_point, axis, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel(
quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
output_dtype = torch.float32
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)

axis = 1
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor(
input, scale, zero_point, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor(
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_channel_asym_4d(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

axis = 2
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel(
input, scale, zero_point, axis, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel(
quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (3, 3, 2, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=0.01, atol=0.01)


if __name__ == "__main__":
unittest.main()
7 changes: 5 additions & 2 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
safe_int_mm,
)
import torch.nn.functional as F
from torch._inductor.utils import do_bench
try:
from torch._inductor.utils import do_bench
except:
from torch._inductor.runtime.runtime_utils import do_bench

aten = torch.ops.aten

AUTOQUANT_CACHE = {}
Expand Down Expand Up @@ -387,4 +391,3 @@ def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filte
model(*example_input)
change_autoquantizable_to_quantized(model, **kwargs)
return model

Loading

0 comments on commit df719e3

Please sign in to comment.