From bdb7bc2c0f986c07f8f75744413bf55127f5860d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 3 Apr 2024 23:37:48 -0700 Subject: [PATCH] Support `model.to` int8 weight only quantized model Summary: registering fields as buffers so they get picked up in `model.to` Test Plan: python test/quantization/test_quant_api.py -k test_int8_wo_quant_save_load Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 26 ++++++++++++++++++++++++-- torchao/quantization/weight_only.py | 5 ++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 5cc5ac1fa3..d772fda831 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -8,6 +8,7 @@ # This test takes a long time to run import unittest import torch +import os from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, @@ -18,9 +19,10 @@ get_symmetric_quantization_config, ) -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.quantization.quant_api import apply_dynamic_quant from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + apply_dynamic_quant, + apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, ) @@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_int8_wo_quant_save_load(self): + m = M().eval().cpu() + apply_weight_only_int8_quant(m) + example_inputs = m.example_inputs() + ref = m(*example_inputs) + _TMP_FN = "_test.pt" + torch.save(m.state_dict(), _TMP_FN) + + state_dict = torch.load(_TMP_FN) + os.remove(_TMP_FN) + m2 = M().eval() + apply_weight_only_int8_quant(m2) + m2.load_state_dict(state_dict) + m2 = m2.to(device="cuda") + example_inputs = map(lambda x: x.cuda(), example_inputs) + res = m2(*example_inputs) + + torch.testing.assert_close(ref, res.cpu()) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer diff --git a/torchao/quantization/weight_only.py b/torchao/quantization/weight_only.py index 2ab5adf3d1..099df0f17f 100644 --- a/torchao/quantization/weight_only.py +++ b/torchao/quantization/weight_only.py @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs): scales = kwargs.pop("scales") super().__init__(*args, **kwargs) - self.w_int8 = w_int8 - - self.scales = scales + self.register_buffer("w_int8", w_int8) + self.register_buffer("scales", scales) def forward(self, x, *args, **kwargs): """