Skip to content

Commit

Permalink
Quantization Docstrings (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored May 25, 2024
1 parent bea1927 commit a7bc592
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 82 deletions.
82 changes: 71 additions & 11 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
except:
from torch._inductor.runtime.runtime_utils import do_bench

from .utils import TORCH_VERSION_AFTER_2_4

aten = torch.ops.aten

AUTOQUANT_CACHE = {}
Expand All @@ -28,10 +26,21 @@ def check_cache(cls, shapes_and_dtype):
def update_cache(cls, shapes_and_dtype, res):
AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res

# TODO: Document the methods
class AutoQuantizableLinearWeight(torch.Tensor):
"""
when run, finds best type of quantization for this tensor and swaps itself with that
A subclass of torch.Tensor that, when run, finds the best type of quantization for itself and swaps
its data with the quantized version.
Args:
weight (torch.Tensor): The initial weight tensor.
qtensor_class_list (list): A list of tensor classes to be considered for quantization.
*args: Additional positional arguments.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
**kwargs: Additional keyword arguments.
"""

@staticmethod
def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs):
kwargs["device"] = weight.device
Expand Down Expand Up @@ -214,7 +223,18 @@ def _is_interpolate_mode(mode):

class AQMixin():
"""
Mixin to turn normal quantized subclasses into autoquantizable ones
Tests and benchmarks the autoquantization process for the given activation matrix, weight, and bias.
Args:
act_mat (torch.Tensor): The activation matrix.
weight (torch.Tensor): The weight tensor.
bias (torch.Tensor or None): The bias tensor.
best_time (float): The best time to beat for the quantization process.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
Returns:
float: The benchmarked time for the autoquantization process.
"""
@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
Expand All @@ -237,6 +257,20 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLi
"""
@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
"""
Tests and benchmarks the autoquantization process with special handling for interpolate mode.
Args:
act_mat (torch.Tensor): The activation matrix.
weight (torch.Tensor): The weight tensor.
bias (torch.Tensor or None): The bias tensor.
best_time (float): The best time to beat for the quantization process.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
Returns:
float: The benchmarked time for the autoquantization process.
"""
if not _is_interpolate_mode(mode):
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)

Expand Down Expand Up @@ -279,6 +313,17 @@ class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQ
"""
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
"""
Performs the quantized linear operations
Args:
act_mat (torch.Tensor): The activation matrix.
w_qtensor (torch.Tensor): The quantized weight tensor.
bias (torch.Tensor or None): The bias tensor.
Returns:
torch.Tensor: The result of the quantized operation.
"""
orig_dtype = act_mat.dtype
orig_shape = act_mat.shape
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
Expand Down Expand Up @@ -383,18 +428,33 @@ def change_autoquantizable_to_quantized(model, **kwargs):
torch._dynamo.config.automatic_dynamic_shapes = hold
torch._dynamo.reset()

# TODO: example_input seems weird to include in the API
# TODO: Document all the modes
# TODO: Mode being a list is weird, should be a string or some object
@torch.no_grad()
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs):
"""
wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model.
AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original
model and then letting the torch.compile run/tracing occur.
Example usage::
Wraps the given model in an AutoQuantWrapper. If `example_input` is provided, performs a forward pass on the input.
Otherwise, returns the wrapped model. The AutoQuantWrapper manages cases where the model is torch-compiled by first
performing autoquantization on the original model and then allowing the torch.compile run/tracing to occur.
Args:
model (torch.nn.Module): The model to be autoquantized.
example_input (Any, optional): An example input for the model. If provided, the function performs a forward pass
on this input. Defaults to None.
qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_CLASS_LIST.
filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"),
and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85].
**aq_kwargs: Additional keyword arguments for the autoquantization process.
Returns:
torch.nn.Module: The autoquantized and wrapped model. If `example_input` is provided, the function performs a forward pass
on the input and returns the result of the forward pass.
Example usage:
torchao.autoquant(torch.compile(model))
model(*example_input)
"""
# the hook we will use to intercept the model forward and perform
# autoquantization
Expand Down
28 changes: 26 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
both because primitives were designed based on the fusions that
come along with it and because that is how we access the intended quantized
and mixed GEMM kernels
TODO: There are 2 different approaches to quantizing a model. The first and more historically
popular approach is to use module swaps which explicitly change the linear modules and the second
approach is to instead use subclasses to change the interpretation of the linear module
"""

import torch
Expand Down Expand Up @@ -51,6 +55,7 @@
"Int4WeightOnlyQuantizer",
"quantize",
"autoquant",
"_get_subclass_inserter",
]

if TORCH_VERSION_AFTER_2_3:
Expand All @@ -72,8 +77,17 @@ def _replace_with_custom_fn_if_matches_filter(
cur_fqn="",
) -> None:
"""
For each `child` in `model`, replaces it with `replacement_fn(child)`
if `filter_fn(child)` is `True`
Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
if `filter_fn(child)` returns `True`.
Args:
model (torch.nn.Module): The model containing modules to be replaced.
replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules.
filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
Returns:
None
"""
if filter_fn(model, cur_fqn[:-1]):
model = replacement_fn(model)
Expand Down Expand Up @@ -125,6 +139,16 @@ def apply_dynamic_quant(model, filter_fn=None):
import torch.nn.utils.parametrize as parametrize

def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
"""
Returns a function which inserts the given subclass into all linear modules
in the model. The inserted module will have its weight set to the result of
`cls(mod.weight, **kwargs)`. If parametrization is enabled then this will be done using
torch.nn.utils.parametrize instead of directly setting the attribute on the module.
Args:
cls (torch.Tensor): The class to insert as a child module.
kwargs (Any): Any additional arguments for the constructor.
"""
constructor = kwargs.pop("constructor", "subclass_constructor")
from_float = kwargs.pop("method", "from_float")
def insert_subclass(lin):
Expand Down
75 changes: 39 additions & 36 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ def choose_qparams_affine(


# copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736

def dynamically_quantize_per_tensor(
x,
quant_min,
Expand All @@ -401,8 +400,6 @@ def dynamically_quantize_per_tensor(
# taken from
# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26
# and slightly modified


def quantize_activation_per_token_absmax(t):
# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]
mapping_type = MappingType.SYMMETRIC
Expand All @@ -426,10 +423,12 @@ def quantize_activation_per_token_absmax(t):


def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
"""
assumes symmetric quantization
assumes axis == 0
assumes dense memory format
TODO(future): relax ^ as needed
"""

assert x.dim() == 2, "only support 2d Tensors"

Expand Down Expand Up @@ -512,21 +511,22 @@ def quant_int8_matmul(
w_scales,
out_dtype=torch.float32,
):
# Quantized matmul of int8 operands that accumulates to int32 and returns
# out_dtype. For now, this is written for approximate numerical
# correctness, and things like aligning accumulation behaviors and
# performance optimizations are left for a future PR.
# Assumes that weight quantization is symmetric, i.e. w_zp is 0.
# Assumes that weight quantization is per-channel.

# see
# https://github.com/google/gemmlowp/blob/master/doc/quantization.md
# for an overview of quantized matmul compute

# in scalar form, assuming out_dtype is fp32 and zw == 0:
#
# Y_i_j_fp32 = sx * sw (dot(X_i, W_j) - zx * sum(W_j))
#
"""
Quantized matmul of int8 operands that accumulates to int32 and returns
out_dtype. For now, this is written for approximate numerical
correctness, and things like aligning accumulation behaviors and
performance optimizations are left for a future PR.
Assumes that weight quantization is symmetric, i.e. w_zp is 0.
Assumes that weight quantization is per-channel.
see
https://github.com/google/gemmlowp/blob/master/doc/quantization.md
for an overview of quantized matmul compute
in scalar form, assuming out_dtype is fp32 and zw == 0:
Y_i_j_fp32 = sx * sw (dot(X_i, W_j) - zx * sum(W_j))
"""

assert x_vals_int8.dtype in (
torch.uint8,
Expand Down Expand Up @@ -571,8 +571,10 @@ def quant_int8_dynamic_per_token_linear(
bias,
out_dtype,
):
# like F.linear, but with int8 dynamic quantization of activation,
# and a quantized weight
"""
like F.linear, but with int8 dynamic quantization of activation,
and a quantized weight
"""
x_vals_int8, x_scales = quantize_activation_per_token_absmax(x)
mm_out = quant_int8_per_token_matmul(
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype
Expand All @@ -589,20 +591,21 @@ def quant_int8_per_token_matmul(
w_scales,
output_dtype=torch.float32,
):
# Quantized matmul of int8 operands that accumulates to int32 and returns
# output_dtype. For now, this is written for approximate numerical
# Assumes that activation and weight quantization are symmetric,
# i.e. act_zp and w_zp is 0.
# Assumes that weight quantization is per-channel.
"""
Quantized matmul of int8 operands that accumulates to int32 and returns
output_dtype. For now, this is written for approximate numerical
Assumes that activation and weight quantization are symmetric,
i.e. act_zp and w_zp is 0.
Assumes that weight quantization is per-channel.
# see
# https://github.com/google/gemmlowp/blob/master/doc/quantization.md
# for an overview of quantized matmul compute
see
https://github.com/google/gemmlowp/blob/master/doc/quantization.md
for an overview of quantized matmul compute
# in scalar form, assuming output_dtype is fp32 and zw == 0:
#
# Y_i_j_fp32 = sx * sw dot(X_i, W_j)
#
in scalar form, assuming output_dtype is fp32 and zw == 0:
Y_i_j_fp32 = sx * sw dot(X_i, W_j)
"""

assert (
x_vals_int8.dtype == torch.int8
Expand Down
40 changes: 32 additions & 8 deletions torchao/quantization/smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@

def get_scale(X_absmax, W_absmax, alpha=0.5):
"""
Calculate the scale based on abs(max(X)), abs(max(W)) and alpha
If X is of dimension `b*n*k` and W is dimension `k*m`, the returned
scale is of dimension `k`.
Note: X_absmax is calculated outside of this function because we
need to keep a running version of it during calibration. W_absmax
is calculated outside of this function for consistency with X_absmax.
Calculate the scale based on abs(max(X)), abs(max(W)), and alpha.
Args:
X_absmax (torch.Tensor): Absolute maximum values of the input tensor X.
W_absmax (torch.Tensor): Absolute maximum values of the weight tensor W.
alpha (float, optional): Scaling factor. Defaults to 0.5.
Returns:
torch.Tensor: The calculated scale of dimension `k` if X is of dimension `b*n*k` and W is of dimension `k*m`.
"""
X_pow = torch.pow(X_absmax, alpha)
W_pow = torch.pow(W_absmax, 1.0 - alpha)
Expand Down Expand Up @@ -210,6 +213,18 @@ def set_debug_x_absmax(self):
def swap_linear_with_smooth_fq_linear(
model, skip_fqn_list=None, cur_fqn="", alpha=0.5
) -> None:
"""
Replaces linear layers in the model with their SmoothFakeDynamicallyQuantizedLinear equivalents.
Args:
model (torch.nn.Module): The model containing linear layers to be replaced.
skip_fqn_list (list of str, optional): List of fully qualified names to skip during replacement. Defaults to None.
cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
alpha (float, optional): The scaling factor for SmoothQuant. Defaults to 0.5.
Returns:
None
"""

name_to_child = dict(model.named_children())
for name, child in name_to_child.items():
Expand All @@ -228,6 +243,17 @@ def swap_linear_with_smooth_fq_linear(


def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None:
"""
Prepares the model for inference by calculating the smoothquant scale for each SmoothFakeDynamicallyQuantizedLinear layer.
Args:
model (torch.nn.Module): The model containing SmoothFakeDynamicallyQuantizedLinear layers.
debug_skip_calibration (bool, optional): If True, sets the running maximum of activations to a debug value for performance benchmarking.
Defaults to False.
Returns:
None
"""
for _, mod in model.named_modules():
if isinstance(mod, tuple(source_cls_to_target_cls.values())):
if debug_skip_calibration:
Expand All @@ -237,8 +263,6 @@ def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None:

# useful for quickly toggling smoothquant debug settings on all smoothquant
# modules in a model


def set_smooth_fq_attribute(model, attribute_name, new_attribute_val):
for _, mod in model.named_modules():
if isinstance(mod, tuple(source_cls_to_target_cls.values())):
Expand Down
Loading

0 comments on commit a7bc592

Please sign in to comment.