Skip to content

Commit

Permalink
Deprecate top level quantization APIs (#344)
Browse files Browse the repository at this point in the history
Summary:
This PR deprecates a few quantization APIs and here are the bc-breaking notes:

1. int8 weight only quantization
int8 weight only quant module swap API
```
apply_weight_only_int8_quant(model)
```

and
int8 weight only tensor subclass API
```
change_linear_weights_to_int8_woqtensors(model)
```

-->

unified tensor subclass API
```
quantize(model, get_apply_int8wo_quant()))
```

2. int8 dynamic quantization

```
apply_dynamic_quant(model)
```
or
```
change_linear_weights_to_int8_dqtensors(model)
```

-->

unified tensor subclass API
```
quantize(model, get_apply_int8dyn_quant()))
```

3. int4 weight only quantization
```
change_linear_weights_to_int4_wotensors(model)
```

-->

unified tensor subclass API
```
quantize(model, get_apply_int4wo_quant()))
```

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jun 13, 2024
1 parent 1de761b commit c2235af
Show file tree
Hide file tree
Showing 11 changed files with 399 additions and 326 deletions.
4 changes: 2 additions & 2 deletions test/dtypes/test_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
TestCase,
run_tests,
)
from torchao.quantization.quant_api import get_apply_int4wo_quant
from torchao.quantization.quant_api import int4wo
import torch
import unittest

Expand All @@ -12,7 +12,7 @@ class TestAQ(TestCase):
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32)
apply_int4wo_quant = int4wo(groupsize=32)
aqt = apply_int4wo_quant(t)
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand Down
92 changes: 66 additions & 26 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
DynamicallyPerAxisQuantizedLinear,
)
from torchao.quantization.quant_api import (
apply_dynamic_quant,
apply_weight_only_int8_quant,
int4wo,
int8wo,
int8da_int8w,
quantize,
_replace_with_custom_fn_if_matches_filter,
)
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
from torchao.quantization.quant_api import (
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import (
safe_int_mm,
Expand Down Expand Up @@ -73,26 +78,53 @@
from parameterized import parameterized
import itertools
import logging
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, is_fbcode
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
is_fbcode,
)

logger = logging.getLogger("INFO")

torch.manual_seed(0)
config.cache_size_limit = 100

# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
]

COMMON_DEVICES = ["cpu", "cuda"]

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8wo())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8da_int8w())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)

# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
_int8wo_api,
_int8da_int8w_api,
_int4wo_api,
]


def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
Expand Down Expand Up @@ -756,14 +788,14 @@ def _test_lin_weight_subclass_api_impl(
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
_int8da_int8w_api, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
_int8wo_api, device, 40, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand All @@ -773,7 +805,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
self.skipTest(f"Fails for {dtype}")
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int4_woqtensors,
_int4wo_api,
device,
15,
test_shape=test_shape,
Expand All @@ -789,8 +821,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}

def api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo(**kwargs))
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)

self._test_lin_weight_subclass_api_impl(
lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs),
api,
device,
15,
test_shape=test_shape,
Expand All @@ -805,7 +845,7 @@ def test_dynamic_quant(self):
m = nn.Sequential(nn.Linear(K, N))

y_ref = m(x)
apply_dynamic_quant(m)
quantize(m, int8da_int8w())
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
Expand All @@ -819,7 +859,7 @@ def test_weight_only_quant(self):
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(4, 5))
y_ref = m(x)
apply_weight_only_int8_quant(m)
_int8wo_api(m)
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 44.0)
Expand All @@ -842,7 +882,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
apply_weight_only_int8_quant(m)
_int8wo_api(m)
m(x)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
Expand All @@ -869,7 +909,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
apply_weight_only_int8_quant(m)
_int8wo_api(m)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
Expand Down Expand Up @@ -910,6 +950,7 @@ def forward(self, x):

# save quantized state_dict
api(model)

torch.save(model.state_dict(), "test.pth")
# get quantized reference
model_qc = torch.compile(model, mode="max-autotune")
Expand All @@ -925,6 +966,7 @@ def forward(self, x):
# load quantized state_dict
state_dict = torch.load("test.pth", mmap=True)
os.remove("test.pth")

model.load_state_dict(state_dict, assign=True)
model = model.to(device=test_device, dtype=test_dtype).eval()

Expand All @@ -941,21 +983,21 @@ def forward(self, x):
def test_save_load_dqtensors(self, device, dtype):
if device == "cpu":
self.skipTest(f"indcutor failed for cpu right now")
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
self._test_handle_save_load_meta_impl(_int8da_int8w_api, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_save_load_int8woqtensors(self, device, dtype):
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@torch.no_grad()
def test_save_load_int4woqtensors(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype)
self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype)


class TorchCompileUnitTest(unittest.TestCase):
Expand Down Expand Up @@ -1275,8 +1317,7 @@ def forward(self, x):
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)
api(model)

# running model
model(x)
Expand Down Expand Up @@ -1321,8 +1362,7 @@ def forward(self, x):
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)
api(model)

# running model
ref = model(x)
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_inference_compile_simple(elem_dtype):
if elem_dtype is torch.float8_e4m3fn:
assert sqnr >= 20.0
else:
assert sqnr >= 14.0
assert sqnr >= 13.5


def test_filter_fn():
Expand Down
Loading

0 comments on commit c2235af

Please sign in to comment.