Skip to content

Commit

Permalink
Composing autoquant with compile
Browse files Browse the repository at this point in the history
Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Apr 25, 2024
1 parent 639432b commit 144b03d
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 21 deletions.
40 changes: 37 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
torch.nn.ReLU(),
).to(device).to(dtype)
out = model(example_input)
torchao.autoquant(model, example_input)
torchao.autoquant(model)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)
Expand All @@ -1396,10 +1396,44 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
).to(device).to(dtype)
example_input = torch.randn(m1, k, device=device, dtype=dtype)
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
torchao.quantization.change_linears_to_autoquantizable(model)
out=model(example_input)
torchao.autoquant(model)
model.forward_log_only(example_input)
model(example_input2)
torchao.quantization.change_autoquantizable_to_quantized(model)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(1, 1, 128, 128),
(1, 32, 128, 128),
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")
if m1 == 1 or m2 == 1:
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(device).to(dtype)
example_input = torch.randn(m1, k, device=device, dtype=dtype)
example_input2 = torch.randn(m1, k, device=device, dtype=dtype)
out = model(example_input)

torchao.autoquant(torch.compile(model))
model.forward_log_only(example_input)
model(example_input2)

out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)
Expand Down
10 changes: 5 additions & 5 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ torch._inductor.config.use_mixed_mm = True
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))
# perform autoquantization and torch.compile
torchao.autoquant(torch.compile(model, mode='max-autotune'))

# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
# pass in an input which is used in order to pick fastest quantization operations
# and apply torch compilation.
model(input)
```

Expand Down Expand Up @@ -167,6 +167,6 @@ model(input)

## Notes

1. APIs have been hardware tested on A100 and T4(colab)
1. APIs have been hardware tested on A100 and T4(colab)
2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
74 changes: 62 additions & 12 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
update_cache(q_cls, shapes_and_dtype, res)

@torch.no_grad()
def to_quantized(self, error_on_unseen, **kwargs):
if error_on_unseen and self.logged_data == {}:
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
Expand Down Expand Up @@ -176,6 +177,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
if func is aten.detach.default:
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))

@torch.no_grad()
def do_autoquant_bench(op, *args, **kwargs):
"""
runs benchmark op(*args, **kwargs) avoiding torch.compile overhead
Expand Down Expand Up @@ -374,20 +376,68 @@ def change_autoquantizable_to_quantized(model, **kwargs):
torch._dynamo.reset()

@torch.no_grad()
def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs):
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs):
"""
Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape
across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer
and applies that type of quantization.
wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model.
AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original
model and then letting the torch.compile run/tracing occur.
Example usage::
torchao.autoquant(torch.compile(model))
model(*example_input)
"""
if filter_fn is None:
from torchao.quantization.quant_api import _is_linear
filter_fn = _is_linear
# the hook we will use to intercept the model forward and perform
# autoquantization
def autoquant_prehook(module, args, kwargs):
module.forward_log_only(*args, **kwargs)
change_autoquantizable_to_quantized(
module,
**aq_kwargs,
)
module.clean_up_autoquant_hooks_and_attrs()
return args, kwargs

change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs)
if not isinstance(example_input, (tuple, list)):
assert isinstance(example_input, torch.Tensor)
# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
change_linears_to_autoquantizable(
model,
filter_fn=filter_fn,
qtensor_class_list=qtensor_class_list,
mode=mode,
**aq_kwargs
)

# access actual model of torch.compile wrapper if needed
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
real_model = model._orig_mod
else:
real_model = model

# we need a consistent way to run the model which bypasses both
# A) the torch.compile tracing (so we need to run the inner model directly)
# B) the autoquant_prehook we're about to register (so we call forward directly)
model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs)

# the autoquant_prehook intercepts the forward call and performs autoquantization
# and then deletes the hook. if model is a torch.compile wrapper, it then
# does the tracing/compile since the prehook is naturally followed by the normal.
# model run.
handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True)

# note the torch.compile wrapper eval_frame moved the assignment of any assigned
# attributes to the inner model, so we have to call delattr on the inner model
def clean_up_autoquant_hooks_and_attrs():
handle.remove()
delattr(real_model, "clean_up_autoquant_hooks_and_attrs")
delattr(real_model, "forward_log_only")
model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs

# if example input was provided, check it and run it
if isinstance(example_input, torch.Tensor):
example_input = [example_input]
model(*example_input)
change_autoquantizable_to_quantized(model, **kwargs)
if isinstance(example_input, (tuple, list)):
model(*example_input)

return model
4 changes: 3 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
from .autoquant import autoquant


__all__ = [
Expand All @@ -46,7 +47,8 @@
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer"
"Int4WeightOnlyQuantizer",
"autoquant"
]

if TORCH_VERSION_AFTER_2_3:
Expand Down

0 comments on commit 144b03d

Please sign in to comment.