diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 5cc5ac1fa3..d772fda831 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -8,6 +8,7 @@ # This test takes a long time to run import unittest import torch +import os from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, @@ -18,9 +19,10 @@ get_symmetric_quantization_config, ) -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.quantization.quant_api import apply_dynamic_quant from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + apply_dynamic_quant, + apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, ) @@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_int8_wo_quant_save_load(self): + m = M().eval().cpu() + apply_weight_only_int8_quant(m) + example_inputs = m.example_inputs() + ref = m(*example_inputs) + _TMP_FN = "_test.pt" + torch.save(m.state_dict(), _TMP_FN) + + state_dict = torch.load(_TMP_FN) + os.remove(_TMP_FN) + m2 = M().eval() + apply_weight_only_int8_quant(m2) + m2.load_state_dict(state_dict) + m2 = m2.to(device="cuda") + example_inputs = map(lambda x: x.cuda(), example_inputs) + res = m2(*example_inputs) + + torch.testing.assert_close(ref, res.cpu()) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer diff --git a/torchao/quantization/weight_only.py b/torchao/quantization/weight_only.py index 2ab5adf3d1..099df0f17f 100644 --- a/torchao/quantization/weight_only.py +++ b/torchao/quantization/weight_only.py @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs): scales = kwargs.pop("scales") super().__init__(*args, **kwargs) - self.w_int8 = w_int8 - - self.scales = scales + self.register_buffer("w_int8", w_int8) + self.register_buffer("scales", scales) def forward(self, x, *args, **kwargs): """