From 4e1d9f4cd3b1a4516057070ceee81b217df48461 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 4 Jun 2024 13:35:39 -0400 Subject: [PATCH] Refactor rest of tinygemm quant primitive ops Summary: This PR replaces the remaining tinygemm specific quant primitive ops with the general quant primitive ops that we want to use for everything, we could delete these ops in a separate PR if needed Test Plan: python test/quantization/test_quant_primitives.py -k test_get_groupwise_affine_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_quantize_tensor_from_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_dequantize_tensor_from_qparams accuracy: perf: no diff for generated code with `TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py` --- test/quantization/test_quant_primitives.py | 109 ++++++++++++++++++++- torchao/quantization/quant_primitives.py | 75 +++++++------- 2 files changed, 141 insertions(+), 43 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 3ce53cbde1..6054c6e66f 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -11,6 +11,8 @@ from torchao.quantization.quant_primitives import ( get_group_qparams_symmetric, get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor_from_qparams, + groupwise_affine_dequantize_tensor_from_qparams, quantize_affine, dequantize_affine, choose_qparams_affine, @@ -38,6 +40,86 @@ def check_idempotent(self, fn, *args, **kwargs): self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.") return output1 +# Legacy tinygemm ops +def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + # assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to( + dtype=dtype + ).reshape(w.shape[0], -1) + +def _groupwise_affine_quantize_tensor_from_qparams( + w, + scales, + zeros, + n_bit=4, + groupsize=128, +): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + # assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int4x8 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int4x8 + +def _groupwise_affine_dequantize_tensor_from_qparams( + w_int4x8, + scales, + zeros, + n_bit=4, + groupsize=128, +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int4x8.shape[-1] + assert w_int4x8.shape[-1] % groupsize == 0 + assert w_int4x8.dim() == 2 + + w_int4x8_grouped = w_int4x8.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int4x8_grouped.sub(2 ** (n_bit - 1)) + .mul(scales) + .add(zeros) + .reshape_as(w_int4x8) + ) + return w_dq + class TestQuantPrimitives(unittest.TestCase): SEED = 123 @@ -356,12 +438,12 @@ def test_not_preserve_zero_not_supported(self): ) - def test_tinygemm_get_groupwise_affine_qparams(self): + def test_get_groupwise_affine_qparams(self): from torchao.quantization.quant_primitives import ZeroPointDomain input = torch.randn(10, 256) n_bit = 4 - scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) + scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 @@ -389,6 +471,29 @@ def test_tinygemm_get_groupwise_affine_qparams(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zero_point_ref)) + def test_groupwise_affine_quantize_tensor_from_qparams(self): + input = torch.randn(10, 256) + scales = torch.randn(10, 2) + zeros = torch.randn(10, 2) + n_bit = 4 + groupsize = 128 + + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + + self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) + + def test_groupwise_affine_dequantize_tensor_from_qparams(self): + input = torch.randint(0, 15, (10, 256), dtype=torch.int32) + scales = torch.randn(10, 2).bfloat16() + zeros = torch.randn(10, 2).bfloat16() + n_bit = 4 + groupsize = 128 + + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + + self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 34dddb0516..5b7b7e021f 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -252,7 +252,7 @@ def dequantize_affine( # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size - assert input.dtype == input_dtype + assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) @@ -647,22 +647,36 @@ def quant_int8_per_token_matmul( def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): - """This is tinygemm specific, we'll keep this for now""" if groupsize > w.shape[-1]: groupsize = w.shape[-1] assert groupsize > 1 assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 + assert n_bit <= 8, f"only n_bit smaller than 8 is supported, got: {n_bit}" - to_quant = w.reshape(-1, groupsize) - # assert torch.isnan(to_quant).sum() == 0 + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int32 + block_size = (1, groupsize) + quant_min = 0 + quant_max = 2**n_bit - 1 + eps = 1e-6 + scale_dtype = dtype + zero_point_dtype = dtype - max_val = to_quant.amax(dim=1, keepdim=True) - min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2**n_bit - 1 - scales = (max_val - min_val).clamp(min=1e-6) / max_int - zeros = min_val + scales * (2 ** (n_bit - 1)) - return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to( + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=False, + zero_point_domain=ZeroPointDomain.FLOAT + ) + + return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero.to( dtype=dtype ).reshape(w.shape[0], -1) @@ -695,7 +709,6 @@ def groupwise_affine_quantize_tensor_from_qparams( n_bit=4, groupsize=128, ): - """This is tinygemm specific, we'll keep this for now""" assert groupsize > 1 # needed for GPTQ single column quantize if groupsize > w.shape[-1] and scales.shape[-1] == 1: @@ -704,25 +717,12 @@ def groupwise_affine_quantize_tensor_from_qparams( assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 - to_quant = w.reshape(-1, groupsize) - # assert torch.isnan(to_quant).sum() == 0 - - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - min_val = zeros - scales * (2 ** (n_bit - 1)) - max_int = 2**n_bit - 1 - min_int = 0 - w_int4x8 = ( - to_quant.sub(min_val) - .div(scales) - .round() - .clamp_(min_int, max_int) - .to(torch.int32) - .reshape_as(w) - ) - - return w_int4x8 + block_size = (1, groupsize) + output_dtype = torch.int32 + quant_min = 0 + quant_max = 2 ** n_bit - 1 + return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) def groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, @@ -731,7 +731,6 @@ def groupwise_affine_dequantize_tensor_from_qparams( n_bit=4, groupsize=128, ): - """This is tinygemm specific, we'll keep this for now""" assert groupsize > 1 # needed for GPTQ single column dequantize if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1: @@ -739,17 +738,11 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert w_int4x8.shape[-1] % groupsize == 0 assert w_int4x8.dim() == 2 - w_int4x8_grouped = w_int4x8.reshape(-1, groupsize) - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - - w_dq = ( - w_int4x8_grouped.sub(2 ** (n_bit - 1)) - .mul(scales) - .add(zeros) - .reshape_as(w_int4x8) - ) - return w_dq + block_size = (1, groupsize) + input_dtype = torch.int32 + quant_min = 0 + quant_max = 2**n_bit - 1 + return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):