diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 82f867d96f..e130773719 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -111,7 +111,7 @@ def all_linear(mod, name): parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply') parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--save', action='store_true', help='Whether to save the model.') diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py new file mode 100644 index 0000000000..45c2f2877c --- /dev/null +++ b/test/dtypes/test_affine_quantized_float.py @@ -0,0 +1,134 @@ +from numpy import full +from torch.testing._internal.common_utils import ( + run_tests, +) +from torch._inductor.test_case import TestCase as InductorTestCase +from torch.testing._internal import common_utils +from torch._dynamo.testing import CompileCounterWithBackend + +from torchao.quantization.quant_api import ( + quantize_, + float8_weight_only, + float8_dynamic_activation_float8_weight, +) +from torchao.float8.float8_utils import compute_error +import torch +import unittest +import pytest +import tempfile +import copy +import random + +from unittest.mock import patch +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + unwrap_tensor_subclass, +) + +if not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +random.seed(0) +torch.manual_seed(0) + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) + self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestAffineQuantizedFloat8(InductorTestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tensor_core_layout_transpose(self): + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + t = l.weight + shape = t.shape + apply_float8_weight_only_quant = float8_weight_only() + ql = apply_float8_weight_only_quant(l) + aqt = ql.weight + aqt_shape = aqt.shape + assert aqt_shape == shape + + # transpose shape test + for _ in range(10): + t = t.t() + aqt = aqt.t() + shape = t.shape + aqt_shape = aqt.shape + assert aqt_shape == shape + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_weights_only_save_load(self): + for apply_quant in [float8_weight_only()]: + # TODO Fails when l requires grad + l = torch.nn.Linear( + 128, 256, dtype=torch.bfloat16, device="cuda" + ).requires_grad_(False) + ql = apply_quant(l) + with tempfile.NamedTemporaryFile() as f: + torch.save(ql.state_dict(), f) + f.seek(0) + # `weights_only=True` is enabled for torch 2.5+ + if TORCH_VERSION_AT_LEAST_2_5: + _ = torch.load(f, weights_only=True) + else: + _ = torch.load(f, weights_only=False) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_cuda_8_9, "Need H100") + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("mode", ["dynamic", "weight-only"]) + @common_utils.parametrize("compile", [True, False]) + # Inputs are (M,..), K, N + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((256,), 512, 256), + ((64,), 128, 64), + ((32, 128), 64, 256), + ((64, 256), 512, 128), + ], + ) + def test_dynamic_fp8_linear( + self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple + ): + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + mode_map = { + "dynamic": float8_dynamic_activation_float8_weight, + "weight-only": float8_weight_only, + } + + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + quantized_model = copy.deepcopy(model) + factory = mode_map[mode]() + quantize_(model, factory) + + if compile: + quantized_model = torch.compile(quantized_model, fullgraph=True) + + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + + assert compute_error(output_original, output_quantized) > 20, "Error is too low" + + +common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index f1080dfeef..eeed767970 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -10,6 +10,8 @@ PlainLayoutType, SemiSparseLayoutType, TensorCoreTiledLayoutType, + Float8LayoutType, + Float8AQTLayout, ) __all__ = [ @@ -24,4 +26,6 @@ "PlainLayoutType", "SemiSparseLayoutType", "TensorCoreTiledLayoutType", + "Float8LayoutType", + "Float8AQTLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index a1c9b2efaa..88bd5a4ece 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable, Any, Tuple, Optional +from typing import Dict, Callable, Any, Tuple, Optional, Union from collections import defaultdict import functools import math @@ -26,6 +26,7 @@ LayoutType, PlainLayoutType, is_device, + get_out_shape, ) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass @@ -35,6 +36,7 @@ TORCH_VERSION_AT_LEAST_2_5, ) +from torchao.float8.float8_tensor import ScaledMMConfig aten = torch.ops.aten ############################### @@ -113,8 +115,8 @@ def __new__( layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, dtype=None, strides=None, @@ -135,8 +137,8 @@ def __init__( layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, dtype=None, strides=None, @@ -270,9 +272,11 @@ def from_hp_to_floatx( cls, input_float: torch.Tensor, block_size: Tuple[int, ...], - target_dtype: torch.dtype = torch.float8_e4m3fn, + target_dtype: torch.dtype, + scale_dtype: Optional[torch.dtype] = None, layout_type: LayoutType = PlainLayoutType(), ): + if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -282,11 +286,11 @@ def from_hp_to_floatx( quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), eps=torch.finfo(torch.float32).eps, - scale_dtype=None, + scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, - zero_point_domain=ZeroPointDomain.INT, - layout_type=PlainLayoutType(), + zero_point_domain=None, + layout_type=layout_type, use_hqq=False, ) else: @@ -376,6 +380,13 @@ def extra_repr(self): return f"inner_k_tiles={self.inner_k_tiles}" +@dataclass(frozen=True) +class Float8LayoutType(LayoutType): + mm_config: ScaledMMConfig + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + return input + @register_layout_cls(PlainLayoutType) class PlainAQTLayout(AQTLayout): """ @@ -529,6 +540,113 @@ def from_plain( return cls(int_data_compressed, scale, zero_point, layout_type) +@register_layout_cls(Float8LayoutType) +class Float8AQTLayout(AQTLayout): + """ + Layout storage class for float8 layout for affine quantized tensor + """ + float8_data: torch.Tensor + scale: torch.Tensor + transposed: bool + mm_config: Optional[ScaledMMConfig] + + def __new__( + cls, + float8_data: torch.Tensor, + scale: torch.Tensor, + mm_config: Optional[ScaledMMConfig], + transposed: bool, + layout_type: LayoutType, + ): + kwargs = {} + kwargs["device"] = float8_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else float8_data.layout + ) + kwargs["dtype"] = float8_data.dtype + kwargs["requires_grad"] = False + shape = float8_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + float8_data: torch.Tensor, + scale: torch.Tensor, + mm_config: Optional[ScaledMMConfig], + transposed: bool, + layout_type: LayoutType, + ): + self.float8_data = float8_data + self.scale = scale + self.transposed = transposed + self.mm_config = mm_config + self.layout_type = layout_type + + def _apply_fn_to_data(self, fn): + fn(self.float8_data) + fn(self.scale) + return self + + def __tensor_flatten__(self): + return ["float8_data", "scale"], [self.transposed, self.mm_config, self.layout_type] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] + transposed, mm_config, layout_type, = tensor_attributes + return cls(float8_data, scale, mm_config, transposed, layout_type) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + args[0].transposed = not args[0].transposed + return return_and_correct_aliasing(func, args, kwargs, args[0]) + + raise NotImplementedError( + f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.float8_data, self.scale, None + + def get_layout_type(self) -> LayoutType: + return self.layout_type + + @classmethod + def from_plain( + cls, + float8_data: torch.Tensor, + scale: torch.Tensor, + zero_points: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, Float8LayoutType) + return cls(float8_data, scale, layout_type.mm_config, False, layout_type) + + def __repr__(self): + float8_data, scale, _ = self.get_plain() + layout_type = self.get_layout_type() + return (f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"mm_config={self.mm_config}, " + f"transposed={self.transposed}, " + f"layout_type={layout_type})") + + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): """ @@ -680,6 +798,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout_type(self) -> LayoutType: return self.layout_type + ##################################################### # torch functional and aten operator implementation # ##################################################### @@ -892,10 +1011,69 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): return y +def _linear_fp_act_fp8_tensor_wise_weight_check( + input_tensor: torch.Tensor, + weight_tensor: AffineQuantizedTensor, + bias: Optional[torch.Tensor], +) -> bool: + def check_aqt_tensorwise(aqt: AffineQuantizedTensor) -> bool: + return ( + is_traceable_wrapper_subclass(input_tensor) and + isinstance(input_tensor, AffineQuantizedTensor) and + isinstance(aqt, AffineQuantizedTensor) and + isinstance(aqt.layout_tensor, Float8AQTLayout) + and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and aqt.shape == aqt.block_size + ) + return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor) + + +def _linear_fp_act_fp8_weight_impl( + input_tensor: AffineQuantizedTensor, + weight_tensor: AffineQuantizedTensor, + bias: Optional[torch.Tensor], +): + from torchao.float8.inference import cast_to_float8_e4m3_inference, preprocess_data + from torchao.float8.float8_tensor import ScaledMMConfig + from torchao.float8.float8_python_api import addmm_float8_unwrapped + + scaled_mm_config = weight_tensor.layout_tensor.mm_config + scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else ScaledMMConfig() + + w_layout = weight_tensor.layout_tensor + w_data = weight_tensor.layout_tensor.float8_data + w_data = w_data.T if w_layout.transposed else w_data + w_scale = w_layout.scale + w_scale = w_scale if w_layout.transposed else w_scale + + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + inpt_data = input_tensor.layout_tensor.float8_data + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1]) + input_scale = input_tensor.layout_tensor.scale + if input_scale.dim() >= 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + return addmm_float8_unwrapped( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + inverse_scale=False + ).reshape(out_shape) + + def _register_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), + (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index d906251f80..036a5ca929 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable, Union +from typing import Dict, Callable, Union, Tuple from collections import defaultdict import functools from dataclasses import dataclass @@ -143,3 +143,15 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(Layout def is_device(target_device_str: str, device: Union[str, torch.device]): return torch.device(device).type == target_device_str + +def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]: + """Returns the unflattened shape of the input tensor. + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + weight_shape: The weight tensor shape. + Returns: + The unflattened shape of the input tensor. + """ + out_dim = weight_shape[0] + inpt_dims = input_shape[:-1] + return (*inpt_dims, out_dim) diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 56c7b28f79..43065d2b8d 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -25,10 +25,13 @@ ) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp -# Needed to load Float8Tensor with weights_only = True -from torch.serialization import add_safe_globals +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) + +if TORCH_VERSION_AT_LEAST_2_5: + # Needed to load Float8Tensor with weights_only = True + from torch.serialization import add_safe_globals + add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) __all__ = [ # configuration diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py index 16e2705740..ade00a0a66 100644 --- a/torchao/float8/float8_python_api.py +++ b/torchao/float8/float8_python_api.py @@ -30,14 +30,19 @@ def addmm_float8_unwrapped( output_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_fast_accum: bool = False, + inverse_scale: bool = True ) -> torch.Tensor: """ This is the unwrapped version of addmm_float8, which does not take in Float8Tensors as inputs. This is used to standardize the logic between subclassed and non subclassed versions of the linear module. """ - a_inverse_scale = a_scale.reciprocal() - b_inverse_scale = b_scale.reciprocal() + if inverse_scale: + a_inverse_scale = a_scale.reciprocal() + b_inverse_scale = b_scale.reciprocal() + else: + a_inverse_scale = a_scale + b_inverse_scale = b_scale if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index ccf83d7cef..66f83d933a 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from enum import auto, Enum -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -242,3 +242,26 @@ def quantize_to_float8( lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), module_filter_fn=module_filter_fn, ) + +from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul + +def preprocess_data(a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig) -> Tuple[torch.Tensor, torch.Tensor]: + """ Preprocess the inner fp8 data tensors for admmm + Args: + a_data: Input tensor A. + b_data: Input tensor B. + scaled_mm_config: Configuration for _scaled_mm. + Returns: + Preprocessed tensors A and B in the format for _scaled_mm. + """ + if scaled_mm_config.pad_inner_dim: + assert a_data.size(1) == b_data.size( + 0 + ), f"Inner dims must match for mm, got {a_data.size(1)} and {b_data.size(0)}" + a_data = pad_tensor_for_matmul(a_data, dims=1) + b_data = pad_tensor_for_matmul(b_data, dims=0) + if not is_row_major(a_data.stride()): + a_data = a_data.contiguous() + if is_row_major(b_data.stride()): + b_data = b_data.t().contiguous().t() + return a_data, b_data diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index cbd8c7f525..0c99078f25 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -28,7 +28,9 @@ PlainLayoutType, AffineQuantizedTensor, SemiSparseLayoutType, - to_affine_quantized_floatx + to_affine_quantized_floatx, + Float8AQTLayout, + Float8LayoutType ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, @@ -57,6 +59,7 @@ from .utils import _get_per_token_block_size import logging from .autoquant import autoquant, AutoQuantizableLinearWeight +from torchao.float8.float8_tensor import ScaledMMConfig __all__ = [ "swap_conv2d_1x1_to_linear", @@ -156,7 +159,6 @@ def change_linear_weights_to_int4_woqtensors(model, groupsize=128, inner_k_tiles ### TO BE DEPRECATED END - def _replace_with_custom_fn_if_matches_filter( model, replacement_fn, @@ -489,19 +491,77 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): """ return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) + def float8_weight_only(target_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. + + Args: + target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + + Note: + The actual matmul will be computed in original precision of the weight tensor. + """ from torchao.dtypes import to_affine_quantized_floatx + def apply_float8wo_quant(weight): - # avoid circular dep block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx(input_float=weight, block_size=block_size, target_dtype=target_dtype) + return to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=target_dtype, + layout_type=Float8LayoutType(mm_config=None), + ) return _get_linear_subclass_inserter(apply_float8wo_quant) +def float8_dynamic_activation_float8_weight( + target_dtype: torch.dtype = torch.float8_e4m3fn, + activation_dtype: torch.dtype = torch.float8_e4m3fn, + mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True) +): + """ + Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers. + + Args: + target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. + mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + + """ + + from torchao.dtypes import to_affine_quantized_floatx + + #TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling + def apply_float8_dynamic_activation_quant(weight: torch.Tensor): + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=weight.shape, + target_dtype=target_dtype, + scale_dtype=torch.float32, + layout_type=Float8LayoutType(mm_config=None), + ) + + def input_quant_func(x: torch.Tensor): + activation = to_affine_quantized_floatx( + input_float=x, + block_size=x.shape, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + layout_type=Float8LayoutType(mm_config=None), + ) + return activation + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func + ) + return quantized_weight + + return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant) + + def uintx_weight_only(dtype, group_size=64, pack_dim=-1): """ Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c54f2a025e..bdc4eeaf88 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -166,8 +166,8 @@ def quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: """ @@ -211,7 +211,7 @@ def quantize_affine( output_dtype, quant_min, quant_max, - zero_point_domain.name, + zero_point_domain.name if zero_point_domain is not None else None, ) @@ -222,9 +222,9 @@ def _quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library """ @@ -249,9 +249,9 @@ def _quantize_affine_no_dtype_cast( block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], - quant_min: int, - quant_max: int, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -267,6 +267,11 @@ def _quantize_affine_no_dtype_cast( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) + if zero_point_domain is None: + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + quant = quant.view(original_shape) + return quant + if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max @@ -291,8 +296,8 @@ def dequantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], input_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, *, output_dtype: torch.dtype = torch.float32, @@ -326,7 +331,7 @@ def dequantize_affine( input_dtype, quant_min, quant_max, - zero_point_domain.name, + zero_point_domain.name if zero_point_domain is not None else None, output_dtype=output_dtype, ) @@ -338,9 +343,9 @@ def _dequantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], input_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library @@ -367,9 +372,9 @@ def _dequantize_affine_no_dtype_check( block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], - quant_min: int, - quant_max: int, - zero_point_domain: str = ZeroPointDomain.INT.name, + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}" @@ -379,7 +384,14 @@ def _dequantize_affine_no_dtype_check( shape_after_reduction = shape_for_reduction for i in reduction_dims: shape_after_reduction[i] = 1 - scale = scale.view(shape_after_reduction) + scale = scale.view(shape_after_reduction) + + if zero_point_domain is None: + assert zero_point is None, "zero_point should be None when zero_point_domain is None" + dequant = input.to(output_dtype) + dequant = dequant * scale + return dequant.view(original_shape).to(output_dtype) + if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) @@ -411,8 +423,8 @@ def fake_quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], quant_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: """ @@ -455,8 +467,8 @@ def fake_quantize_affine_cachemask( scale: torch.Tensor, zero_point: Optional[torch.Tensor], quant_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -498,8 +510,8 @@ def _do_fake_quantize_affine( scale: torch.Tensor, zero_point: Optional[torch.Tensor], quant_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -535,8 +547,8 @@ def choose_qparams_affine( mapping_type: MappingType, block_size: Tuple[int, ...], target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, @@ -586,7 +598,7 @@ def choose_qparams_affine( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name + zero_point_domain.name if zero_point_domain is not None else None, ) @@ -637,13 +649,13 @@ def _choose_qparams_affine( mapping_type: str, block_size: List[int], target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: str = "INT", + zero_point_domain: Optional[str] = "INT", min_val: Optional[torch.Tensor] = None, max_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -689,7 +701,7 @@ def _choose_qparams_affine( scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT.name: + if zero_point_domain is not None and zero_point_domain != ZeroPointDomain.INT.name: raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") scale = torch.clamp(scale, min=eps) zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))