Skip to content

Commit

Permalink
Deprecate top level quantization APIs
Browse files Browse the repository at this point in the history
Summary:
This PR deprecates a few quantization APIs and here are the bc-breaking notes:

1. int8 weight only quantization
int8 weight only quant module swap API
```
apply_weight_only_int8_quant(model)
```

and
int8 weight only tensor subclass API
```
change_linear_weights_to_int8_woqtensors(model)
```

-->

unified tensor subclass API
```
quantize(model, get_apply_int8wo_quant()))
```

2. int8 dynamic quantization

```
apply_dynamic_quant(model)
```
or
```
change_linear_weights_to_int8_dqtensors(model)
```

-->

unified tensor subclass API
```
quantize(model, get_apply_int8dyn_quant()))
```

3. int4 weight only quantization
```
change_linear_weights_to_int4_wotensors(model)
```

-->

unified tensor subclass API
```
quantize(model, get_apply_int4wo_quant()))
```

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 12, 2024
1 parent 950a893 commit b66f0cf
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 308 deletions.
66 changes: 39 additions & 27 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
DynamicallyPerAxisQuantizedLinear,
)
from torchao.quantization.quant_api import (
apply_dynamic_quant,
apply_weight_only_int8_quant,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
int4wo,
int8wo,
int8da_int8w,
quantize,
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import (
Expand Down Expand Up @@ -73,7 +72,11 @@
from parameterized import parameterized
import itertools
import logging
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
)

logger = logging.getLogger("INFO")

Expand All @@ -82,9 +85,9 @@

# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
int4wo,
int8wo,
int8da_int8w,
]

COMMON_DEVICES = ["cpu", "cuda"]
Expand Down Expand Up @@ -736,7 +739,8 @@ def _test_lin_weight_subclass_api_impl(
nn.Linear(k, n, device=test_device), nn.ReLU(), nn.Linear(n, n, device=test_device)
).to(test_dtype)
ref_f = mod(x)
api(mod)
quantize(mod, api())
unwrap_tensor_subclass(mod)

test = mod(x)
self.assertGreater(
Expand All @@ -756,13 +760,13 @@ def _test_lin_weight_subclass_api_impl(
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
int8da_int8w, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
int8wo, device, 40, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand All @@ -772,7 +776,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
self.skipTest(f"Fails for {dtype}")
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int4_woqtensors,
int4wo,
device,
15,
test_shape=test_shape,
Expand All @@ -789,7 +793,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
self._test_lin_weight_subclass_api_impl(
lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs),
lambda: int4wo(**kwargs),
device,
15,
test_shape=test_shape,
Expand All @@ -804,7 +808,7 @@ def test_dynamic_quant(self):
m = nn.Sequential(nn.Linear(K, N))

y_ref = m(x)
apply_dynamic_quant(m)
quantize(m, int8da_int8w())
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
Expand All @@ -818,7 +822,7 @@ def test_weight_only_quant(self):
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(4, 5))
y_ref = m(x)
apply_weight_only_int8_quant(m)
quantize(m, int8wo())
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 44.0)
Expand All @@ -841,7 +845,8 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
apply_weight_only_int8_quant(m)
m = quantize(m, int8wo())
m = unwrap_tensor_subclass(m)
m(x)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
Expand All @@ -868,7 +873,8 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
apply_weight_only_int8_quant(m)
m = quantize(m, int8wo())
m = unwrap_tensor_subclass(m)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
Expand Down Expand Up @@ -908,7 +914,9 @@ def forward(self, x):
ref_f = model(x)

# save quantized state_dict
api(model)
quantize(model, api())
unwrap_tensor_subclass(model)

torch.save(model.state_dict(), "test.pth")
# get quantized reference
model_qc = torch.compile(model, mode="max-autotune")
Expand All @@ -919,11 +927,13 @@ def forward(self, x):
# load model structure
with torch.device('meta'):
model = test_model().to(dtype=test_dtype)
api(model)
quantize(model, api())
unwrap_tensor_subclass(model)

# load quantized state_dict
state_dict = torch.load("test.pth", mmap=True)
os.remove("test.pth")

model.load_state_dict(state_dict, assign=True)
model = model.to(device=test_device, dtype=test_dtype).eval()

Expand All @@ -936,23 +946,26 @@ def forward(self, x):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Skip if torch version is prior to 2.4 since the current API relies on a parametrization fix")
def test_save_load_dqtensors(self, device, dtype):
if device == "cpu":
self.skipTest(f"indcutor failed for cpu right now")
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
self._test_handle_save_load_meta_impl(int8da_int8w, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Skip if torch version is prior to 2.4 since the current API relies on a parametrization fix")
def test_save_load_int8woqtensors(self, device, dtype):
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)
self._test_handle_save_load_meta_impl(int8wo, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@torch.no_grad()
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Skip if torch version is prior to 2.4 since the current API relies on a parametrization fix")
def test_save_load_int4woqtensors(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype)
self._test_handle_save_load_meta_impl(int4wo, device, 20, test_dtype=dtype)


class TorchCompileUnitTest(unittest.TestCase):
Expand Down Expand Up @@ -1271,8 +1284,7 @@ def forward(self, x):
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)
quantize(model, api())

# running model
model(x)
Expand Down Expand Up @@ -1317,8 +1329,8 @@ def forward(self, x):
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)
model = quantize(model, api())
model = unwrap_tensor_subclass(model)

# running model
ref = model(x)
Expand Down
58 changes: 25 additions & 33 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,13 @@
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
quantize,
get_apply_8da4w_quant,
get_apply_int4wo_quant,
get_apply_int8wo_quant,
get_apply_int8dyn_quant,
int8da_int4w,
int4wo,
int8wo,
int8da_int8w,
)
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
Expand All @@ -53,6 +51,7 @@
from torchao._models.llama.tokenizer import get_tokenizer
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
import copy
import tempfile


def dynamic_quant(model, example_inputs):
Expand All @@ -62,20 +61,6 @@ def dynamic_quant(model, example_inputs):
m = convert_pt2e(m)
return m

def _apply_dynamic_quant(model):
"""
Applies dynamic symmetric per-token activation and per-channel weight
quantization to all linear layers in the given model using
module swaps.
"""
_replace_with_custom_fn_if_matches_filter(
model,
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)
return model


def capture_and_prepare(model, example_inputs):
m = torch.export.export(model, example_inputs)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
Expand Down Expand Up @@ -104,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:

class TorchCompileDynamicQuantizer(Quantizer):
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
apply_dynamic_quant(model)
quantize(model, int8da_int8w())
return model

class ToyLinearModel(torch.nn.Module):
Expand All @@ -127,11 +112,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _in_features_greater_than_16
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

def _in_features_greater_than_16(mod, *args):
return hasattr(mod, "in_features") and mod.in_features > 16

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
Expand Down Expand Up @@ -167,7 +154,7 @@ class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = _apply_dynamic_quant(m)
m = quantize(m, int8da_int8w())
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
Expand Down Expand Up @@ -205,16 +192,21 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int8_wo_quant_save_load(self):
m = ToyLinearModel().eval().cpu()
apply_weight_only_int8_quant(m)
m = quantize(m, int8wo())

from torchao.utils import unwrap_tensor_subclass
unwrap_tensor_subclass(m)
example_inputs = m.example_inputs()
ref = m(*example_inputs)
_TMP_FN = "_test.pt"
torch.save(m.state_dict(), _TMP_FN)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

state_dict = torch.load(_TMP_FN)
os.remove(_TMP_FN)
m2 = ToyLinearModel().eval()
apply_weight_only_int8_quant(m2)
m2 = quantize(m2, int8wo())
unwrap_tensor_subclass(m2)

m2.load_state_dict(state_dict)
m2 = m2.to(device="cuda")
example_inputs = map(lambda x: x.cuda(), example_inputs)
Expand Down Expand Up @@ -508,7 +500,7 @@ def test_quantized_tensor_subclass_8da4w(self):
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize))
m = quantize(m, int8da_int4w(groupsize=groupsize))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down Expand Up @@ -537,7 +529,7 @@ def test_quantized_tensor_subclass_int4(self):
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

groupsize = 32
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
m = quantize(m, int4wo(groupsize=groupsize))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -557,7 +549,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

m = quantize(m, get_apply_int8wo_quant())
m = quantize(m, int8wo())

assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
Expand All @@ -580,7 +572,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
m = quantize(m, get_apply_int8dyn_quant())
m = quantize(m, int8da_int8w())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down
2 changes: 0 additions & 2 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,6 @@ def __new__(
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout
)
if dtype is None:
dtype = scale.dtype
kwargs["dtype"] = dtype
if strides is not None:
kwargs["strides"] = strides
Expand Down
Loading

0 comments on commit b66f0cf

Please sign in to comment.