Skip to content

Commit

Permalink
Composing autoquant with compile (#175)
Browse files Browse the repository at this point in the history
* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* Fused DoRA kernels (#216)

* add dora kernels

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Unified AffineQuantizedTensor subclass (#214)

Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>

* add expecttest to requirements.txt (#225)

* add expecttest to requirements.txt

* update

* Install dev-requirements.txt in doc build (#224)

Install dev-requirements.txt

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>

* Fix an error in subclass impl (#226)

Summary:
Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Some follow up fixes for quant primitives (#220)

Summary:
att

Test Plan:
python test/quantization/test_quant_primitives.py -k test_raises

Reviewers:

Subscribers:

Tasks:

Tags:

* Composing autoquant with compile

Summary:

this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:

torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)

now you can do
torchao.autoquant(torch.compile(model))
model(input)

The new method works with/without compile. Also this is BC so the old
path also works.

We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.

note: in the case of multiple inputs, you can also do:

model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.

Test Plan: python test/integration/test_integration.py -k "autoquant"

Reviewers:

Subscribers:

Tasks:

Tags:

* allowing error_on_unseen in autoquant func

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* update readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* trying to fix the error in CI on cleanup hooks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* correct docs

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: jeromeku <jerome.ku@gmail.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
  • Loading branch information
5 people authored May 8, 2024
1 parent 63c5ac5 commit f6d56ca
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 33 deletions.
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@ torch._inductor.config.use_mixed_mm = True
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))

# compile the model to recover performance
model = torch.compile(model, mode='max-autotune')
model(input)
# perform autoquantization and compilation
q_model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
q_model(input)
```

### Sparsity
Expand Down
63 changes: 55 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
torch.nn.ReLU(),
).to(device).to(dtype)
out = model(example_input)
torchao.autoquant(model, example_input)
torchao.autoquant(model)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)
Expand All @@ -1400,7 +1400,9 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
Expand All @@ -1414,15 +1416,60 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
torch.nn.ReLU(),
).to(device).to(dtype)
example_input = torch.randn(m1, k, device=device, dtype=dtype)
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
torchao.quantization.change_linears_to_autoquantizable(model)
out=model(example_input)
model(example_input2)
torchao.quantization.change_autoquantizable_to_quantized(model)
out2 = model(example_input)
example_input2 = torch.randn(m1, k, device=device, dtype=dtype)
out = model(example_input)

mod = torchao.autoquant(torch.compile(model))
mod.forward_log_only(example_input)
mod(example_input2)

out2 = mod(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(1, 1, 128, 128),
(1, 32, 128, 128),
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")
if m1 == 1 or m2 == 1:
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")

class NeedsKwargs(torch.nn.Module):
def __init__(self):
super().__init__()
self.rel = torch.nn.ReLU()
self.lin = torch.nn.Linear(k,n)

def forward(self, x, y):
x = self.rel(x)
z = self.lin(x + y)
return z

model = NeedsKwargs().to(device).to(dtype)
example_input = {
"x": torch.randn(m1, k, device=device, dtype=dtype),
"y": torch.randn(m1, k, device=device, dtype=dtype),
}
out = model(**example_input)

mod = torchao.autoquant(torch.compile(model))
mod.forward_log_only(**example_input)
mod(**example_input)

out2 = mod(**example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

class TestAOTI(unittest.TestCase):
@parameterized.expand(
Expand Down
10 changes: 5 additions & 5 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ torch._inductor.config.use_mixed_mm = True
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))
# perform autoquantization and torch.compile
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
# pass in an input which is used in order to pick fastest quantization operations
# and apply torch compilation.
model(input)
```

Expand Down Expand Up @@ -167,6 +167,6 @@ model(input)

## Notes

1. APIs have been hardware tested on A100 and T4(colab)
1. APIs have been hardware tested on A100 and T4(colab)
2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
80 changes: 67 additions & 13 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
update_cache(q_cls, shapes_and_dtype, res)

@torch.no_grad()
def to_quantized(self, error_on_unseen, **kwargs):
if error_on_unseen and self.logged_data == {}:
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
Expand Down Expand Up @@ -123,7 +124,7 @@ def count_shapes(self, do_print=True):
torch._dynamo.reset()
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
if shape_count is not None and shape_count > 1:
print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
if best_time >= cur_time:
best_time = cur_time
best_cls = q_cls
Expand Down Expand Up @@ -176,6 +177,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
if func is aten.detach.default:
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))

@torch.no_grad()
def do_autoquant_bench(op, *args, **kwargs):
"""
runs benchmark op(*args, **kwargs) avoiding torch.compile overhead
Expand Down Expand Up @@ -335,6 +337,7 @@ def change_linears_to_autoquantizable(model, **kwargs):
"""
from torchao.quantization.quant_api import _is_linear
filter_fn = kwargs.pop("filter_fn", _is_linear)
_ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized
kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST)
kwargs["mode"] = kwargs.get("mode", ["relu", None])
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
Expand Down Expand Up @@ -374,20 +377,71 @@ def change_autoquantizable_to_quantized(model, **kwargs):
torch._dynamo.reset()

@torch.no_grad()
def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs):
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs):
"""
Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape
across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer
and applies that type of quantization.
wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model.
AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original
model and then letting the torch.compile run/tracing occur.
Example usage::
torchao.autoquant(torch.compile(model))
model(*example_input)
"""
if filter_fn is None:
from torchao.quantization.quant_api import _is_linear
filter_fn = _is_linear
# the hook we will use to intercept the model forward and perform
# autoquantization
def autoquant_prehook(module, args, kwargs):
module.forward_log_only(*args, **kwargs)
change_autoquantizable_to_quantized(
module,
**aq_kwargs,
)
module.clean_up_autoquant_hooks_and_attrs()
return args, kwargs

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
change_linears_to_autoquantizable(
model,
filter_fn=filter_fn,
qtensor_class_list=qtensor_class_list,
mode=mode,
**aq_kwargs
)

# access actual model of torch.compile wrapper if needed
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
real_model = model._orig_mod
else:
real_model = model

# we need a consistent way to run the model which bypasses both
# A) the torch.compile tracing (so we need to run the inner model directly)
# B) the autoquant_prehook we're about to register (so we call forward directly)
model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs)

# the autoquant_prehook intercepts the forward call and performs autoquantization
# and then deletes the hook. if model is a torch.compile wrapper, it then
# does the tracing/compile since the prehook is naturally followed by the normal.
# model run.
handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True)

# note the torch.compile wrapper eval_frame moved the assignment of any assigned
# attributes to the inner model, so we have to call delattr on the inner model
def clean_up_autoquant_hooks_and_attrs():
try:
handle.remove()
delattr(real_model, "clean_up_autoquant_hooks_and_attrs")
delattr(real_model, "forward_log_only")
except:
pass
model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs

change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs)
if not isinstance(example_input, (tuple, list)):
assert isinstance(example_input, torch.Tensor)
# if example input was provided, check it and run it
if isinstance(example_input, torch.Tensor):
example_input = [example_input]
model(*example_input)
change_autoquantizable_to_quantized(model, **kwargs)
if isinstance(example_input, (tuple, list)):
model(*example_input)

return model
4 changes: 3 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
from .autoquant import autoquant


__all__ = [
Expand All @@ -46,7 +47,8 @@
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer"
"Int4WeightOnlyQuantizer",
"autoquant"
]

if TORCH_VERSION_AFTER_2_3:
Expand Down

0 comments on commit f6d56ca

Please sign in to comment.