Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply Autoquant (#82) #109

Merged
merged 29 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a10d5b1
Autoquant (#82)
HDCharles Mar 25, 2024
edd2708
Skip autoquant CPU tests
cpuhrsch Apr 1, 2024
550d6af
Add more device skips
cpuhrsch Apr 1, 2024
015b31a
Merge main
cpuhrsch Apr 1, 2024
c14ab3d
Remove merge artifact
cpuhrsch Apr 1, 2024
c40a175
Remove duplicate test
cpuhrsch Apr 1, 2024
bc4eb7b
Remove duplicate test
cpuhrsch Apr 1, 2024
9b632e2
Reduce top level API
cpuhrsch Apr 1, 2024
9bc66b4
Version guards
cpuhrsch Apr 2, 2024
e24c607
Change import path
cpuhrsch Apr 2, 2024
429fd86
Clean up init
cpuhrsch Apr 5, 2024
4a603d7
Merge remote-tracking branch 'origin' into autoquant2
cpuhrsch Apr 5, 2024
d4b11bc
Merge remote-tracking branch 'origin' into autoquant2
cpuhrsch Apr 5, 2024
57d9ffa
Clean up import
cpuhrsch Apr 5, 2024
e3e82d9
Calm down test shapes and deal with imports
cpuhrsch Apr 5, 2024
62f3787
Clean up import
cpuhrsch Apr 5, 2024
d2a573b
Multiple of 16
cpuhrsch Apr 5, 2024
74526f6
Version guards
cpuhrsch Apr 5, 2024
a6fbe2f
More parameterizations
cpuhrsch Apr 5, 2024
0f225ec
More parameterizations
cpuhrsch Apr 5, 2024
907ca84
Merge branch 'main' of github.com:pytorch-labs/ao into autoquant2
cpuhrsch Apr 5, 2024
33bbc33
Run all tests
cpuhrsch Apr 5, 2024
f433553
bfloat16 guard
cpuhrsch Apr 5, 2024
e4c62e4
Shape guards
cpuhrsch Apr 5, 2024
24d2bcc
Working shapes
cpuhrsch Apr 5, 2024
e75b68c
Exclude huge shape
cpuhrsch Apr 5, 2024
0a908bf
Exclude huge shape
cpuhrsch Apr 5, 2024
6207659
Exclude huge shape
cpuhrsch Apr 5, 2024
a1b8a0c
Update readme
cpuhrsch Apr 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ jobs:

- name: Run tests
run: |
pytest test --verbose -s -x
pytest test --verbose -s
33 changes: 25 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# torchao: PyTorch Architecture Optimization
# torchao: PyTorch Architecture Optimization

**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue**

The `torchao` package allows you to quantize and prune your models using native PyTorch.
The `torchao` package allows you to quantize and prune your models using native PyTorch.

The repo hosts both
1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4
Expand Down Expand Up @@ -38,30 +38,46 @@ pip install -e .

Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change.

### A8W8 Dynamic Quantization
### Autoquantization

```Python
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.

```python
import torch
from torchao.quantization import quant_api
import torchao

# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_dqtensors(model)
# perform autoquantization
torchao.autoquant(model, (input))

# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
model(input)
```


### A8W8 Dynamic Quantization

```python
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True
from torchao.quantization import quant_api
# convert linear modules to quantized tensor subclasses
quant_api.change_linear_weights_to_int8_dqtensors(model)
```

### A16W8 WeightOnly Quantization

```python
from torchao.quantization import quant_api
quant_api.change_linear_weights_to_int8_woqtensors(model)
```

Expand All @@ -71,6 +87,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
### A16W4 WeightOnly Quantization

```python
from torchao.quantization import quant_api
quant_api.change_linear_weights_to_int4_woqtensors(model)
```

Expand Down
113 changes: 113 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# mypy: ignore-errors
import copy
import unittest
import itertools

import torch
import torch.nn as nn
from torch._inductor.utils import run_and_get_code
from torch._dynamo import config
import torchao
from torch.ao.quantization import MinMaxObserver, QConfigMapping

from torchao.quantization.dynamic_quant import (
Expand Down Expand Up @@ -54,6 +56,13 @@
_fqn_to_op_to_shape_to_count,
LoggingTensorMode,
)
from torchao.quantization.autoquant import (
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3

)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
from parameterized import parameterized
Expand All @@ -71,6 +80,12 @@
("cuda", torch.bfloat16),
]

def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
new_tuples.append(tuple1 + tuple2)
return new_tuples

def run_supported_device_dtype(test_method):
def wrapper(*args, **kwargs):
if args[2] == "cuda" and not torch.cuda.is_available():
Expand Down Expand Up @@ -907,6 +922,36 @@ def test_int8_weight_only_quant_subclass(self, device, dtype):
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
def test_int4_weight_only_quant_subclass(self, device, dtype):
Expand Down Expand Up @@ -1290,6 +1335,74 @@ def test_on_dummy_distilbert(self):
print("sqnr_pt_quant", sqnr_pt_quant)
self.assertTrue(sqnr_sq >= 8.0)

class TestAutoQuant(unittest.TestCase):
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(16, 128, 128),
(64, 128, 128),
# (2**15, 128, 128), TODO: Runs out of shared memory on T4
(16, 128, 256),
# (64, 128, 256), # TODO: Runs out of shared memory on T4
(16, 256, 128),
(64, 256, 128),
# (256, 256, 128), TODO: Runs out of shared memory on T4
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_one_input(self, device, dtype, m, k, n):
print("(m, k, n): ", (m, k, n))
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")
if m == 1:
self.skipTest(f"Shape {(m, k, n)} requires sm80+")
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.use_mixed_mm = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._dynamo.config.automatic_dynamic_shapes = False

example_input = torch.randn(m, k, device=device, dtype=dtype)
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(device).to(dtype)
out = model(example_input)
torchao.autoquant(model, example_input)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(1, 1, 128, 128),
(1, 32, 128, 128),
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")
if m1 == 1 or m2 == 1:
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(device).to(dtype)
example_input = torch.randn(m1, k, device=device, dtype=dtype)
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
torchao.quantization.change_linears_to_autoquantizable(model)
out=model(example_input)
model(example_input2)
torchao.quantization.change_autoquantizable_to_quantized(model)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

if __name__ == "__main__":
unittest.main()
13 changes: 9 additions & 4 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
autoquant,
)
from . import dtypes
from .quantization.quant_api import apply_dynamic_quant
from .quantization.quant_api import apply_weight_only_int8_quant

__all__ = [
"dtypes",
"apply_dynamic_quant",
"dtypes",
"apply_dynamic_quant",
"apply_weight_only_int8_quant",
"autoquant",
]
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .utils import * # noqa: F403
from .weight_only import * # noqa: F403
from .unified import *
from .autoquant import *

__all__ = [
"DynamicallyPerAxisQuantizedLinear",
Expand All @@ -26,6 +27,9 @@
"dynamically_quantize_per_channel",
"dequantize_per_tensor",
"dequantize_per_channel",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"quant_int8_dynamic_linear",
"quant_int8_matmul",
"quant_int8_dynamic_per_token_linear",
Expand Down
Loading
Loading