Skip to content

Commit

Permalink
[StaticQuant] add a linear observer class and test (#807)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Sep 6, 2024
1 parent 038a0a2 commit 422301b
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 11 deletions.
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ include = [
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
]
112 changes: 108 additions & 4 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
Expand All @@ -9,13 +10,23 @@
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
from torch.testing._internal import common_utils
import unittest

# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver


class TestQuantFlow(TestCase):
def _test_obs_helper(self, obs1, obs2):
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
example_inputs = [
torch.randn(10, 2048),
torch.randn(10, 2048),
torch.randn(10, 2048),
]
for example_input in example_inputs:
obs1(example_input)
obs2(example_input)
Expand All @@ -26,13 +37,29 @@ def _test_obs_helper(self, obs1, obs2):
self.assertTrue(torch.allclose(zero_point1, zero_point2))

def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
self._test_obs_helper(obs, ref_obs)

def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = PerChannelMinMaxObserver(
dtype=torch.uint8, qscheme=torch.per_channel_affine
)
self._test_obs_helper(obs, ref_obs)

def test_block_size_calc_success(self):
Expand Down Expand Up @@ -109,5 +136,82 @@ def test_block_size_row_errors(self):
obs(example_input)


class TestLinearObserver(TestCase):
@common_utils.parametrize("observe_weight", [True, False])
def test_linear_observer_tensor(self, observe_weight: bool):
# Create a simple linear layer
in_features, out_features = 10, 5
linear = nn.Linear(in_features, out_features)

# Create observers
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
else:
weight_observer = None

# Wrap the weight with LinearObserverTensor
insert_observers_(linear, input_observer, weight_observer)

# Create some example inputs
example_inputs = [torch.randn(5, in_features) for _ in range(3)]
max_val = 42.1234
min_val = -39.760
big_tensor = torch.full((6, in_features), max_val)
small_tensor = torch.full((40, in_features), min_val)
example_inputs.extend([big_tensor, small_tensor])

# Run forward passes
for example_input in example_inputs:
_ = linear(example_input)

input_observer = linear.weight.input_observer

# Check that the observers have recorded statistics
assert input_observer.min_val == min_val
assert input_observer.max_val == max_val

# Calculate qparams and ensure they're not None
input_scale, input_zero_point = input_observer.calculate_qparams()

max_fp8 = torch.finfo(torch.float8_e4m3fn).max
self.assertEqual(
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
weight_scale, weight_zero_point = weight_observer.calculate_qparams()
torch.testing.assert_close(
weight_scale,
torch.max(linear.weight.original_weight_tensor) / max_fp8,
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)


common_utils.instantiate_parametrized_tests(TestLinearObserver)

if __name__ == "__main__":
unittest.main()
152 changes: 152 additions & 0 deletions torchao/quantization/linear_activation_weight_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch
from typing import Callable, Optional, Dict
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
)

from torchao.quantization.observer import AffineQuantizedObserverBase

__all__ = [
"LinearActivationWeightObservedTensor",
]

aten = torch.ops.aten
Tensor = torch.Tensor


class LinearActivationWeightObservedTensor(TorchAOBaseTensor):
"""
This subclass of Tensor is used in conjuction with a static calibration flow.
The flow is broken up into 3 parts;
1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight
3. quantize_ the model to static using the statistics recorded by the observer
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer
will first calculat statistics on BOTH the input and weight, and then run the linear op.
"""

original_weight_tensor: torch.Tensor
input_observer: Optional[AffineQuantizedObserverBase]
weight_observer: Optional[AffineQuantizedObserverBase]

def __new__(
cls,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
kwargs = {}
dtype = original_weight_tensor.dtype
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["device"] = original_weight_tensor.device
shape = original_weight_tensor.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
self.original_weight_tensor = original_weight_tensor
self.input_observer = input_observer
self.weight_observer = weight_observer

def __repr__(self):
return (
f"LinearActivationWeightObservedTensor(\n"
f"original_weight={self.original_weight_tensor}\n"
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n"
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)"
)

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_observer, self.weight_observer]

@classmethod
def __tensor_unflatten__(
cls,
tensor_data_dict: Dict[str, Tensor],
tensor_attributes,
outer_size,
outer_stride,
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
(input_observer, weight_observer) = tensor_attributes
return cls(original_weight_tensor, input_observer, weight_observer)

@classmethod
def from_float(
cls,
original_weight_tensor: Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
return cls(original_weight_tensor, input_observer, weight_observer)

def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to the tensor component of the LinearActivationWeightObservedTensor"""
return self.__class__(
fn(self.original_weight_tensor),
self.input_observer,
self.weight_observer,
)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self._apply_fn_to_data(lambda x: x.to(**kwargs))


implements = LinearActivationWeightObservedTensor.implements


@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if weight_tensor.input_observer is not None:
input_tensor = weight_tensor.input_observer(input_tensor)
if weight_tensor.weight_observer is not None:
weight_tensor = weight_tensor.weight_observer(
weight_tensor.original_weight_tensor
)
else:
weight_tensor = weight_tensor.original_weight_tensor

return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor])
18 changes: 13 additions & 5 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Callable, List, Tuple, Optional, Any
from typing import Tuple, Optional, Any
from functools import partial
import logging

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,6 +53,7 @@ class PerAxis(GranularityType):
"""
axis: int


# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
def __init__(self, p):
Expand All @@ -66,6 +68,7 @@ def __repr__(self):
def with_args(self, *args, **kwargs):
return _with_args(self, *args, **kwargs)


def _with_args(cls_or_self, *args, **kwargs):
r"""Wrapper that allows creation of class factories.
Expand Down Expand Up @@ -103,8 +106,10 @@ def get_block_size(
return tuple(block_size)
raise ValueError(f"Unsupported GranularityType: {granularity_type}")


ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:


class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
Expand All @@ -114,9 +119,11 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
Current supported granularity type are `PerTensor` and `PerAxis`
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
"""

with_args = classmethod(_with_args)

def __init__(self,
def __init__(
self,
mapping_type: MappingType,
target_dtype: torch.dtype,
granularity_type: GranularityType,
Expand All @@ -126,7 +133,7 @@ def __init__(self,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
):
super().__init__()
assert granularity_type is not None, "granularity_type is None"
Expand All @@ -144,7 +151,7 @@ def __init__(self,

@abstractmethod
def forward(self, input: torch.Tensor) -> torch.Tensor:
""" forward function should take the input tensor
"""forward function should take the input tensor
and updates internal stats and return the original input Tensor
"""
pass
Expand All @@ -156,6 +163,7 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
pass


class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
def forward(self, input: torch.Tensor):
if input.numel() == 0:
Expand Down Expand Up @@ -200,5 +208,5 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
self.scale_dtype,
self.zero_point_dtype,
self.preserve_zero,
self.zero_point_domain
self.zero_point_domain,
)
Loading

0 comments on commit 422301b

Please sign in to comment.