diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 139116def2..90fd8f8bf0 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -126,6 +126,34 @@ def test_choose_qparams_tensor_sym(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + def test_quantize_activation_per_token_abs_max(self): + from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax + input = torch.randn(10, 10) + quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) + + mapping_type = MappingType.SYMMETRIC + block_size = list(input.shape) + for i in range(len(block_size) - 1): + block_size[i] = 1 + dtype = torch.int8 + eps = 1e-5 + quant_min = -127 + quant_max = 127 + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + + quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(scale, scale_ref)) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + def test_quantize_activation_per_token_abs_max_zero_input(self): + from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax + input = torch.zeros(10, 10) + # make sure it still works + quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") def test_quantize_dequantize_group_sym(self): diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c3089224de..90316e1557 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -330,21 +330,23 @@ def dynamically_quantize_per_tensor( def quantize_activation_per_token_absmax(t): - n_bits = 8 # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] - - scales = t.abs().amax(dim=-1, keepdim=True) - if scales.dtype == torch.float16: - scales = ( - scales.float() - ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) - q_max = 2 ** (n_bits - 1) - 1 - scales = scales.clamp(min=1e-5).div(q_max) + mapping_type = MappingType.SYMMETRIC + block_size = list(t.shape) + for i in range(len(block_size) - 1): + block_size[i] = 1 + dtype = torch.int8 + eps = 1e-5 # Note: the original smoothquant does not clamp to qmin/qmax here, # but some of the tests with bfloat16 ended up with a flipped sign # if we don't clamp. TODO(future) look into this further. - t = torch.round(t / scales).clamp(-127, 127).to(torch.int8) - return t, scales + quant_min = -127 + quant_max = 127 + scale, zero_point = choose_qparams_affine(t, mapping_type, block_size, dtype, quant_min, quant_max, eps, scale_dtype=torch.float) + + quantized = quantize_affine(t, block_size, scale, zero_point, dtype, quant_min, quant_max) + + return quantized, scale def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): diff --git a/tutorials/quantize_vit/quant.json.gz b/tutorials/quantize_vit/quant.json.gz index a207cefc5f..8caa43d81e 100644 Binary files a/tutorials/quantize_vit/quant.json.gz and b/tutorials/quantize_vit/quant.json.gz differ