18
18
)
19
19
20
20
21
- def _int8wo_api (mod , ** kwargs ):
22
- quantize_ (mod , Int8WeightOnlyConfig (** kwargs ), set_inductor_config = False )
23
-
24
-
25
- def _int8da_int8w_api (mod , ** kwargs ):
26
- quantize_ (
27
- mod ,
28
- Int8DynamicActivationInt8WeightConfig (** kwargs ),
29
- set_inductor_config = False ,
30
- )
31
-
32
-
33
- def _int4wo_api (mod , ** kwargs ):
34
- kwargs_copy = kwargs .copy ()
35
- if "groupsize" in kwargs_copy :
36
- kwargs_copy ["group_size" ] = kwargs_copy ["groupsize" ]
37
- del kwargs_copy ["groupsize" ]
38
- quantize_ (mod , Int4WeightOnlyConfig (** kwargs_copy ), set_inductor_config = False )
39
-
40
-
41
21
class ToyLinearModel (torch .nn .Module ):
42
22
"""Single linear for m * k * n problem size"""
43
23
@@ -117,26 +97,14 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
117
97
118
98
119
99
@torch .no_grad
120
- def _bench_quantized_tensor_subclass_perf (api , ref_api , M , N , K , kwargs = None ):
121
- if kwargs is None :
122
- kwargs = {}
123
-
100
+ def _bench_quantized_tensor_subclass_perf (api , config , M , N , K ):
124
101
m = ToyLinearModel (
125
102
M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda"
126
103
).eval ()
127
104
m_bf16 = copy .deepcopy (m )
128
- m_ref = copy .deepcopy (m )
129
105
example_inputs = m .example_inputs ()
130
106
131
- api (m , ** kwargs )
132
-
133
- # reference
134
- ref_api (m_ref , ** kwargs )
135
-
136
- res = m (* example_inputs )
137
- ref = m_ref (* example_inputs )
138
-
139
- assert torch .equal (res , ref )
107
+ api (m , config ) # Pass both model and config
140
108
141
109
# perf comparison
142
110
from torchao .utils import benchmark_model
@@ -146,22 +114,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
146
114
RUNS = 100
147
115
148
116
torch ._dynamo .reset ()
149
- m_ref = torch .compile (m_ref , mode = "max-autotune" , fullgraph = True )
150
- benchmark_model (m_ref , WARMUP , example_inputs )
151
- ref_elapsed_time = benchmark_model (m_ref , RUNS , example_inputs )
117
+ m_bf16 = torch .compile (m_bf16 , mode = "max-autotune" , fullgraph = True )
118
+ benchmark_model (m_bf16 , WARMUP , example_inputs )
119
+ bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
152
120
153
121
torch ._dynamo .reset ()
154
122
m = torch .compile (m , mode = "max-autotune" , fullgraph = True )
155
123
benchmark_model (m , WARMUP , example_inputs )
156
124
elapsed_time = benchmark_model (m , RUNS , example_inputs )
157
125
158
- torch ._dynamo .reset ()
159
- m_bf16 = torch .compile (m_bf16 , mode = "max-autotune" , fullgraph = True )
160
- benchmark_model (m_bf16 , WARMUP , example_inputs )
161
- bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
162
-
163
126
print (
164
- f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } "
127
+ f"{ (M , N , K )} : elapsed time: { elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } "
165
128
)
166
129
167
130
@@ -170,20 +133,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
170
133
(20 , 2048 , 2048 ),
171
134
]
172
135
173
- print ("_int8da_int8w_api" )
174
-
136
+ print ("Int8DynamicActivationInt8WeightConfig" )
175
137
for M , N , K in all_shapes :
176
138
_bench_quantized_tensor_subclass_perf (
177
- _int8da_int8w_api , _int8da_int8w_api , M , N , K
139
+ quantize_ ,
140
+ Int8DynamicActivationInt8WeightConfig (),
141
+ M ,
142
+ N ,
143
+ K ,
178
144
)
179
145
180
- print ("_int8wo_api" )
181
-
146
+ print ("Int8WeightOnlyConfig" )
182
147
for M , N , K in all_shapes :
183
- _bench_quantized_tensor_subclass_perf (_int8wo_api , _int8wo_api , M , N , K )
184
-
185
- print ("_int4wo_api" )
186
- kwargs = {"groupsize" : 32 , "version" : 1 }
148
+ _bench_quantized_tensor_subclass_perf (
149
+ quantize_ ,
150
+ Int8WeightOnlyConfig (),
151
+ M ,
152
+ N ,
153
+ K ,
154
+ )
187
155
156
+ print ("Int4WeightOnlyConfig" )
188
157
for M , N , K in all_shapes :
189
- _bench_quantized_tensor_subclass_perf (_int4wo_api , _int4wo_api , M , N , K , kwargs )
158
+ _bench_quantized_tensor_subclass_perf (
159
+ quantize_ ,
160
+ Int4WeightOnlyConfig (group_size = 32 ),
161
+ M ,
162
+ N ,
163
+ K ,
164
+ )
0 commit comments