Skip to content

Commit

Permalink
Add generic fake quantized linear for QAT
Browse files Browse the repository at this point in the history
Summary: This commit adds a generic fake quantized linear module
to replace the uses of the existing more specific QAT linears.
For example, `Int8DynActInt4WeightQATLinear` can be expressed
as follows:

```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=8)
fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config)
```

The main motivation is to provide a more flexible way to perform
QAT on models with linear layers. Previously, we would have to
create a new linear class every time we wish to experiment with
different fake quantization settings, e.g. different group size
or different bit width. Now we can express this easily using a
single linear module.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_config_granularity
python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases
python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type
python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w
python test/quantization/test_qat.py -k test_fake_quantized_linear_4w

ghstack-source-id: e3761e47d36001ab161c33943705f35860076ffa
Pull Request resolved: #1020
  • Loading branch information
andrewor14 committed Oct 14, 2024
1 parent d4b2f33 commit c400006
Show file tree
Hide file tree
Showing 10 changed files with 903 additions and 211 deletions.
2 changes: 2 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def test_swap(self):
assert torch.allclose(y_ref, y)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
def test_weight_t_and_non_t_numerics_match(self):
# verify that numerics match whether weight is stored
# in transposed format (for cuBLAS) vs non-transposed format
Expand Down Expand Up @@ -1126,6 +1127,7 @@ def test_shape_logger(self):
class SmoothquantIntegrationTest(unittest.TestCase):
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
def test_non_dynamically_quantizable_linear(self):
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("test requires SM capability of at least (8, 0).")
Expand Down
336 changes: 293 additions & 43 deletions test/quantization/test_qat.py

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ change_linear_weights_to_int8_dqtensors(model)

```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.quant_api import PerTensor
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

Expand Down Expand Up @@ -321,7 +320,7 @@ This API works today but has not been extensively tested and benchmarked yet. Ha

```python
# for torch 2.5+
from torchao.quantization.quant_api import quantize_, PerRow, float8_dynamic_activation_float8_weight
from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow()))
```

Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from .weight_only import * # noqa: F403
from .unified import *
from .autoquant import *
from .linear_activation_quantized_tensor import ( # noqat: F403
from .granularity import *
from .linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .linear_activation_scale import ( # noqat: F403
from .linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)

Expand Down
24 changes: 19 additions & 5 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class PerTensor(Granularity):
"""
Represents per-tensor granularity in quantization.
This granularity type calcualtes the quantization parameters
This granularity type calculates the quantization parameters
based off the entire tensor.
"""
pass
Expand All @@ -32,26 +32,24 @@ class PerAxis(Granularity):
"""
Represents per-axis granularity in quantization.
This granularity type calcualtes different quantization parameters
This granularity type calculates different quantization parameters
along a specified axis of the tensor.
For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int

@dataclass(frozen=True)

class PerGroup(Granularity):
"""
Represents per-channel group granularity in quantization.
This granularity type calcualtes different quantization parameters
This granularity type calculates different quantization parameters
for each group of <group_size> elements.
For example if the input tensor is shape [8, 16], and the group size is 4, then
Expand All @@ -74,3 +72,19 @@ class PerRow(Granularity):
is quantized with a block_size of (1, weight.shape[1]).
"""
pass

class PerToken(Granularity):
"""
Represents per-token granularity in quantization.
This granularity type calculates a different set of quantization parameters
for each token, which is represented as the last dimension of the tensor.
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
with 4 elements each, and we will calculate 6 sets of quantization parameters,
one for each token.
If the input tensor has only two dimensions, e.g. [8, 16], then this is
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
"""
pass
215 changes: 214 additions & 1 deletion torchao/quantization/prototype/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,224 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Union

import torch

from torchao.quantization.granularity import (
Granularity,
PerAxis,
PerGroup,
PerToken,
)
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.quant_primitives import (
_SUB_BYTE_INT_BOUNDS,
_SUB_BYTE_UINT_BOUNDS,
MappingType,
TorchAODType,
ZeroPointDomain,
)


@dataclass
class FakeQuantizeConfig:
"""
Config for how to fake quantize weights or activations.
args:
dtype: dtype to simulate during fake quantization, e.g. torch.int8.
For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
granularity: granularity of scales and zero points, e.g. PerGroup(32).
We also support the following strings:
1) 'per_token': equivalent to PerToken()
2) 'per_channel': equivalent to PerAxis(0)
3) 'per_group': equivalent to PerGroup(group_size), must be combined
with separate `group_size` kwarg, Alternatively, just set the
`group_size` kwarg and leave this field empty.
mapping_type: whether to use symmetric (default) or asymmetric quantization
Alternatively, set `is_symmetric` (bool) and leave this field empty.
scale_precision: scale dtype (default torch.fp32)
zero_point_precision: zero point dtype (default torch.int32)
zero_point_domain: whether zero point is in integer (default) or float domain
is_dynamic: whether to use dynamic (defualt) or static scale and zero points
range_learning: whether to learn scale and zero points during training (coming soon)
kwargs (optional):
group_size: size of each group in per group fake quantization,
can be set instead of `granularity`
is_symmetric: whether to use symmetric or asymmetric quantization,
can be set instead of `mapping_type`
Example usage::
# Per token asymmetric quantization
FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC)
# Per channel symmetric quantization
FakeQuantizeConfig(torch.int4, "per_channel")
FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True)
FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC)
# Per group symmetric quantization
FakeQuantizeConfig(torch.int4, group_size=32)
FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True)
FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True)
FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC)
"""
dtype: Union[torch.dtype, TorchAODType]
granularity: Granularity
mapping_type: MappingType
scale_precision: torch.dtype
zero_point_precision: torch.dtype
zero_point_domain: ZeroPointDomain
is_dynamic: bool = True
range_learning: bool = False

def __init__(
self,
dtype: Union[torch.dtype, TorchAODType],
granularity: Union[Granularity, str, None] = None,
mapping_type: Optional[MappingType] = None,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
is_dynamic: bool = True,
range_learning: bool = False,
*,
group_size: Optional[int] = None,
is_symmetric: Optional[bool] = None,
):
self.dtype = dtype
self.granularity = self._get_granularity(granularity, group_size)
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision
self.zero_point_domain = zero_point_domain
self.is_dynamic = is_dynamic
self.range_learning = range_learning

# Validate dtype
all_dtypes = [torch.int8, torch.uint8]
all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys()))
all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys()))
if dtype not in all_dtypes:
raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes))

def _get_granularity(
self,
granularity: Union[Granularity, str, None],
group_size: Optional[int],
) -> Granularity:
"""
Parse the `Granularity` represented in the args.
Granularity can be specified in one of three ways:
1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size)
2) str: one of 'per_token', 'per_channel', and 'per_group'
3) None: `group_size` must be set instead, represents per group granularity
"""
# If group_size is set, then granularity must be either "per_group" or None
if group_size is not None and granularity != "per_group" and granularity is not None:
raise ValueError("`group_size` conflicts with granularity '%s'" % granularity)

# Case 1: Granularity object
if isinstance(granularity, Granularity):
if not isinstance(granularity, (PerToken, PerAxis, PerGroup)):
raise ValueError("Granularity '%s' is not supported" % granularity)
if isinstance(granularity, PerAxis) and granularity.axis != 0:
raise ValueError("Only axis=0 is supported for PerAxis granularity")
return granularity

# Case 2: str granularity
if granularity == "per_token":
return PerToken()
elif granularity == "per_channel":
return PerAxis(axis=0)
elif granularity == "per_group":
if group_size is None:
raise ValueError("Granularity was 'per_group' but no `group_size` was set")
return PerGroup(group_size)
elif isinstance(granularity, str):
raise ValueError(
"Unexpected granularity: '%s', must be one of %s" %
(granularity, ["per_token", "per_channel", "per_group"])
)

# Case 3: None granularity + group_size was specified
if granularity is not None:
raise ValueError(
"Granularity '%s' has unexpected type %s" % (granularity, type(granularity))
)
if group_size is None:
raise ValueError("At least one of `granularity` or `group_size` must be set")
return PerGroup(group_size)

def _get_mapping_type(
self,
mapping_type: Optional[MappingType],
is_symmetric: Optional[bool],
) -> MappingType:
"""
Parse the `MappingType` represented in the args.
Mapping type can be specified in one of two ways:
1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC
2): is_symmetric bool
"""
if mapping_type is not None and is_symmetric is not None:
raise ValueError("Cannot set both `mapping_type` and `is_symmetric`")

# Case 0: Default to symmetric
if mapping_type is None and is_symmetric is None:
return MappingType.SYMMETRIC

# Case 1: MappingType object
if mapping_type is not None:
if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]:
raise ValueError("MappingType '%s' is not supported" % mapping_type)
return mapping_type

# Case 2: is_symmetric flag
assert is_symmetric is not None
if is_symmetric:
return MappingType.SYMMETRIC
else:
return MappingType.ASYMMETRIC

@property
def group_size(self) -> int:
"""
If this is per group granularity, return the group size.
Otherwise, throw an error.
"""
if isinstance(self.granularity, PerGroup):
return self.granularity.group_size
else:
raise ValueError("`group_size` is undefined for %s granularity" % self.granularity)

@property
def is_symmetric(self) -> bool:
"""
Return True if mapping type is symmetric, else False (asymmetric).
"""
return self.mapping_type == MappingType.SYMMETRIC

def __setattr__(self, name: str, value: Any):
"""
Support setting `group_size` and `is_symmetric`.
"""
if name == "group_size":
super().__setattr__("granularity", PerGroup(value))
elif name == "is_symmetric":
mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC
super().__setattr__("mapping_type", mapping_type)
else:
super().__setattr__(name, value)


class ComposableQATQuantizer(TwoStepQuantizer):
Expand Down
Loading

0 comments on commit c400006

Please sign in to comment.