15
15
)
16
16
from torchao .prototype .smoothquant .core import SmoothQuantStep
17
17
from torchao .quantization import quantize_
18
+ from torchao .quantization .linear_activation_scale import (
19
+ WeightTensorWithLinearActivationScaleMetadata ,
20
+ )
18
21
from torchao .quantization .quant_api import (
19
22
Int8DynamicActivationInt8WeightConfig ,
20
23
)
24
+ from torchao .quantization .utils import (
25
+ compute_error as SQNR ,
26
+ )
21
27
22
28
23
29
class ToyLinearModel (torch .nn .Module ):
@@ -34,16 +40,19 @@ def example_inputs(
34
40
dtype = torch .bfloat16 ,
35
41
device = "cuda" ,
36
42
):
37
- return [
38
- torch .randn (
39
- 1 ,
40
- sequence_length ,
41
- self .linear1 .in_features ,
42
- dtype = dtype ,
43
- device = device ,
44
- )
45
- for j in range (batch_size )
46
- ]
43
+ # For SmoothQuant tests, we intentionally insert some outliers to input features
44
+ x = torch .randn (
45
+ batch_size ,
46
+ sequence_length ,
47
+ self .linear1 .in_features ,
48
+ dtype = dtype ,
49
+ device = device ,
50
+ )
51
+ n_outliers = max (1 , int (x .size (- 1 ) * 0.1 ))
52
+ # Randomly select outlier features
53
+ outlier_indices = torch .randperm (x .size (- 1 ))[:n_outliers ]
54
+ x [:, :, outlier_indices ] *= 10.0
55
+ return (x ,)
47
56
48
57
def forward (self , x ):
49
58
x = self .linear1 (x )
@@ -52,7 +61,9 @@ def forward(self, x):
52
61
return x
53
62
54
63
55
- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
64
+ device_list = ["cpu" , "cuda" ] if torch .cuda .is_available () else ["cpu" ]
65
+
66
+
56
67
@unittest .skipIf (torch .version .hip is not None , "Skipping tests in ROCm" )
57
68
class TestSmoothQuant (unittest .TestCase ):
58
69
"""SmoothQuant tests using only supported quantization configs."""
@@ -72,37 +83,25 @@ def setUpClass(cls):
72
83
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
73
84
],
74
85
)
75
- @common_utils .parametrize ("device" , [ "cpu" , "cuda" ] )
86
+ @common_utils .parametrize ("device" , device_list )
76
87
@common_utils .parametrize ("input_dtype" , [torch .bfloat16 ])
77
88
def test_smoothquant_accuracy (self , alpha , base_config , device , input_dtype ):
78
89
"""Test if SmoothQuant achieves lower loss than basic quantization."""
79
- in_features = 64
80
- out_features = 128
81
-
82
- # Note: This is sanity check. For real run, consider Transformer model to reproduce.
83
- X = torch .randn (16 , in_features , dtype = input_dtype , device = device )
84
- W = torch .randn (out_features , in_features , dtype = input_dtype , device = device )
85
-
86
90
# Create linear layer
87
- linear = (
88
- torch .nn .Linear (in_features , out_features , bias = False )
89
- .to (device )
90
- .to (input_dtype )
91
- )
92
- with torch .no_grad ():
93
- linear .weight .copy_ (W )
91
+ m = ToyLinearModel ().eval ().to (device ).to (input_dtype )
92
+ x = m .example_inputs (batch_size = 16 , dtype = input_dtype , device = device )
94
93
95
94
# Reference output
96
- out_ref = linear ( X )
95
+ out_ref = m ( * x )
97
96
98
97
# Step 1. Basic quantization
99
- basic_model = deepcopy (linear )
98
+ basic_model = deepcopy (m )
100
99
quantize_ (basic_model , base_config )
101
- out_basic = basic_model (X )
100
+ out_basic = basic_model (* x )
102
101
loss_base = torch .nn .functional .mse_loss (out_basic , out_ref ).item ()
103
102
104
- # SmoothQuant quantization
105
- model = deepcopy (linear )
103
+ # Step 2. SmoothQuant
104
+ model = deepcopy (m )
106
105
config = SmoothQuantConfig (
107
106
base_config = base_config ,
108
107
step = SmoothQuantStep .PREPARE ,
@@ -111,18 +110,25 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
111
110
quantize_ (model , config )
112
111
113
112
# Perform calibration with test data
114
- model (X )
113
+ model (* x )
115
114
116
- # Step 2. SmoothQuant
117
115
config .step = SmoothQuantStep .CONVERT
118
116
quantize_ (model , config )
117
+ assert isinstance (
118
+ model .linear1 .weight , WeightTensorWithLinearActivationScaleMetadata
119
+ )
120
+ assert isinstance (
121
+ model .linear2 .weight , WeightTensorWithLinearActivationScaleMetadata
122
+ )
119
123
120
- out_smoothquant = model (X )
124
+ out_smoothquant = model (* x )
121
125
loss_smoothquant = torch .nn .functional .mse_loss (out_smoothquant , out_ref ).item ()
122
126
123
127
assert loss_smoothquant < loss_base , (
124
128
f"SmoothQuant loss ({ loss_smoothquant :.6f} ) should not be higher than basic loss ({ loss_base :.6f} )"
125
129
)
130
+ # Make sure the result is reasonable
131
+ self .assertGreater (SQNR (out_ref , out_smoothquant ), 20.0 )
126
132
127
133
@common_utils .parametrize (
128
134
"base_config" ,
0 commit comments