diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e71f67767d..36d55400bc 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -31,11 +31,11 @@ Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) +from torchao import quantize from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, Quantizer, TwoStepQuantizer, - quantize, int8da_int4w, int4wo, int8wo, @@ -51,6 +51,7 @@ from torchao.utils import unwrap_tensor_subclass import copy import tempfile +from torch.testing._internal.common_utils import TestCase def dynamic_quant(model, example_inputs): @@ -147,7 +148,7 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): _ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) _ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) -class TestQuantFlow(unittest.TestCase): +class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() @@ -601,5 +602,20 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # make sure it compiles torch._export.aot_compile(m_unwrapped, example_inputs) + def test_register_apply_tensor_subclass(self): + from torchao import register_apply_tensor_subclass + def apply_my_dtype(weight): + return weight * 2 + + m = ToyLinearModel().eval() + example_inputs = m.example_inputs() + with self.assertRaisesRegex(ValueError, "not supported"): + quantize(m, "my_dtype") + + register_apply_tensor_subclass("my_dtype", apply_my_dtype) + # make sure it runs + quantize(m, "my_dtype") + m(*example_inputs) + if __name__ == "__main__": unittest.main() diff --git a/torchao/__init__.py b/torchao/__init__.py index 81051dcc3d..255e9ccfa3 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -25,10 +25,14 @@ from torchao.quantization import ( autoquant, + quantize, + register_apply_tensor_subclass, ) from . import dtypes __all__ = [ "dtypes", "autoquant", + "quantize", + "register_apply_tensor_subclass", ] diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index cd8e0f91f0..e461171d9e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -14,12 +14,6 @@ from .autoquant import * __all__ = [ - "DynamicallyPerAxisQuantizedLinear", - "apply_weight_only_int8_quant", - "apply_dynamic_quant", - "change_linear_weights_to_int8_dqtensors", - "change_linear_weights_to_int8_woqtensors", - "change_linear_weights_to_int4_woqtensors", "swap_conv2d_1x1_to_linear" "safe_int_mm", "autoquant", @@ -31,14 +25,12 @@ "swap_linear_with_smooth_fq_linear", "smooth_fq_linear_to_inference", "set_smooth_fq_attribute", - "Int8DynamicallyQuantizedLinearWeight", - "Int8WeightOnlyQuantizedLinearWeight", - "Int4WeightOnlyQuantizedLinearWeight", "compute_error", - "WeightOnlyInt8QuantLinear", "Int4WeightOnlyGPTQQuantizer", "Int4WeightOnlyQuantizer", "quantize_affine", "dequantize_affine", "choose_qprams_affine", + "quantize", + "register_apply_tensor_subclass", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e97d7e8ec0..2a65d3c831 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -41,6 +41,7 @@ Int4WeightOnlyGPTQQuantizer, Int4WeightOnlyQuantizer, ) +import logging from .autoquant import autoquant, AutoQuantizableLinearWeight @@ -50,13 +51,14 @@ "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", "Int4WeightOnlyQuantizer", - "quantize", "autoquant", "_get_subclass_inserter", + "quantize", "int8da_int4w", "int8da_int8w", "int4wo", "int8wo", + "register_apply_tensor_subclass", ] from .GPTQ import ( @@ -292,7 +294,8 @@ def filter_fn(module, fqn): m = quantize(m, apply_weight_quant, filter_fn) """ if isinstance(apply_tensor_subclass, str): - assert apply_tensor_subclass in _APPLY_TS_TABLE, f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}" + if apply_tensor_subclass not in _APPLY_TS_TABLE: + raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}") apply_tensor_subclass = _APPLY_TS_TABLE[apply_tensor_subclass] assert not isinstance(apply_tensor_subclass, str) @@ -438,3 +441,19 @@ def get_per_token_block_size(x): "int8_weight_only": int8wo(), "int8_dynamic": int8da_int8w(), } + +def register_apply_tensor_subclass(name: str, apply_tensor_subclass: Callable): + """Register a string shortcut for `apply_tensor_subclass` that takes a weight Tensor + as input and ouptuts a tensor with tensor subclass applied + + Example: + def apply_my_dtype(weight): + return weight * 2 + + register_apply_tensor_subclass("my_dtype", apply_my_dtype) + # calls `apply_my_dtype` on weights + quantize(m, "my_dtype") + """ + if name in _APPLY_TS_TABLE: + logging.warning(f"shortcut string {name} already exist, overwriting") + _APPLY_TS_TABLE[name] = apply_tensor_subclass