-
Notifications
You must be signed in to change notification settings - Fork 207
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[StaticQuant] add a linear observer class and test
stack-info: PR: #807, branch: drisspg/stack/8
- Loading branch information
Showing
5 changed files
with
362 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
import torch | ||
import torch.nn as nn | ||
from typing import Callable, Optional, Dict | ||
from torch.utils._python_dispatch import return_and_correct_aliasing | ||
from torchao.utils import ( | ||
TorchAOBaseTensor, | ||
TORCH_VERSION_AT_LEAST_2_5, | ||
) | ||
|
||
from torchao.quantization.quant_api import ( | ||
_replace_with_custom_fn_if_matches_filter, | ||
_is_linear, | ||
) | ||
from torchao.quantization.observer import AffineQuantizedObserverBase | ||
|
||
__all__ = [ | ||
"LinearActivationWeightObservedTensor", | ||
"insert_observers_", | ||
] | ||
|
||
aten = torch.ops.aten | ||
Tensor = torch.Tensor | ||
|
||
|
||
class LinearActivationWeightObservedTensor(TorchAOBaseTensor): | ||
""" | ||
This subclass of Tensor is used in conjuction with a static calibration flow. | ||
The flow is broken up into 3 parts; | ||
1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers | ||
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight | ||
3. quantize_ the model to static using the statistics recorded by the observer | ||
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer | ||
will first calculat statistics on BOTH the input and weight, and then run the linear op. | ||
""" | ||
|
||
original_weight_tensor: torch.Tensor | ||
input_observer: Optional[AffineQuantizedObserverBase] | ||
weight_observer: Optional[AffineQuantizedObserverBase] | ||
|
||
def __new__( | ||
cls, | ||
original_weight_tensor: torch.Tensor, | ||
input_observer: Optional[AffineQuantizedObserverBase] = None, | ||
weight_observer: Optional[AffineQuantizedObserverBase] = None, | ||
): | ||
kwargs = {} | ||
dtype = original_weight_tensor.dtype | ||
kwargs["dtype"] = dtype | ||
kwargs["requires_grad"] = False | ||
kwargs["device"] = original_weight_tensor.device | ||
shape = original_weight_tensor.shape | ||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
||
def __init__( | ||
self, | ||
original_weight_tensor: torch.Tensor, | ||
input_observer: Optional[AffineQuantizedObserverBase] = None, | ||
weight_observer: Optional[AffineQuantizedObserverBase] = None, | ||
): | ||
self.original_weight_tensor = original_weight_tensor | ||
self.input_observer = input_observer | ||
self.weight_observer = weight_observer | ||
|
||
def __repr__(self): | ||
return ( | ||
f"LinearActivationWeightObservedTensor(\n" | ||
f"original_weight={self.original_weight_tensor}\n" | ||
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n" | ||
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)" | ||
) | ||
|
||
def __tensor_flatten__(self): | ||
return ["original_weight_tensor"], [self.input_observer, self.weight_observer] | ||
|
||
@classmethod | ||
def __tensor_unflatten__( | ||
cls, | ||
tensor_data_dict: Dict[str, Tensor], | ||
tensor_attributes, | ||
outer_size, | ||
outer_stride, | ||
): | ||
original_weight_tensor = tensor_data_dict["original_weight_tensor"] | ||
(input_observer, weight_observer) = tensor_attributes | ||
return cls(original_weight_tensor, input_observer, weight_observer) | ||
|
||
@classmethod | ||
def from_float( | ||
cls, | ||
original_weight_tensor: Tensor, | ||
input_observer: Optional[AffineQuantizedObserverBase] = None, | ||
weight_observer: Optional[AffineQuantizedObserverBase] = None, | ||
): | ||
return cls(original_weight_tensor, input_observer, weight_observer) | ||
|
||
def _apply_fn_to_data(self, fn: Callable): | ||
"""Applies a fn to the tensor component of the LinearActivationWeightObservedTensor""" | ||
return self.__class__( | ||
fn(self.original_weight_tensor), | ||
self.input_observer, | ||
self.weight_observer, | ||
) | ||
|
||
def to(self, *args, **kwargs): | ||
kwargs = self._get_to_kwargs(*args, **kwargs) | ||
return self._apply_fn_to_data(lambda x: x.to(**kwargs)) | ||
|
||
|
||
implements = LinearActivationWeightObservedTensor.implements | ||
|
||
|
||
@implements(torch.nn.functional.linear) | ||
def _(func, types, args, kwargs): | ||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
if weight_tensor.input_observer is not None: | ||
input_tensor = weight_tensor.input_observer(input_tensor) | ||
if weight_tensor.weight_observer is not None: | ||
weight_tensor = weight_tensor.weight_observer( | ||
weight_tensor.original_weight_tensor | ||
) | ||
else: | ||
weight_tensor = weight_tensor.original_weight_tensor | ||
|
||
return torch.nn.functional.linear(input_tensor, weight_tensor, bias) | ||
|
||
|
||
@implements(aten.detach.default) | ||
def _(func, types, args, kwargs): | ||
return return_and_correct_aliasing( | ||
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) | ||
) | ||
|
||
|
||
@implements(aten.clone.default) | ||
def _(func, types, args, kwargs): | ||
return return_and_correct_aliasing( | ||
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) | ||
) | ||
|
||
|
||
@implements(aten._to_copy.default) | ||
def _(func, types, args, kwargs): | ||
return return_and_correct_aliasing( | ||
func, | ||
args, | ||
kwargs, | ||
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), | ||
) | ||
|
||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` | ||
torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) | ||
|
||
|
||
def insert_observers_( | ||
model: nn.Module, | ||
input_observer: Optional[AffineQuantizedObserverBase], | ||
weight_observer: Optional[AffineQuantizedObserverBase], | ||
*, | ||
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, | ||
): | ||
""" | ||
Converts the weight of a linear module to a LinearActivationWeightObservedTensor. | ||
This function wraps the weight of the given linear module with a LinearActivationWeightObservedTensor, | ||
which enables observation of both input and weight tensors during forward passes. | ||
The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility | ||
with PyTorch's module system. | ||
Example:: | ||
``` | ||
import torch | ||
import torch.nn as nn | ||
from torchao.quantization.linear_observer_tensor import insert_observers_ | ||
from torchao.quantization.observer import ( | ||
AffineQuantizedMinMaxObserver, | ||
PerTensor, | ||
MappingType | ||
) | ||
# Create observers | ||
input_observer = AffineQuantizedMinMaxObserver( | ||
MappingType.SYMMETRIC, | ||
torch.float8_e4m3fn, | ||
granularity_type=PerTensor(), | ||
eps=torch.finfo(torch.float32).eps, | ||
scale_dtype=torch.float, | ||
zero_point_dtype=torch.int, | ||
zero_point_domain=None, | ||
) | ||
# Create a linear module | ||
linear_module = nn.Linear(10, 20) | ||
# Convert the linear module's weight to an observed tensor | ||
insert_observers_(linear_module, input_observer, weight_observer=None) | ||
# The linear_module can now be used as usual, with observers calculating statistics | ||
output = linear_module(torch.randn(10, 10)) | ||
``` | ||
Args: | ||
model (nn.Module): The nn.Module to convert. | ||
input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor. | ||
weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor. | ||
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): Filter function to select which modules to convert. | ||
If not provided, all linear modules will be converted. | ||
Returns: | ||
nn.Linear: The modified linear module with its weight wrapped in a LinearActivationWeightObservedTensor. | ||
""" | ||
|
||
def convert_to_linear_observer(linear_module: nn.Linear): | ||
# Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter | ||
linear_module.weight = nn.Parameter( | ||
LinearActivationWeightObservedTensor.from_float( | ||
linear_module.weight, | ||
input_observer=input_observer, | ||
weight_observer=weight_observer, | ||
), | ||
requires_grad=linear_module.weight.requires_grad, | ||
) | ||
return linear_module | ||
|
||
_replace_with_custom_fn_if_matches_filter( | ||
model, | ||
convert_to_linear_observer, | ||
_is_linear if filter_fn is None else filter_fn, | ||
) |
Oops, something went wrong.