From 5583d81e66664f2cb1686a0e23fbbbd554e5113e Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 25 Apr 2024 12:02:25 -0700 Subject: [PATCH 1/5] Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 74 ++++++++++++++++++++++++++-- torchao/quantization/README.md | 10 ++-- torchao/quantization/autoquant.py | 74 +++++++++++++++++++++++----- torchao/quantization/quant_api.py | 4 +- 4 files changed, 141 insertions(+), 21 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 0d11093fd1..60efc7a933 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1388,7 +1388,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): torch.nn.ReLU(), ).to(device).to(dtype) out = model(example_input) - torchao.autoquant(model, example_input) + torchao.autoquant(model) out2 = model(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) @@ -1415,10 +1415,78 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): ).to(device).to(dtype) example_input = torch.randn(m1, k, device=device, dtype=dtype) example_input2 = torch.randn(m2, k, device=device, dtype=dtype) - torchao.quantization.change_linears_to_autoquantizable(model) out=model(example_input) + torchao.autoquant(model) + model.forward_log_only(example_input) model(example_input2) - torchao.quantization.change_autoquantizable_to_quantized(model) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) + + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_compile(self, device, dtype, m1, m2, k, n): + if device != "cuda" and dtype != torch.bfloat16: + self.skipTest(f"autoquant currently does not support {device}") + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m1 == 1 or m2 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + example_input = torch.randn(m1, k, device=device, dtype=dtype) + example_input2 = torch.randn(m1, k, device=device, dtype=dtype) + out = model(example_input) + + torchao.autoquant(torch.compile(model)) + model.forward_log_only(example_input) + model(example_input2) + + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) + + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_compile(self, device, dtype, m1, m2, k, n): + if device != "cuda" and dtype != torch.bfloat16: + self.skipTest(f"autoquant currently does not support {device}") + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m1 == 1 or m2 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + example_input = torch.randn(m1, k, device=device, dtype=dtype) + example_input2 = torch.randn(m1, k, device=device, dtype=dtype) + out = model(example_input) + + torchao.autoquant(torch.compile(model)) + model.forward_log_only(example_input) + model(example_input2) + out2 = model(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index fc8dbf0137..9e4568a31d 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -28,11 +28,11 @@ torch._inductor.config.use_mixed_mm = True model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# perform autoquantization -torchao.autoquant(model, (input)) +# perform autoquantization and torch.compile +torchao.autoquant(torch.compile(model, mode='max-autotune')) -# compile the model to improve performance -model = torch.compile(model, mode='max-autotune') +# pass in an input which is used in order to pick fastest quantization operations +# and apply torch compilation. model(input) ``` @@ -167,6 +167,6 @@ model(input) ## Notes -1. APIs have been hardware tested on A100 and T4(colab) +1. APIs have been hardware tested on A100 and T4(colab) 2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance. 3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 9f2b59f20a..58792a7ca9 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -74,6 +74,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) update_cache(q_cls, shapes_and_dtype, res) + @torch.no_grad() def to_quantized(self, error_on_unseen, **kwargs): if error_on_unseen and self.logged_data == {}: raise RuntimeError("must run module normally to get shape, dtype info for autoquant") @@ -176,6 +177,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.detach.default: return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) +@torch.no_grad() def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead @@ -374,20 +376,68 @@ def change_autoquantizable_to_quantized(model, **kwargs): torch._dynamo.reset() @torch.no_grad() -def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs): +def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs): """ - Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape - across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer - and applies that type of quantization. + wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model. + AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original + model and then letting the torch.compile run/tracing occur. + + Example usage:: + + torchao.autoquant(torch.compile(model)) + model(*example_input) + """ - if filter_fn is None: - from torchao.quantization.quant_api import _is_linear - filter_fn = _is_linear + # the hook we will use to intercept the model forward and perform + # autoquantization + def autoquant_prehook(module, args, kwargs): + module.forward_log_only(*args, **kwargs) + change_autoquantizable_to_quantized( + module, + **aq_kwargs, + ) + module.clean_up_autoquant_hooks_and_attrs() + return args, kwargs - change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) - if not isinstance(example_input, (tuple, list)): - assert isinstance(example_input, torch.Tensor) + # perform initial swap from linear weights + # to AutoQuantizableLinearWeight + change_linears_to_autoquantizable( + model, + filter_fn=filter_fn, + qtensor_class_list=qtensor_class_list, + mode=mode, + **aq_kwargs + ) + + # access actual model of torch.compile wrapper if needed + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + real_model = model._orig_mod + else: + real_model = model + + # we need a consistent way to run the model which bypasses both + # A) the torch.compile tracing (so we need to run the inner model directly) + # B) the autoquant_prehook we're about to register (so we call forward directly) + model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs) + + # the autoquant_prehook intercepts the forward call and performs autoquantization + # and then deletes the hook. if model is a torch.compile wrapper, it then + # does the tracing/compile since the prehook is naturally followed by the normal. + # model run. + handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True) + + # note the torch.compile wrapper eval_frame moved the assignment of any assigned + # attributes to the inner model, so we have to call delattr on the inner model + def clean_up_autoquant_hooks_and_attrs(): + handle.remove() + delattr(real_model, "clean_up_autoquant_hooks_and_attrs") + delattr(real_model, "forward_log_only") + model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs + + # if example input was provided, check it and run it + if isinstance(example_input, torch.Tensor): example_input = [example_input] - model(*example_input) - change_autoquantizable_to_quantized(model, **kwargs) + if isinstance(example_input, (tuple, list)): + model(*example_input) + return model diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2dcd935912..a5a3a2b3db 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -34,6 +34,7 @@ Int4WeightOnlyGPTQQuantizer, Int4WeightOnlyQuantizer, ) +from .autoquant import autoquant __all__ = [ @@ -46,7 +47,8 @@ "Quantizer", "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", - "Int4WeightOnlyQuantizer" + "Int4WeightOnlyQuantizer", + "autoquant" ] if TORCH_VERSION_AFTER_2_3: From f08c3392f564c714bd969de9fe7e5f7f87618e55 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 6 May 2024 18:45:20 -0700 Subject: [PATCH 2/5] allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 61 ++++++++++++++++++++++++---- torchao/quantization/autoquant.py | 3 +- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 60efc7a933..881ddf7d11 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1449,11 +1449,11 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): example_input2 = torch.randn(m1, k, device=device, dtype=dtype) out = model(example_input) - torchao.autoquant(torch.compile(model)) - model.forward_log_only(example_input) - model(example_input2) + mod = torchao.autoquant(torch.compile(model)) + mod.forward_log_only(example_input) + mod(example_input2) - out2 = model(example_input) + out2 = mod(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) @@ -1480,17 +1480,60 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): torch.nn.ReLU(), ).to(device).to(dtype) example_input = torch.randn(m1, k, device=device, dtype=dtype) - example_input2 = torch.randn(m1, k, device=device, dtype=dtype) + example_input2 = torch.randn(m1+1, k, device=device, dtype=dtype) out = model(example_input) - torchao.autoquant(torch.compile(model)) - model.forward_log_only(example_input) - model(example_input2) + mod = torchao.autoquant(torch.compile(model)) + mod.forward_log_only(example_input) + mod(example_input2) - out2 = model(example_input) + out2 = mod(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): + if device != "cuda" and dtype != torch.bfloat16: + self.skipTest(f"autoquant currently does not support {device}") + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m1 == 1 or m2 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + + class NeedsKwargs(torch.nn.Module): + def __init__(self): + super().__init__() + self.rel = torch.nn.ReLU() + self.lin = torch.nn.Linear(k,n) + + def forward(self, x, y): + x = self.rel(x) + z = self.lin(x + y) + return z + + model = NeedsKwargs().to(device).to(dtype) + example_input = { + "x": torch.randn(m1, k, device=device, dtype=dtype), + "y": torch.randn(m1, k, device=device, dtype=dtype), + } + out = model(**example_input) + + mod = torchao.autoquant(torch.compile(model)) + mod.forward_log_only(**example_input) + mod(**example_input) + + out2 = mod(**example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) class TestAOTI(unittest.TestCase): @parameterized.expand( diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 58792a7ca9..ed24e70fe5 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -124,7 +124,7 @@ def count_shapes(self, do_print=True): torch._dynamo.reset() cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen if shape_count is not None and shape_count > 1: - print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") + print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") if best_time >= cur_time: best_time = cur_time best_cls = q_cls @@ -337,6 +337,7 @@ def change_linears_to_autoquantizable(model, **kwargs): """ from torchao.quantization.quant_api import _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) + _ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) kwargs["mode"] = kwargs.get("mode", ["relu", None]) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter From 8be645b5c4135a5f4d369c586590ec785278188b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 7 May 2024 20:06:34 -0700 Subject: [PATCH 3/5] update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 80f4a932d5..c07409d6da 100644 --- a/README.md +++ b/README.md @@ -44,11 +44,8 @@ torch._inductor.config.use_mixed_mm = True model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# perform autoquantization -torchao.autoquant(model, (input)) - -# compile the model to recover performance -model = torch.compile(model, mode='max-autotune') +# perform autoquantization and compilation +q_model = torchao.autoquant(torch.compile(model)) model(input) ``` From 668a02eea63e885427686b9837db19f5f9620c9a Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 7 May 2024 20:28:35 -0700 Subject: [PATCH 4/5] trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 64 ---------------------------- torchao/quantization/autoquant.py | 9 ++-- 2 files changed, 6 insertions(+), 67 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 881ddf7d11..e6da3e7340 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1393,36 +1393,6 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, - [ - (1, 1, 128, 128), - (1, 32, 128, 128), - (32, 32, 128, 128), - ])) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") - def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): - if device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"autoquant currently does not support {device}") - if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): - if dtype == torch.bfloat16: - self.skipTest(f"bfloat16 requires sm80+") - if m1 == 1 or m2 == 1: - self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(device).to(dtype) - example_input = torch.randn(m1, k, device=device, dtype=dtype) - example_input2 = torch.randn(m2, k, device=device, dtype=dtype) - out=model(example_input) - torchao.autoquant(model) - model.forward_log_only(example_input) - model(example_input2) - out2 = model(example_input) - sqnr = SQNR(out, out2) - self.assertTrue(sqnr >= 30) - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ (1, 1, 128, 128), @@ -1457,40 +1427,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, - [ - (1, 1, 128, 128), - (1, 32, 128, 128), - (32, 32, 128, 128), - ])) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") - def test_autoquant_compile(self, device, dtype, m1, m2, k, n): - if device != "cuda" and dtype != torch.bfloat16: - self.skipTest(f"autoquant currently does not support {device}") - if device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"autoquant currently does not support {device}") - if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): - if dtype == torch.bfloat16: - self.skipTest(f"bfloat16 requires sm80+") - if m1 == 1 or m2 == 1: - self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - model = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k,n), - torch.nn.ReLU(), - ).to(device).to(dtype) - example_input = torch.randn(m1, k, device=device, dtype=dtype) - example_input2 = torch.randn(m1+1, k, device=device, dtype=dtype) - out = model(example_input) - - mod = torchao.autoquant(torch.compile(model)) - mod.forward_log_only(example_input) - mod(example_input2) - - out2 = mod(example_input) - sqnr = SQNR(out, out2) - self.assertTrue(sqnr >= 30) - @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ (1, 1, 128, 128), diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index ed24e70fe5..fc38c04169 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -430,9 +430,12 @@ def autoquant_prehook(module, args, kwargs): # note the torch.compile wrapper eval_frame moved the assignment of any assigned # attributes to the inner model, so we have to call delattr on the inner model def clean_up_autoquant_hooks_and_attrs(): - handle.remove() - delattr(real_model, "clean_up_autoquant_hooks_and_attrs") - delattr(real_model, "forward_log_only") + try: + handle.remove() + delattr(real_model, "clean_up_autoquant_hooks_and_attrs") + delattr(real_model, "forward_log_only") + except: + pass model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs # if example input was provided, check it and run it From b6347eb27f6c770e2ed8e46adc8610796b6fe8c8 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 7 May 2024 20:34:21 -0700 Subject: [PATCH 5/5] correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 4 ++-- torchao/quantization/README.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c07409d6da..21a7195c27 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,8 @@ model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') # perform autoquantization and compilation -q_model = torchao.autoquant(torch.compile(model)) -model(input) +q_model = torchao.autoquant(torch.compile(model, mode='max-autotune')) +q_model(input) ``` ### Sparsity diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 9e4568a31d..622ec1cbcf 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -29,7 +29,7 @@ model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') # perform autoquantization and torch.compile -torchao.autoquant(torch.compile(model, mode='max-autotune')) +model = torchao.autoquant(torch.compile(model, mode='max-autotune')) # pass in an input which is used in order to pick fastest quantization operations # and apply torch compilation.