Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantize_activation_per_token_absmax use general quant primitives #193

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,34 @@ def test_choose_qparams_tensor_sym(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_activation_per_token_abs_max(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.randn(10, 10)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)

mapping_type = MappingType.SYMMETRIC
block_size = list(input.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)

self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(scale, scale_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_zero_input(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10)
# make sure it still works
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_quantize_dequantize_group_sym(self):
Expand Down
24 changes: 13 additions & 11 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,23 @@ def dynamically_quantize_per_tensor(


def quantize_activation_per_token_absmax(t):
n_bits = 8
# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]

scales = t.abs().amax(dim=-1, keepdim=True)
if scales.dtype == torch.float16:
scales = (
scales.float()
) # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
q_max = 2 ** (n_bits - 1) - 1
scales = scales.clamp(min=1e-5).div(q_max)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor Author

@jerryzh168 jerryzh168 May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo @HDCharles do you know if this clampping before the scale is calculated is required for quantize_activation_per_token_absmax to work for our use cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need an epsilon to handle the case where max(abs(x)) is zero

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so can we do this after we divide by q_max? that's where we are doing the clamp typically

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    if eps is not None:
        scale = torch.clamp(scale, min=eps)

you have this logic in choose_qparams_affine, I would imagine it's for the same purpose. I would recommend:

  1. always specify epsilon (not sure why this is optional)
  2. ensure test cases include testing for max(abs(x)) == 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo Is eps best chosen to be (a multiple of / the) machine epsilon of the input tensor's dtype? Or is it a parameter that needs to be searched depending on input data distribution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll add a test. will think about if we want to make it required, we didn't do clamping in some of the ops right now I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will think about if we want to make it required

IMO the case of max(abs(x)) == 0 should be handled for every op

Is eps best chosen to be (a multiple of / the) machine epsilon of the input tensor's dtype?

I think that choosing eps based on the dtype makes sense, it should not be data dependent. My educated guess of how to choose this number would be "smallest number so that the resulting scale is finite", I haven't looked into whether expressing that in terms of machine epsilon would make sense.

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after a second thought I guess it makes sense to have an eps, otherwise we may have divide by zero during quantization, I'll change it in a separate PR

for eps, I think we can do torch.finfo(input_dtype).eps if it's not provided from user

mapping_type = MappingType.SYMMETRIC
block_size = list(t.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
# Note: the original smoothquant does not clamp to qmin/qmax here,
# but some of the tests with bfloat16 ended up with a flipped sign
# if we don't clamp. TODO(future) look into this further.
t = torch.round(t / scales).clamp(-127, 127).to(torch.int8)
return t, scales
quant_min = -127
quant_max = 127
scale, zero_point = choose_qparams_affine(t, mapping_type, block_size, dtype, quant_min, quant_max, eps, scale_dtype=torch.float)

quantized = quantize_affine(t, block_size, scale, zero_point, dtype, quant_min, quant_max)

return quantized, scale


def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
Expand Down
Binary file modified tutorials/quantize_vit/quant.json.gz
Binary file not shown.
Loading