diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 621e3596e0..43d5b48eed 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -26,11 +26,17 @@ float8_weight_only, quantize_, ) -from torchao.quantization.observer import PerRow, PerTensor +from torchao.quantization.granularity import ( + PerRow, + PerTensor, +) from torchao.quantization.quant_api import ( float8_static_activation_float8_weight, ) -from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_affine, +) random.seed(0) torch.manual_seed(0) diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 8c8007871b..0526ee01b2 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -9,11 +9,13 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase -from torchao.quantization.observer import ( - AffineQuantizedMinMaxObserver, +from torchao.quantization.granularity import ( PerAxis, PerTensor, ) +from torchao.quantization.observer import ( + AffineQuantizedMinMaxObserver, +) from torchao.quantization.quant_api import ( insert_observers_, ) @@ -42,7 +44,7 @@ def test_min_max_per_tensor_affine(self): obs = AffineQuantizedMinMaxObserver( MappingType.ASYMMETRIC, torch.uint8, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -54,7 +56,7 @@ def test_min_max_per_channel_affine(self): obs = AffineQuantizedMinMaxObserver( MappingType.ASYMMETRIC, torch.uint8, - granularity_type=PerAxis(axis=0), + granularity=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -68,7 +70,7 @@ def test_block_size_calc_success(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -87,7 +89,7 @@ def test_block_size_calc_success(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerAxis(1), + granularity=PerAxis(1), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -102,7 +104,7 @@ def test_block_size_row_errors(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerAxis(0), + granularity=PerAxis(0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -121,7 +123,7 @@ def test_block_size_row_errors(self): obs = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerAxis(1), + granularity=PerAxis(1), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -149,7 +151,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_observer = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, @@ -159,7 +161,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): weight_observer = AffineQuantizedMinMaxObserver( MappingType.SYMMETRIC, torch.float8_e4m3fn, - granularity_type=PerTensor(), + granularity=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 6d46e45878..46883867c6 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -24,9 +24,9 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, ) -from torchao.quantization.observer import PerRow, PerTensor from torchao._models._eval import TransformerEvalWrapper, InputRecorder from torchao._models.llama.model import prepare_inputs_for_model +from torchao.quantization.granularity import PerRow, PerTensor from tokenizer import get_tokenizer import time @@ -255,4 +255,4 @@ def run_evaluation( args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, - ) \ No newline at end of file + ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 270054e130..23ed9864f8 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -216,7 +216,7 @@ def main( float8_weight_only, float8_dynamic_activation_float8_weight, ) - from torchao.quantization.observer import PerTensor, PerRow + from torchao.quantization.granularity import PerTensor, PerRow if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index e3a8827e2a..fc1b04f940 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -1,13 +1,13 @@ import torch import torch.nn.functional as F +from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, _DTYPE_TO_QVALUE_BOUNDS, ) from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata -from torchao.quantization.observer import PerGroup from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType from torchao.dtypes import( diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 77810a2e4a..725c168f90 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -7,12 +7,13 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType from torchao.dtypes import to_affine_quantized_intx +from torchao.quantization.granularity import Granularity from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) from torchao.quantization.observer import ( - AffineQuantizedObserverBase, GranularityType + AffineQuantizedObserverBase, ) @@ -20,7 +21,7 @@ class AWQObserver(AffineQuantizedObserverBase): def __init__(self, weight: torch.Tensor, bias: torch.Tensor, - quantization_granularity: GranularityType, + quantization_granularity: Granularity, mapping_type: MappingType, target_dtype: torch.dtype, n_validation_examples: int, @@ -40,7 +41,7 @@ def __init__(self, Args: weight: The weight tensor to be observed. bias: The bias tensor to be observed. - quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point + quantization_granularity: Granularity which specifies how many weights share the same scale/zero point input_dtype: The data type of the input tensor. mapping_type: Always set to asymmetric target_dtype: The target data type of the quantized tensor @@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver): observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias - return observed_linear \ No newline at end of file + return observed_linear diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index c936b7ef83..9d7b049470 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight -from torchao.quantization.observer import PerTensor +from torchao.quantization.quant_api import PerTensor quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) ``` diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index a5568c4e17..15dc8b4e0d 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -12,12 +12,14 @@ from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torch.utils._python_dispatch import return_and_correct_aliasing -from .quant_primitives import ( - safe_int_mm, +from .granularity import ( + PerAxis, + PerRow, + PerTensor, ) +from .quant_primitives import safe_int_mm from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax -from torchao.quantization.observer import PerAxis, PerTensor, PerRow from torchao.float8.inference import Float8MMConfig import torch.nn.functional as F diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py new file mode 100644 index 0000000000..5251c7865e --- /dev/null +++ b/torchao/quantization/granularity.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + pass + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calcualtes the quantization parameters + based off the entire tensor. + """ + pass + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calcualtes different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + + Attributes: + axis (int): The axis along which reduction is performed. + """ + axis: int + +@dataclass(frozen=True) + +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calcualtes different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + group_size: int + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + pass diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index bef4abe710..f3f1ea385f 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -1,4 +1,10 @@ import torch +from .granularity import ( + Granularity, + PerAxis, + PerRow, + PerTensor, +) from .quant_primitives import ( _get_reduction_params, choose_qparams_affine_with_min_max, @@ -8,7 +14,6 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from typing import Tuple, Optional, Any from functools import partial import logging @@ -16,74 +21,6 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class GranularityType: - """ - Base class for representing the granularity of quantization. - - This class serves as a parent for specific granularity types used in - quantization operations, such as per-tensor or per-axis quantization. - """ - pass - -@dataclass(frozen=True) -class PerTensor(GranularityType): - """ - Represents per-tensor granularity in quantization. - - This granularity type calcualtes the quantization parameters - based off the entire tensor. - """ - pass - -@dataclass(frozen=True) -class PerAxis(GranularityType): - """ - Represents per-axis granularity in quantization. - - This granularity type calcualtes different quantization parameters - along a specified axis of the tensor. - - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. - - - Attributes: - axis (int): The axis along which reduction is performed. - """ - axis: int - -@dataclass(frozen=True) - -class PerGroup(GranularityType): - """ - Represents per-channel group granularity in quantization. - - This granularity type calcualtes different quantization parameters - for each group of elements. - - For example if the input tensor is shape [8, 16], and the group size is 4, then - the input tensor is reshaped to [64, 4] - quantization parameters are calculated for each group of 4 elements, - giving a total of 64 quantization parameters. - - Attributes: - group_size (int): The size of each quantization group - - """ - group_size: int - -class PerRow(GranularityType): - """ - Represents row-wise granularity in quantization. - - This is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). - """ - pass - # borrowed from torch.ao.quantization.observer class _PartialWrapper: def __init__(self, p): @@ -120,23 +57,23 @@ def _with_args(cls_or_self, *args, **kwargs): def get_block_size( - input_shape: Tuple[int, ...], granularity_type: GranularityType + input_shape: Tuple[int, ...], granularity: Granularity ) -> Tuple[int, ...]: """Get the block size based on the input shape and granularity type. Args: input_shape: The input tensor shape possibly more than 2 dimensions - granularity_type: The granularity type of the quantization + granularity: The granularity type of the quantization """ - if isinstance(granularity_type, PerTensor): + if isinstance(granularity, PerTensor): return input_shape - elif isinstance(granularity_type, PerAxis): + elif isinstance(granularity, PerAxis): block_size = list(input_shape) - block_size[granularity_type.axis] = 1 + block_size[granularity.axis] = 1 return tuple(block_size) - elif isinstance(granularity_type, PerRow): + elif isinstance(granularity, PerRow): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - raise ValueError(f"Unsupported GranularityType: {granularity_type}") + raise ValueError(f"Unsupported Granularity: {granularity}") ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: @@ -146,7 +83,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module): """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) Args: - `granularity_type` and `block_size`: The granularity of the quantization, + `granularity` and `block_size`: The granularity of the quantization, must specify at least one, if both are specified `block_size` takes precedence Current supported granularity type are `PerTensor` and `PerAxis` other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` @@ -158,7 +95,7 @@ def __init__( self, mapping_type: MappingType, target_dtype: torch.dtype, - granularity_type: GranularityType, + granularity: Granularity, quant_min: Optional[int] = None, quant_max: Optional[int] = None, eps: Optional[float] = None, @@ -168,11 +105,11 @@ def __init__( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ): super().__init__() - assert granularity_type is not None, "granularity_type is None" + assert granularity is not None, "granularity is None" self.mapping_type = mapping_type self.target_dtype = target_dtype - self.granularity_type = granularity_type + self.granularity = granularity self.quant_min = quant_min self.quant_max = quant_max self.eps = eps @@ -202,8 +139,8 @@ def forward(self, input: torch.Tensor): return input input_detached = input.detach() - assert self.granularity_type is not None, "granularity_type is None" - block_size = get_block_size(input_detached.shape, self.granularity_type) + assert self.granularity is not None, "granularity is None" + block_size = get_block_size(input_detached.shape, self.granularity) shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input_detached.size() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c41425062..dbaaa15295 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -51,7 +51,10 @@ from torchao.quantization.weight_tensor_linear_activation_quantization import ( to_weight_tensor_with_linear_activation_quantization_metadata, ) - +from .granularity import ( + PerRow, + PerTensor, +) from .quant_primitives import ( MappingType, ZeroPointDomain, @@ -71,7 +74,7 @@ ) from torchao.float8.inference import Float8MMConfig -from torchao.quantization.observer import PerTensor, PerRow, get_block_size +from torchao.quantization.observer import get_block_size logger = logging.getLogger(__name__) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index b1561e4cff..ea3a9d54c5 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -64,6 +64,7 @@ class ZeroPointDomain(Enum): INT = auto() FLOAT = auto() + if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index 037dbae0f6..b71933e3b9 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -22,8 +22,10 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, +) +from torchao.quantization.granularity import ( PerAxis, + PerTensor, ) from torchao.quantization.quant_primitives import ( MappingType, diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index edb1b257ee..9b639091db 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -37,10 +37,10 @@ from torchao.quantization import quantize_ from torchao.quantization import to_linear_activation_quantized from torchao.quantization import LinearActivationQuantizedTensor +from torchao.quantization.granularity import PerTensor from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, ) from torchao.quantization.quant_primitives import ( MappingType, diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index f75485d3d5..31d2be201d 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -17,8 +17,10 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, - PerTensor, +) +from torchao.quantization.granularity import ( PerAxis, + PerTensor, ) from torchao.quantization.quant_primitives import ( MappingType,