Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic fake quantized linear for QAT #1020

Merged
merged 36 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0756f39
Make module swap the main QAT flow again
andrewor14 Oct 4, 2024
9e9fdef
Add generic fake quantized linear for QAT
andrewor14 Oct 4, 2024
7f623a5
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
c8f9f37
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
75fcd21
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
9185cc4
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
d671826
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
59b6644
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
d4332cb
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
8e5d2ea
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
dbad878
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
ab43744
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
d6750a9
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 9, 2024
15a3d81
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 9, 2024
0153d66
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
8de3ba6
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
c18c60f
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
ef4f062
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
8f48663
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
e442439
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
4239d47
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
b0c6cc7
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
75c83ef
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
39ebc46
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
e08517c
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
f9286c5
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
d0d9573
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
c0ed9ed
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
f9a2f4c
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
83e2f10
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
5b4feb0
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
fbc0259
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
756cb8d
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
5642f44
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
622b6df
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
b5fe5a7
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def test_swap(self):
assert torch.allclose(y_ref, y)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
def test_weight_t_and_non_t_numerics_match(self):
# verify that numerics match whether weight is stored
# in transposed format (for cuBLAS) vs non-transposed format
Expand Down Expand Up @@ -1126,6 +1127,7 @@ def test_shape_logger(self):
class SmoothquantIntegrationTest(unittest.TestCase):
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
def test_non_dynamically_quantizable_linear(self):
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("test requires SM capability of at least (8, 0).")
Expand Down
336 changes: 293 additions & 43 deletions test/quantization/test_qat.py

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ change_linear_weights_to_int8_dqtensors(model)

```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.quant_api import PerTensor
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

Expand Down Expand Up @@ -321,7 +320,7 @@ This API works today but has not been extensively tested and benchmarked yet. Ha

```python
# for torch 2.5+
from torchao.quantization.quant_api import quantize_, PerRow, float8_dynamic_activation_float8_weight
from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow()))
```

Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from .weight_only import * # noqa: F403
from .unified import *
from .autoquant import *
from .linear_activation_quantized_tensor import ( # noqat: F403
from .granularity import *
from .linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .linear_activation_scale import ( # noqat: F403
from .linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)

Expand Down
24 changes: 19 additions & 5 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class PerTensor(Granularity):
"""
Represents per-tensor granularity in quantization.
This granularity type calcualtes the quantization parameters
This granularity type calculates the quantization parameters
based off the entire tensor.
"""
pass
Expand All @@ -32,26 +32,24 @@ class PerAxis(Granularity):
"""
Represents per-axis granularity in quantization.
This granularity type calcualtes different quantization parameters
This granularity type calculates different quantization parameters
along a specified axis of the tensor.
For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int

@dataclass(frozen=True)

class PerGroup(Granularity):
"""
Represents per-channel group granularity in quantization.
This granularity type calcualtes different quantization parameters
This granularity type calculates different quantization parameters
for each group of <group_size> elements.
For example if the input tensor is shape [8, 16], and the group size is 4, then
Expand All @@ -74,3 +72,19 @@ class PerRow(Granularity):
is quantized with a block_size of (1, weight.shape[1]).
"""
pass

class PerToken(Granularity):
"""
Represents per-token granularity in quantization.
This granularity type calculates a different set of quantization parameters
for each token, which is represented as the last dimension of the tensor.
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
with 4 elements each, and we will calculate 6 sets of quantization parameters,
one for each token.
If the input tensor has only two dimensions, e.g. [8, 16], then this is
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
"""
pass
215 changes: 214 additions & 1 deletion torchao/quantization/prototype/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,224 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Union

import torch

from torchao.quantization.granularity import (
Granularity,
PerAxis,
PerGroup,
PerToken,
)
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.quant_primitives import (
_SUB_BYTE_INT_BOUNDS,
_SUB_BYTE_UINT_BOUNDS,
MappingType,
TorchAODType,
ZeroPointDomain,
)


@dataclass
class FakeQuantizeConfig:
"""
Config for how to fake quantize weights or activations.
args:
dtype: dtype to simulate during fake quantization, e.g. torch.int8.
For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
granularity: granularity of scales and zero points, e.g. PerGroup(32).
We also support the following strings:
1) 'per_token': equivalent to PerToken()
2) 'per_channel': equivalent to PerAxis(0)
3) 'per_group': equivalent to PerGroup(group_size), must be combined
with separate `group_size` kwarg, Alternatively, just set the
`group_size` kwarg and leave this field empty.
mapping_type: whether to use symmetric (default) or asymmetric quantization
Alternatively, set `is_symmetric` (bool) and leave this field empty.
scale_precision: scale dtype (default torch.fp32)
zero_point_precision: zero point dtype (default torch.int32)
zero_point_domain: whether zero point is in integer (default) or float domain
is_dynamic: whether to use dynamic (defualt) or static scale and zero points
range_learning: whether to learn scale and zero points during training (coming soon)
kwargs (optional):
group_size: size of each group in per group fake quantization,
can be set instead of `granularity`
is_symmetric: whether to use symmetric or asymmetric quantization,
can be set instead of `mapping_type`
Example usage::
# Per token asymmetric quantization
FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC)
# Per channel symmetric quantization
FakeQuantizeConfig(torch.int4, "per_channel")
FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True)
FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC)
# Per group symmetric quantization
FakeQuantizeConfig(torch.int4, group_size=32)
FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True)
FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True)
FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC)
"""
dtype: Union[torch.dtype, TorchAODType]
granularity: Granularity
mapping_type: MappingType
scale_precision: torch.dtype
zero_point_precision: torch.dtype
zero_point_domain: ZeroPointDomain
is_dynamic: bool = True
range_learning: bool = False

def __init__(
self,
dtype: Union[torch.dtype, TorchAODType],
granularity: Union[Granularity, str, None] = None,
mapping_type: Optional[MappingType] = None,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
is_dynamic: bool = True,
range_learning: bool = False,
*,
group_size: Optional[int] = None,
is_symmetric: Optional[bool] = None,
):
self.dtype = dtype
self.granularity = self._get_granularity(granularity, group_size)
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision
self.zero_point_domain = zero_point_domain
self.is_dynamic = is_dynamic
self.range_learning = range_learning

# Validate dtype
all_dtypes = [torch.int8, torch.uint8]
all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys()))
all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys()))
if dtype not in all_dtypes:
raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes))

def _get_granularity(
self,
granularity: Union[Granularity, str, None],
group_size: Optional[int],
) -> Granularity:
"""
Parse the `Granularity` represented in the args.
Granularity can be specified in one of three ways:
1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size)
2) str: one of 'per_token', 'per_channel', and 'per_group'
3) None: `group_size` must be set instead, represents per group granularity
"""
# If group_size is set, then granularity must be either "per_group" or None
if group_size is not None and granularity != "per_group" and granularity is not None:
raise ValueError("`group_size` conflicts with granularity '%s'" % granularity)

# Case 1: Granularity object
if isinstance(granularity, Granularity):
if not isinstance(granularity, (PerToken, PerAxis, PerGroup)):
raise ValueError("Granularity '%s' is not supported" % granularity)
if isinstance(granularity, PerAxis) and granularity.axis != 0:
raise ValueError("Only axis=0 is supported for PerAxis granularity")
return granularity

# Case 2: str granularity
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
if granularity == "per_token":
return PerToken()
elif granularity == "per_channel":
return PerAxis(axis=0)
elif granularity == "per_group":
if group_size is None:
raise ValueError("Granularity was 'per_group' but no `group_size` was set")
return PerGroup(group_size)
elif isinstance(granularity, str):
raise ValueError(
"Unexpected granularity: '%s', must be one of %s" %
(granularity, ["per_token", "per_channel", "per_group"])
)

# Case 3: None granularity + group_size was specified
if granularity is not None:
raise ValueError(
"Granularity '%s' has unexpected type %s" % (granularity, type(granularity))
)
if group_size is None:
raise ValueError("At least one of `granularity` or `group_size` must be set")
return PerGroup(group_size)

def _get_mapping_type(
self,
mapping_type: Optional[MappingType],
is_symmetric: Optional[bool],
) -> MappingType:
"""
Parse the `MappingType` represented in the args.
Mapping type can be specified in one of two ways:
1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC
2): is_symmetric bool
"""
if mapping_type is not None and is_symmetric is not None:
raise ValueError("Cannot set both `mapping_type` and `is_symmetric`")

# Case 0: Default to symmetric
if mapping_type is None and is_symmetric is None:
return MappingType.SYMMETRIC

# Case 1: MappingType object
if mapping_type is not None:
if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]:
raise ValueError("MappingType '%s' is not supported" % mapping_type)
return mapping_type

# Case 2: is_symmetric flag
assert is_symmetric is not None
if is_symmetric:
return MappingType.SYMMETRIC
else:
return MappingType.ASYMMETRIC

@property
def group_size(self) -> int:
"""
If this is per group granularity, return the group size.
Otherwise, throw an error.
"""
if isinstance(self.granularity, PerGroup):
return self.granularity.group_size
else:
raise ValueError("`group_size` is undefined for %s granularity" % self.granularity)

@property
def is_symmetric(self) -> bool:
"""
Return True if mapping type is symmetric, else False (asymmetric).
"""
return self.mapping_type == MappingType.SYMMETRIC

def __setattr__(self, name: str, value: Any):
"""
Support setting `group_size` and `is_symmetric`.
"""
if name == "group_size":
super().__setattr__("granularity", PerGroup(value))
elif name == "is_symmetric":
mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC
super().__setattr__("mapping_type", mapping_type)
else:
super().__setattr__(name, value)


class ComposableQATQuantizer(TwoStepQuantizer):
Expand Down
Loading
Loading