Skip to content

Commit

Permalink
[Float8Quant] Add rowwise scaling option to float8 dyanmic quant
Browse files Browse the repository at this point in the history
stack-info: PR: #819, branch: drisspg/stack/11
  • Loading branch information
drisspg committed Sep 5, 2024
1 parent a246d87 commit 45997fe
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 29 deletions.
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ include = [
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
"test/dtypes/test_affine_quantized_float.py"
]
24 changes: 13 additions & 11 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,28 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)
import pytest

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from numpy import full
from torch.testing._internal.common_utils import (
run_tests,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.quantization import (
quantize_,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, RowWise
from torchao.float8.float8_utils import compute_error
import torch
import unittest
import pytest
import tempfile
import copy
import random

from unittest.mock import patch
from functools import partial
from typing import Tuple


random.seed(0)
Expand Down Expand Up @@ -56,6 +50,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor, RowWise])
# Inputs are (M,..), K, N
@common_utils.parametrize(
"sizes",
Expand All @@ -68,13 +63,20 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
],
)
def test_fp8_linear_variants(
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
):
if granularity == RowWise and mode == "dynamic":
pytest.skip(
"RowWise quantization only works for bfloat16 precision input weight and activation for now"
)

M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

mode_map = {
"dynamic": float8_dynamic_activation_float8_weight,
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=granularity
),
"weight-only": float8_weight_only,
}

Expand Down
47 changes: 32 additions & 15 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,21 +1161,32 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):

return out.view(*act.shape[:-1], out_dim).to(act.dtype)

def _linear_fp_act_fp8_tensor_wise_weight_check(
def _linear_fp_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
bias: Optional[torch.Tensor],
) -> bool:
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
row_block_size = (1,) * (aqt.dim() - 1) + (aqt.shape[-1],)
return (
isinstance(aqt, AffineQuantizedTensor) and
isinstance(aqt.layout_type, Float8LayoutType)
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and aqt.shape == aqt.block_size
and (aqt.shape == aqt.block_size or aqt.block_size == row_block_size)
)
return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor)
return check_aqt(input_tensor) and check_aqt(weight_tensor)


def preprocess_input_tensor(inpt_data: torch.Tensor, input_scale: torch.Tensor, input_shape: Tuple[int]):
""" Ensures input tensor is correctly formated for _scaled_mm """
if input_scale.size(0) != 1:
input_scale = input_scale.unsqueeze(-1)

if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])

return inpt_data, input_scale

def _linear_fp_act_fp8_weight_impl(
input_tensor: AffineQuantizedTensor,
weight_tensor: AffineQuantizedTensor,
Expand All @@ -1186,28 +1197,34 @@ def _linear_fp_act_fp8_weight_impl(
preprocess_data,
Float8MMConfig,
addmm_float8_unwrapped_inference,
_is_rowwise_scaled
)

scaled_mm_config = weight_tensor.layout_type.mm_config
scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig()
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Weight tensor preprocessing
w_layout = weight_tensor.layout_tensor
w_data = weight_tensor.layout_tensor.float8_data
w_data = w_data.T if w_layout.transposed else w_data
assert not w_layout.transposed, "Weight tensor must be contiguous"
w_data = w_layout.float8_data
w_scale = w_layout.scale
w_scale = w_scale if w_layout.transposed else w_scale

out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Input tensor preprocessing
inpt_data = input_tensor.layout_tensor.float8_data
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
input_scale = input_tensor.layout_tensor.scale
if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])

# Handle rowwise case
if _is_rowwise_scaled(weight_tensor):
assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size"
w_scale = w_scale.unsqueeze(-1).T
inpt_data, input_scale = preprocess_input_tensor(inpt_data, input_scale, input_tensor.shape)

# Preprocess data
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

# Perform the computation
return addmm_float8_unwrapped_inference(
inpt_data,
input_scale,
Expand All @@ -1223,7 +1240,7 @@ def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),
Expand Down
8 changes: 8 additions & 0 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference(
use_fast_accum=use_fast_accum,
)
return output


def _is_rowwise_scaled(x) -> bool:
"""Checks if an AQT tensor is rowwise scaled
Args:
x: AffineQuantizedTensor tensor
"""
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)
11 changes: 11 additions & 0 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ class PerAxis(GranularityType):
"""
axis: int

@dataclass(frozen=True)
class RowWise(GranularityType):
"""
Represents row-wise granularity in quantization.
This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, input.shape[1]). And the weight
is quantized with a block_size of (weight.shape[0], 1).
"""
pass

# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
def __init__(self, p):
Expand Down
24 changes: 21 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torchao
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional
from typing import Any, Callable, Union, Dict, Optional, Literal
import types

from torchao.dtypes.uintx.Uintx import UintxLayoutType
Expand Down Expand Up @@ -60,6 +60,8 @@
from .autoquant import autoquant, AutoQuantizableLinearWeight
from torchao.float8.inference import Float8MMConfig

from torchao.quantization.observer import GranularityType, PerTensor, PerAxis, RowWise

logger = logging.getLogger(__name__)

__all__ = [
Expand Down Expand Up @@ -550,6 +552,7 @@ def apply_float8wo_quant(weight):
def float8_dynamic_activation_float8_weight(
activation_dtype: torch.dtype = torch.float8_e4m3fn,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
granularity: Literal[PerTensor, RowWise] = PerTensor,
mm_config: Optional[Float8MMConfig] = None
):
"""
Expand All @@ -566,20 +569,35 @@ def float8_dynamic_activation_float8_weight(
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)

def get_block_size(x: torch.Tensor, granularity: Literal[PerTensor, RowWise]):
if granularity == PerTensor:
return x.shape
elif granularity == RowWise:
return (1,) * (x.dim() - 1) + (x.shape[-1],)
else:
raise ValueError(f"Unsupported granularity: {granularity}")

#TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
if granularity == RowWise:
assert weight.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation"

block_size = get_block_size(weight, granularity)
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=weight.shape,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=mm_config),
)

def input_quant_func(x: torch.Tensor):
if granularity == RowWise:
assert weight.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation"
block_size = get_block_size(x, granularity)
activation = to_affine_quantized_floatx(
input_float=x,
block_size=x.shape,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
Expand Down

0 comments on commit 45997fe

Please sign in to comment.