-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[quant] Add per block quantization primitives #159
Conversation
a02d061
to
794d9b5
Compare
return shape_for_reduction, reduction_dims | ||
|
||
|
||
def quantize_affine_per_block( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be easier to first write a version of this that assumes
input.dim() == len(block_size) == scale.dim() == zero_point.dim()
and then use various tools to implement broadcasting. But our ops should kind of imply the broadcasting here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch any ideas on how we can support broadcasting for the example I described here: #159 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Is the next step to rewrite existing quant primitives using this? How will this work for qdq ops currently living in pytorch?
torch.uint7: (0, 2**7-1), | ||
}) | ||
|
||
def _get_qmin_qmax(dtype, quant_min, quant_max): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this is more like _check_qmin_qmax
? Alternatively make quant_min
and quant_max
default to None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this actually combined two functions, i can split them as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are thinking about just don't allow quant_min/quant_max, will update this after we made a decision here, I'll add a TODO here
zero_point = zero_point.view(shape_after_reduction) | ||
|
||
quant = torch.clamp( | ||
torch.round(input / scale) + zero_point, quant_min, quant_max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is slightly different from existing quant primitives, e.g. quantize_per_token
does round after adding zp: https://github.com/pytorch/pytorch/blob/d40774f4ed4a45c70d49e66f4e1f197dfc274758/torch/ao/quantization/fx/_decomposed.py#L771
However, as written this is consistent with the existing torch.fake_quantize_per_channel_affine
, which adds the zp after round. Which one do we want to follow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should just choose one, we can start with this and check to see if we can adjust others to use this I think, we could make the dtypes more explicit as well
) | ||
|
||
self.assertTrue(torch.equal(quantized, quantized_ref)) | ||
self.assertTrue(torch.equal(dequantized, dequantized_ref)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this tell you how many elements were different and by how much? Should we use this instead?
torch.testing.assert_close(quantized, quantized_ref, atol=0, rtol=0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one is equal actually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accepting now with the intent of iterating on this over time
yeah we'll rewrite existing primitives in torchao to use this first, and then expand to pytorch later, we'll need to move the ops to pytorch in order to refactor the ops there |
I also tried to not include block_size as args and use scale size: (1, 1, 5, 1) I'm not sure how can we broadcast the scale to be size (1, 1, 10, 1) in order for it to be divided by input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these look significantly more complicated than the old code so i'm wondering if we can still torch.compile them into performant kernels?
Would like to see some ad-hoc microbenchmarks at least to indicate that we're not immediately going to see a huge perf hit from this change, at the very least for per token symmetric quant.
@HDCharles this is just a starting point, I'm planning to replace the existing ops in separate PRs and we can make improvement at that time, including making sure perf is good etc. does that sounds good? |
df719e3
to
93bd2e3
Compare
Summary: We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py Test Plan: python test/quantization/test_quant_primitives.py Reviewers: Subscribers: Tasks: Tags:
reduction_dims = [] | ||
cur_dim = 0 | ||
for i in range(len(block_size)): | ||
if block_size[i] != input_size[i] and block_size[i] > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is block_size[i] != input_size[i] and block_size[i] == 1
. As in if corresponding block size is 1. What would that mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block_size[i] == 1 means that for ith dimension, each slice will have their own qparams
reduction_dims.append(cur_dim + 1) | ||
cur_dim += 2 | ||
else: | ||
# block_size[i] == input_size[i] or block_size[i] == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i see it here
|
||
def quantize_affine( | ||
input: torch.Tensor, | ||
block_size: List[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we decided that you can use scale/zero point shape to infer htis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is some issues with broadcasting: #159 (comment) let me know if you have some ideas
""" | ||
# TODO: validations | ||
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) | ||
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you also validate that the blocksize should also correspond to the scale/zp size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me just add a TODO for now, so we don't over complicate the code, maybe we could remove some of the shape code if broadcasting is working in the future
dequant = input.to(torch.float32) | ||
scale = scale.to(torch.float32) | ||
if zero_point is not None: | ||
zero_point = zero_point.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not feel accurate. I think we have had some discussion around this that it should be `(input.to(torch.int32) - zero_point.to(torch.int32)).to(torch.float32) * scale)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense, will fix
quantized = quantize_affine(input, block_size, scale, zero_point, dtype) | ||
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) | ||
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float | ||
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add tests where you expect exceptions thrown
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for reviewing htis late so just leaving some comments. Mainly two
- Do we need block size or it can be drived?
- My concern around dequantize routine
sorry just saw the comments
|
input : (3, 3, 10, 10) It we assume scale is always of valid shape than to broadcast scale[2] == 5 to input[2]==10, we will have to interpret scale as blockwise scale where block size = 2., Right? |
Summary: We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py Test Plan: python test/quantization/test_quant_primitives.py Reviewers: Subscribers: Tasks: Tags:
…ch#159) * move loading of modelfor inference into _load_inference_model * type * load_inference_model * load_inference_model * typo
Summary:
We want to use this to replace all quantize/dequantize/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
Note: this PR only adds the ops, we'll do replacement in separate PRs and make sure it does not degrade the performance or accuracy
Test Plan:
python test/quantization/test_quant_primitives.py
Reviewers:
Subscribers:
Tasks:
Tags: