From 32b6f0a27e11d7423dcbbab74ca129843e25236c Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Wed, 8 May 2024 16:13:38 -0400 Subject: [PATCH] Composing autoquant with compile (#175) * Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * Fused DoRA kernels (#216) * add dora kernels * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Unified AffineQuantizedTensor subclass (#214) Summary: Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine and dequantize_affine ops) only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Mark Saroufim * add expecttest to requirements.txt (#225) * add expecttest to requirements.txt * update * Install dev-requirements.txt in doc build (#224) Install dev-requirements.txt --------- Co-authored-by: Mark Saroufim * Fix an error in subclass impl (#226) Summary: Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Some follow up fixes for quant primitives (#220) Summary: att Test Plan: python test/quantization/test_quant_primitives.py -k test_raises Reviewers: Subscribers: Tasks: Tags: * Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: jeromeku Co-authored-by: Jerry Zhang Co-authored-by: Mark Saroufim Co-authored-by: Svetlana Karslioglu --- README.md | 9 ++-- test/integration/test_integration.py | 63 +++++++++++++++++++--- torchao/quantization/README.md | 10 ++-- torchao/quantization/autoquant.py | 80 +++++++++++++++++++++++----- torchao/quantization/quant_api.py | 4 +- 5 files changed, 133 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 80f4a932d5..21a7195c27 100644 --- a/README.md +++ b/README.md @@ -44,12 +44,9 @@ torch._inductor.config.use_mixed_mm = True model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# perform autoquantization -torchao.autoquant(model, (input)) - -# compile the model to recover performance -model = torch.compile(model, mode='max-autotune') -model(input) +# perform autoquantization and compilation +q_model = torchao.autoquant(torch.compile(model, mode='max-autotune')) +q_model(input) ``` ### Sparsity diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 0d11093fd1..e6da3e7340 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1388,7 +1388,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): torch.nn.ReLU(), ).to(device).to(dtype) out = model(example_input) - torchao.autoquant(model, example_input) + torchao.autoquant(model) out2 = model(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) @@ -1400,7 +1400,9 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): (32, 32, 128, 128), ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") - def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): + def test_autoquant_compile(self, device, dtype, m1, m2, k, n): + if device != "cuda" and dtype != torch.bfloat16: + self.skipTest(f"autoquant currently does not support {device}") if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1414,15 +1416,60 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): torch.nn.ReLU(), ).to(device).to(dtype) example_input = torch.randn(m1, k, device=device, dtype=dtype) - example_input2 = torch.randn(m2, k, device=device, dtype=dtype) - torchao.quantization.change_linears_to_autoquantizable(model) - out=model(example_input) - model(example_input2) - torchao.quantization.change_autoquantizable_to_quantized(model) - out2 = model(example_input) + example_input2 = torch.randn(m1, k, device=device, dtype=dtype) + out = model(example_input) + + mod = torchao.autoquant(torch.compile(model)) + mod.forward_log_only(example_input) + mod(example_input2) + + out2 = mod(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): + if device != "cuda" and dtype != torch.bfloat16: + self.skipTest(f"autoquant currently does not support {device}") + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m1 == 1 or m2 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + + class NeedsKwargs(torch.nn.Module): + def __init__(self): + super().__init__() + self.rel = torch.nn.ReLU() + self.lin = torch.nn.Linear(k,n) + + def forward(self, x, y): + x = self.rel(x) + z = self.lin(x + y) + return z + + model = NeedsKwargs().to(device).to(dtype) + example_input = { + "x": torch.randn(m1, k, device=device, dtype=dtype), + "y": torch.randn(m1, k, device=device, dtype=dtype), + } + out = model(**example_input) + + mod = torchao.autoquant(torch.compile(model)) + mod.forward_log_only(**example_input) + mod(**example_input) + + out2 = mod(**example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) class TestAOTI(unittest.TestCase): @parameterized.expand( diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index fc8dbf0137..622ec1cbcf 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -28,11 +28,11 @@ torch._inductor.config.use_mixed_mm = True model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# perform autoquantization -torchao.autoquant(model, (input)) +# perform autoquantization and torch.compile +model = torchao.autoquant(torch.compile(model, mode='max-autotune')) -# compile the model to improve performance -model = torch.compile(model, mode='max-autotune') +# pass in an input which is used in order to pick fastest quantization operations +# and apply torch compilation. model(input) ``` @@ -167,6 +167,6 @@ model(input) ## Notes -1. APIs have been hardware tested on A100 and T4(colab) +1. APIs have been hardware tested on A100 and T4(colab) 2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance. 3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 9f2b59f20a..fc38c04169 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -74,6 +74,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) update_cache(q_cls, shapes_and_dtype, res) + @torch.no_grad() def to_quantized(self, error_on_unseen, **kwargs): if error_on_unseen and self.logged_data == {}: raise RuntimeError("must run module normally to get shape, dtype info for autoquant") @@ -123,7 +124,7 @@ def count_shapes(self, do_print=True): torch._dynamo.reset() cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen if shape_count is not None and shape_count > 1: - print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") + print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") if best_time >= cur_time: best_time = cur_time best_cls = q_cls @@ -176,6 +177,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.detach.default: return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) +@torch.no_grad() def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead @@ -335,6 +337,7 @@ def change_linears_to_autoquantizable(model, **kwargs): """ from torchao.quantization.quant_api import _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) + _ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) kwargs["mode"] = kwargs.get("mode", ["relu", None]) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter @@ -374,20 +377,71 @@ def change_autoquantizable_to_quantized(model, **kwargs): torch._dynamo.reset() @torch.no_grad() -def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs): +def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs): """ - Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape - across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer - and applies that type of quantization. + wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model. + AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original + model and then letting the torch.compile run/tracing occur. + + Example usage:: + + torchao.autoquant(torch.compile(model)) + model(*example_input) + """ - if filter_fn is None: - from torchao.quantization.quant_api import _is_linear - filter_fn = _is_linear + # the hook we will use to intercept the model forward and perform + # autoquantization + def autoquant_prehook(module, args, kwargs): + module.forward_log_only(*args, **kwargs) + change_autoquantizable_to_quantized( + module, + **aq_kwargs, + ) + module.clean_up_autoquant_hooks_and_attrs() + return args, kwargs + + # perform initial swap from linear weights + # to AutoQuantizableLinearWeight + change_linears_to_autoquantizable( + model, + filter_fn=filter_fn, + qtensor_class_list=qtensor_class_list, + mode=mode, + **aq_kwargs + ) + + # access actual model of torch.compile wrapper if needed + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + real_model = model._orig_mod + else: + real_model = model + + # we need a consistent way to run the model which bypasses both + # A) the torch.compile tracing (so we need to run the inner model directly) + # B) the autoquant_prehook we're about to register (so we call forward directly) + model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs) + + # the autoquant_prehook intercepts the forward call and performs autoquantization + # and then deletes the hook. if model is a torch.compile wrapper, it then + # does the tracing/compile since the prehook is naturally followed by the normal. + # model run. + handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True) + + # note the torch.compile wrapper eval_frame moved the assignment of any assigned + # attributes to the inner model, so we have to call delattr on the inner model + def clean_up_autoquant_hooks_and_attrs(): + try: + handle.remove() + delattr(real_model, "clean_up_autoquant_hooks_and_attrs") + delattr(real_model, "forward_log_only") + except: + pass + model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs - change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) - if not isinstance(example_input, (tuple, list)): - assert isinstance(example_input, torch.Tensor) + # if example input was provided, check it and run it + if isinstance(example_input, torch.Tensor): example_input = [example_input] - model(*example_input) - change_autoquantizable_to_quantized(model, **kwargs) + if isinstance(example_input, (tuple, list)): + model(*example_input) + return model diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2dcd935912..a5a3a2b3db 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -34,6 +34,7 @@ Int4WeightOnlyGPTQQuantizer, Int4WeightOnlyQuantizer, ) +from .autoquant import autoquant __all__ = [ @@ -46,7 +47,8 @@ "Quantizer", "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", - "Int4WeightOnlyQuantizer" + "Int4WeightOnlyQuantizer", + "autoquant" ] if TORCH_VERSION_AFTER_2_3: