From 68d77364de68cbe9132d4936a0738957be43d964 Mon Sep 17 00:00:00 2001 From: Facebook Community Bot Date: Tue, 28 May 2024 15:18:15 -0400 Subject: [PATCH] Re-sync with internal repository (#2271) Add torchao_backend.py to pt2 benchmark runner --- .../dynamo/dynamobench/torchao_backend.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 userbenchmark/dynamo/dynamobench/torchao_backend.py diff --git a/userbenchmark/dynamo/dynamobench/torchao_backend.py b/userbenchmark/dynamo/dynamobench/torchao_backend.py new file mode 100644 index 0000000000..29e7d55d76 --- /dev/null +++ b/userbenchmark/dynamo/dynamobench/torchao_backend.py @@ -0,0 +1,54 @@ +from typing import Any, Callable + +import torch + + +def setup_baseline(): + torch._dynamo.epilogue_fusion = False + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.cache_size_limit = 10000 + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.use_mixed_mm = True + + +def torchao_optimize_ctx(quantization: str): + import torchao + from torchao.quantization import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + ) + + def inner(model_iter_fn: Callable): + def _torchao_apply(module: torch.nn.Module, example_inputs: Any): + if getattr(module, "_quantized", None) is None: + if quantization == "int8dynamic": + change_linear_weights_to_int8_dqtensors(module) + elif quantization == "int8weightonly": + change_linear_weights_to_int8_woqtensors(module) + elif quantization == "int4weightonly": + change_linear_weights_to_int4_woqtensors(module) + elif quantization == "autoquant": + torchao.autoquant(module, error_on_unseen=False) + if isinstance(example_inputs, dict): + module(**example_inputs) + else: + module(*example_inputs) + from torchao.quantization.autoquant import AUTOQUANT_CACHE + + assert ( + len(AUTOQUANT_CACHE) > 0 + ), f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization" + elif quantization == "noquant": + pass + else: + raise AssertionError( + f"Unsupposed quantization mode {quantization}." + ) + setattr(module, "_quantized", True) # noqa: B010 + model_iter_fn(module, example_inputs) + + return _torchao_apply + + return inner