diff --git a/benchmarks/benchmark_sam.py b/benchmarks/benchmark_sam.py index f0c4021de2..14c2d8bc5e 100644 --- a/benchmarks/benchmark_sam.py +++ b/benchmarks/benchmark_sam.py @@ -5,7 +5,7 @@ from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, - get_subclass_inserter, + _get_subclass_inserter, _is_linear, QuantizedLinearWeightBase, Int8DynamicallyQuantizedLinearWeight, @@ -101,7 +101,7 @@ def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, if filter_fn(mod, name): mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight)) elif subclass and issubclass(subclass, QuantizedLinearWeightBase): - _replace_with_custom_fn_if_matches_filter(model, get_subclass_inserter(subclass), filter_fn) + _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn) if compile: model = torch.compile(model, mode='max-autotune') diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 193238427c..83c0544f6e 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -8,7 +8,7 @@ from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, - get_subclass_inserter, + _get_subclass_inserter, _is_linear, ) from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 @@ -61,7 +61,7 @@ def test_quant_semi_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - _replace_with_custom_fn_if_matches_filter(model, get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear) + _replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index d4679d1286..9eeb146f55 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -393,10 +393,10 @@ def change_linears_to_autoquantizable(model, **kwargs): kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) kwargs["mode"] = kwargs.get("mode", ["relu", None]) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - from torchao.quantization.quant_api import get_subclass_inserter + from torchao.quantization.quant_api import _get_subclass_inserter _replace_with_custom_fn_if_matches_filter( model, - get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), + _get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), filter_fn if filter_fn is not None else _is_linear, ) @@ -417,10 +417,10 @@ def change_autoquantizable_to_quantized(model, **kwargs): ) error_on_unseen=kwargs.pop("error_on_unseen", True) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - from torchao.quantization.quant_api import get_subclass_inserter + from torchao.quantization.quant_api import _get_subclass_inserter _replace_with_custom_fn_if_matches_filter( model, - get_subclass_inserter( + _get_subclass_inserter( AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs ), filter_fn, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6a0a6c80c2..d9b731bace 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -55,7 +55,7 @@ "Int4WeightOnlyQuantizer", "quantize", "autoquant", - "get_subclass_inserter", + "_get_subclass_inserter", ] if TORCH_VERSION_AFTER_2_3: @@ -138,7 +138,7 @@ def apply_dynamic_quant(model, filter_fn=None): import torch.nn.utils.parametrize as parametrize -def get_subclass_inserter(cls, enable_parametrization=False, **kwargs): +def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs): """ Returns a function which inserts the given subclass into all linear modules in the model. The inserted module will have its weight set to the result of @@ -178,7 +178,7 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): ) _replace_with_custom_fn_if_matches_filter( - model, get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn ) @@ -191,7 +191,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): """ _replace_with_custom_fn_if_matches_filter( model, - get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), + _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), _is_linear if filter_fn is None else filter_fn, ) @@ -207,7 +207,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): _replace_with_custom_fn_if_matches_filter( model, - get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), + _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn, )