Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using AffineQuantizedTensor with weights_only=True #630

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@
TestCase,
run_tests,
)
from torchao.quantization.quant_api import int4_weight_only
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
)
import torch
import unittest
import tempfile
from torchao.utils import (
TORCH_VERSION_AFTER_2_5,
)


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
Expand All @@ -31,5 +37,20 @@ def test_tensor_core_layout_transpose(self):
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only(self):
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AFTER_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)


if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,3 +916,7 @@ def _(func, types, args, kwargs):

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static

if TORCH_VERSION_AFTER_2_5:
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([AffineQuantizedTensor])
5 changes: 4 additions & 1 deletion torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
import functools
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

"""
Helper function for implementing aten op or torch function dispatch
Expand Down Expand Up @@ -94,7 +95,6 @@ def extra_repr(self) -> str:
class PlainLayoutType(LayoutType):
pass


"""
layout tensor constructor registration for different tensor subclassesa

Expand All @@ -117,6 +117,9 @@ def _register_layout_cls(cls: Callable, layout_type_class: type(LayoutType)):
"""
def decorator(layout_cls):
_LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] = layout_cls.from_plain
if TORCH_VERSION_AFTER_2_5:
# Allow serialization to work for models uses this layout tensor subclass
torch.serialization.add_safe_globals([layout_type_class, layout_cls])
return layout_cls
return decorator

Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from typing import Callable
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TORCH_VERSION_AFTER_2_5

__all__ = [
"LinearActivationQuantizedTensor",
Expand Down Expand Up @@ -171,3 +172,7 @@ def _(func, types, args, kwargs):
)

to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float

if TORCH_VERSION_AFTER_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([LinearActivationQuantizedTensor])
79 changes: 49 additions & 30 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .utils import _get_per_token_block_size
import logging
from .autoquant import autoquant, AutoQuantizableLinearWeight
from torchao.utils import TORCH_VERSION_AFTER_2_5


__all__ = [
Expand Down Expand Up @@ -326,6 +327,35 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
_is_linear if filter_fn is None else filter_fn,
)

def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)

def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# input settings
input_quant_func = _int8_asymm_per_token_quant

weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight

def int8_dynamic_activation_int4_weight(group_size=32):
"""Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
Expand All @@ -336,31 +366,11 @@ def int8_dynamic_activation_int4_weight(group_size=32):
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained
"""
def apply_int8_dynamic_activation_int4_weight_quant(weight):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype)

weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(apply_int8_dynamic_activation_int4_weight_quant(lin.weight, group_size), requires_grad=False)
return lin

return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant)
return insert_subclass


def int4_weight_only(group_size=128, inner_k_tiles=8):
Expand Down Expand Up @@ -421,6 +431,16 @@ def apply_int8wo_quant(weight):

return _get_linear_subclass_inserter(apply_int8wo_quant)

def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)


def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
"""
Expand All @@ -444,12 +464,7 @@ def get_weight_block_size(x):
zero_point_dtype = torch.int64

# input settings
input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = _int8_symm_per_token_reduced_range_quant

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
Expand All @@ -466,3 +481,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
"""
from torchao.dtypes import SemiSparseLayoutType
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())


if TORCH_VERSION_AFTER_2_5:
Copy link

@mikaylagawarecki mikaylagawarecki Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A n00b qn: it wasn't super clear to me why these functions are used during unpickling?

but I trust that you added them because they appeared in the error message :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's correct, these functions are serialized as well when we do dynamic quantization

torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])
3 changes: 3 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class ZeroPointDomain(Enum):
INT = auto()
FLOAT = auto()

if TORCH_VERSION_AFTER_2_5:
torch.serialization.add_safe_globals([MappingType, ZeroPointDomain])

"""
Map from dtype to the bound value of integers
TODO: maybe can replace this with call to torch.iinfo
Expand Down
Loading