Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[StaticQuant] add a linear observer class and test #807

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ include = [
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
]
112 changes: 108 additions & 4 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
Expand All @@ -9,13 +10,23 @@
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
from torch.testing._internal import common_utils
import unittest

# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver


class TestQuantFlow(TestCase):
def _test_obs_helper(self, obs1, obs2):
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
example_inputs = [
torch.randn(10, 2048),
torch.randn(10, 2048),
torch.randn(10, 2048),
]
for example_input in example_inputs:
obs1(example_input)
obs2(example_input)
Expand All @@ -26,13 +37,29 @@ def _test_obs_helper(self, obs1, obs2):
self.assertTrue(torch.allclose(zero_point1, zero_point2))

def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
self._test_obs_helper(obs, ref_obs)

def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = PerChannelMinMaxObserver(
dtype=torch.uint8, qscheme=torch.per_channel_affine
)
self._test_obs_helper(obs, ref_obs)

def test_block_size_calc_success(self):
Expand Down Expand Up @@ -109,5 +136,82 @@ def test_block_size_row_errors(self):
obs(example_input)


class TestLinearObserver(TestCase):
@common_utils.parametrize("observe_weight", [True, False])
def test_linear_observer_tensor(self, observe_weight: bool):
# Create a simple linear layer
in_features, out_features = 10, 5
linear = nn.Linear(in_features, out_features)

# Create observers
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
else:
weight_observer = None

# Wrap the weight with LinearObserverTensor
insert_observers_(linear, input_observer, weight_observer)

# Create some example inputs
example_inputs = [torch.randn(5, in_features) for _ in range(3)]
max_val = 42.1234
min_val = -39.760
big_tensor = torch.full((6, in_features), max_val)
small_tensor = torch.full((40, in_features), min_val)
example_inputs.extend([big_tensor, small_tensor])

# Run forward passes
for example_input in example_inputs:
_ = linear(example_input)

input_observer = linear.weight.input_observer

# Check that the observers have recorded statistics
assert input_observer.min_val == min_val
assert input_observer.max_val == max_val

# Calculate qparams and ensure they're not None
input_scale, input_zero_point = input_observer.calculate_qparams()

max_fp8 = torch.finfo(torch.float8_e4m3fn).max
self.assertEqual(
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
weight_scale, weight_zero_point = weight_observer.calculate_qparams()
torch.testing.assert_close(
weight_scale,
torch.max(linear.weight.original_weight_tensor) / max_fp8,
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)


common_utils.instantiate_parametrized_tests(TestLinearObserver)

if __name__ == "__main__":
unittest.main()
152 changes: 152 additions & 0 deletions torchao/quantization/linear_activation_weight_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch
from typing import Callable, Optional, Dict
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
)

from torchao.quantization.observer import AffineQuantizedObserverBase

__all__ = [
"LinearActivationWeightObservedTensor",
]

aten = torch.ops.aten
Tensor = torch.Tensor


class LinearActivationWeightObservedTensor(TorchAOBaseTensor):
"""
This subclass of Tensor is used in conjuction with a static calibration flow.
The flow is broken up into 3 parts;
1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight
3. quantize_ the model to static using the statistics recorded by the observer

This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer
will first calculat statistics on BOTH the input and weight, and then run the linear op.
"""

original_weight_tensor: torch.Tensor
input_observer: Optional[AffineQuantizedObserverBase]
weight_observer: Optional[AffineQuantizedObserverBase]

def __new__(
cls,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
kwargs = {}
dtype = original_weight_tensor.dtype
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["device"] = original_weight_tensor.device
shape = original_weight_tensor.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
self.original_weight_tensor = original_weight_tensor
self.input_observer = input_observer
self.weight_observer = weight_observer

def __repr__(self):
return (
f"LinearActivationWeightObservedTensor(\n"
f"original_weight={self.original_weight_tensor}\n"
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n"
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)"
)

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_observer, self.weight_observer]

@classmethod
def __tensor_unflatten__(
cls,
tensor_data_dict: Dict[str, Tensor],
tensor_attributes,
outer_size,
outer_stride,
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
(input_observer, weight_observer) = tensor_attributes
return cls(original_weight_tensor, input_observer, weight_observer)

@classmethod
def from_float(
cls,
original_weight_tensor: Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
return cls(original_weight_tensor, input_observer, weight_observer)

def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to the tensor component of the LinearActivationWeightObservedTensor"""
return self.__class__(
fn(self.original_weight_tensor),
self.input_observer,
self.weight_observer,
)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self._apply_fn_to_data(lambda x: x.to(**kwargs))


implements = LinearActivationWeightObservedTensor.implements


@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if weight_tensor.input_observer is not None:
input_tensor = weight_tensor.input_observer(input_tensor)
if weight_tensor.weight_observer is not None:
weight_tensor = weight_tensor.weight_observer(
weight_tensor.original_weight_tensor
)
else:
weight_tensor = weight_tensor.original_weight_tensor

return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor])
18 changes: 13 additions & 5 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Callable, List, Tuple, Optional, Any
from typing import Tuple, Optional, Any
from functools import partial
import logging

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,6 +53,7 @@ class PerAxis(GranularityType):
"""
axis: int


# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
def __init__(self, p):
Expand All @@ -66,6 +68,7 @@ def __repr__(self):
def with_args(self, *args, **kwargs):
return _with_args(self, *args, **kwargs)


def _with_args(cls_or_self, *args, **kwargs):
r"""Wrapper that allows creation of class factories.

Expand Down Expand Up @@ -103,8 +106,10 @@ def get_block_size(
return tuple(block_size)
raise ValueError(f"Unsupported GranularityType: {granularity_type}")


ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:


class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)

Expand All @@ -114,9 +119,11 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
Current supported granularity type are `PerTensor` and `PerAxis`
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
"""

with_args = classmethod(_with_args)

def __init__(self,
def __init__(
self,
mapping_type: MappingType,
target_dtype: torch.dtype,
granularity_type: GranularityType,
Expand All @@ -126,7 +133,7 @@ def __init__(self,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
):
super().__init__()
assert granularity_type is not None, "granularity_type is None"
Expand All @@ -144,7 +151,7 @@ def __init__(self,

@abstractmethod
def forward(self, input: torch.Tensor) -> torch.Tensor:
""" forward function should take the input tensor
"""forward function should take the input tensor
and updates internal stats and return the original input Tensor
"""
pass
Expand All @@ -156,6 +163,7 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
pass


class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
def forward(self, input: torch.Tensor):
if input.numel() == 0:
Expand Down Expand Up @@ -200,5 +208,5 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
self.scale_dtype,
self.zero_point_dtype,
self.preserve_zero,
self.zero_point_domain
self.zero_point_domain,
)
Loading
Loading