From 7b412e3ba6bde4b366e555dafbcb5b97669211a1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 13 Sep 2024 14:26:53 -0700 Subject: [PATCH] Fix inference_mode Summary: Fixes: https://github.com/pytorch/ao/issues/875 Test Plan: Test locally with tutorials/quantize_vit/run_vit_b_quant.py with: ``` with torch.inference_mode(): benchmark_model(model, 20, inputs) ``` but can't repro the issue in unit tests Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/affine_quantized_tensor.py | 2 +- torchao/quantization/linear_activation_quantized_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 142a49a368..5ab937f6c3 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1483,7 +1483,7 @@ def _register_aqt_quantized_linear_dispatches(): _register_aqt_quantized_linear_dispatches() -@implements(torch.nn.functional.linear) +@implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( args[0], diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 7fe76b20fa..f2eae07152 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -91,7 +91,7 @@ def to(self, *args, **kwargs): implements = LinearActivationQuantizedTensor.implements -@implements(torch.nn.functional.linear) +@implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( args[0],