From 7b03ef351b3974cf5b4ebdda5447ff005d3be757 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:00:37 -0400 Subject: [PATCH] adding default inductor config settings (#423) * adding default inductor config settings Summary: making autoquant and quantize apis call a new recommended_inductor_config_setter util to set recommended apis also update groupsize -> groupsize in generate.py Test Plan: sh benchmarks.sh comparison of different config combinations for matmul precision, mixed_mm and coordinate_descent tok/s= 9.14, mem/s= 60.55 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=147.02, mem/s= 973.53 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.23, mem/s= 61.11 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=139.59, mem/s= 924.33 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.10, mem/s= 60.26 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=146.98, mem/s= 973.23 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.28, mem/s= 61.48 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=146.90, mem/s= 972.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.08, mem/s= 60.09 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=137.58, mem/s= 911.00 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.19, mem/s= 60.87 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=166.02, mem/s=1099.30 GB/s, peak_mem= 8.97 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, Reviewers: Subscribers: Tasks: Tags: * fixing tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix weight only failures Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing new broken test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing autoquant test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * testing if inductor config is the issue Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * are inductor configs somehow being set? Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * when is coordinate descent tuning beinng enabled? Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * reset inductor config for tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * more test fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * adding warning Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * handling of errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * option to supress autoquant errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 37 ++++++++++++++++++++-------- torchao/_models/llama/eval.py | 6 ++--- torchao/_models/llama/generate.py | 16 +++++------- torchao/quantization/README.md | 10 +++----- torchao/quantization/autoquant.py | 35 ++++++++++++++++++++++---- torchao/quantization/quant_api.py | 6 ++++- torchao/quantization/utils.py | 18 ++++++++++++++ 7 files changed, 92 insertions(+), 36 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b4fbcb152a..4d5a2c511c 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -98,21 +98,21 @@ def _int8wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_weight_only()) + quantize(mod, int8_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) def _int8da_int8w_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_dynamic_activation_int8_weight()) + quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4_weight_only()) + quantize(mod, int4_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod) @@ -124,6 +124,13 @@ def _int4wo_api(mod): _int4wo_api, ] +def undo_recommended_configs(): + torch._inductor.config.coordinate_descent_tuning = False + torch._inductor.config.coordinate_descent_check_all_directions = False + torch._inductor.config.force_fuse_int_mm_with_mul = False + torch._inductor.config.fx_graph_cache = False + torch._inductor.config.triton.unique_kernel_names = False + torch.set_float32_matmul_precision("highest") def combine_parameters(a, b): new_tuples = [] @@ -689,6 +696,7 @@ def test_int8_dynamic_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) def test_int8_weight_only_quant_subclass(self, device, dtype): + undo_recommended_configs() self._test_lin_weight_subclass_impl( Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype ) @@ -794,6 +802,7 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): + undo_recommended_configs() self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype ) @@ -879,6 +888,7 @@ def test_weight_only_quant(self): @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_weight_only_quant_force_mixed_mm(self, device, dtype): + undo_recommended_configs() if device != "cuda": self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): @@ -907,6 +917,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_weight_only_quant_use_mixed_mm(self, device, dtype): + undo_recommended_configs() if device != "cuda": self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): @@ -1004,6 +1015,7 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): + undo_recommended_configs() self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -1153,6 +1165,7 @@ class TestAutoQuant(unittest.TestCase): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): + undo_recommended_configs() print("(m, k, n): ", (m, k, n)) if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") @@ -1173,7 +1186,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): torch.nn.ReLU(), ).to(device).to(dtype) out = model(example_input) - torchao.autoquant(model) + torchao.autoquant(model, set_inductor_config=False) out2 = model(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) @@ -1186,6 +1199,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1202,7 +1216,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) - mod = torchao.autoquant(torch.compile(model), manual=True) + mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) mod(example_input) mod(example_input2) mod.finalize_autoquant() @@ -1214,6 +1228,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_manual(self, device, dtype): + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1229,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype): example_input2 = torch.randn(m2, k, device=device, dtype=dtype) out = model(example_input) - mod = torchao.autoquant(torch.compile(model), manual=True) + mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False) mod(example_input) mod(example_input2) mod.finalize_autoquant() @@ -1237,7 +1252,7 @@ def test_autoquant_manual(self, device, dtype): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) - mod2 = torchao.autoquant(model, manual=True) + mod2 = torchao.autoquant(model, manual=True, set_inductor_config=False) mod2(example_input) mod2(example_input2) mod2.finalize_autoquant() @@ -1254,6 +1269,7 @@ def test_autoquant_manual(self, device, dtype): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1280,7 +1296,7 @@ def forward(self, x, y): } out = model(**example_input) - mod = torchao.autoquant(torch.compile(model)) + mod = torchao.autoquant(torch.compile(model), set_inductor_config=False) mod(**example_input) out2 = mod(**example_input) @@ -1293,6 +1309,7 @@ def forward(self, x, y): ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): + undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1316,7 +1333,7 @@ def forward(self, x): x_in = torch.randn(m, k, device=device, dtype=dtype) model = DoubleAccess().to(device).to(dtype) model(x_in) - torchao.autoquant(model) + torchao.autoquant(model, set_inductor_config=False) assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) model(x_in) @@ -1443,7 +1460,7 @@ def test_get_model_size_autoquant(self, device, dtype): qtensor_class_list = ( AQWeightOnlyQuantizedLinearWeight2, ) - mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list) + mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False) mod(example_input) size2 = torchao.utils.get_model_size_in_bytes(mod) self.assertTrue(size2 < size) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 36e5085018..73deafffec 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -23,9 +23,6 @@ from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer from torchao._models.llama.model import prepare_inputs_for_model -torch._inductor.config.fx_graph_cache = True -torch._inductor.config.force_fuse_int_mm_with_mul = True - def run_evaluation( checkpoint_path: Path, tasks: List[str], @@ -41,6 +38,9 @@ def run_evaluation( pad_calibration_inputs: Optional[bool] = False, ): """Runs the evaluation of a model using LM Eval.""" + + torchao.quantization.utils.recommended_inductor_config_setter() + assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 34e7ca82b2..8142f80bb8 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -22,13 +22,6 @@ def device_sync(device): else: print(f"device={device} is not yet suppported") - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future -torch._inductor.config.force_fuse_int_mm_with_mul = True -# torch._inductor.config.use_mixed_mm = True - default_device = 'cuda' if torch.cuda.is_available() else 'cpu' # support running without installing as a package @@ -163,6 +156,9 @@ def main( ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer. """ + + torchao.quantization.utils.recommended_inductor_config_setter() + assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) @@ -203,7 +199,7 @@ def main( if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4_weight_only(groupsize=groupsize)) + quantize(model, int4_weight_only(group_size=groupsize)) if "autoquant" == quantization: model = autoquant(model, manual=True) @@ -339,8 +335,8 @@ def callback(x): parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument("--quantization", type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') parser.add_argument('--profile', type=Path, default=None, help='Profile path.') diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index a6e95d0bed..086065b8da 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -30,10 +30,6 @@ of the activations that the different linear layers see, it then benchmarks thes import torch import torchao -# inductor settings which improve torch.compile performance for quantized modules -torch._inductor.config.force_fuse_int_mm_with_mul = True -torch._inductor.config.use_mixed_mm = True - # Plug in your model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') @@ -107,9 +103,6 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') group_size = 32 m = quantize(m, int4_weight_only(group_size=group_size)) -torch._inductor.config.force_fuse_int_mm_with_mul = True -torch._inductor.config.use_mixed_mm = True - # temporary workaround for tensor subclass + torch.compile from torchao.quantization.utils import unwrap_tensor_subclass m = unwrap_tensor_subclass(m) @@ -163,6 +156,9 @@ m = torch.export.export(m_unwrapped, example_inputs).module() torch._export.aot_compile(m_unwrapped, example_inputs) ``` +### Automatic Inductor Configuration +The `quantize` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. + ### Other Available Quantization Techniques #### A8W8 Dynamic Quantization diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 18a58cd17f..83d7837d3e 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -1,4 +1,5 @@ import torch +import torchao from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -90,7 +91,11 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): with torch.no_grad(): act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device) - res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + try: + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + except Exception as e: + print(f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}") + res = torch.inf update_cache(q_cls, shapes_and_dtype, res) @torch.no_grad() @@ -407,16 +412,21 @@ def _change_linears_to_autoquantizable(model, **kwargs): filter_fn if filter_fn is not None else _is_linear, ) -def _change_autoquantizable_to_quantized(model, **kwargs): +def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, **kwargs): """ Converts AutoQuantizableLinearWeight tensor subclasses to various quantized/non-quantized tensor subclasses depending on benchmark results. Expectation is that these modules are torch.compiled afterwards. """ - hold = torch._dynamo.config.automatic_dynamic_shapes + hold_automatic_dynamic_shapes = torch._dynamo.config.automatic_dynamic_shapes torch._dynamo.config.automatic_dynamic_shapes = False + if supress_autoquant_errors: + hold_supress_errors = torch._dynamo.config.suppress_errors + torch._dynamo.config.suppress_errors = True + import logging + torch._logging.set_logs(inductor=logging.CRITICAL, dynamo=logging.CRITICAL) filter_fn = kwargs.pop( "filter_fn", lambda mod, *args: @@ -432,7 +442,13 @@ def _change_autoquantizable_to_quantized(model, **kwargs): ), filter_fn, ) - torch._dynamo.config.automatic_dynamic_shapes = hold + # undo dynamic shape change + torch._dynamo.config.automatic_dynamic_shapes = hold_automatic_dynamic_shapes + + # undo error supression + if supress_autoquant_errors: + torch._dynamo.config.suppress_errors = hold_supress_errors + torch._logging.set_logs() torch._dynamo.reset() # TODO: example_input seems weird to include in the API @@ -443,8 +459,11 @@ def autoquant( model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, - filter_fn=None, mode=["interpolate", .85], + filter_fn=None, + mode=["interpolate", .85], manual=False, + set_inductor_config=True, + supress_autoquant_errors=True, **aq_kwargs ): """ @@ -477,6 +496,8 @@ def autoquant( and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. + set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) + supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True) **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -493,6 +514,9 @@ def autoquant( model(*example_input2) model.finalize_autoquant() """ + if set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + # perform initial swap from linear weights # to AutoQuantizableLinearWeight @@ -539,6 +563,7 @@ def autoquant_prehook(module, args, kwargs): def finalize_autoquant(): _change_autoquantizable_to_quantized( real_model, + supress_autoquant_errors, **aq_kwargs, ) if hasattr(real_model, "old_forward"): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6f7f549704..33821f1d82 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -16,6 +16,7 @@ """ import torch +import torchao import torch.nn as nn import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional @@ -258,7 +259,7 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: +def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` Args: @@ -266,6 +267,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance) filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module + set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) Example:: @@ -306,6 +308,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: m = quantize(m, apply_weight_quant, filter_fn) """ + if set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() if isinstance(apply_tensor_subclass, str): if apply_tensor_subclass not in _APPLY_TS_TABLE: raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}") diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3e3943c93c..d158862147 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -36,6 +36,7 @@ "groupwise_affine_dequantize_tensor", "per_token_dynamic_quant", "get_group_qparams_symmetric", + "recommended_inductor_config_setter" ] try: @@ -456,3 +457,20 @@ def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor: input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype ) return input.to(orig_dtype) + +def recommended_inductor_config_setter(): + """ + Set inductor config to use the following optimizations which have been showed to improve performance for quantized models: + coordinate_descent_tuning = True + coordinate_descent_check_all_directions = True + force_fuse_int_mm_with_mul = True + fx_graph_cache = True + triton.unique_kernel_names = True + torch.set_float32_matmul_precision("high") + """ + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.coordinate_descent_check_all_directions = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.fx_graph_cache = True + torch._inductor.config.triton.unique_kernel_names = True + torch.set_float32_matmul_precision("high")