From aac19a10160e7fb5b26b1315ed5f05cb79c7ab4a Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 5 Sep 2024 15:33:09 -0700 Subject: [PATCH] [Float8Quant] Add rowwise scaling option to float8 dyanmic quant stack-info: PR: https://github.com/pytorch/ao/pull/819, branch: drisspg/stack/11 --- ruff.toml | 1 + test/dtypes/test_affine_quantized_float.py | 114 +++++++++++++++------ torchao/dtypes/affine_quantized_tensor.py | 56 ++++++---- torchao/float8/inference.py | 8 ++ torchao/quantization/observer.py | 10 ++ torchao/quantization/quant_api.py | 75 ++++++++++++-- 6 files changed, 202 insertions(+), 62 deletions(-) diff --git a/ruff.toml b/ruff.toml index 04c9e32cca..f1a48ad9a0 100644 --- a/ruff.toml +++ b/ruff.toml @@ -10,4 +10,5 @@ include = [ "torchao/float8/float8_tensor.py", "torchao/quantization/linear_activation_weight_observer.py", "test/quantization/test_observer.py", + "test/dtypes/test_affine_quantized_float.py", ] diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 7e2ce278d5..aae164e828 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -1,34 +1,29 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, - unwrap_tensor_subclass, ) import pytest if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from numpy import full -from torch.testing._internal.common_utils import ( - run_tests, -) from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils -from torch._dynamo.testing import CompileCounterWithBackend from torchao.quantization import ( quantize_, float8_weight_only, float8_dynamic_activation_float8_weight, ) +from torchao.quantization.observer import PerTensor, PerRow from torchao.float8.float8_utils import compute_error import torch import unittest import pytest -import tempfile import copy import random - -from unittest.mock import patch +from functools import partial +from typing import Tuple +from contextlib import nullcontext random.seed(0) @@ -56,6 +51,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize( + "granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()] + ) # Inputs are (M,..), K, N @common_utils.parametrize( "sizes", @@ -68,33 +66,81 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): ], ) def test_fp8_linear_variants( - self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple + self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity ): - M, N, K = sizes - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") - - mode_map = { - "dynamic": float8_dynamic_activation_float8_weight, - "weight-only": float8_weight_only, - } - - # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + raises = ( + isinstance(granularity, PerRow) + and mode == "dynamic" + and dtype != torch.bfloat16 + ) + context = ( + nullcontext() + if not raises + else pytest.raises( + AssertionError, + match="PerRow quantization only works for bfloat16 precision", + ) + ) + with context: + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + mode_map = { + "dynamic": partial( + float8_dynamic_activation_float8_weight, granularity=granularity + ), + "weight-only": float8_weight_only, + } + + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + quantized_model = copy.deepcopy(model) + factory = mode_map[mode]() + quantize_(model, factory) + + if compile: + quantized_model = torch.compile(quantized_model, fullgraph=True) + + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + + error = compute_error(output_original, output_quantized) + assert ( + compute_error(output_original, output_quantized) > 20 + ), f"Quantization error is too high got a SQNR of {error}" + + def test_invalid_granularity(self): + with pytest.raises(ValueError, match="Invalid granularity specification"): + float8_dynamic_activation_float8_weight(granularity="invalid") + + def test_mismatched_granularity(self): + with pytest.raises( + ValueError, + match="Different granularities for activation and weight are not supported", + ): + float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) + + def test_unsupported_granularity(self): + class UnsupportedGranularity: + pass + + with pytest.raises(ValueError, match="Invalid granularity types"): + float8_dynamic_activation_float8_weight( + granularity=(UnsupportedGranularity(), UnsupportedGranularity()) + ) - quantized_model = copy.deepcopy(model) - factory = mode_map[mode]() - quantize_(model, factory) - - if compile: - quantized_model = torch.compile(quantized_model, fullgraph=True) - - output_original = model(input_tensor) - output_quantized = quantized_model(input_tensor) - - error = compute_error(output_original, output_quantized) - assert ( - compute_error(output_original, output_quantized) > 20 - ), f"Quantization error is too high got a SQNR of {error}" + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + def test_per_row_with_float32(self): + with pytest.raises( + AssertionError, + match="PerRow quantization only works for bfloat16 precision", + ): + model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") + quantize_( + model, float8_dynamic_activation_float8_weight(granularity=PerRow()) + ) common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 784c3c5d87..47fe91e8a4 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -27,6 +27,12 @@ is_device, get_out_shape, ) +from torchao.float8.inference import ( + preprocess_data, + Float8MMConfig, + addmm_float8_unwrapped_inference, + _is_rowwise_scaled +) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass from torchao.utils import ( @@ -1355,20 +1361,29 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -def _linear_fp_act_fp8_tensor_wise_weight_check( +def _linear_fp_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], ) -> bool: - def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: + def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and isinstance(aqt.layout_type, Float8LayoutType) and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and aqt.shape == aqt.block_size + and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) - return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor) + return check_aqt(input_tensor) and check_aqt(weight_tensor) + +def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): + """ Ensures input tensor is correctly formated for _scaled_mm """ + input_scale = input_scale.unsqueeze(-1) + + if input_scale.dim() > 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + return input_scale def _linear_fp_act_fp8_weight_impl( input_tensor: AffineQuantizedTensor, @@ -1376,32 +1391,31 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - ) - scaled_mm_config = weight_tensor.layout_type.mm_config - scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig() + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + # Weight tensor preprocessing w_layout = weight_tensor.layout_tensor - w_data = weight_tensor.layout_tensor.float8_data - w_data = w_data.T if w_layout.transposed else w_data + assert not w_layout.transposed, "Weight tensor must be contiguous" + w_data = w_layout.float8_data w_scale = w_layout.scale - w_scale = w_scale if w_layout.transposed else w_scale - - out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + # Input tensor preprocessing inpt_data = input_tensor.layout_tensor.float8_data - # Handle case where input tensor is more than 2D - inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1]) input_scale = input_tensor.layout_tensor.scale - if input_scale.dim() > 2: - input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, input_tensor.shape) + # Preprocess data inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + # Perform the computation return addmm_float8_unwrapped_inference( inpt_data, input_scale, @@ -1458,7 +1472,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f1d56ba2dd..816a55ee61 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference( use_fast_accum=use_fast_accum, ) return output + + +def _is_rowwise_scaled(x) -> bool: + """Checks if an AQT tensor is rowwise scaled + Args: + x: AffineQuantizedTensor tensor + """ + return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 08a3eacf6b..1ea8134ae2 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -53,6 +53,16 @@ class PerAxis(GranularityType): """ axis: int +@dataclass(frozen=True) +class PerRow(GranularityType): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + pass # borrowed from torch.ao.quantization.observer class _PartialWrapper: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8b48c66c29..c8508ab9e4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -19,7 +19,7 @@ import torchao import torch.nn as nn import torch.nn.functional as F -from typing import Any, Callable, Union, Dict, Optional +from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple import types from torchao.dtypes.uintx.Uintx import UintxLayoutType @@ -65,6 +65,8 @@ ) from torchao.float8.inference import Float8MMConfig +from torchao.quantization.observer import PerTensor, PerAxis, PerRow + logger = logging.getLogger(__name__) __all__ = [ @@ -641,17 +643,53 @@ def apply_float8wo_quant(weight): return _get_linear_subclass_inserter(apply_float8wo_quant) +_fp8_granularities = Literal[PerTensor, PerRow] + + +# Validate and process granularity input +def _validate_granularity( + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] +) -> Tuple[_fp8_granularities, _fp8_granularities]: + if granularity is None: + return (PerTensor(), PerTensor()) + elif isinstance(granularity, (PerTensor, PerRow)): + return (granularity, granularity) + elif isinstance(granularity, tuple) and len(granularity) == 2: + if not ( + isinstance(granularity[0], (PerTensor, PerRow)) + and isinstance(granularity[1], (PerTensor, PerRow)) + ): + raise ValueError(f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported.") + if type(granularity[0]) != type(granularity[1]): + raise ValueError( + f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + ) + return granularity + else: + raise ValueError(f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported.") + + def float8_dynamic_activation_float8_weight( activation_dtype: torch.dtype = torch.float8_e4m3fn, weight_dtype: torch.dtype = torch.float8_e4m3fn, - mm_config: Optional[Float8MMConfig] = None + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None, + mm_config: Optional[Float8MMConfig] = None, ): """ - Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers. + Applies float8 dynamic symmetric quantization to both activations and weights of linear layers. Args: activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + granularity: + The granularity for quantization. Can be either a single granularity (applied to both + activations and weights) or a tuple of two granularities (one for activations, one for weights). + If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And + only PerTensor and PerRow are supported. mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ @@ -660,23 +698,46 @@ def float8_dynamic_activation_float8_weight( if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) - #TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling + activation_granularity, weight_granularity = _validate_granularity(granularity) + + def get_block_size(x: torch.Tensor, granularity: _fp8_granularities): + if isinstance(granularity, PerTensor): + return x.shape + elif isinstance(granularity, PerRow): + return (1,) * (x.dim() - 1) + (x.shape[-1],) + else: + raise ValueError(f"Unsupported granularity: {granularity}") + def apply_float8_dynamic_activation_quant(weight: torch.Tensor): + if isinstance(weight_granularity, PerRow): + assert ( + weight.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input weight" + + block_size = get_block_size(weight, weight_granularity) quantized_weight = to_affine_quantized_floatx( input_float=weight, - block_size=weight.shape, + block_size=block_size, target_dtype=weight_dtype, scale_dtype=torch.float32, layout_type=Float8LayoutType(mm_config=mm_config), ) def input_quant_func(x: torch.Tensor): + if isinstance(activation_granularity, PerRow): + assert ( + x.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input activation" + + block_size = get_block_size(x, activation_granularity) activation = to_affine_quantized_floatx( input_float=x, - block_size=x.shape, + block_size=block_size, target_dtype=activation_dtype, scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight + layout_type=Float8LayoutType( + mm_config=None + ), # Config is stored on weight ) return activation