-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
quantize_activation_per_token_absmax
use general quant primitives
#193
Conversation
diff in generated quantized code: https://www.internalfb.com/phabricator/paste/view/P1226948181 |
quantize_activation_per_token_absmax
to use general quant …quantize_activation_per_token_absmax
use general quant primitives
Changes to generated quantized code for vit: https://www.internalfb.com/phabricator/paste/view/P1227030036 |
Hm those changes seem to indicate that something substantial is still different about this refactor. There must be a slight difference somewhere. |
yeah I don't think the code is exactly the same, it's calling mostly the same ops though, but with different args, for example: https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py#L273 has different dim and keepdim args is there a way to test the speed of generated code? |
looks like the benchmark code will print the time, let me check |
before change: elapsed_time: 1.4610368347167968 milliseconds looks like not much difference |
I'm less worried about speed and more about correctness. I don't think a single benchmark datapoint will be conclusive here either. I saw an additional call to |
sure, for correctness, are you referring to numerics? so we have regression tests here: https://github.com/pytorch/ao/blob/main/test/integration/test_integration.py#L586 and it passes |
I just did a quick test with a 10x10 tensor, the quantized integer values match exactly, but there are some differences for scales, I think it's because it's doing clamping before dividing q_max, instead of after: https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py#L335 |
4738a2c
to
9ede63b
Compare
scales.float() | ||
) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) | ||
q_max = 2 ** (n_bits - 1) - 1 | ||
scales = scales.clamp(min=1e-5).div(q_max) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo @HDCharles do you know if this clampping before the scale is calculated is required for quantize_activation_per_token_absmax
to work for our use cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need an epsilon to handle the case where max(abs(x)) is zero
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so can we do this after we divide by q_max? that's where we are doing the clamp typically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if eps is not None:
scale = torch.clamp(scale, min=eps)
you have this logic in choose_qparams_affine
, I would imagine it's for the same purpose. I would recommend:
- always specify epsilon (not sure why this is optional)
- ensure test cases include testing for
max(abs(x)) == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo Is eps
best chosen to be (a multiple of / the) machine epsilon of the input tensor's dtype? Or is it a parameter that needs to be searched depending on input data distribution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I'll add a test. will think about if we want to make it required, we didn't do clamping in some of the ops right now I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will think about if we want to make it required
IMO the case of max(abs(x)) == 0
should be handled for every op
Is eps best chosen to be (a multiple of / the) machine epsilon of the input tensor's dtype?
I think that choosing eps based on the dtype makes sense, it should not be data dependent. My educated guess of how to choose this number would be "smallest number so that the resulting scale is finite", I haven't looked into whether expressing that in terms of machine epsilon would make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after a second thought I guess it makes sense to have an eps, otherwise we may have divide by zero during quantization, I'll change it in a separate PR
for eps, I think we can do torch.finfo(input_dtype).eps if it's not provided from user
…primitives Summary: att Test Plan: OSS CI Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we verify the torch.compiled code is the same between the old/new primitives in a unit test? It seems pretty important to make sure this is doing what we want after torch.compile on an ongoing basis.
e.g. we use https://github.com/pytorch/ao/blob/main/test/integration/test_integration.py#L1110
to extract the triton code in this test, it'd be good if we could compare the generated triton code.
There are some differences for this specific op, see: #193 (comment) (for vit model), but we understand that it's from the order of when we do clampping: #193 (comment) I feel we probably want to establish high level benchmarks like eval accuracy and performance numbers for these in the future. |
oh do you mean to verify that mix_mm is called in the compiled code? that's a good idea, I can check that edit: |
i don't think that's an issue, just reverse the order in the original op and compare the generated code for that, then. the triton generation test is mostly about verifying perf between two things written in very different ways, functional changes aren't super important to that concern (though still relevant if they are significantly different, this one shouldn't have a large impact). |
so my point is that the generated code may change and the change should be allowed if the generated code has desired op (like int4mm op), similar performance and accuracy |
…ytorch#193) Summary: att Test Plan: OSS CI Reviewers: Subscribers: Tasks: Tags:
Add GUI-based chat to table and mark as not available
Summary:
att
Test Plan:
OSS CI
Reviewers:
Subscribers:
Tasks:
Tags: