diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4461355fb3..74eab6cad6 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -20,11 +20,10 @@ DynamicallyPerAxisQuantizedLinear, ) from torchao.quantization.quant_api import ( - apply_dynamic_quant, - apply_weight_only_int8_quant, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, - change_linear_weights_to_int4_woqtensors, + int4wo, + int8wo, + int8da_int8w, + quantize, _replace_with_custom_fn_if_matches_filter, ) from torchao.quantization.quant_primitives import ( @@ -73,7 +72,11 @@ from parameterized import parameterized import itertools import logging -from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + unwrap_tensor_subclass, +) logger = logging.getLogger("INFO") @@ -82,9 +85,9 @@ # TODO: use this to reduce the number of tests TENSOR_SUBCLASS_APIS = [ - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, - change_linear_weights_to_int4_woqtensors, + int4wo, + int8wo, + int8da_int8w, ] COMMON_DEVICES = ["cpu", "cuda"] @@ -736,7 +739,8 @@ def _test_lin_weight_subclass_api_impl( nn.Linear(k, n, device=test_device), nn.ReLU(), nn.Linear(n, n, device=test_device) ).to(test_dtype) ref_f = mod(x) - api(mod) + quantize(mod, api()) + unwrap_tensor_subclass(mod) test = mod(x) self.assertGreater( @@ -756,13 +760,13 @@ def _test_lin_weight_subclass_api_impl( @unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( - change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype + int8da_int8w, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) def test_int8_weight_only_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( - change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype + int8wo, device, 40, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -772,7 +776,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): self.skipTest(f"Fails for {dtype}") for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])): self._test_lin_weight_subclass_api_impl( - change_linear_weights_to_int4_woqtensors, + int4wo, device, 15, test_shape=test_shape, @@ -789,7 +793,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): for inner_k_tiles in [4, 2]: kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles} self._test_lin_weight_subclass_api_impl( - lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs), + lambda: int4wo(**kwargs), device, 15, test_shape=test_shape, @@ -804,7 +808,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - apply_dynamic_quant(m) + quantize(m, int8da_int8w()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -818,7 +822,7 @@ def test_weight_only_quant(self): x = torch.randn(*x_shape) m = nn.Sequential(nn.Linear(4, 5)) y_ref = m(x) - apply_weight_only_int8_quant(m) + quantize(m, int8wo()) y_wo = m(x) sqnr = compute_error(y_ref, y_wo) self.assertGreater(sqnr, 44.0) @@ -841,7 +845,8 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): x = torch.randn(*x_shape).to(device).to(dtype) m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype) y_ref = m(x) - apply_weight_only_int8_quant(m) + m = quantize(m, int8wo()) + m = unwrap_tensor_subclass(m) m(x) m_c = torch.compile(m, mode="max-autotune") y_wo, (code,) = run_and_get_code(m_c, x) @@ -868,7 +873,8 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): x = torch.randn(*x_shape).to(device).to(dtype) m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype) y_ref = m(x) - apply_weight_only_int8_quant(m) + m = quantize(m, int8wo()) + m = unwrap_tensor_subclass(m) m_c = torch.compile(m, mode="max-autotune") y_wo, (code,) = run_and_get_code(m_c, x) sqnr = compute_error(y_ref, y_wo) @@ -908,7 +914,9 @@ def forward(self, x): ref_f = model(x) # save quantized state_dict - api(model) + quantize(model, api()) + unwrap_tensor_subclass(model) + torch.save(model.state_dict(), "test.pth") # get quantized reference model_qc = torch.compile(model, mode="max-autotune") @@ -919,11 +927,13 @@ def forward(self, x): # load model structure with torch.device('meta'): model = test_model().to(dtype=test_dtype) - api(model) + quantize(model, api()) + unwrap_tensor_subclass(model) # load quantized state_dict state_dict = torch.load("test.pth", mmap=True) os.remove("test.pth") + model.load_state_dict(state_dict, assign=True) model = model.to(device=test_device, dtype=test_dtype).eval() @@ -936,23 +946,26 @@ def forward(self, x): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Skip if torch version is prior to 2.4 since the current API relies on a parametrization fix") def test_save_load_dqtensors(self, device, dtype): if device == "cpu": self.skipTest(f"indcutor failed for cpu right now") - self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype) + self._test_handle_save_load_meta_impl(int8da_int8w, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Skip if torch version is prior to 2.4 since the current API relies on a parametrization fix") def test_save_load_int8woqtensors(self, device, dtype): - self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype) + self._test_handle_save_load_meta_impl(int8wo, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") @torch.no_grad() + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Skip if torch version is prior to 2.4 since the current API relies on a parametrization fix") def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") - self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype) + self._test_handle_save_load_meta_impl(int4wo, device, 20, test_dtype=dtype) class TorchCompileUnitTest(unittest.TestCase): @@ -1271,8 +1284,7 @@ def forward(self, x): model = test_model().to(dtype=test_dtype, device=test_device).eval() ref_f = model(x) - kwargs = {"dtype": test_dtype} - api(model, **kwargs) + quantize(model, api()) # running model model(x) @@ -1317,8 +1329,8 @@ def forward(self, x): model = test_model().to(dtype=test_dtype, device=test_device).eval() ref_f = model(x) - kwargs = {"dtype": test_dtype} - api(model, **kwargs) + model = quantize(model, api()) + model = unwrap_tensor_subclass(model) # running model ref = model(x) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e0ead9ad12..f16a2c4d2a 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -35,15 +35,13 @@ ) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, - apply_dynamic_quant, - apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, quantize, - get_apply_8da4w_quant, - get_apply_int4wo_quant, - get_apply_int8wo_quant, - get_apply_int8dyn_quant, + int8da_int4w, + int4wo, + int8wo, + int8da_int8w, ) from torchao.utils import ( TORCH_VERSION_AFTER_2_3, @@ -53,6 +51,7 @@ from torchao._models.llama.tokenizer import get_tokenizer from torchao._models.llama.model import Transformer, prepare_inputs_for_model import copy +import tempfile def dynamic_quant(model, example_inputs): @@ -62,20 +61,6 @@ def dynamic_quant(model, example_inputs): m = convert_pt2e(m) return m -def _apply_dynamic_quant(model): - """ - Applies dynamic symmetric per-token activation and per-channel weight - quantization to all linear layers in the given model using - module swaps. - """ - _replace_with_custom_fn_if_matches_filter( - model, - lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)), - lambda mod, fqn: isinstance(mod, torch.nn.Linear), - ) - return model - - def capture_and_prepare(model, example_inputs): m = torch.export.export(model, example_inputs) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) @@ -104,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - apply_dynamic_quant(model) + quantize(model, int8da_int8w()) return model class ToyLinearModel(torch.nn.Module): @@ -127,11 +112,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs The deprecated implementation for int8 dynamic quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _in_features_greater_than_16 from torchao.quantization.quant_api import _is_linear from torchao.quantization.quant_api import _get_subclass_inserter from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight + def _in_features_greater_than_16(mod, *args): + return hasattr(mod, "in_features") and mod.in_features > 16 + if filter_fn is None: filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( *args @@ -167,7 +154,7 @@ class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - m = _apply_dynamic_quant(m) + m = quantize(m, int8da_int8w()) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -205,16 +192,21 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() - apply_weight_only_int8_quant(m) + m = quantize(m, int8wo()) + + from torchao.utils import unwrap_tensor_subclass + unwrap_tensor_subclass(m) example_inputs = m.example_inputs() ref = m(*example_inputs) - _TMP_FN = "_test.pt" - torch.save(m.state_dict(), _TMP_FN) + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) - state_dict = torch.load(_TMP_FN) - os.remove(_TMP_FN) m2 = ToyLinearModel().eval() - apply_weight_only_int8_quant(m2) + m2 = quantize(m2, int8wo()) + unwrap_tensor_subclass(m2) + m2.load_state_dict(state_dict) m2 = m2.to(device="cuda") example_inputs = map(lambda x: x.cuda(), example_inputs) @@ -508,7 +500,7 @@ def test_quantized_tensor_subclass_8da4w(self): m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize)) + m = quantize(m, int8da_int4w(groupsize=groupsize)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -537,7 +529,7 @@ def test_quantized_tensor_subclass_int4(self): example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") groupsize = 32 - m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize)) + m = quantize(m, int4wo(groupsize=groupsize)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -557,7 +549,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - m = quantize(m, get_apply_int8wo_quant()) + m = quantize(m, int8wo()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -580,7 +572,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - m = quantize(m, get_apply_int8dyn_quant()) + m = quantize(m, int8da_int8w()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index ae05720b05..02737fadb6 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -394,8 +394,6 @@ def __new__( kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout ) - if dtype is None: - dtype = scale.dtype kwargs["dtype"] = dtype if strides is not None: kwargs["strides"] = strides diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 04efddd0ca..a86f9493d9 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -56,33 +56,139 @@ with open("quantization-cache.pkl", "wb") as f: with open("quantization-cache.pkl", "rb") as f: torchao.quantization.AUTOQUANT_CACHE.update(pickle.load(f)) ``` +## Affine Quantization +Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data. +### Quantization Primitives +We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. -## A8W8 Dynamic Quantization +### Quantized Tensor Subclass +We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) + +### Quantization Flow Example +Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul +as an example: +```python +import torch +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain +from torchao.dtypes import to_aq +from torch._inductor.runtime.runtime_utils import do_bench_gpu +import copy +from torchao.quantization.quant_api import ( + quantize, + int4wo, +) + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + +dtype = torch.bfloat16 +m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") +m_bf16 = copy.deepcopy(m) +example_inputs = m.example_inputs(dtype=dtype, device="cuda") + +m_bf16 = torch.compile(m_bf16, mode='max-autotune') +# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) +groupsize = 32 +m = quantize(m, int4wo(groupsize=groupsize)) + +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.use_mixed_mm = True + +# temporary workaround for tensor subclass + torch.compile +from torchao.quantization.utils import unwrap_tensor_subclass +m = unwrap_tensor_subclass(m) +# compile the model to improve performance +m = torch.compile(m, mode='max-autotune') + +# benchmark to see the speedup +from torchao.utils import benchmark_model + +num_runs = 100 +torch._dynamo.reset() +bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0]) +print(f"bf16 mean time: {bf16_time}") +int4_time = benchmark_model(m, num_runs, example_inputs[0]) +print(f"int4 weight only quantized mean time: {int4_time}") +print(f"speedup: {bf16_time / int4_time}") + +# output (1xA100 GPU machine) +bf16 mean time: 71.457685546875 +int4 weight only quantized mean time: 31.4580908203125 +speedup: 2.2715200981216173 +``` + +What we do underlying the APIs are roughly the following: +``` +from torchao.dtypes.aqt import to_aq +def int8wo_quant(weight): + return to_aq(weight, MappingType.SYMMETRIC, (1, weight.shape[1]), torch.int8, eps=torch.finfo(torch.float32).eps, zero_point_dtype=torch.int64) + +for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + # optional filtering for module name, shape etc. + m.weight = nn.Parameter(int8wo_quant(m.weight)) + + # note: quantization for activation need to be applied after the weight quantization + # quantization activation (needed by dynamic quantization) + input_quant_func = int8wo_quant # specify how input activation is quantized + m.weight = nn.Parameter(to_linear_act_quantized(m.weight, input_quant_func)) +``` +The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support +`torch.export.export` and `torch.aot_compile` with the following workaround: +``` +from torchao.quantization.utils import unwrap_tensor_subclass +m_unwrapped = unwrap_tensor_subclass(m) + + +# export +m = torch.export.export(m_unwrapped, example_inputs).module() + +# aot_compile +torch._export.aot_compile(m_unwrapped, example_inputs) +``` + +### Other Available Quantization Techniques +#### A8W8 Dynamic Quantization ```python # Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor torch._inductor.config.force_fuse_int_mm_with_mul = True from torchao.quantization import quant_api -# convert linear modules to quantized tensor subclasses -quant_api.change_linear_weights_to_int8_dqtensors(model) + +from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import get_apply_int8dyn_quant +quantize(model, get_apply_int8dyn_quant()) ``` -## A16W8 WeightOnly Quantization +#### A16W8 WeightOnly Quantization ```python -from torchao.quantization import quant_api -quant_api.change_linear_weights_to_int8_woqtensors(model) +from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import get_apply_int8wo_quant +quantize(model, get_apply_int8wo_quant()) ``` This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor. -## A16W4 WeightOnly Quantization +#### A16W4 WeightOnly Quantization ```python -from torchao.quantization import quant_api -quant_api.change_linear_weights_to_int4_woqtensors(model) +from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import get_apply_int4wo_quant +quantize(model, get_apply_int4wo_quant(groupsize=32)) ``` Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. @@ -137,17 +243,17 @@ model = quantizer.quantize(model, inputs).cuda() ``` -## A8W8 Dynamic Quantization +## (To be deprecated) A8W8 Dynamic Quantization ```Python from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer -quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) +quantizer = Int8DynActInt4WeightQuantizer(groupsize=128) model = quantizer.quantize(model) ``` This is used in [ExecuTorch](https://github.com/pytorch/executorch) to quantize llama model right now. -## A8W8 Dynamic Quantization with Smoothquant +## (To be moved to prototype) A8W8 Dynamic Quantization with Smoothquant We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above. Due to requiring calibration, the API is more complicated. @@ -186,118 +292,6 @@ model = torch.compile(model, mode='max-autotune') model(input) ``` -## Affine Quantization -Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data. - -### Quantization Primitives -We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. - -### Quantized Tensor Subclass -We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) - -### Quantization Flow -What we need to do afterwards is roughly the following - -``` -from torchao.dtypes.aqt import to_aq -def apply_int8wo_quant(weight): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - block_size = (1, weight.shape[1]) - return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - -for n, m in model.named_modules(): - if isinstance(m, torch.nn.Linear): - # optional filtering for module name, shape etc. - m.weight = nn.Parameter(apply_int8wo_quant(m.weight)) - # note: quantization for activation need to be applied after the weight quantization - # quantization activation (needed by dynamic quantization) - # input_quant_func = apply_int8wo_quant # specify how input activation is quantized - # m.weight = nn.Parameter(to_laq(m.weight, input_quant_func)) -``` -The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support -`torch.export.export` and `torch.aot_compile` with the following workaround: -``` -from torchao.quantization.utils import unwrap_tensor_subclass -m_unwrapped = unwrap_tensor_subclass(m) - - -# export -m = torch.export.export(m_unwrapped, example_inputs).module() - -# aot_compile -torch._export.aot_compile(m_unwrapped, example_inputs) -``` - -But we expect this will be integrated into the export path by default in the future. - - -### Example -Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul -as an example: -```python -import torch -from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain -from torchao.dtypes import to_aq -from torch._inductor.runtime.runtime_utils import do_bench_gpu -import copy -from torchao.quantization.quant_api import ( - quantize, - get_apply_int4wo_quant, -) - -class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) - - def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - -dtype = torch.bfloat16 -m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") -m_bf16 = copy.deepcopy(m) -example_inputs = m.example_inputs(dtype=dtype, device="cuda") - -m_bf16 = torch.compile(m_bf16, mode='max-autotune') -# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) -groupsize = 32 -m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize)) - -torch._inductor.config.force_fuse_int_mm_with_mul = True -torch._inductor.config.use_mixed_mm = True - -# temporary workaround for tensor subclass + torch.compile -from torchao.quantization.utils import unwrap_tensor_subclass -m = unwrap_tensor_subclass(m) -# compile the model to improve performance -m = torch.compile(m, mode='max-autotune') - -# benchmark to see the speedup -from torchao.utils import benchmark_model - -num_runs = 100 -torch._dynamo.reset() -bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0]) -print(f"bf16 mean time: {bf16_time}") -int4_time = benchmark_model(m, num_runs, example_inputs[0]) -print(f"int4 weight only quantized mean time: {int4_time}") -print(f"speedup: {bf16_time / int4_time}") - -# output (1xA100 GPU machine) -bf16 mean time: 71.457685546875 -int4 weight only quantized mean time: 31.4580908203125 -speedup: 2.2715200981216173 -``` - ## Notes 1. APIs have been hardware tested on A100 and T4(colab) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 510db85512..5b899a1afa 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -13,10 +13,6 @@ both because primitives were designed based on the fusions that come along with it and because that is how we access the intended quantized and mixed GEMM kernels - -TODO: There are 2 different approaches to quantizing a model. The first and more historically -popular approach is to use module swaps which explicitly change the linear modules and the second -approach is to instead use subclasses to change the interpretation of the linear module """ import torch @@ -24,18 +20,15 @@ import torch.nn.functional as F from typing import Any, Callable -from .dynamic_quant import DynamicallyPerAxisQuantizedLinear from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, ) from .subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, - to_laq, + LinearActQuantizedTensor, + to_linear_act_quantized, ) from .quant_primitives import ( @@ -52,11 +45,6 @@ __all__ = [ - "apply_weight_only_int8_quant", - "apply_dynamic_quant", - "change_linear_weights_to_int8_dqtensors", - "change_linear_weights_to_int8_woqtensors", - "change_linear_weights_to_int4_woqtensors", "swap_conv2d_1x1_to_linear", "Quantizer", "TwoStepQuantizer", @@ -65,10 +53,10 @@ "quantize", "autoquant", "_get_subclass_inserter", - "get_apply_8da4w_quant", - "get_apply_int4wo_quant", - "get_apply_int8wo_quant", - "get_apply_int8dyn_quant", + "int8da_int4w", + "int8da_int8w", + "int4wo", + "int8wo", ] from .GPTQ import ( @@ -81,6 +69,68 @@ "Int8DynActInt4WeightGPTQQuantizer", ] +### TO BE DEPRECATED START +from .subclass import ( + Int4WeightOnlyQuantizedLinearWeight, + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, +) + +def _in_features_greater_than_16(mod, *args): + return hasattr(mod, "in_features") and mod.in_features > 16 + +def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): + """ + Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` + Tensor subclass, effectively applying the same form of quantization + as apply_dynamic_quant while not modifying the linear modules. + """ + if filter_fn is None: + filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( + *args + ) + + _replace_with_custom_fn_if_matches_filter( + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + ) + + +def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): + """ + Converts all linear weight tensors to the + `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, + effectively applying the same form of quantization + as apply_weight_only_int8_quant while not modifying the linear modules. + """ + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs), + _is_linear if filter_fn is None else filter_fn, + ) + +def change_linear_weights_to_int4_woqtensors(model, groupsize=128, inner_k_tiles=8, filter_fn=None): + """ + Converts all linear weight tensors to the + `Int4WeightOnlyQuantizedLinearWeight` tensor subclass, + effectively applying the same form of quantization + as apply_dynamic_quant while not modifying the linear modules. + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained, choices are [256, 128, 64, 32] + `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] + """ + if filter_fn is None: + filter_fn = _is_linear + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles), + filter_fn, + ) + +### TO BE DEPRECATED END + + def _replace_with_custom_fn_if_matches_filter( model, @@ -115,39 +165,20 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): + # avoid circular dep + from torchao.dtypes import AffineQuantizedTensor + + # adding weight tensor subclass isinstance check to make sure the weight is only quantized once + # when it is shared by multiple linear modules return ( isinstance(mod, torch.nn.Linear) and hasattr(mod, "weight") and not isinstance(mod.weight, QuantizedLinearWeightBase) and not isinstance(mod.weight, AutoQuantizableLinearWeight) + and not isinstance(mod.weight, AffineQuantizedTensor) + and not isinstance(mod.weight, LinearActQuantizedTensor) ) - -def _in_features_greater_than_16(mod, *args): - return hasattr(mod, "in_features") and mod.in_features > 16 - - -def apply_weight_only_int8_quant(model, filter_fn=None): - """ - Applies weight-only symmetric per-channel int8 quantization to all linear layers - in the given model using module swaps. - """ - _replace_with_custom_fn_if_matches_filter( - model, - WeightOnlyInt8QuantLinear.from_float, - _is_linear if filter_fn is None else filter_fn, - ) - - -def apply_dynamic_quant(model, filter_fn=None): - """ - Applies dynamic symmetric per-token activation and per-channel weight - quantization to all linear layers by converting all linear weight - tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass. - """ - change_linear_weights_to_int8_dqtensors(model, filter_fn) - - import torch.nn.utils.parametrize as parametrize def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs): @@ -178,70 +209,6 @@ def insert_subclass(lin): return insert_subclass -def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` - Tensor subclass, effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - """ - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - if TORCH_VERSION_AFTER_2_4: - quantize(model, get_apply_int8dyn_quant(), filter_fn) - unwrap_tensor_subclass(model, filter_fn) - else: - _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn - ) - - -def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the - `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_weight_only_int8_quant while not modifying the linear modules. - """ - - if TORCH_VERSION_AFTER_2_4: - quantize(model, get_apply_int8wo_quant(), filter_fn) - unwrap_tensor_subclass(model, filter_fn) - else: - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs), - _is_linear if filter_fn is None else filter_fn, - ) - - -def change_linear_weights_to_int4_woqtensors(model, groupsize=128, inner_k_tiles=8, filter_fn=None): - """ - Converts all linear weight tensors to the - `Int4WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - - Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, choices are [256, 128, 64, 32] - `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] - """ - if filter_fn is None: - filter_fn = _is_linear - - if TORCH_VERSION_AFTER_2_4: - quantize(model, get_apply_int4wo_quant(groupsize=groupsize, inner_k_tiles=inner_k_tiles), filter_fn) - unwrap_tensor_subclass(model, filter_fn) - else: - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles), - filter_fn, - ) - def swap_conv2d_1x1_to_linear(model, filter_fn=None): """ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. @@ -303,7 +270,9 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + apply_weight_quant = lambda x: to_aq( + x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, + zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) # apply to modules under block0 submodule def filter_fn(module, fqn): @@ -319,11 +288,18 @@ def filter_fn(module, fqn): ) return model -def get_apply_8da4w_quant(groupsize=32): +def int8da_int4w(groupsize=32): + """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear + This is used to produce a model for executorch backend, but currently executorch did not + support lowering for the quantized model from this flow yet + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + """ def apply_8da4w_quant(weight): # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_aq # weight settings mapping_type = MappingType.SYMMETRIC @@ -348,16 +324,25 @@ def get_per_token_block_size(x): input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) weight = to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) - weight = to_laq(weight, input_quant_func) + weight = to_linear_act_quantized(weight, input_quant_func) return weight return apply_8da4w_quant -def get_apply_int4wo_quant(groupsize=32, inner_k_tiles=8): +def int4wo(groupsize=128, inner_k_tiles=8): + """ + Applies uint4 weight-only asymmetric per-group quantization to linear layers, using + "tensor_core_tiled" layout for speedup with tinygemm kernel + + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained, choices are [256, 128, 64, 32] + `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] + """ def apply_int4wo_quant(weight): # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_aq mapping_type = MappingType.ASYMMETRIC block_size = (1, groupsize) @@ -373,10 +358,13 @@ def apply_int4wo_quant(weight): return apply_int4wo_quant -def get_apply_int8wo_quant(): +def int8wo(): + """ + Applies int8 weight-only symmetric per-channel quantization to linear layers. + """ def apply_int8wo_quant(weight): # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_aq mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -386,10 +374,19 @@ def apply_int8wo_quant(weight): return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) return apply_int8wo_quant -def get_apply_int8dyn_quant(): +def int8da_int8w(): + """ + Applies int8 dynamic symmetric per-token activation and int8 per-channel weight + quantization to linear layers + """ def apply_int8dyn_quant(weight): + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + return weight + # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_aq # weight settings mapping_type = MappingType.SYMMETRIC def get_weight_block_size(x): @@ -414,6 +411,6 @@ def get_per_token_block_size(x): block_size = get_weight_block_size(weight) weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - weight = to_laq(weight, input_quant_func) + weight = to_linear_act_quantized(weight, input_quant_func) return weight return apply_int8dyn_quant diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index c299a5834f..a2801a622f 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -31,7 +31,7 @@ "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", "LinearActQuantizedTensor", - "to_laq", + "to_linear_act_quantized", ] @@ -751,4 +751,4 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) -to_laq = LinearActQuantizedTensor.from_float +to_linear_act_quantized = LinearActQuantizedTensor.from_float diff --git a/torchao/utils.py b/torchao/utils.py index 27650dae1c..4e2afc795d 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -126,6 +126,11 @@ def right_inverse(self, tensor): return plain_tensors def unwrap_tensor_subclass(model, filter_fn=None): + """Unwraps (nested) tensor subclass in the model to plain tensors + This is a workaround to make a model with tensor subclass to work with `torch.export.export` + and `torch.aot_compile`, we hope this can be integrated into compile stack soon + tracking issue: https://github.com/pytorch/ao/issues/345 + """ for name, child in model.named_children(): # make sure child.weight is a tensor subclass if ( diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 918284ae1e..96ed6e00c8 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -3,6 +3,7 @@ from torchao.utils import benchmark_model, profiler_runner from torchvision import models +import torchao.quantization.quant_api as quant_api torch.set_float32_matmul_precision("high") # Load Vision Transformer model @@ -16,7 +17,7 @@ ## Quantization code - start # int8 act, int8 weight dynamic quantization, see README for other APIs -torchao.apply_dynamic_quant(model) +quant_api.quantize(model, quant_api.get_apply_int8dyn_quant()) ## Quantization code - end ## compilation configs