Skip to content

Commit fbb2f2b

Browse files
committed
inline quantization API in benchmarks
1 parent 964b7e4 commit fbb2f2b

File tree

1 file changed

+28
-53
lines changed

1 file changed

+28
-53
lines changed

benchmarks/benchmark_aq.py

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,6 @@
1818
)
1919

2020

21-
def _int8wo_api(mod, **kwargs):
22-
quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False)
23-
24-
25-
def _int8da_int8w_api(mod, **kwargs):
26-
quantize_(
27-
mod,
28-
Int8DynamicActivationInt8WeightConfig(**kwargs),
29-
set_inductor_config=False,
30-
)
31-
32-
33-
def _int4wo_api(mod, **kwargs):
34-
kwargs_copy = kwargs.copy()
35-
if "groupsize" in kwargs_copy:
36-
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
37-
del kwargs_copy["groupsize"]
38-
quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False)
39-
40-
4121
class ToyLinearModel(torch.nn.Module):
4222
"""Single linear for m * k * n problem size"""
4323

@@ -117,26 +97,14 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
11797

11898

11999
@torch.no_grad
120-
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
121-
if kwargs is None:
122-
kwargs = {}
123-
100+
def _bench_quantized_tensor_subclass_perf(api, config, M, N, K):
124101
m = ToyLinearModel(
125102
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
126103
).eval()
127104
m_bf16 = copy.deepcopy(m)
128-
m_ref = copy.deepcopy(m)
129105
example_inputs = m.example_inputs()
130106

131-
api(m, **kwargs)
132-
133-
# reference
134-
ref_api(m_ref, **kwargs)
135-
136-
res = m(*example_inputs)
137-
ref = m_ref(*example_inputs)
138-
139-
assert torch.equal(res, ref)
107+
api(m, config) # Pass both model and config
140108

141109
# perf comparison
142110
from torchao.utils import benchmark_model
@@ -146,22 +114,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
146114
RUNS = 100
147115

148116
torch._dynamo.reset()
149-
m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True)
150-
benchmark_model(m_ref, WARMUP, example_inputs)
151-
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
117+
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
118+
benchmark_model(m_bf16, WARMUP, example_inputs)
119+
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
152120

153121
torch._dynamo.reset()
154122
m = torch.compile(m, mode="max-autotune", fullgraph=True)
155123
benchmark_model(m, WARMUP, example_inputs)
156124
elapsed_time = benchmark_model(m, RUNS, example_inputs)
157125

158-
torch._dynamo.reset()
159-
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
160-
benchmark_model(m_bf16, WARMUP, example_inputs)
161-
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
162-
163126
print(
164-
f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
127+
f"{(M, N, K)}: elapsed time: {elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
165128
)
166129

167130

@@ -170,20 +133,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
170133
(20, 2048, 2048),
171134
]
172135

173-
print("_int8da_int8w_api")
174-
136+
print("Int8DynamicActivationInt8WeightConfig")
175137
for M, N, K in all_shapes:
176138
_bench_quantized_tensor_subclass_perf(
177-
_int8da_int8w_api, _int8da_int8w_api, M, N, K
139+
quantize_,
140+
Int8DynamicActivationInt8WeightConfig(),
141+
M,
142+
N,
143+
K,
178144
)
179145

180-
print("_int8wo_api")
181-
146+
print("Int8WeightOnlyConfig")
182147
for M, N, K in all_shapes:
183-
_bench_quantized_tensor_subclass_perf(_int8wo_api, _int8wo_api, M, N, K)
184-
185-
print("_int4wo_api")
186-
kwargs = {"groupsize": 32, "version": 1}
148+
_bench_quantized_tensor_subclass_perf(
149+
quantize_,
150+
Int8WeightOnlyConfig(),
151+
M,
152+
N,
153+
K,
154+
)
187155

156+
print("Int4WeightOnlyConfig")
188157
for M, N, K in all_shapes:
189-
_bench_quantized_tensor_subclass_perf(_int4wo_api, _int4wo_api, M, N, K, kwargs)
158+
_bench_quantized_tensor_subclass_perf(
159+
quantize_,
160+
Int4WeightOnlyConfig(group_size=32),
161+
M,
162+
N,
163+
K,
164+
)

0 commit comments

Comments
 (0)