Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add register_apply_tensor_subclass #366

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao import quantize
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
Quantizer,
TwoStepQuantizer,
quantize,
int8da_int4w,
int4wo,
int8wo,
Expand All @@ -51,6 +51,7 @@
from torchao.utils import unwrap_tensor_subclass
import copy
import tempfile
from torch.testing._internal.common_utils import TestCase


def dynamic_quant(model, example_inputs):
Expand Down Expand Up @@ -147,7 +148,7 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)

class TestQuantFlow(unittest.TestCase):
class TestQuantFlow(TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
Expand Down Expand Up @@ -601,5 +602,20 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

def test_register_apply_tensor_subclass(self):
from torchao import register_apply_tensor_subclass
def apply_my_dtype(weight):
return weight * 2

m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
with self.assertRaisesRegex(ValueError, "not supported"):
quantize(m, "my_dtype")

register_apply_tensor_subclass("my_dtype", apply_my_dtype)
# make sure it runs
quantize(m, "my_dtype")
m(*example_inputs)

if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@

from torchao.quantization import (
autoquant,
quantize,
register_apply_tensor_subclass,
)
from . import dtypes

__all__ = [
"dtypes",
"autoquant",
"quantize",
"register_apply_tensor_subclass",
]
12 changes: 2 additions & 10 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
from .autoquant import *

__all__ = [
"DynamicallyPerAxisQuantizedLinear",
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_int8_dqtensors",
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
Expand All @@ -31,14 +25,12 @@
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"Int8DynamicallyQuantizedLinearWeight",
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"compute_error",
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize_affine",
"dequantize_affine",
"choose_qprams_affine",
"quantize",
"register_apply_tensor_subclass",
]
23 changes: 21 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
import logging
from .autoquant import autoquant, AutoQuantizableLinearWeight


Expand All @@ -50,13 +51,14 @@
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize",
"autoquant",
"_get_subclass_inserter",
"quantize",
"int8da_int4w",
"int8da_int8w",
"int4wo",
"int8wo",
"register_apply_tensor_subclass",
]

from .GPTQ import (
Expand Down Expand Up @@ -292,7 +294,8 @@ def filter_fn(module, fqn):
m = quantize(m, apply_weight_quant, filter_fn)
"""
if isinstance(apply_tensor_subclass, str):
assert apply_tensor_subclass in _APPLY_TS_TABLE, f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}"
if apply_tensor_subclass not in _APPLY_TS_TABLE:
raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}")
apply_tensor_subclass = _APPLY_TS_TABLE[apply_tensor_subclass]

assert not isinstance(apply_tensor_subclass, str)
Expand Down Expand Up @@ -438,3 +441,19 @@ def get_per_token_block_size(x):
"int8_weight_only": int8wo(),
"int8_dynamic": int8da_int8w(),
}

def register_apply_tensor_subclass(name: str, apply_tensor_subclass: Callable):
"""Register a string shortcut for `apply_tensor_subclass` that takes a weight Tensor
as input and ouptuts a tensor with tensor subclass applied

Example:
def apply_my_dtype(weight):
return weight * 2

register_apply_tensor_subclass("my_dtype", apply_my_dtype)
# calls `apply_my_dtype` on weights
quantize(m, "my_dtype")
"""
if name in _APPLY_TS_TABLE:
logging.warning(f"shortcut string {name} already exist, overwriting")
_APPLY_TS_TABLE[name] = apply_tensor_subclass
Loading