From 06a56280b906c60f3ceca5051277018daaf0a500 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 5 Sep 2024 15:33:09 -0700 Subject: [PATCH] [Float8Quant] Add rowwise scaling option to float8 dyanmic quant stack-info: PR: https://github.com/pytorch/ao/pull/819, branch: drisspg/stack/11 --- ruff.toml | 1 + test/dtypes/test_affine_quantized_float.py | 24 ++++----- torchao/dtypes/affine_quantized_tensor.py | 57 ++++++++++++++-------- torchao/float8/inference.py | 8 +++ torchao/quantization/observer.py | 11 +++++ torchao/quantization/quant_api.py | 24 +++++++-- 6 files changed, 90 insertions(+), 35 deletions(-) diff --git a/ruff.toml b/ruff.toml index dee9710df4..b3a283a317 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,4 +8,5 @@ include = [ "torchao/dtypes/nf4tensor.py", "test/dtypes/test_nf4.py", "torchao/float8/float8_tensor.py", + "test/dtypes/test_affine_quantized_float.py" ] diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 7e2ce278d5..64ae5e7a6b 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -1,34 +1,28 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, - unwrap_tensor_subclass, ) import pytest if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from numpy import full -from torch.testing._internal.common_utils import ( - run_tests, -) from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils -from torch._dynamo.testing import CompileCounterWithBackend from torchao.quantization import ( quantize_, float8_weight_only, float8_dynamic_activation_float8_weight, ) +from torchao.quantization.observer import PerTensor, RowWise from torchao.float8.float8_utils import compute_error import torch import unittest import pytest -import tempfile import copy import random - -from unittest.mock import patch +from functools import partial +from typing import Tuple random.seed(0) @@ -56,6 +50,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize("granularity", [PerTensor, RowWise]) # Inputs are (M,..), K, N @common_utils.parametrize( "sizes", @@ -68,13 +63,20 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): ], ) def test_fp8_linear_variants( - self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple + self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity ): + if granularity == RowWise and mode == "dynamic": + pytest.skip( + "RowWise quantization only works for bfloat16 precision input weight and activation for now" + ) + M, N, K = sizes input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") mode_map = { - "dynamic": float8_dynamic_activation_float8_weight, + "dynamic": partial( + float8_dynamic_activation_float8_weight, granularity=granularity + ), "weight-only": float8_weight_only, } diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index f4e5446ba5..b0389a5872 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -26,6 +26,12 @@ is_device, get_out_shape, ) +from torchao.float8.inference import ( + preprocess_data, + Float8MMConfig, + addmm_float8_unwrapped_inference, + _is_rowwise_scaled +) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass from torchao.utils import ( @@ -1161,20 +1167,30 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -def _linear_fp_act_fp8_tensor_wise_weight_check( +def _linear_fp_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], ) -> bool: - def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: + def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and isinstance(aqt.layout_type, Float8LayoutType) and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and aqt.shape == aqt.block_size + and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) - return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor) + return check_aqt(input_tensor) and check_aqt(weight_tensor) + +def preprocess_input_tensor(inpt_data: torch.Tensor, input_scale: torch.Tensor, input_shape: Tuple[int]): + """ Ensures input tensor is correctly formated for _scaled_mm """ + if input_scale.size(0) != 1: + input_scale = input_scale.unsqueeze(-1) + + if input_scale.dim() > 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + return inpt_data, input_scale def _linear_fp_act_fp8_weight_impl( input_tensor: AffineQuantizedTensor, @@ -1182,32 +1198,31 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - ) - scaled_mm_config = weight_tensor.layout_type.mm_config - scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig() + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + # Weight tensor preprocessing w_layout = weight_tensor.layout_tensor - w_data = weight_tensor.layout_tensor.float8_data - w_data = w_data.T if w_layout.transposed else w_data + assert not w_layout.transposed, "Weight tensor must be contiguous" + w_data = w_layout.float8_data w_scale = w_layout.scale - w_scale = w_scale if w_layout.transposed else w_scale - - out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + # Input tensor preprocessing inpt_data = input_tensor.layout_tensor.float8_data - # Handle case where input tensor is more than 2D - inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1]) input_scale = input_tensor.layout_tensor.scale - if input_scale.dim() > 2: - input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + inpt_data, input_scale = preprocess_input_tensor(inpt_data, input_scale, input_tensor.shape) + # Preprocess data inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + # Perform the computation return addmm_float8_unwrapped_inference( inpt_data, input_scale, @@ -1223,7 +1238,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f1d56ba2dd..816a55ee61 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference( use_fast_accum=use_fast_accum, ) return output + + +def _is_rowwise_scaled(x) -> bool: + """Checks if an AQT tensor is rowwise scaled + Args: + x: AffineQuantizedTensor tensor + """ + return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 10e0113bfc..b921ecf532 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -52,6 +52,17 @@ class PerAxis(GranularityType): """ axis: int +@dataclass(frozen=True) +class RowWise(GranularityType): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, input.shape[1]). And the weight + is quantized with a block_size of (weight.shape[0], 1). + """ + pass + # borrowed from torch.ao.quantization.observer class _PartialWrapper: def __init__(self, p): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 89bccf1264..4f071f6b5a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -19,7 +19,7 @@ import torchao import torch.nn as nn import torch.nn.functional as F -from typing import Any, Callable, Union, Dict, Optional +from typing import Any, Callable, Union, Dict, Optional, Literal import types from torchao.dtypes.uintx.Uintx import UintxLayoutType @@ -60,6 +60,8 @@ from .autoquant import autoquant, AutoQuantizableLinearWeight from torchao.float8.inference import Float8MMConfig +from torchao.quantization.observer import GranularityType, PerTensor, PerAxis, RowWise + logger = logging.getLogger(__name__) __all__ = [ @@ -550,6 +552,7 @@ def apply_float8wo_quant(weight): def float8_dynamic_activation_float8_weight( activation_dtype: torch.dtype = torch.float8_e4m3fn, weight_dtype: torch.dtype = torch.float8_e4m3fn, + granularity: Literal[PerTensor, RowWise] = PerTensor, mm_config: Optional[Float8MMConfig] = None ): """ @@ -566,20 +569,35 @@ def float8_dynamic_activation_float8_weight( if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) + def get_block_size(x: torch.Tensor, granularity: Literal[PerTensor, RowWise]): + if granularity == PerTensor: + return x.shape + elif granularity == RowWise: + return (1,) * (x.dim() - 1) + (x.shape[-1],) + else: + raise ValueError(f"Unsupported granularity: {granularity}") + #TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling def apply_float8_dynamic_activation_quant(weight: torch.Tensor): + if granularity == RowWise: + assert weight.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation" + + block_size = get_block_size(weight, granularity) quantized_weight = to_affine_quantized_floatx( input_float=weight, - block_size=weight.shape, + block_size=block_size, target_dtype=weight_dtype, scale_dtype=torch.float32, layout_type=Float8LayoutType(mm_config=mm_config), ) def input_quant_func(x: torch.Tensor): + if granularity == RowWise: + assert x.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation" + block_size = get_block_size(x, granularity) activation = to_affine_quantized_floatx( input_float=x, - block_size=x.shape, + block_size=block_size, target_dtype=activation_dtype, scale_dtype=torch.float32, layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight