Skip to content

Commit

Permalink
fixing autoquant bug (#265)
Browse files Browse the repository at this point in the history
Summary:

in some model topographies you see the same weight accessed by multiple
modules which caused a bug where weights would get autoquantized
multiple times.

Also fixed a shape issue with x_scales in some situation with new
primitives.

Also changed default for autoquant to be interpolation which seems to
work better for torchbench benchmarking

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored May 24, 2024
1 parent 49755f6 commit 163cb93
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
41 changes: 40 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3
AQWeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
Expand Down Expand Up @@ -1471,6 +1472,44 @@ def forward(self, x, y):
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(16, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_double_access(self, device, dtype, m, k, n):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")

class DoubleAccess(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin1 = torch.nn.Linear(k, n)
self.lin2 = torch.nn.Linear(n, k)
self.lin3 = torch.nn.Linear(k, n)
self.lin3.weight = self.lin1.weight

def forward(self, x):
x = self.lin1(x)
x = self.lin2(x)
x = self.lin3(x)
return x

x_in = torch.randn(m, k, device=device, dtype=dtype)
model = DoubleAccess().to(device).to(dtype)
model(x_in)
torchao.autoquant(model)
assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight)
model(x_in)




class TestAOTI(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data)
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
Expand Down Expand Up @@ -384,7 +384,7 @@ def change_autoquantizable_to_quantized(model, **kwargs):
torch._dynamo.reset()

@torch.no_grad()
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs):
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs):
"""
wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model.
AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
from .autoquant import autoquant
from .autoquant import autoquant, AutoQuantizableLinearWeight


__all__ = [
Expand Down Expand Up @@ -91,6 +91,7 @@ def _is_linear(mod, *args):
isinstance(mod, torch.nn.Linear)
and hasattr(mod, "weight")
and not isinstance(mod.weight, QuantizedLinearWeightBase)
and not isinstance(mod.weight, AutoQuantizableLinearWeight)
)


Expand Down

0 comments on commit 163cb93

Please sign in to comment.