Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support model.to int8 weight only quantized model #122

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# This test takes a long time to run
import unittest
import torch
import os
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
Expand All @@ -18,9 +19,10 @@
get_symmetric_quantization_config,
)

from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import apply_dynamic_quant
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
)
Expand Down Expand Up @@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int8_wo_quant_save_load(self):
m = M().eval().cpu()
apply_weight_only_int8_quant(m)
example_inputs = m.example_inputs()
ref = m(*example_inputs)
_TMP_FN = "_test.pt"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using named temporary files might be better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.save(m.state_dict(), _TMP_FN)

state_dict = torch.load(_TMP_FN)
os.remove(_TMP_FN)
m2 = M().eval()
apply_weight_only_int8_quant(m2)
m2.load_state_dict(state_dict)
m2 = m2.to(device="cuda")
example_inputs = map(lambda x: x.cuda(), example_inputs)
res = m2(*example_inputs)

torch.testing.assert_close(ref, res.cpu())

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_8da4w_quantizer(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down
5 changes: 2 additions & 3 deletions torchao/quantization/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs):
scales = kwargs.pop("scales")
super().__init__(*args, **kwargs)

self.w_int8 = w_int8

self.scales = scales
self.register_buffer("w_int8", w_int8)
self.register_buffer("scales", scales)

def forward(self, x, *args, **kwargs):
"""
Expand Down
Loading