Skip to content

Commit

Permalink
Update on "Add generic fake quantized linear for QAT"
Browse files Browse the repository at this point in the history
**Summary:** This commit adds a generic fake quantized linear module
to replace the uses of the existing more specific QAT linears.
For example, `Int8DynActInt4WeightQATLinear` can be expressed
as follows:

```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=8)
fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config)
```

The main motivation is to provide a more flexible way to perform
QAT on models with linear layers. Previously, we would have to
create a new linear class every time we wish to experiment with
different fake quantization settings, e.g. different group size
or different bit width. Now we can express this easily using a
single linear module.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config_granularity
python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases
python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type
python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w
python test/quantization/test_qat.py -k test_fake_quantized_linear_4w

[ghstack-poisoned]
  • Loading branch information
andrewor14 committed Oct 14, 2024
2 parents 5642f44 + 622b6df commit b5fe5a7
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def test_swap(self):
assert torch.allclose(y_ref, y)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
def test_weight_t_and_non_t_numerics_match(self):
# verify that numerics match whether weight is stored
# in transposed format (for cuBLAS) vs non-transposed format
Expand Down

0 comments on commit b5fe5a7

Please sign in to comment.