Skip to content

Commit

Permalink
[Float8Quant] Add rowwise scaling option to float8 dyanmic quant
Browse files Browse the repository at this point in the history
stack-info: PR: #819, branch: drisspg/stack/11
  • Loading branch information
drisspg committed Sep 6, 2024
1 parent 65d86c6 commit 5d3683d
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 59 deletions.
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ include = [
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
]
82 changes: 47 additions & 35 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,29 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)
import pytest

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from numpy import full
from torch.testing._internal.common_utils import (
run_tests,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.quantization import (
quantize_,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, RowWise
from torchao.float8.float8_utils import compute_error
import torch
import unittest
import pytest
import tempfile
import copy
import random

from unittest.mock import patch
from functools import partial
from typing import Tuple
from contextlib import nullcontext


random.seed(0)
Expand Down Expand Up @@ -56,6 +51,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"granularity", [PerTensor, RowWise] if is_H100 else [PerTensor]
)
# Inputs are (M,..), K, N
@common_utils.parametrize(
"sizes",
Expand All @@ -68,33 +66,47 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
],
)
def test_fp8_linear_variants(
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
):
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

mode_map = {
"dynamic": float8_dynamic_activation_float8_weight,
"weight-only": float8_weight_only,
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
quantize_(model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
assert (
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"
raises = (
granularity == RowWise and mode == "dynamic" and dtype != torch.bfloat16
)
context = (
nullcontext()
if not raises
else pytest.raises(
AssertionError,
match="RowWise quantization only works for bfloat16 precision input weight and activation",
)
)
with context:
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=granularity
),
"weight-only": float8_weight_only,
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
quantize_(model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
assert (
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
Expand Down
56 changes: 35 additions & 21 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
is_device,
get_out_shape,
)
from torchao.float8.inference import (
preprocess_data,
Float8MMConfig,
addmm_float8_unwrapped_inference,
_is_rowwise_scaled
)
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from dataclasses import dataclass
from torchao.utils import (
Expand Down Expand Up @@ -1355,53 +1361,61 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):

return out.view(*act.shape[:-1], out_dim).to(act.dtype)

def _linear_fp_act_fp8_tensor_wise_weight_check(
def _linear_fp_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
bias: Optional[torch.Tensor],
) -> bool:
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
return (
isinstance(aqt, AffineQuantizedTensor) and
isinstance(aqt.layout_type, Float8LayoutType)
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and aqt.shape == aqt.block_size
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
)
return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor)
return check_aqt(input_tensor) and check_aqt(weight_tensor)


def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
""" Ensures input tensor is correctly formated for _scaled_mm """
input_scale = input_scale.unsqueeze(-1)

if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])

return input_scale

def _linear_fp_act_fp8_weight_impl(
input_tensor: AffineQuantizedTensor,
weight_tensor: AffineQuantizedTensor,
bias: Optional[torch.Tensor],
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
from torchao.float8.inference import (
preprocess_data,
Float8MMConfig,
addmm_float8_unwrapped_inference,
)

scaled_mm_config = weight_tensor.layout_type.mm_config
scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig()
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Weight tensor preprocessing
w_layout = weight_tensor.layout_tensor
w_data = weight_tensor.layout_tensor.float8_data
w_data = w_data.T if w_layout.transposed else w_data
assert not w_layout.transposed, "Weight tensor must be contiguous"
w_data = w_layout.float8_data
w_scale = w_layout.scale
w_scale = w_scale if w_layout.transposed else w_scale

out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Input tensor preprocessing
inpt_data = input_tensor.layout_tensor.float8_data
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
input_scale = input_tensor.layout_tensor.scale
if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])

# Handle rowwise case
if _is_rowwise_scaled(weight_tensor):
assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size"
w_scale = w_scale.unsqueeze(-1).T
input_scale = preprocess_scale(input_scale, input_tensor.shape)

# Preprocess data
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

# Perform the computation
return addmm_float8_unwrapped_inference(
inpt_data,
input_scale,
Expand Down Expand Up @@ -1458,7 +1472,7 @@ def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),
Expand Down
8 changes: 8 additions & 0 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference(
use_fast_accum=use_fast_accum,
)
return output


def _is_rowwise_scaled(x) -> bool:
"""Checks if an AQT tensor is rowwise scaled
Args:
x: AffineQuantizedTensor tensor
"""
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)
10 changes: 10 additions & 0 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class PerAxis(GranularityType):
"""
axis: int

@dataclass(frozen=True)
class RowWise(GranularityType):
"""
Represents row-wise granularity in quantization.
This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, input.shape[1]). And the weight
is quantized with a block_size of (weight.shape[0], 1).
"""
pass

# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
Expand Down
24 changes: 21 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torchao
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional
from typing import Any, Callable, Union, Dict, Optional, Literal
import types

from torchao.dtypes.uintx.Uintx import UintxLayoutType
Expand Down Expand Up @@ -65,6 +65,8 @@
)
from torchao.float8.inference import Float8MMConfig

from torchao.quantization.observer import GranularityType, PerTensor, PerAxis, RowWise

logger = logging.getLogger(__name__)

__all__ = [
Expand Down Expand Up @@ -644,6 +646,7 @@ def apply_float8wo_quant(weight):
def float8_dynamic_activation_float8_weight(
activation_dtype: torch.dtype = torch.float8_e4m3fn,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
granularity: Literal[PerTensor, RowWise] = PerTensor,
mm_config: Optional[Float8MMConfig] = None
):
"""
Expand All @@ -660,20 +663,35 @@ def float8_dynamic_activation_float8_weight(
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)

def get_block_size(x: torch.Tensor, granularity: Literal[PerTensor, RowWise]):
if isinstance(granularity, PerTensor):
return x.shape
elif isinstance(granularity, RowWise):
return (1,) * (x.dim() - 1) + (x.shape[-1],)
else:
raise ValueError(f"Unsupported granularity: {granularity}")

#TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
if isinstance(granularity, RowWise):
assert weight.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation"

block_size = get_block_size(weight, granularity)
quantized_weight = to_affine_quantized_floatx(
input_float=weight,
block_size=weight.shape,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=mm_config),
)

def input_quant_func(x: torch.Tensor):
if isinstance(granularity, RowWise):
assert x.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation"
block_size = get_block_size(x, granularity)
activation = to_affine_quantized_floatx(
input_float=x,
block_size=x.shape,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
Expand Down

0 comments on commit 5d3683d

Please sign in to comment.