Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add custom CUDA tinygemm unpacker #415

Merged
merged 21 commits into from
Jul 4, 2024

Conversation

jeromeku
Copy link
Collaborator

@jeromeku jeromeku commented Jun 21, 2024

Description

Adds CUDA custom ops to unpack weights that have been packed with torch.ops.aten._convert_weight_to_int4pack for use with torch.ops.aten._weight_int4pack_mm.

Currently there is only a packing function that permutes and prepacks the weights in tensor-core format. However, there is no equivalent unpacking function that reorders the weights back to the original logical layout.

The implementation is an adaptation of the original packing code (int4mm.cu) with modifications to simplify indexing logic and fused unpacking & dequantization.

Motivation

Fast unpacking of packed weights is needed when switching quantized gemm backends during inference.

As workloads transition from memory-bound to compute-bound (i.e., context length growth during decoding), users might wish to switch to a different kernel implementation that is more performant in this regime than tinygemm.

In order to do this, the weights need to be unpacked from the packed format. Alternative would be to store 2 copies of the weights -- one packed, one in logical format -- but this is clearly not ideal given memory burden.

Features

Add 2 custom CUDA ops, registered per the instructions in torchao custom op documentation:

  1. torchao.ops.unpack_int4 - unpacks the packed weight to the original N x K logical layout with dtype torch.int. Can be used within TensorCoreTiledAQTLayout.get_plain to recover original layout of the (quantized) tensor.
  2. torchao.ops.dequantize_int4 - dequantizes the packed weight to bfloat16 tensor with original N x K logical layout. This is useful for developers who want to unpack and dequantize the packed weight when switching quantized matmul backends on the fly.

Tests

Tests have been added to test/test_ops.py for both correctness as well as for correct custom op registration.

Note: opcheck test_aot_dispatch_dynamic currently failing, investigating.

TODO

  • Fuse dequant into unpacking kernel
    • Kernel works against a reference implementation per my understanding of dequantization but needs further verification (see notes in test/test_ops.py:test_dequant_int4_correctness)
  • Implement dequantize for ZeroPointDomain.Float per update
  • Debug test_aot_dispatch_dynamic opcheck failure
  • Integrate with AQT get_plain

@msaroufim

Copy link

pytorch-bot bot commented Jun 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/415

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit e90e280 with merge base e5548b7 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 21, 2024
@jerryzh168
Copy link
Contributor

great, after this is landed we can replace this workaround code

def get_plain(self):
with the op I think

@jeromeku
Copy link
Collaborator Author

jeromeku commented Jun 22, 2024

@jerryzh168

Took a look at get_plain for int4 AQT type: will revise the kernel so that it more closely aligns with the API of get_plain.

Are there tests for get_plain that I can adapt to verify my implementation?

Currently working on fusing dequant into the unpacking kernel, however, a simple sanity check using the same logic as get_plain is failing.

That is, I'm using tinygemm to dequantize by passing in an identity matrix as operand a and packed weights, scales, and zeros, respectively. Comparing against a reference dequantize method that does (q - zero) * scale on the unpacked weights, scales, and zeros does not check out. Am I misinterpreting the scales and zeros?

@jeromeku jeromeku marked this pull request as draft June 22, 2024 16:26
@jerryzh168
Copy link
Contributor

@jerryzh168

Took a look at get_plain for int4 AQT type: will revise the kernel so that it more closely aligns with the API of get_plain.

Are there tests for get_plain that I can adapt to verify my implementation?

Currently working on fusing dequant into the unpacking kernel, however, a simple sanity check using the same logic as get_plain is failing.

That is, I'm using tinygemm to dequantize by passing in an identity matrix as operand a and packed weights, scales, and zeros, respectively. Comparing against a reference dequantize method that does (q - zero) * scale on the unpacked weights, scales, and zeros does not check out. Am I misinterpreting the scales and zeros?

we don't have get_plain() tests yet, but I'm planning to add some tests for AffineQuantizedTensor in the future

the way that tinygemm dequantize implmeneted is a bit different from the normal path, here is how it's implmeneted:

function call to our primitive ops:

return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)

code path:

main difference is the zero_point is in floating point domain (while the the quant/dequant that we are more familiar with is in integer domain):

integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale

@jeromeku
Copy link
Collaborator Author

jeromeku commented Jun 24, 2024

@jerryzh168

Many thanks on the clarification!

Helps explain why tinygemm kernel is able to use a single fma to dequantize (using integer zero-point would require a sub then mul unless zeros are stored as scales * zeros, which is not the case).

This is good to know as the original motivation for this PR was to help answer.ai / hqq developers who are using tinygemm as a quantized matmul backend. However, I believe hqq is using the integer zeropoint derivation (but keeping the zeropoint in the original floating-point dtype), which will result in incorrect results when using tinygemm kernel, which is dequantizing based on the floating-point zeropoint calculation.

What is the mathematical derivation of the float dequantization method vs the more common integer quantization scheme? Are there any papers / blogs that explain the reasoning for this difference?

@jerryzh168
Copy link
Contributor

Helps explain why tinygemm kernel is able to use a single fma to dequantize (using integer zero-point would require a sub then mul unless zeros are stored as scales * zeros, which is not the case).

yes, the motivation for tinygemm to have zero_point in floating point domain is exactly to use a single fma (I talked to Jeff about this).

This is good to know as the original motivation for this PR was to help answer.ai / hqq developers who are using tinygemm as a quantized matmul backend. However, I believe hqq is using the integer zeropoint derivation (but keeping the zeropoint in the original floating-point dtype), which will result in incorrect results when using tinygemm kernel, which is dequantizing based on the floating-point zeropoint calculation.

for hqq yeah we need to make sure this detail is correct since they are using tinygemm kernels. cc @HDCharles @mobicham

What is the mathematical derivation of the float dequantization method vs the more common integer quantization scheme? Are there any papers / blogs that explain the reasoning for this difference?

I'm not aware of any formal papers or blogs. so the differences are shown in our quant_primitive ops in these two flags:

preserve_zero (bool): a flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero padding, like convolution
it's less important for ops that doesn't have zero padding in the op itself, like linear.
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True,
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such
gurantee.
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT

traditional integer quantization:

  1. preserve_zero is True: this is because traditionally we use quantization on conv, and it can have zero_padding, so there is a domain specific requirement of floating point zero has to be exactly representable: https://github.com/google/gemmlowp/blob/master/doc/quantization.md#domain-specific-constraint-the-real-value-0-must-be-exactly-representable

  2. zero_point is in integer domain
    This is probably for static quantization where there are hardwares that only support integer compute

tinygemm:

  1. preserve_zero is False because mainly we care about linear, this will also help improve accuracy in some cases since we don't need to always include zero during quantization
  2. zero_point is in floating point domain
    this is because of fma I think

@jeromeku
Copy link
Collaborator Author

@jerryzh168

Thanks!

Looking at gpt-fast quantization code here, scales / zeros are calculated as:

  scales = (max_val - min_val).clamp(min=1e-6) / max_int
  zeros = min_val + scales * (2 ** (n_bit - 1)) 

where 2 ** (nbit-1) is the mid-point. Hence it is fusing the mid-point into the zeros.

Then in the tinygemm kernel, dequantization is performed as q * scale + zero, which is close to but does not match exactly ZeroPointDomain.FLOAT dequant. Can you elaborate on what zero_point (float) refers to here?

Comment on lines 265 to 266
m.impl("torchao::unpack_int4_to_int", &_unpack_int4_to_int);
m.impl("torchao::dequantize_int4", &_dequantize_int4);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel we probably need to mention tensor_core_tiled layout in the name of these ops if these are specific to that packing format

@jerryzh168
Copy link
Contributor

@jerryzh168

Thanks!

Looking at gpt-fast quantization code here, scales / zeros are calculated as:

  scales = (max_val - min_val).clamp(min=1e-6) / max_int
  zeros = min_val + scales * (2 ** (n_bit - 1)) 

where 2 ** (nbit-1) is the mid-point. Hence it is fusing the mid-point into the zeros.

Then in the tinygemm kernel, dequantization is performed as q * scale + zero, which is close to but does not match exactly ZeroPointDomain.FLOAT dequant. Can you elaborate on what zero_point (float) refers to here?

sorry I just wrote down the quantize function there, it's not the dequant function, I should probably add all algorithms (choose_qparams, quant, dequant) there. the dequant we are using is here:

assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
if zero_point is not None:
dequant += zero_point

@jeromeku
Copy link
Collaborator Author

@jerryzh168
Sorry for persisting on this matter but still a gap in my understanding:

If we unpack what gpt-fast and tinygemm is doing:

  scales = (max_val - min_val).clamp(min=1e-6) / max_int
  zeros = min_val + scales * (2 ** (n_bit - 1)) = min_val + scales * mid_point

Then in tinygemm, dequantization is calculated as:

x = q * scales + zeros 
   =  q * scales + min_val + scales * mid_point

Where x is dequantized value, q is quantized value.

Comparing this to torchao dequant per your link:

 x = (q - mid_point) * scales + zeros 
    = q * scales - scales * mid_point + zeros

Assuming zero_point is calculated per gpt_fast:

x = q * scales - scales * mid_point + min_val + scales * mid_point
   = q * scales + min_val

How to reconcile these differences? How are zeros expected to be calculated in torchao?

@jerryzh168
Copy link
Contributor

@jerryzh168 Sorry for persisting on this matter but still a gap in my understanding:

If we unpack what gpt-fast and tinygemm is doing:

  scales = (max_val - min_val).clamp(min=1e-6) / max_int
  zeros = min_val + scales * (2 ** (n_bit - 1)) = min_val + scales * mid_point

Then in tinygemm, dequantization is calculated as:

x = q * scales + zeros 
   =  q * scales + min_val + scales * mid_point

Where x is dequantized value, q is quantized value.

Comparing this to torchao dequant per your link:

 x = (q - mid_point) * scales + zeros 
    = q * scales - scales * mid_point + zeros

Assuming zero_point is calculated per gpt_fast:

x = q * scales - scales * mid_point + min_val + scales * mid_point
   = q * scales + min_val

How to reconcile these differences? How are zeros expected to be calculated in torchao?

yeah no problem, I think the main difference as you listed is this part:

> Then in `tinygemm`, dequantization is calculated as:
> 
> ```python
> x = q * scales + zeros 
>    =  q * scales + min_val + scales * mid_point
> ```

I feel the q here should probably be (q - mid_point)

I'm not very familiar with tinygemm kernel implementation itself, but I think this should be accounted for either by preprocess of q or post process of the results after that dequant op.

also for some additional context, current quant primitives in torchao are adapted from the original gpt-fast/tinygemm choose_qparams/quantize/dequantize implemnetations and we have regression tests to make sure they match:

# Legacy tinygemm ops

@jerryzh168
Copy link
Contributor

maybe related to https://github.com/pytorch/pytorch/blob/93a33bf3ac0b4c9560b49780eabcad2f76dcf43e/aten/src/ATen/native/cuda/int4mm.cu#L197

cc @HDCharles do you know how tinygemm kernel dequant implementation match up with the python dequant implementation?

@jeromeku
Copy link
Collaborator Author

@jerryzh168

For the purposes of this PR, then, what should be the expected behavior of dequantize_int4?

That is, given packed weights, scales, zeros, etc., what should be the calculation to dequantize the weights from int4 to bfloat16?

Checking the quant_primitives dequant methodology against calling tinygemm with an identity matrix to dequant gives a small error (~ 1e-2), see this script.

@jerryzh168
Copy link
Contributor

@jeromeku OK I just confirmed with Jeff Johnson that this code is actually doing both a uint4 -> int4 conversion ([0, 15] --> [-8, 7]) which is equivalent to (q_val - mid_point) in our dequant code, and also a conversion to bfloat16

so I feel the dequantize_op in this case should follow what we are doing in dequantize_affine op, at least for int4, I also need to think a bit about this uint4 -> int4 conversion stuff, I feel it should probably be done outside of quant primitives op

for test, what you described make sense, I think you can do two test:

  1. quant + dequant v.s. (quant + pack + packed_dequant_int4 (your op))
  2. quant + pack + tinygemm_int4_mm with identity matrix v.s. (quant + pack + packed_dequant_int4 (your op))

@jeromeku
Copy link
Collaborator Author

@jerryzh168

Thanks for the clarification.

Will update the dequant kernel to reflect this change: u4 -> s4 + upcast followed by scale + shift and add relevant tests.

Would be good to add some additional documentation explaining the pre-processing / post-processing needed to use quantized weights, scales, and zero-points prepared using "conventional" (ZeroPointDomain.INT) schemes for use with tinygemm, e.g., hqq.

@jerryzh168
Copy link
Contributor

sure thanks, I'll add some docs for quant_primitives in our README

test/test_ops.py Outdated Show resolved Hide resolved
test/test_ops.py Outdated Show resolved Hide resolved
test/test_ops.py Outdated Show resolved Hide resolved
torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1")
torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2")

return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this, is this supposed to call dequantize_tensor_core_tiled_layout

Copy link
Collaborator Author

@jeromeku jeromeku Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was the expected pattern for registering a custom op? I was following the example of the pre-existing fp6_linear custom op already in ops.py.

Previously one would register an abstract impl for composability with torch.compile. Thought this was the expected interface with the new custom op registration API. That is, register a "fake" implementation that runs checks and just returns the expected shape of the output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh OK, I think I understand now, register_custom_op is calling register_fake/impl_abstract, I feel we need to rename this util to something more accurate, cc @msaroufim

torchao/ops.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good overall, thanks for working on this @jeromeku! just had a few nits + requested additional tests and questions around the motivation of having two ops

jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Jul 2, 2024
Summary:
att, per request in pytorch#415 (comment)

Test Plan:
doc changes

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Jul 2, 2024
Summary:
att, per request in pytorch#415 (comment)

Test Plan:
doc changes

Reviewers:

Subscribers:

Tasks:

Tags:
msaroufim pushed a commit that referenced this pull request Jul 3, 2024
Summary:
att, per request in #415 (comment)

Test Plan:
doc changes

Reviewers:

Subscribers:

Tasks:

Tags:
@jeromeku
Copy link
Collaborator Author

jeromeku commented Jul 3, 2024

@jerryzh168

Fixed all the above:

  • Renamed innerKTiles -> inner_k_tiles
  • Changed unpack test from testing for closeness to equality
  • Added additional (unpack_tensor_core_tiled_layout_op + dequant) vs. dequantize_tensor_core_tiled_layout_op test
  • Added comments clarifying the logic of the fused dequant kernel tests
    • Since tinygemm id matrix dequant hack and the fused dequant kernel utilize the same underlying fast CUDA numeric conversion path from u4 -> s4 -> bf16, they have identical numerics
    • Both result in ~1e-2 discrepancies when compared with ao groupwise_affine_dequantize
    • These conditions are tested for in the dequantize_tensor_core_layout tests:
      • difference between tinygemm id matrix dequant and dequant_tensor_core_layout is 0
      • the difference between tinygemm id matrix vs. groupwise_affine_dequant is the same as the difference between dequantize_tensor_core_layout and groupwise_affine_dequant.


return torch.empty((N, K), dtype=torch.int32, device=packed_w.device)

def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor:
Copy link
Contributor

@jerryzh168 jerryzh168 Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this specific to uint4 btw?

looks like so, maybe we can add uint4 to the name as well in that case, unless this layout makes sense for other dtypes as well and we want to extend it to other dtypes in the future

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! really appreciate adding this functionality and the thorough comments/testing!

@msaroufim
Copy link
Member

Just a minor merge conflict and this should be good to merge

@jeromeku
Copy link
Collaborator Author

jeromeku commented Jul 4, 2024

@jerryzh168 @msaroufim

Getting CI failure unrelated to PR:

 =========================== short test summary info ============================
  FAILED test/integration/test_integration.py::SmoothquantIntegrationTest::test_on_dummy_distilbert - requests.exceptions.ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 72162155-ccbe-44a5-9304-e6e08c061a4e)')
  ====== 1 failed, 202 passed, 556 skipped, 11 warnings in 93.47s (0:01:33) ======
  Error: Process completed with exit code 1.

@msaroufim msaroufim self-requested a review July 4, 2024 01:50
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool thank you for the awesome work @jeromeku and thank you for the thorough review @jerryzh168

The CI failure indeed seems unrelated, most likely a flake due to connection issues with HF

@msaroufim msaroufim merged commit 74846da into pytorch:main Jul 4, 2024
12 of 13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
…#469)

Summary:
att, per request in pytorch#415 (comment)

Test Plan:
doc changes

Reviewers:

Subscribers:

Tasks:

Tags:
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* add unpack cuda

* add tests

* fix tests

* refactor tinygemm unpacking kernel

* add dequant

* add additional dequant check

* update tinygemm dequantize test

* correct dequant kernel logic

* clean up kernel

* update dequantize kernel tests

* rename kernel ops to tensor_core_tiled_layout

* add renamed kernel source

* add back test_aot_dispatch opcheck

* rename innerKTiles to inner_k_tiles

* add unpack and dequant test

* additional numerical checks for unpack then dequant

* rebase test_ops on main

* remove commented out code

* skip dynamic opcheck unless torch>=2.5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants