From 10d038fcda78c2f8a0f63a954f5524145f33d236 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:04:23 -0700 Subject: [PATCH] [Float8Quant] Add rowwise scaling option to float8 dyanmic quant (#819) --- ruff.toml | 1 + test/dtypes/test_affine_quantized_float.py | 170 +++++++++++++++++---- torchao/dtypes/affine_quantized_tensor.py | 56 ++++--- torchao/float8/inference.py | 8 + torchao/quantization/observer.py | 12 ++ torchao/quantization/quant_api.py | 93 +++++++++-- 6 files changed, 272 insertions(+), 68 deletions(-) diff --git a/ruff.toml b/ruff.toml index 04c9e32cca..f1a48ad9a0 100644 --- a/ruff.toml +++ b/ruff.toml @@ -10,4 +10,5 @@ include = [ "torchao/float8/float8_tensor.py", "torchao/quantization/linear_activation_weight_observer.py", "test/quantization/test_observer.py", + "test/dtypes/test_affine_quantized_float.py", ] diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 7e2ce278d5..3599be9af9 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -1,34 +1,30 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, - unwrap_tensor_subclass, ) import pytest if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from numpy import full -from torch.testing._internal.common_utils import ( - run_tests, -) from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils -from torch._dynamo.testing import CompileCounterWithBackend from torchao.quantization import ( quantize_, float8_weight_only, float8_dynamic_activation_float8_weight, ) +from torchao.quantization.observer import PerTensor, PerRow from torchao.float8.float8_utils import compute_error import torch import unittest import pytest -import tempfile import copy import random - -from unittest.mock import patch +from functools import partial +from typing import Tuple +from contextlib import nullcontext +import io random.seed(0) @@ -56,6 +52,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize( + "granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()] + ) # Inputs are (M,..), K, N @common_utils.parametrize( "sizes", @@ -68,33 +67,142 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): ], ) def test_fp8_linear_variants( - self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple + self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity ): - M, N, K = sizes - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") - - mode_map = { - "dynamic": float8_dynamic_activation_float8_weight, - "weight-only": float8_weight_only, - } + raises = ( + isinstance(granularity, PerRow) + and mode == "dynamic" + and dtype != torch.bfloat16 + ) + context = ( + nullcontext() + if not raises + else pytest.raises( + AssertionError, + match="PerRow quantization only works for bfloat16 precision", + ) + ) + with context: + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + mode_map = { + "dynamic": partial( + float8_dynamic_activation_float8_weight, granularity=granularity + ), + "weight-only": float8_weight_only, + } + + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + quantized_model = copy.deepcopy(model) + factory = mode_map[mode]() + quantize_(model, factory) + + if compile: + quantized_model = torch.compile(quantized_model, fullgraph=True) + + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + + error = compute_error(output_original, output_quantized) + assert ( + compute_error(output_original, output_quantized) > 20 + ), f"Quantization error is too high got a SQNR of {error}" + + def test_invalid_granularity(self): + with pytest.raises(ValueError, match="Invalid granularity specification"): + float8_dynamic_activation_float8_weight(granularity="invalid") + + def test_mismatched_granularity(self): + with pytest.raises( + ValueError, + match="Different granularities for activation and weight are not supported", + ): + float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) + + def test_unsupported_granularity(self): + class UnsupportedGranularity: + pass + + with pytest.raises(ValueError, match="Invalid granularity types"): + float8_dynamic_activation_float8_weight( + granularity=(UnsupportedGranularity(), UnsupportedGranularity()) + ) - # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + def test_per_row_with_float32(self): + with pytest.raises( + AssertionError, + match="PerRow quantization only works for bfloat16 precision", + ): + model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") + quantize_( + model, float8_dynamic_activation_float8_weight(granularity=PerRow()) + ) - quantized_model = copy.deepcopy(model) - factory = mode_map[mode]() + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @common_utils.parametrize("mode", ["dynamic", "weight-only"]) + def test_serialization(self, mode: str): + # Create and quantize the model + model = ToyLinearModel(16, 32).to(device="cuda") + if mode == "dynamic": + factory = float8_dynamic_activation_float8_weight() + else: + factory = float8_weight_only() quantize_(model, factory) - if compile: - quantized_model = torch.compile(quantized_model, fullgraph=True) - - output_original = model(input_tensor) - output_quantized = quantized_model(input_tensor) - - error = compute_error(output_original, output_quantized) - assert ( - compute_error(output_original, output_quantized) > 20 - ), f"Quantization error is too high got a SQNR of {error}" + # Save the state dict to an in-memory buffer + buffer = io.BytesIO() + torch.save(model.state_dict(), buffer) + + # Reset the buffer position + buffer.seek(0) + + # Load the state dict from the buffer + loaded_state_dict = torch.load(buffer) + + # Create a new model and load the state dict + with torch.device("meta"): + new_model = ToyLinearModel(16, 32) + new_model.load_state_dict(loaded_state_dict, assign=True) + + # Compare the original and loaded models + if mode == "weight-only": + model_weight_1 = model.linear1.weight.layout_tensor.float8_data.to( + torch.float32 + ) + new_model_weight_1 = new_model.linear1.weight.layout_tensor.float8_data.to( + torch.float32 + ) + + model_weight_2 = model.linear2.weight.layout_tensor.float8_data.to( + torch.float32 + ) + new_model_weight_2 = new_model.linear2.weight.layout_tensor.float8_data.to( + torch.float32 + ) + + else: + model_weight_1 = model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to( + torch.float32 + ) + new_model_weight_1 = new_model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to( + torch.float32 + ) + + model_weight_2 = model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to( + torch.float32 + ) + new_model_weight_2 = new_model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to( + torch.float32 + ) + + assert torch.allclose(model_weight_1, new_model_weight_1) + assert torch.allclose(model_weight_2, new_model_weight_2) common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index c6a3730859..1ec6421ea6 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -26,6 +26,12 @@ is_device, get_out_shape, ) +from torchao.float8.inference import ( + preprocess_data, + Float8MMConfig, + addmm_float8_unwrapped_inference, + _is_rowwise_scaled +) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass from torchao.utils import ( @@ -1354,20 +1360,29 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -def _linear_fp_act_fp8_tensor_wise_weight_check( +def _linear_fp_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], ) -> bool: - def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: + def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and isinstance(aqt.layout_type, Float8LayoutType) and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and aqt.shape == aqt.block_size + and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) - return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor) + return check_aqt(input_tensor) and check_aqt(weight_tensor) + +def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): + """ Ensures input tensor is correctly formated for _scaled_mm """ + input_scale = input_scale.unsqueeze(-1) + + if input_scale.dim() > 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + return input_scale def _linear_fp_act_fp8_weight_impl( input_tensor: AffineQuantizedTensor, @@ -1375,32 +1390,31 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - ) - scaled_mm_config = weight_tensor.layout_type.mm_config - scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig() + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + # Weight tensor preprocessing w_layout = weight_tensor.layout_tensor - w_data = weight_tensor.layout_tensor.float8_data - w_data = w_data.T if w_layout.transposed else w_data + assert not w_layout.transposed, "Weight tensor must be contiguous" + w_data = w_layout.float8_data w_scale = w_layout.scale - w_scale = w_scale if w_layout.transposed else w_scale - - out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + # Input tensor preprocessing inpt_data = input_tensor.layout_tensor.float8_data - # Handle case where input tensor is more than 2D - inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1]) input_scale = input_tensor.layout_tensor.scale - if input_scale.dim() > 2: - input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, input_tensor.shape) + # Preprocess data inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + # Perform the computation return addmm_float8_unwrapped_inference( inpt_data, input_scale, @@ -1459,7 +1473,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f1d56ba2dd..816a55ee61 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference( use_fast_accum=use_fast_accum, ) return output + + +def _is_rowwise_scaled(x) -> bool: + """Checks if an AQT tensor is rowwise scaled + Args: + x: AffineQuantizedTensor tensor + """ + return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 08a3eacf6b..596bf403de 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -53,6 +53,16 @@ class PerAxis(GranularityType): """ axis: int +@dataclass(frozen=True) +class PerRow(GranularityType): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + pass # borrowed from torch.ao.quantization.observer class _PartialWrapper: @@ -104,6 +114,8 @@ def get_block_size( block_size = list(input_shape) block_size[granularity_type.axis] = 1 return tuple(block_size) + elif isinstance(granularity_type, PerRow): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) raise ValueError(f"Unsupported GranularityType: {granularity_type}") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bbee8589b1..9516da9763 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -19,12 +19,13 @@ import torchao import torch.nn as nn import torch.nn.functional as F -from typing import Any, Callable, Union, Dict, Optional +from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple import types from torchao.dtypes.uintx.Uintx import UintxLayoutType from torchao.dtypes import ( to_affine_quantized_intx, + to_affine_quantized_floatx, TensorCoreTiledLayoutType, PlainLayoutType, AffineQuantizedTensor, @@ -65,6 +66,8 @@ ) from torchao.float8.inference import Float8MMConfig +from torchao.quantization.observer import PerTensor, PerRow, get_block_size + logger = logging.getLogger(__name__) __all__ = [ @@ -641,44 +644,102 @@ def apply_float8wo_quant(weight): return _get_linear_subclass_inserter(apply_float8wo_quant) +_fp8_granularities = Union[PerTensor, PerRow] + + +# Validate and process granularity input +def _normalize_granularity( + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] +) -> Tuple[_fp8_granularities, _fp8_granularities]: + if granularity is None: + return (PerTensor(), PerTensor()) + elif isinstance(granularity, (PerTensor, PerRow)): + return (granularity, granularity) + elif isinstance(granularity, tuple) and len(granularity) == 2: + if not ( + isinstance(granularity[0], (PerTensor, PerRow)) + and isinstance(granularity[1], (PerTensor, PerRow)) + ): + raise ValueError(f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported.") + if not isinstance(granularity[0], type(granularity[1])): + raise ValueError( + f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + ) + return granularity + else: + raise ValueError(f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported.") + + +def _input_quant_func_dyanmic_fp8( + x: torch.Tensor, + activation_granularity: _fp8_granularities, + activation_dtype: torch.dtype, +): + if isinstance(activation_granularity, PerRow): + assert ( + x.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input activation" + + block_size = get_block_size(x.shape, activation_granularity) + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight + ) + return activation + + def float8_dynamic_activation_float8_weight( activation_dtype: torch.dtype = torch.float8_e4m3fn, weight_dtype: torch.dtype = torch.float8_e4m3fn, - mm_config: Optional[Float8MMConfig] = None + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None, + mm_config: Optional[Float8MMConfig] = None, ): """ - Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers. + Applies float8 dynamic symmetric quantization to both activations and weights of linear layers. Args: activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + granularity: + The granularity for quantization. Can be either a single granularity (applied to both + activations and weights) or a tuple of two granularities (one for activations, one for weights). + If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And + only PerTensor and PerRow are supported. mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ - from torchao.dtypes import to_affine_quantized_floatx - if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) - #TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling + activation_granularity, weight_granularity = _normalize_granularity(granularity) + def apply_float8_dynamic_activation_quant(weight: torch.Tensor): + if isinstance(weight_granularity, PerRow): + assert ( + weight.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input weight" + + block_size = get_block_size(weight.shape, weight_granularity) quantized_weight = to_affine_quantized_floatx( input_float=weight, - block_size=weight.shape, + block_size=block_size, target_dtype=weight_dtype, scale_dtype=torch.float32, layout_type=Float8LayoutType(mm_config=mm_config), ) - def input_quant_func(x: torch.Tensor): - activation = to_affine_quantized_floatx( - input_float=x, - block_size=x.shape, - target_dtype=activation_dtype, - scale_dtype=torch.float32, - layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight - ) - return activation + input_quant_func = partial( + _input_quant_func_dyanmic_fp8, + activation_granularity=activation_granularity, + activation_dtype=activation_dtype, + ) quantized_weight = to_linear_activation_quantized( quantized_weight, input_quant_func