Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Float8Quant] Add rowwise scaling option to float8 dyanmic quant #819

Merged
merged 1 commit into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ include = [
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
]
170 changes: 139 additions & 31 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)
import pytest

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from numpy import full
from torch.testing._internal.common_utils import (
run_tests,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.quantization import (
quantize_,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, PerRow
from torchao.float8.float8_utils import compute_error
import torch
import unittest
import pytest
import tempfile
import copy
import random

from unittest.mock import patch
from functools import partial
from typing import Tuple
from contextlib import nullcontext
import io


random.seed(0)
Expand Down Expand Up @@ -56,6 +52,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
)
# Inputs are (M,..), K, N
@common_utils.parametrize(
"sizes",
Expand All @@ -68,33 +67,142 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
],
)
def test_fp8_linear_variants(
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
):
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

mode_map = {
"dynamic": float8_dynamic_activation_float8_weight,
"weight-only": float8_weight_only,
}
raises = (
isinstance(granularity, PerRow)
and mode == "dynamic"
and dtype != torch.bfloat16
)
context = (
nullcontext()
if not raises
else pytest.raises(
AssertionError,
match="PerRow quantization only works for bfloat16 precision",
)
)
with context:
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=granularity
),
"weight-only": float8_weight_only,
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
quantize_(model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
assert (
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"

def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")

def test_mismatched_granularity(self):
with pytest.raises(
ValueError,
match="Different granularities for activation and weight are not supported",
drisspg marked this conversation as resolved.
Show resolved Hide resolved
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))

def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
float8_dynamic_activation_float8_weight(
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
)

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
def test_per_row_with_float32(self):
with pytest.raises(
AssertionError,
match="PerRow quantization only works for bfloat16 precision",
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
)

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
def test_serialization(self, mode: str):
# Create and quantize the model
model = ToyLinearModel(16, 32).to(device="cuda")
if mode == "dynamic":
factory = float8_dynamic_activation_float8_weight()
else:
factory = float8_weight_only()
quantize_(model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
assert (
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"
# Save the state dict to an in-memory buffer
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)

# Reset the buffer position
buffer.seek(0)

# Load the state dict from the buffer
loaded_state_dict = torch.load(buffer)

# Create a new model and load the state dict
with torch.device("meta"):
new_model = ToyLinearModel(16, 32)
new_model.load_state_dict(loaded_state_dict, assign=True)

# Compare the original and loaded models
if mode == "weight-only":
model_weight_1 = model.linear1.weight.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_1 = new_model.linear1.weight.layout_tensor.float8_data.to(
torch.float32
)

model_weight_2 = model.linear2.weight.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_2 = new_model.linear2.weight.layout_tensor.float8_data.to(
torch.float32
)

else:
model_weight_1 = model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_1 = new_model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)

model_weight_2 = model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)
new_model_weight_2 = new_model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
)

assert torch.allclose(model_weight_1, new_model_weight_1)
assert torch.allclose(model_weight_2, new_model_weight_2)


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
Expand Down
56 changes: 35 additions & 21 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
is_device,
get_out_shape,
)
from torchao.float8.inference import (
preprocess_data,
Float8MMConfig,
addmm_float8_unwrapped_inference,
_is_rowwise_scaled
)
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from dataclasses import dataclass
from torchao.utils import (
Expand Down Expand Up @@ -1355,53 +1361,61 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):

return out.view(*act.shape[:-1], out_dim).to(act.dtype)

def _linear_fp_act_fp8_tensor_wise_weight_check(
def _linear_fp_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
bias: Optional[torch.Tensor],
) -> bool:
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
return (
isinstance(aqt, AffineQuantizedTensor) and
isinstance(aqt.layout_type, Float8LayoutType)
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and aqt.shape == aqt.block_size
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
)
return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor)
return check_aqt(input_tensor) and check_aqt(weight_tensor)


def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
""" Ensures input tensor is correctly formated for _scaled_mm """
input_scale = input_scale.unsqueeze(-1)

if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])

return input_scale

def _linear_fp_act_fp8_weight_impl(
input_tensor: AffineQuantizedTensor,
weight_tensor: AffineQuantizedTensor,
bias: Optional[torch.Tensor],
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
from torchao.float8.inference import (
preprocess_data,
Float8MMConfig,
addmm_float8_unwrapped_inference,
)

scaled_mm_config = weight_tensor.layout_type.mm_config
scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig()
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Weight tensor preprocessing
w_layout = weight_tensor.layout_tensor
w_data = weight_tensor.layout_tensor.float8_data
w_data = w_data.T if w_layout.transposed else w_data
assert not w_layout.transposed, "Weight tensor must be contiguous"
w_data = w_layout.float8_data
w_scale = w_layout.scale
w_scale = w_scale if w_layout.transposed else w_scale

out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Input tensor preprocessing
inpt_data = input_tensor.layout_tensor.float8_data
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
input_scale = input_tensor.layout_tensor.scale
if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])

# Handle rowwise case
if _is_rowwise_scaled(weight_tensor):
assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size"
w_scale = w_scale.unsqueeze(-1).T
input_scale = preprocess_scale(input_scale, input_tensor.shape)

# Preprocess data
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

# Perform the computation
return addmm_float8_unwrapped_inference(
inpt_data,
input_scale,
Expand Down Expand Up @@ -1458,7 +1472,7 @@ def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),
Expand Down
8 changes: 8 additions & 0 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference(
use_fast_accum=use_fast_accum,
)
return output


def _is_rowwise_scaled(x) -> bool:
"""Checks if an AQT tensor is rowwise scaled
Args:
x: AffineQuantizedTensor tensor
"""
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)
12 changes: 12 additions & 0 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class PerAxis(GranularityType):
"""
axis: int

@dataclass(frozen=True)
class PerRow(GranularityType):
"""
Represents row-wise granularity in quantization.

This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""
pass

# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
Expand Down Expand Up @@ -104,6 +114,8 @@ def get_block_size(
block_size = list(input_shape)
block_size[granularity_type.axis] = 1
return tuple(block_size)
elif isinstance(granularity_type, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
raise ValueError(f"Unsupported GranularityType: {granularity_type}")


Expand Down
Loading
Loading