Skip to content

Commit

Permalink
Add generic fake quantized linear for QAT
Browse files Browse the repository at this point in the history
Summary: This commit adds a generic fake quantized linear module
to replace the uses of the existing more specific QAT linears.
For example, `Int8DynActInt4WeightQATLinear` can be expressed
as follows:

```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear

activation_config = FakeQuantizeConfig(
    bit_width=8,
    granularity="per_token",
    symmetric=False,
    dynamic=True,
)
weight_config = FakeQuantizeConfig(
    bit_width=4,
    group_size=8,
    symmetric=True,
    dynamic=True,
)
fq_linear = FakeQuantizedLinear(
    16, 32, False, activation_config, weight_config,
)
```

The main motivation is to provide a more flexible way to perform
QAT on models with linear layers. Previously, we would have to
create a new linear class every time we wish to experiment with
different fake quantization settings, e.g. different group size
or different bit width. Now we can express this easily using a
single linear module.

Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_config
python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w
python test/quantization/test_qat.py -k test_fake_quantized_linear_4w

ghstack-source-id: 44843b01a98db95f2f8620f581bef5e6dde66642
Pull Request resolved: #1020
  • Loading branch information
andrewor14 committed Oct 8, 2024
1 parent 2a2817c commit bd09e9d
Show file tree
Hide file tree
Showing 5 changed files with 558 additions and 191 deletions.
229 changes: 186 additions & 43 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,27 @@
import unittest

import torch
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.dtypes import (
TensorCoreTiledLayoutType,
)
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
QuantizationGranularity,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.linear import (
FakeQuantizedLinear,
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_get_qmin_qmax,
_GenericFakeQuantize,
)
from torchao.quantization.quant_api import (
Expand Down Expand Up @@ -92,15 +102,10 @@ def forward(self, x):
class TestQAT(unittest.TestCase):
SEED = 123

def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(qmin, qmax) = _get_qmin_qmax(n_bit)
group_size = 128

torch.manual_seed(self.SEED)
Expand All @@ -126,7 +131,7 @@ def test_fake_quantize_per_channel_group(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)
(qmin, qmax) = _get_qmin_qmax(8)

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
Expand Down Expand Up @@ -165,11 +170,11 @@ def _set_ptq_weight(
Int4WeightOnlyQATLinear,
)
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(qmin, qmax) = _get_qmin_qmax(n_bit)
group_size = qat_linear.weight_fake_quantizer.config.group_size
if isinstance(ptq_linear, Int8DynActInt4WeightLinear):
assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear)
fp32_weight = qat_linear.weight
group_size = qat_linear.groupsize
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
Expand All @@ -180,7 +185,7 @@ def _set_ptq_weight(
elif isinstance(ptq_linear, WeightOnlyInt4Linear):
assert isinstance(qat_linear, Int4WeightOnlyQATLinear)
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
qat_linear.weight, n_bit, qat_linear.groupsize,
qat_linear.weight, n_bit, group_size,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to("cuda"), qat_linear.inner_k_tiles,
Expand Down Expand Up @@ -218,31 +223,36 @@ def test_qat_8da4w_linear(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
subclass_model = subclass_quantizer.prepare(m)
module_swap_model = module_swap_quantizer.prepare(m2)
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# Convert QAT model and compare model values
subclass_model = subclass_quantizer.convert(subclass_model)
module_swap_model = module_swap_quantizer.convert(module_swap_model)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
Expand Down Expand Up @@ -275,9 +285,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
self.assertFalse(qat_model.linear1.activation_fake_quantizer.enabled)
self.assertFalse(qat_model.linear1.weight_fake_quantizer.enabled)
self.assertFalse(qat_model.linear2.activation_fake_quantizer.enabled)
self.assertFalse(qat_model.linear2.weight_fake_quantizer.enabled)
self.assertFalse(qat_model.sub.linear.activation_fake_quantizer.enabled)
self.assertFalse(qat_model.sub.linear.weight_fake_quantizer.enabled)

# Disabled fake quant is just a normal linear
m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight)
Expand All @@ -292,9 +305,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
self.assertTrue(qat_model.linear1.activation_fake_quantizer.enabled)
self.assertTrue(qat_model.linear1.weight_fake_quantizer.enabled)
self.assertTrue(qat_model.linear2.activation_fake_quantizer.enabled)
self.assertTrue(qat_model.linear2.weight_fake_quantizer.enabled)
self.assertTrue(qat_model.sub.linear.activation_fake_quantizer.enabled)
self.assertTrue(qat_model.sub.linear.weight_fake_quantizer.enabled)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
Expand Down Expand Up @@ -407,7 +423,7 @@ def test_qat_generic_fake_quantize(self):
the numerics of existing fake quantize ops in Pytorch in both
the forward and the backward passes.
"""
(qmin, qmax) = self._get_qmin_qmax(4)
(qmin, qmax) = _get_qmin_qmax(4)
py_input = torch.randn(16, 64).float().requires_grad_()
py_s = torch.randn(16).float()
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
Expand Down Expand Up @@ -521,7 +537,7 @@ def test_qat_4w_quantizer_gradients(self):
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
inner_k_tiles = 8
Expand All @@ -530,29 +546,34 @@ def test_qat_4w_quantizer(self):
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)
subclass_quantizer = Int4WeightOnlyQATQuantizer(
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
module_swap_quantizer = Int4WeightOnlyQATQuantizer(
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
subclass_model = subclass_quantizer.prepare(m)
module_swap_model = module_swap_quantizer.prepare(m2)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = [i.to(device).to(dtype) for i in m.example_inputs()]
x2 = copy.deepcopy(x)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
self._assert_close_4w(qat_out, ptq_out)

# Convert QAT model and compare model values
subclass_model = subclass_quantizer.convert(subclass_model)
module_swap_model = module_swap_quantizer.convert(module_swap_model)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

class _MyQATQuantizer(TwoStepQuantizer):
"""
Expand Down Expand Up @@ -603,5 +624,127 @@ def test_qat_4w_embedding(self):
converted = quantizer.convert(model)
converted_out = converted(*x)

def test_fake_quantize_config(self):
"""
Test initialization and property setting of `FakeQuantizeConfig`.
"""
# basic configs
per_token_config = FakeQuantizeConfig(8, "per_token")
self.assertEqual(per_token_config.bit_width, 8)
self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN)
self.assertIsNone(per_token_config.group_size)
per_channel_config = FakeQuantizeConfig(4, "per_channel")
self.assertEqual(per_channel_config.bit_width, 4)
self.assertEqual(per_channel_config.granularity, QuantizationGranularity.PER_CHANNEL)
self.assertIsNone(per_channel_config.group_size)

# initialize per_group config using only group size
per_group_config = FakeQuantizeConfig(4, group_size=32)
self.assertEqual(per_group_config.bit_width, 4)
self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP)
self.assertEqual(per_group_config.group_size, 32)

# set granularity after initialization, should accept str as before
per_group_config.granularity = "per_token"
self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN)

# set group_size after initialization, should also update granularity
per_group_config.group_size = 16
self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP)
self.assertEqual(per_group_config.group_size, 16)

# bad config1: no granularity or group size provided
with self.assertRaisesRegex(ValueError, "group_size or granularity must be set"):
FakeQuantizeConfig(8)

# bad config2: 'per_group' but no group size
with self.assertRaisesRegex(ValueError, "no group_size was set"):
FakeQuantizeConfig(8, "per_group")

# bad config3: group size was set but granularity was not 'per_group'
with self.assertRaisesRegex(ValueError, "group_size was set"):
FakeQuantizeConfig(8, "per_token", group_size=16)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantized_linear_8da4w(self):
"""
Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`.
"""
group_size = 128
torch.manual_seed(self.SEED)
fq_linear = FakeQuantizedLinear(
256,
688,
bias=False,
activation_config=FakeQuantizeConfig(8, "per_token", symmetric=False),
weight_config=FakeQuantizeConfig(4, group_size=group_size),
)

def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
"""
# activations
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
(qmin, qmax) = _get_qmin_qmax(8)
x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax)

# weights
(s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32)
zp = zp.to(torch.int32)
(qmin, qmax) = _get_qmin_qmax(4)
w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size)
return F.linear(x_fq, w_fq)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256)
x2 = copy.deepcopy(x)
fq_out = fq_linear(x)
baseline_out = linear_forward_8da4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantized_linear_4w(self):
"""
Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`.
"""
group_size = 128
weight_config = FakeQuantizeConfig(
bit_width=4,
group_size=group_size,
symmetric=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)
torch.manual_seed(self.SEED)
fq_linear = FakeQuantizedLinear(
256,
688,
bias=False,
activation_config=None,
weight_config=weight_config,
)

def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Baseline for int4 weight only fake quantization that simulates the tinygemm kernel.
"""
(qmin, qmax) = _get_qmin_qmax(4, symmetric=False)
(s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32)
zp = zp.to(torch.int32)
w_fq = _fake_quantize_per_channel_group(
weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT,
)
return F.linear(x, w_fq)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256)
x2 = copy.deepcopy(x)
fq_out = fq_linear(x)
baseline_out = linear_forward_4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit bd09e9d

Please sign in to comment.