Skip to content

Commit 389b91d

Browse files
committed
Improve SmoothQuant test cases
1 parent 5cbbd73 commit 389b91d

File tree

1 file changed

+40
-34
lines changed

1 file changed

+40
-34
lines changed

test/prototype/test_smoothquant.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515
)
1616
from torchao.prototype.smoothquant.core import SmoothQuantStep
1717
from torchao.quantization import quantize_
18+
from torchao.quantization.linear_activation_scale import (
19+
WeightTensorWithLinearActivationScaleMetadata,
20+
)
1821
from torchao.quantization.quant_api import (
1922
Int8DynamicActivationInt8WeightConfig,
2023
)
24+
from torchao.quantization.utils import (
25+
compute_error as SQNR,
26+
)
2127

2228

2329
class ToyLinearModel(torch.nn.Module):
@@ -34,16 +40,19 @@ def example_inputs(
3440
dtype=torch.bfloat16,
3541
device="cuda",
3642
):
37-
return [
38-
torch.randn(
39-
1,
40-
sequence_length,
41-
self.linear1.in_features,
42-
dtype=dtype,
43-
device=device,
44-
)
45-
for j in range(batch_size)
46-
]
43+
# For SmoothQuant tests, we intentionally insert some outliers to input features
44+
x = torch.randn(
45+
batch_size,
46+
sequence_length,
47+
self.linear1.in_features,
48+
dtype=dtype,
49+
device=device,
50+
)
51+
n_outliers = max(1, int(x.size(-1) * 0.1))
52+
# Randomly select outlier features
53+
outlier_indices = torch.randperm(x.size(-1))[:n_outliers]
54+
x[:, :, outlier_indices] *= 10.0
55+
return (x,)
4756

4857
def forward(self, x):
4958
x = self.linear1(x)
@@ -52,7 +61,9 @@ def forward(self, x):
5261
return x
5362

5463

55-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
64+
device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
65+
66+
5667
@unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm")
5768
class TestSmoothQuant(unittest.TestCase):
5869
"""SmoothQuant tests using only supported quantization configs."""
@@ -72,37 +83,25 @@ def setUpClass(cls):
7283
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
7384
],
7485
)
75-
@common_utils.parametrize("device", ["cpu", "cuda"])
86+
@common_utils.parametrize("device", device_list)
7687
@common_utils.parametrize("input_dtype", [torch.bfloat16])
7788
def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
7889
"""Test if SmoothQuant achieves lower loss than basic quantization."""
79-
in_features = 64
80-
out_features = 128
81-
82-
# Note: This is sanity check. For real run, consider Transformer model to reproduce.
83-
X = torch.randn(16, in_features, dtype=input_dtype, device=device)
84-
W = torch.randn(out_features, in_features, dtype=input_dtype, device=device)
85-
8690
# Create linear layer
87-
linear = (
88-
torch.nn.Linear(in_features, out_features, bias=False)
89-
.to(device)
90-
.to(input_dtype)
91-
)
92-
with torch.no_grad():
93-
linear.weight.copy_(W)
91+
m = ToyLinearModel().eval().to(device).to(input_dtype)
92+
x = m.example_inputs(batch_size=16, dtype=input_dtype, device=device)
9493

9594
# Reference output
96-
out_ref = linear(X)
95+
out_ref = m(*x)
9796

9897
# Step 1. Basic quantization
99-
basic_model = deepcopy(linear)
98+
basic_model = deepcopy(m)
10099
quantize_(basic_model, base_config)
101-
out_basic = basic_model(X)
100+
out_basic = basic_model(*x)
102101
loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item()
103102

104-
# SmoothQuant quantization
105-
model = deepcopy(linear)
103+
# Step 2. SmoothQuant
104+
model = deepcopy(m)
106105
config = SmoothQuantConfig(
107106
base_config=base_config,
108107
step=SmoothQuantStep.PREPARE,
@@ -111,18 +110,25 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
111110
quantize_(model, config)
112111

113112
# Perform calibration with test data
114-
model(X)
113+
model(*x)
115114

116-
# Step 2. SmoothQuant
117115
config.step = SmoothQuantStep.CONVERT
118116
quantize_(model, config)
117+
assert isinstance(
118+
model.linear1.weight, WeightTensorWithLinearActivationScaleMetadata
119+
)
120+
assert isinstance(
121+
model.linear2.weight, WeightTensorWithLinearActivationScaleMetadata
122+
)
119123

120-
out_smoothquant = model(X)
124+
out_smoothquant = model(*x)
121125
loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item()
122126

123127
assert loss_smoothquant < loss_base, (
124128
f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})"
125129
)
130+
# Make sure the result is reasonable
131+
self.assertGreater(SQNR(out_ref, out_smoothquant), 20.0)
126132

127133
@common_utils.parametrize(
128134
"base_config",

0 commit comments

Comments
 (0)