Skip to content

Commit

Permalink
revert public change
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed May 24, 2024
1 parent 816e3b5 commit 97e8d9d
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
get_subclass_inserter,
_get_subclass_inserter,
_is_linear,
QuantizedLinearWeightBase,
Int8DynamicallyQuantizedLinearWeight,
Expand Down Expand Up @@ -101,7 +101,7 @@ def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True,
if filter_fn(mod, name):
mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight))
elif subclass and issubclass(subclass, QuantizedLinearWeightBase):
_replace_with_custom_fn_if_matches_filter(model, get_subclass_inserter(subclass), filter_fn)
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn)

if compile:
model = torch.compile(model, mode='max-autotune')
Expand Down
4 changes: 2 additions & 2 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
get_subclass_inserter,
_get_subclass_inserter,
_is_linear,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_quant_semi_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

_replace_with_custom_fn_if_matches_filter(model, get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)
Expand Down
8 changes: 4 additions & 4 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,10 @@ def change_linears_to_autoquantizable(model, **kwargs):
kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST)
kwargs["mode"] = kwargs.get("mode", ["relu", None])
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import get_subclass_inserter
from torchao.quantization.quant_api import _get_subclass_inserter
_replace_with_custom_fn_if_matches_filter(
model,
get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs),
_get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs),
filter_fn if filter_fn is not None else _is_linear,
)

Expand All @@ -417,10 +417,10 @@ def change_autoquantizable_to_quantized(model, **kwargs):
)
error_on_unseen=kwargs.pop("error_on_unseen", True)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import get_subclass_inserter
from torchao.quantization.quant_api import _get_subclass_inserter
_replace_with_custom_fn_if_matches_filter(
model,
get_subclass_inserter(
_get_subclass_inserter(
AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs
),
filter_fn,
Expand Down
10 changes: 5 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"Int4WeightOnlyQuantizer",
"quantize",
"autoquant",
"get_subclass_inserter",
"_get_subclass_inserter",
]

if TORCH_VERSION_AFTER_2_3:
Expand Down Expand Up @@ -138,7 +138,7 @@ def apply_dynamic_quant(model, filter_fn=None):

import torch.nn.utils.parametrize as parametrize

def get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
"""
Returns a function which inserts the given subclass into all linear modules
in the model. The inserted module will have its weight set to the result of
Expand Down Expand Up @@ -178,7 +178,7 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
)

_replace_with_custom_fn_if_matches_filter(
model, get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
)


Expand All @@ -191,7 +191,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
_replace_with_custom_fn_if_matches_filter(
model,
get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
_is_linear if filter_fn is None else filter_fn,
)

Expand All @@ -207,7 +207,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):

_replace_with_custom_fn_if_matches_filter(
model,
get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
filter_fn,
)

Expand Down

0 comments on commit 97e8d9d

Please sign in to comment.