Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantize_activation_per_token_absmax use general quant primitives #193

Merged
merged 3 commits into from
May 3, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Apr 30, 2024

Summary:
att

Test Plan:
OSS CI

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2024
@jerryzh168
Copy link
Contributor Author

diff in generated quantized code: https://www.internalfb.com/phabricator/paste/view/P1226948181

@jerryzh168 jerryzh168 changed the title Refactor quantize_activation_per_token_absmax to use general quant … quantize_activation_per_token_absmax use general quant primitives Apr 30, 2024
@jerryzh168
Copy link
Contributor Author

Changes to generated quantized code for vit: https://www.internalfb.com/phabricator/paste/view/P1227030036

@cpuhrsch
Copy link
Contributor

cpuhrsch commented May 1, 2024

Hm those changes seem to indicate that something substantial is still different about this refactor. There must be a slight difference somewhere.

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented May 1, 2024

Hm those changes seem to indicate that something substantial is still different about this refactor. There must be a slight difference somewhere.

yeah I don't think the code is exactly the same, it's calling mostly the same ops though, but with different args, for example: https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py#L273 has different dim and keepdim args

is there a way to test the speed of generated code?

@jerryzh168
Copy link
Contributor Author

Hm those changes seem to indicate that something substantial is still different about this refactor. There must be a slight difference somewhere.

yeah I don't think the code is exactly the same, it's calling mostly the same ops though, but with different args, for example: main/torchao/quantization/quant_primitives.py#L273 has different dim and keepdim args

is there a way to test the speed of generated code?

looks like the benchmark code will print the time, let me check

@jerryzh168
Copy link
Contributor Author

before change: elapsed_time: 1.4610368347167968 milliseconds
after change: elapsed_time: 1.4671536254882813 milliseconds

looks like not much difference

@cpuhrsch
Copy link
Contributor

cpuhrsch commented May 1, 2024

I'm less worried about speed and more about correctness. I don't think a single benchmark datapoint will be conclusive here either. I saw an additional call to mul in your generated code. Can you try changing the generic primitives until it matches for this one example only? At least so we can see the difference in code needed to match.

@jerryzh168
Copy link
Contributor Author

I'm less worried about speed and more about correctness. I don't think a single benchmark datapoint will be conclusive here either. I saw an additional call to mul in your generated code. Can you try changing the generic primitives until it matches for this one example only? At least so we can see the difference in code needed to match.

sure, for correctness, are you referring to numerics? so we have regression tests here: https://github.com/pytorch/ao/blob/main/test/integration/test_integration.py#L586 and it passes

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented May 1, 2024

I just did a quick test with a 10x10 tensor, the quantized integer values match exactly, but there are some differences for scales, I think it's because it's doing clamping before dividing q_max, instead of after: https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py#L335

@jerryzh168 jerryzh168 force-pushed the dedup-2 branch 2 times, most recently from 4738a2c to 9ede63b Compare May 1, 2024 02:59
scales.float()
) # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
q_max = 2 ** (n_bits - 1) - 1
scales = scales.clamp(min=1e-5).div(q_max)
Copy link
Contributor Author

@jerryzh168 jerryzh168 May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo @HDCharles do you know if this clampping before the scale is calculated is required for quantize_activation_per_token_absmax to work for our use cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need an epsilon to handle the case where max(abs(x)) is zero

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so can we do this after we divide by q_max? that's where we are doing the clamp typically

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    if eps is not None:
        scale = torch.clamp(scale, min=eps)

you have this logic in choose_qparams_affine, I would imagine it's for the same purpose. I would recommend:

  1. always specify epsilon (not sure why this is optional)
  2. ensure test cases include testing for max(abs(x)) == 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo Is eps best chosen to be (a multiple of / the) machine epsilon of the input tensor's dtype? Or is it a parameter that needs to be searched depending on input data distribution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll add a test. will think about if we want to make it required, we didn't do clamping in some of the ops right now I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will think about if we want to make it required

IMO the case of max(abs(x)) == 0 should be handled for every op

Is eps best chosen to be (a multiple of / the) machine epsilon of the input tensor's dtype?

I think that choosing eps based on the dtype makes sense, it should not be data dependent. My educated guess of how to choose this number would be "smallest number so that the resulting scale is finite", I haven't looked into whether expressing that in terms of machine epsilon would make sense.

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after a second thought I guess it makes sense to have an eps, otherwise we may have divide by zero during quantization, I'll change it in a separate PR

for eps, I think we can do torch.finfo(input_dtype).eps if it's not provided from user

…primitives

Summary:
att

Test Plan:
OSS CI

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we verify the torch.compiled code is the same between the old/new primitives in a unit test? It seems pretty important to make sure this is doing what we want after torch.compile on an ongoing basis.

e.g. we use https://github.com/pytorch/ao/blob/main/test/integration/test_integration.py#L1110
to extract the triton code in this test, it'd be good if we could compare the generated triton code.

@jerryzh168
Copy link
Contributor Author

can we verify the torch.compiled code is the same between the old/new primitives in a unit test? It seems pretty important to make sure this is doing what we want after torch.compile on an ongoing basis.

e.g. we use main/test/integration/test_integration.py#L1110 to extract the triton code in this test, it'd be good if we could compare the generated triton code.

There are some differences for this specific op, see: #193 (comment) (for vit model), but we understand that it's from the order of when we do clampping: #193 (comment)

I feel we probably want to establish high level benchmarks like eval accuracy and performance numbers for these in the future.

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented May 3, 2024

can we verify the torch.compiled code is the same between the old/new primitives in a unit test? It seems pretty important to make sure this is doing what we want after torch.compile on an ongoing basis.

e.g. we use main/test/integration/test_integration.py#L1110 to extract the triton code in this test, it'd be good if we could compare the generated triton code.

oh do you mean to verify that mix_mm is called in the compiled code? that's a good idea, I can check that

edit:
it seems that this op is not related to mixed_mm though?

@msaroufim msaroufim merged commit 5364de6 into pytorch:main May 3, 2024
15 checks passed
@HDCharles
Copy link
Contributor

HDCharles commented May 9, 2024

can we verify the torch.compiled code is the same between the old/new primitives in a unit test? It seems pretty important to make sure this is doing what we want after torch.compile on an ongoing basis.
e.g. we use main/test/integration/test_integration.py#L1110 to extract the triton code in this test, it'd be good if we could compare the generated triton code.

There are some differences for this specific op, see: #193 (comment) (for vit model), but we understand that it's from the order of when we do clampping: #193 (comment)

I feel we probably want to establish high level benchmarks like eval accuracy and performance numbers for these in the future.

i don't think that's an issue, just reverse the order in the original op and compare the generated code for that, then.

the triton generation test is mostly about verifying perf between two things written in very different ways, functional changes aren't super important to that concern (though still relevant if they are significantly different, this one shouldn't have a large impact).

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented May 9, 2024

i don't think that's an issue, just reverse the order in the original op and compare the generated code for that, then.

the triton generation test is mostly about verifying perf between two things written in very different ways, functional changes aren't super important to that concern (though still relevant if they are significantly different, this one shouldn't have a large impact).

so my point is that the generated code may change and the change should be allowed if the generated code has desired op (like int4mm op), similar performance and accuracy

dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
…ytorch#193)

Summary:
att

Test Plan:
OSS CI

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Add GUI-based chat to table and mark as not available
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants