diff --git a/ruff.toml b/ruff.toml index dee9710df4..04c9e32cca 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,4 +8,6 @@ include = [ "torchao/dtypes/nf4tensor.py", "test/dtypes/test_nf4.py", "torchao/float8/float8_tensor.py", + "torchao/quantization/linear_activation_weight_observer.py", + "test/quantization/test_observer.py", ] diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index e0c9257a96..1964913a12 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -1,5 +1,6 @@ import re import torch +import torch.nn as nn from torch.testing._internal.common_utils import TestCase from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, @@ -9,13 +10,23 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.quant_api import ( + insert_observers_, +) +from torch.testing._internal import common_utils import unittest + # NOTE: we can copy paste these here if we decide to deprecate them in torch.ao from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver + class TestQuantFlow(TestCase): def _test_obs_helper(self, obs1, obs2): - example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)] + example_inputs = [ + torch.randn(10, 2048), + torch.randn(10, 2048), + torch.randn(10, 2048), + ] for example_input in example_inputs: obs1(example_input) obs2(example_input) @@ -26,13 +37,29 @@ def _test_obs_helper(self, obs1, obs2): self.assertTrue(torch.allclose(zero_point1, zero_point2)) def test_min_max_per_tensor_affine(self): - obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) + obs = AffineQuantizedMinMaxObserver( + MappingType.ASYMMETRIC, + torch.uint8, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + ) ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine) self._test_obs_helper(obs, ref_obs) def test_min_max_per_channel_affine(self): - obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) - ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine) + obs = AffineQuantizedMinMaxObserver( + MappingType.ASYMMETRIC, + torch.uint8, + granularity_type=PerAxis(axis=0), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + ) + ref_obs = PerChannelMinMaxObserver( + dtype=torch.uint8, qscheme=torch.per_channel_affine + ) self._test_obs_helper(obs, ref_obs) def test_block_size_calc_success(self): @@ -109,5 +136,82 @@ def test_block_size_row_errors(self): obs(example_input) +class TestLinearObserver(TestCase): + @common_utils.parametrize("observe_weight", [True, False]) + def test_linear_observer_tensor(self, observe_weight: bool): + # Create a simple linear layer + in_features, out_features = 10, 5 + linear = nn.Linear(in_features, out_features) + + # Create observers + input_observer = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + if observe_weight: + weight_observer = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + else: + weight_observer = None + + # Wrap the weight with LinearObserverTensor + insert_observers_(linear, input_observer, weight_observer) + + # Create some example inputs + example_inputs = [torch.randn(5, in_features) for _ in range(3)] + max_val = 42.1234 + min_val = -39.760 + big_tensor = torch.full((6, in_features), max_val) + small_tensor = torch.full((40, in_features), min_val) + example_inputs.extend([big_tensor, small_tensor]) + + # Run forward passes + for example_input in example_inputs: + _ = linear(example_input) + + input_observer = linear.weight.input_observer + + # Check that the observers have recorded statistics + assert input_observer.min_val == min_val + assert input_observer.max_val == max_val + + # Calculate qparams and ensure they're not None + input_scale, input_zero_point = input_observer.calculate_qparams() + + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + self.assertEqual( + input_scale.item(), + max_val / max_fp8, + ) + self.assertIsNotNone(input_zero_point) + + if observe_weight: + weight_observer = linear.weight.weight_observer + weight_scale, weight_zero_point = weight_observer.calculate_qparams() + torch.testing.assert_close( + weight_scale, + torch.max(linear.weight.original_weight_tensor) / max_fp8, + atol=5e-5, + rtol=0.0, + ) + self.assertIsNotNone(weight_zero_point) + else: + self.assertIsNone(linear.weight.weight_observer) + + +common_utils.instantiate_parametrized_tests(TestLinearObserver) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/linear_activation_weight_observer.py b/torchao/quantization/linear_activation_weight_observer.py new file mode 100644 index 0000000000..f7c2ab3742 --- /dev/null +++ b/torchao/quantization/linear_activation_weight_observer.py @@ -0,0 +1,152 @@ +import torch +from typing import Callable, Optional, Dict +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.utils import ( + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_5, +) + +from torchao.quantization.observer import AffineQuantizedObserverBase + +__all__ = [ + "LinearActivationWeightObservedTensor", +] + +aten = torch.ops.aten +Tensor = torch.Tensor + + +class LinearActivationWeightObservedTensor(TorchAOBaseTensor): + """ + This subclass of Tensor is used in conjuction with a static calibration flow. + The flow is broken up into 3 parts; + 1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers + 2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight + 3. quantize_ the model to static using the statistics recorded by the observer + + This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer + will first calculat statistics on BOTH the input and weight, and then run the linear op. + """ + + original_weight_tensor: torch.Tensor + input_observer: Optional[AffineQuantizedObserverBase] + weight_observer: Optional[AffineQuantizedObserverBase] + + def __new__( + cls, + original_weight_tensor: torch.Tensor, + input_observer: Optional[AffineQuantizedObserverBase] = None, + weight_observer: Optional[AffineQuantizedObserverBase] = None, + ): + kwargs = {} + dtype = original_weight_tensor.dtype + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + kwargs["device"] = original_weight_tensor.device + shape = original_weight_tensor.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + original_weight_tensor: torch.Tensor, + input_observer: Optional[AffineQuantizedObserverBase] = None, + weight_observer: Optional[AffineQuantizedObserverBase] = None, + ): + self.original_weight_tensor = original_weight_tensor + self.input_observer = input_observer + self.weight_observer = weight_observer + + def __repr__(self): + return ( + f"LinearActivationWeightObservedTensor(\n" + f"original_weight={self.original_weight_tensor}\n" + f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n" + f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)" + ) + + def __tensor_flatten__(self): + return ["original_weight_tensor"], [self.input_observer, self.weight_observer] + + @classmethod + def __tensor_unflatten__( + cls, + tensor_data_dict: Dict[str, Tensor], + tensor_attributes, + outer_size, + outer_stride, + ): + original_weight_tensor = tensor_data_dict["original_weight_tensor"] + (input_observer, weight_observer) = tensor_attributes + return cls(original_weight_tensor, input_observer, weight_observer) + + @classmethod + def from_float( + cls, + original_weight_tensor: Tensor, + input_observer: Optional[AffineQuantizedObserverBase] = None, + weight_observer: Optional[AffineQuantizedObserverBase] = None, + ): + return cls(original_weight_tensor, input_observer, weight_observer) + + def _apply_fn_to_data(self, fn: Callable): + """Applies a fn to the tensor component of the LinearActivationWeightObservedTensor""" + return self.__class__( + fn(self.original_weight_tensor), + self.input_observer, + self.weight_observer, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self._apply_fn_to_data(lambda x: x.to(**kwargs)) + + +implements = LinearActivationWeightObservedTensor.implements + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if weight_tensor.input_observer is not None: + input_tensor = weight_tensor.input_observer(input_tensor) + if weight_tensor.weight_observer is not None: + weight_tensor = weight_tensor.weight_observer( + weight_tensor.original_weight_tensor + ) + else: + weight_tensor = weight_tensor.original_weight_tensor + + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 10e0113bfc..08a3eacf6b 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -8,9 +8,10 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Callable, List, Tuple, Optional, Any +from typing import Tuple, Optional, Any from functools import partial import logging + logger = logging.getLogger(__name__) @@ -52,6 +53,7 @@ class PerAxis(GranularityType): """ axis: int + # borrowed from torch.ao.quantization.observer class _PartialWrapper: def __init__(self, p): @@ -66,6 +68,7 @@ def __repr__(self): def with_args(self, *args, **kwargs): return _with_args(self, *args, **kwargs) + def _with_args(cls_or_self, *args, **kwargs): r"""Wrapper that allows creation of class factories. @@ -103,8 +106,10 @@ def get_block_size( return tuple(block_size) raise ValueError(f"Unsupported GranularityType: {granularity_type}") + ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + class AffineQuantizedObserverBase(ABC, torch.nn.Module): """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) @@ -114,9 +119,11 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module): Current supported granularity type are `PerTensor` and `PerAxis` other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` """ + with_args = classmethod(_with_args) - def __init__(self, + def __init__( + self, mapping_type: MappingType, target_dtype: torch.dtype, granularity_type: GranularityType, @@ -126,7 +133,7 @@ def __init__(self, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ): super().__init__() assert granularity_type is not None, "granularity_type is None" @@ -144,7 +151,7 @@ def __init__(self, @abstractmethod def forward(self, input: torch.Tensor) -> torch.Tensor: - """ forward function should take the input tensor + """forward function should take the input tensor and updates internal stats and return the original input Tensor """ pass @@ -156,6 +163,7 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: """ pass + class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): def forward(self, input: torch.Tensor): if input.numel() == 0: @@ -200,5 +208,5 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: self.scale_dtype, self.zero_point_dtype, self.preserve_zero, - self.zero_point_domain + self.zero_point_domain, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 89bccf1264..239c369d2c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -58,6 +58,10 @@ from .utils import _get_per_token_block_size import logging from .autoquant import autoquant, AutoQuantizableLinearWeight +from torchao.quantization.observer import AffineQuantizedObserverBase +from torchao.quantization.linear_activation_weight_observer import ( + LinearActivationWeightObservedTensor, +) from torchao.float8.inference import Float8MMConfig logger = logging.getLogger(__name__) @@ -279,6 +283,86 @@ def replace_conv2d_1x1(conv): _replace_with_custom_fn_if_matches_filter( model, replace_conv2d_1x1, filter_fn=filter_fn ) +def insert_observers_( + model: nn.Module, + input_observer: Optional[AffineQuantizedObserverBase], + weight_observer: Optional[AffineQuantizedObserverBase], + *, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, +): + """ + Converts the weight of a linear module to a LinearActivationWeightObservedTensor. + + This function wraps the weight of the given linear module with a LinearActivationWeightObservedTensor, + which enables observation of both input and weight tensors during forward passes. + The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility + with PyTorch's module system. + + Example:: + + ``` + import torch + import torch.nn as nn + from torchao.quantization.linear_observer_tensor import insert_observers_ + from torchao.quantization.observer import ( + AffineQuantizedMinMaxObserver, + PerTensor, + MappingType + ) + + # Create observers + input_observer = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + + # Create a linear module + linear_module = nn.Linear(10, 20) + + # Convert the linear module's weight to an observed tensor + insert_observers_(linear_module, input_observer, weight_observer=None) + + # The linear_module can now be used as usual, with observers calculating statistics + output = linear_module(torch.randn(10, 10)) + + # Get the scale and zero point of the input observer + scale, zero_point = linear_module.weight.input_observer.calculate_qparams() + ``` + + Args: + model (nn.Module): The nn.Module to convert. + input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor. + weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor. + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): Filter function to select which modules to convert. + If not provided, all linear modules will be converted. This function should take a module and its fully qualified name. + + Returns: + nn.Linear: The modified linear module with its weight wrapped in a LinearActivationWeightObservedTensor. + """ + + def convert_to_linear_observer(linear_module: nn.Linear): + # Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter + linear_module.weight = nn.Parameter( + LinearActivationWeightObservedTensor.from_float( + linear_module.weight, + input_observer=input_observer, + weight_observer=weight_observer, + ), + requires_grad=linear_module.weight.requires_grad, + ) + return linear_module + + _replace_with_custom_fn_if_matches_filter( + model, + convert_to_linear_observer, + _is_linear if filter_fn is None else filter_fn, + ) + def _quantization_type(weight: torch.Tensor): if isinstance(weight, AffineQuantizedTensor): diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index d8f5f9aeec..484e12e6e4 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -580,7 +580,7 @@ def choose_qparams_affine( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -641,7 +641,7 @@ def choose_qparams_affine_with_min_max( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` operator that pass in min_val and max_val directly instead of deriving these from a single input.