Skip to content

Commit

Permalink
Support model.to int8 weight only quantized model
Browse files Browse the repository at this point in the history
Summary:
registering fields as buffers so they get picked up in `model.to`

Test Plan:
python test/quantization/test_quant_api.py -k test_int8_wo_quant_save_load
Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 4, 2024
1 parent eba4c36 commit bdb7bc2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
26 changes: 24 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# This test takes a long time to run
import unittest
import torch
import os
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
Expand All @@ -18,9 +19,10 @@
get_symmetric_quantization_config,
)

from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import apply_dynamic_quant
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
)
Expand Down Expand Up @@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int8_wo_quant_save_load(self):
m = M().eval().cpu()
apply_weight_only_int8_quant(m)
example_inputs = m.example_inputs()
ref = m(*example_inputs)
_TMP_FN = "_test.pt"
torch.save(m.state_dict(), _TMP_FN)

state_dict = torch.load(_TMP_FN)
os.remove(_TMP_FN)
m2 = M().eval()
apply_weight_only_int8_quant(m2)
m2.load_state_dict(state_dict)
m2 = m2.to(device="cuda")
example_inputs = map(lambda x: x.cuda(), example_inputs)
res = m2(*example_inputs)

torch.testing.assert_close(ref, res.cpu())

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_8da4w_quantizer(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down
5 changes: 2 additions & 3 deletions torchao/quantization/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs):
scales = kwargs.pop("scales")
super().__init__(*args, **kwargs)

self.w_int8 = w_int8

self.scales = scales
self.register_buffer("w_int8", w_int8)
self.register_buffer("scales", scales)

def forward(self, x, *args, **kwargs):
"""
Expand Down

0 comments on commit bdb7bc2

Please sign in to comment.