diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 76e7cd9ff2..6152cb4e12 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -62,9 +62,14 @@ Affine quantization refers to the type of quantization that maps from floating p ### Quantization Primitives We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. +Note: these primitive ops supports two "types" of quantization, distinguished by whether `zero_point` is in floating point domain or integer domain. See docstrings for `choose_qparams` for more details. + ### Quantized Tensor Subclass We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) +#### Layouts +We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. + ### Quantization Flow Example Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul as an example: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 31ab71f385..d8e618635f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -364,6 +364,14 @@ def int4_weight_only(group_size=128, inner_k_tiles=8): Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel + Note: + This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference + of quantization algorithm compared to the more traditional type of integer quantization is the following: + 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) + 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) + please follow the relevant code in `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` + to learn about how the quantization parameters are chosen and how the Tensor is quantized/dequantized for tinygemm + Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 39f76928c4..4b960f2e60 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -324,6 +324,7 @@ def _dequantize_affine( dequant = dequant * scale else: assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}" + # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point