diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 808f7d89d3..9eeb146f55 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -16,8 +16,6 @@ except: from torch._inductor.runtime.runtime_utils import do_bench -from .utils import TORCH_VERSION_AFTER_2_4 - aten = torch.ops.aten AUTOQUANT_CACHE = {} @@ -28,10 +26,21 @@ def check_cache(cls, shapes_and_dtype): def update_cache(cls, shapes_and_dtype, res): AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res +# TODO: Document the methods class AutoQuantizableLinearWeight(torch.Tensor): """ - when run, finds best type of quantization for this tensor and swaps itself with that + A subclass of torch.Tensor that, when run, finds the best type of quantization for itself and swaps + its data with the quantized version. + + Args: + weight (torch.Tensor): The initial weight tensor. + qtensor_class_list (list): A list of tensor classes to be considered for quantization. + *args: Additional positional arguments. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + **kwargs: Additional keyword arguments. """ + @staticmethod def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): kwargs["device"] = weight.device @@ -214,7 +223,18 @@ def _is_interpolate_mode(mode): class AQMixin(): """ - Mixin to turn normal quantized subclasses into autoquantizable ones + Tests and benchmarks the autoquantization process for the given activation matrix, weight, and bias. + + Args: + act_mat (torch.Tensor): The activation matrix. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor or None): The bias tensor. + best_time (float): The best time to beat for the quantization process. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + + Returns: + float: The benchmarked time for the autoquantization process. """ @classmethod def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): @@ -237,6 +257,20 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLi """ @classmethod def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + """ + Tests and benchmarks the autoquantization process with special handling for interpolate mode. + + Args: + act_mat (torch.Tensor): The activation matrix. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor or None): The bias tensor. + best_time (float): The best time to beat for the quantization process. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + + Returns: + float: The benchmarked time for the autoquantization process. + """ if not _is_interpolate_mode(mode): return super()._autoquant_test(act_mat, weight, bias, best_time, mode) @@ -279,6 +313,17 @@ class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQ """ @staticmethod def _quantized_op(act_mat, w_qtensor, bias): + """ + Performs the quantized linear operations + + Args: + act_mat (torch.Tensor): The activation matrix. + w_qtensor (torch.Tensor): The quantized weight tensor. + bias (torch.Tensor or None): The bias tensor. + + Returns: + torch.Tensor: The result of the quantized operation. + """ orig_dtype = act_mat.dtype orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) @@ -383,18 +428,33 @@ def change_autoquantizable_to_quantized(model, **kwargs): torch._dynamo.config.automatic_dynamic_shapes = hold torch._dynamo.reset() +# TODO: example_input seems weird to include in the API +# TODO: Document all the modes +# TODO: Mode being a list is weird, should be a string or some object @torch.no_grad() def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs): """ - wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model. - AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original - model and then letting the torch.compile run/tracing occur. - - Example usage:: - + Wraps the given model in an AutoQuantWrapper. If `example_input` is provided, performs a forward pass on the input. + Otherwise, returns the wrapped model. The AutoQuantWrapper manages cases where the model is torch-compiled by first + performing autoquantization on the original model and then allowing the torch.compile run/tracing to occur. + + Args: + model (torch.nn.Module): The model to be autoquantized. + example_input (Any, optional): An example input for the model. If provided, the function performs a forward pass + on this input. Defaults to None. + qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_CLASS_LIST. + filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), + and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. + **aq_kwargs: Additional keyword arguments for the autoquantization process. + + Returns: + torch.nn.Module: The autoquantized and wrapped model. If `example_input` is provided, the function performs a forward pass + on the input and returns the result of the forward pass. + + Example usage: torchao.autoquant(torch.compile(model)) model(*example_input) - """ # the hook we will use to intercept the model forward and perform # autoquantization diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 39a977dd00..d9b731bace 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -13,6 +13,10 @@ both because primitives were designed based on the fusions that come along with it and because that is how we access the intended quantized and mixed GEMM kernels + +TODO: There are 2 different approaches to quantizing a model. The first and more historically +popular approach is to use module swaps which explicitly change the linear modules and the second +approach is to instead use subclasses to change the interpretation of the linear module """ import torch @@ -51,6 +55,7 @@ "Int4WeightOnlyQuantizer", "quantize", "autoquant", + "_get_subclass_inserter", ] if TORCH_VERSION_AFTER_2_3: @@ -72,8 +77,17 @@ def _replace_with_custom_fn_if_matches_filter( cur_fqn="", ) -> None: """ - For each `child` in `model`, replaces it with `replacement_fn(child)` - if `filter_fn(child)` is `True` + Recursively replaces each child module in `model` with the result of `replacement_fn(child)` + if `filter_fn(child)` returns `True`. + + Args: + model (torch.nn.Module): The model containing modules to be replaced. + replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules. + filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. + cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". + + Returns: + None """ if filter_fn(model, cur_fqn[:-1]): model = replacement_fn(model) @@ -125,6 +139,16 @@ def apply_dynamic_quant(model, filter_fn=None): import torch.nn.utils.parametrize as parametrize def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs): + """ + Returns a function which inserts the given subclass into all linear modules + in the model. The inserted module will have its weight set to the result of + `cls(mod.weight, **kwargs)`. If parametrization is enabled then this will be done using + torch.nn.utils.parametrize instead of directly setting the attribute on the module. + + Args: + cls (torch.Tensor): The class to insert as a child module. + kwargs (Any): Any additional arguments for the constructor. + """ constructor = kwargs.pop("constructor", "subclass_constructor") from_float = kwargs.pop("method", "from_float") def insert_subclass(lin): diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index bc2d44e576..d86966b48c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -375,7 +375,6 @@ def choose_qparams_affine( # copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736 - def dynamically_quantize_per_tensor( x, quant_min, @@ -401,8 +400,6 @@ def dynamically_quantize_per_tensor( # taken from # https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 # and slightly modified - - def quantize_activation_per_token_absmax(t): # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] mapping_type = MappingType.SYMMETRIC @@ -426,10 +423,12 @@ def quantize_activation_per_token_absmax(t): def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): - # assumes symmetric quantization - # assumes axis == 0 - # assumes dense memory format - # TODO(future): relax ^ as needed + """ + assumes symmetric quantization + assumes axis == 0 + assumes dense memory format + TODO(future): relax ^ as needed + """ assert x.dim() == 2, "only support 2d Tensors" @@ -512,21 +511,22 @@ def quant_int8_matmul( w_scales, out_dtype=torch.float32, ): - # Quantized matmul of int8 operands that accumulates to int32 and returns - # out_dtype. For now, this is written for approximate numerical - # correctness, and things like aligning accumulation behaviors and - # performance optimizations are left for a future PR. - # Assumes that weight quantization is symmetric, i.e. w_zp is 0. - # Assumes that weight quantization is per-channel. - - # see - # https://github.com/google/gemmlowp/blob/master/doc/quantization.md - # for an overview of quantized matmul compute - - # in scalar form, assuming out_dtype is fp32 and zw == 0: - # - # Y_i_j_fp32 = sx * sw (dot(X_i, W_j) - zx * sum(W_j)) - # + """ + Quantized matmul of int8 operands that accumulates to int32 and returns + out_dtype. For now, this is written for approximate numerical + correctness, and things like aligning accumulation behaviors and + performance optimizations are left for a future PR. + Assumes that weight quantization is symmetric, i.e. w_zp is 0. + Assumes that weight quantization is per-channel. + + see + https://github.com/google/gemmlowp/blob/master/doc/quantization.md + for an overview of quantized matmul compute + + in scalar form, assuming out_dtype is fp32 and zw == 0: + + Y_i_j_fp32 = sx * sw (dot(X_i, W_j) - zx * sum(W_j)) + """ assert x_vals_int8.dtype in ( torch.uint8, @@ -571,8 +571,10 @@ def quant_int8_dynamic_per_token_linear( bias, out_dtype, ): - # like F.linear, but with int8 dynamic quantization of activation, - # and a quantized weight + """ + like F.linear, but with int8 dynamic quantization of activation, + and a quantized weight + """ x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) mm_out = quant_int8_per_token_matmul( x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype @@ -589,20 +591,21 @@ def quant_int8_per_token_matmul( w_scales, output_dtype=torch.float32, ): - # Quantized matmul of int8 operands that accumulates to int32 and returns - # output_dtype. For now, this is written for approximate numerical - # Assumes that activation and weight quantization are symmetric, - # i.e. act_zp and w_zp is 0. - # Assumes that weight quantization is per-channel. + """ + Quantized matmul of int8 operands that accumulates to int32 and returns + output_dtype. For now, this is written for approximate numerical + Assumes that activation and weight quantization are symmetric, + i.e. act_zp and w_zp is 0. + Assumes that weight quantization is per-channel. - # see - # https://github.com/google/gemmlowp/blob/master/doc/quantization.md - # for an overview of quantized matmul compute + see + https://github.com/google/gemmlowp/blob/master/doc/quantization.md + for an overview of quantized matmul compute - # in scalar form, assuming output_dtype is fp32 and zw == 0: - # - # Y_i_j_fp32 = sx * sw dot(X_i, W_j) - # + in scalar form, assuming output_dtype is fp32 and zw == 0: + + Y_i_j_fp32 = sx * sw dot(X_i, W_j) + """ assert ( x_vals_int8.dtype == torch.int8 diff --git a/torchao/quantization/smoothquant.py b/torchao/quantization/smoothquant.py index 35b54382c0..dd81bada7e 100644 --- a/torchao/quantization/smoothquant.py +++ b/torchao/quantization/smoothquant.py @@ -34,12 +34,15 @@ def get_scale(X_absmax, W_absmax, alpha=0.5): """ - Calculate the scale based on abs(max(X)), abs(max(W)) and alpha - If X is of dimension `b*n*k` and W is dimension `k*m`, the returned - scale is of dimension `k`. - Note: X_absmax is calculated outside of this function because we - need to keep a running version of it during calibration. W_absmax - is calculated outside of this function for consistency with X_absmax. + Calculate the scale based on abs(max(X)), abs(max(W)), and alpha. + + Args: + X_absmax (torch.Tensor): Absolute maximum values of the input tensor X. + W_absmax (torch.Tensor): Absolute maximum values of the weight tensor W. + alpha (float, optional): Scaling factor. Defaults to 0.5. + + Returns: + torch.Tensor: The calculated scale of dimension `k` if X is of dimension `b*n*k` and W is of dimension `k*m`. """ X_pow = torch.pow(X_absmax, alpha) W_pow = torch.pow(W_absmax, 1.0 - alpha) @@ -210,6 +213,18 @@ def set_debug_x_absmax(self): def swap_linear_with_smooth_fq_linear( model, skip_fqn_list=None, cur_fqn="", alpha=0.5 ) -> None: + """ + Replaces linear layers in the model with their SmoothFakeDynamicallyQuantizedLinear equivalents. + + Args: + model (torch.nn.Module): The model containing linear layers to be replaced. + skip_fqn_list (list of str, optional): List of fully qualified names to skip during replacement. Defaults to None. + cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". + alpha (float, optional): The scaling factor for SmoothQuant. Defaults to 0.5. + + Returns: + None + """ name_to_child = dict(model.named_children()) for name, child in name_to_child.items(): @@ -228,6 +243,17 @@ def swap_linear_with_smooth_fq_linear( def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None: + """ + Prepares the model for inference by calculating the smoothquant scale for each SmoothFakeDynamicallyQuantizedLinear layer. + + Args: + model (torch.nn.Module): The model containing SmoothFakeDynamicallyQuantizedLinear layers. + debug_skip_calibration (bool, optional): If True, sets the running maximum of activations to a debug value for performance benchmarking. + Defaults to False. + + Returns: + None + """ for _, mod in model.named_modules(): if isinstance(mod, tuple(source_cls_to_target_cls.values())): if debug_skip_calibration: @@ -237,8 +263,6 @@ def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None: # useful for quickly toggling smoothquant debug settings on all smoothquant # modules in a model - - def set_smooth_fq_attribute(model, attribute_name, new_attribute_val): for _, mod in model.named_modules(): if isinstance(mod, tuple(source_cls_to_target_cls.values())): diff --git a/torchao/quantization/unified.py b/torchao/quantization/unified.py index 16112ac0f0..7da915dec7 100644 --- a/torchao/quantization/unified.py +++ b/torchao/quantization/unified.py @@ -1,9 +1,19 @@ import torch from typing import Any +from abc import ABC, abstractmethod + +""" +The vast majority of quantization algorithms follow one of two patterns +1. Single quantize call to create a quantized model with quantized state_dict +2. Flow that needs calibration or training + +This file defines the API for both patterns +""" + -############################# Unified Quantization APIs ############################## # API 1, single quantize call to create a quantized model with quantized state_dict -class Quantizer: +class Quantizer(ABC): + @abstractmethod def quantize( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: @@ -13,6 +23,7 @@ def quantize( # API 2, flow that needs calibration or training class TwoStepQuantizer: + @abstractmethod def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: @@ -24,6 +35,3 @@ def convert( ) -> torch.nn.Module: pass - - -############################# Unified Quantization APIs ############################## diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 948c1357c8..74cb7deb20 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -138,8 +138,6 @@ def unwrap_tensor_subclass(model, filter_fn=None): # https://discuss.pytorch.org/t/finding-model-size/130275 - - def get_model_size_in_bytes(model): s = 0 for p in model.parameters(): diff --git a/torchao/quantization/weight_only.py b/torchao/quantization/weight_only.py index 099df0f17f..bb6a0136ef 100644 --- a/torchao/quantization/weight_only.py +++ b/torchao/quantization/weight_only.py @@ -5,19 +5,34 @@ # LICENSE file in the root directory of this source tree. import torch - from .quant_primitives import dynamically_quantize_per_channel __all__ = ["WeightOnlyInt8QuantLinear"] - class WeightOnlyInt8QuantLinear(torch.nn.Linear): """ This class is a replacement for `torch.nn.Linear`. It implements a - mixed dtype matmul using int8 symmetric per-channel weight quantization + mixed dtype matrix multiplication using int8 symmetric per-channel weight quantization. + + The primary goal of this class is to leverage int8 quantization for weights to reduce the + memory footprint and computational requirements while performing linear transformations. + This can be particularly beneficial for deploying models in low latency environments + + Attributes: + w_int8 (torch.Tensor): The quantized weights in int8 format. + scales (torch.Tensor): The scaling factors for each channel to convert the quantized + weights back to floating point format during the forward pass. """ def __init__(self, *args, **kwargs): + """ + Initializes the WeightOnlyInt8QuantLinear module. + + Args: + *args: Variable length argument list for `torch.nn.Linear`. + **kwargs: Arbitrary keyword arguments. + Must include 'w_int8' (int8 quantized weights) and 'scales' (scaling factors). + """ w_int8 = kwargs.pop("w_int8") scales = kwargs.pop("scales") super().__init__(*args, **kwargs) @@ -25,21 +40,20 @@ def __init__(self, *args, **kwargs): self.register_buffer("w_int8", w_int8) self.register_buffer("scales", scales) - def forward(self, x, *args, **kwargs): + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ - Performs the forward pass of the quantized linear layer which consists - ofmixed dtype matmul using int8 symmetric per-channel weight quantization + Performs the forward pass of the quantized linear layer, which consists of + mixed dtype matrix multiplication using int8 symmetric per-channel weight quantization. Args: - X (torch.Tensor): The input floating point tensor to the quantized linear layer. + x (torch.Tensor): The input floating point tensor to the quantized linear layer. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. Returns: - torch.Tensor: The output floating point tensor after the quantized matmul and rescale. - + torch.Tensor: The output floating point tensor after the quantized matrix multiplication + and rescale. """ - # if len(x.shape)<=2: - # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales - # else: # turn x into 2d tensor, then undo it for y x_view = x.view(-1, x.shape[-1]) y = torch.mm(x_view, self.w_int8.to(x.dtype)) * self.scales y = y.reshape(*x.shape[:-1], -1) @@ -48,23 +62,25 @@ def forward(self, x, *args, **kwargs): return y @classmethod - def from_float(cls, mod): + def from_float(cls, mod: torch.nn.Linear): """ - Converts a `mod` of class `torch.nn.Linear` to the - `WeightOnlyInt8QuantLinear` class + Converts a `torch.nn.Linear` module to a `WeightOnlyInt8QuantLinear` module. + + This method performs the conversion by dynamically quantizing the weights of the original + floating point linear layer to int8 format and creating a new `WeightOnlyInt8QuantLinear` + instance with these quantized weights and the corresponding scaling factors. Args: mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. Returns: - WeightOnlyInt8QuantLinear: The converted quantized linear module. - + WeightOnlyInt8QuantLinear: The converted quantized linear module with int8 weights. """ w_fp32 = mod.weight w_int8, scales, _zp = dynamically_quantize_per_channel( w_fp32, -128, 127, torch.int8 ) - # create the new module with a toy size to ensure initialization is fast + # Create the new module with a toy size to ensure initialization is fast fake_in_features, fake_out_features = 8, 8 new_mod = cls( fake_in_features,