Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: aten.linear.default #875

Closed
agunapal opened this issue Sep 11, 2024 · 1 comment · Fixed by #885

Comments

@agunapal
Copy link

I am using nightly torchao and torch

torchao                  0.6.0.dev20240910+cu121
torch                    2.5.0.dev20240909+cu121

Getting the following error

  File "/home/agunapal/torch_ao/vit_ao.py", line 23, in <module>
    benchmark_model(model, 20, input)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/utils.py", line 74, in benchmark_model
    model(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
    x = self.encoder(x)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
    return self.ln(self.layers(self.dropout(input)))
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
    x, _ = self.self_attention(x, x, x, need_weights=False)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1368, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/functional.py", line 5984, in multi_head_attention_forward
    return handle_torch_function(
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/overrides.py", line 1739, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/utils.py", line 375, in _dispatch__torch_function__
    return func(*args, **kwargs)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torch/nn/functional.py", line 6285, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/utils.py", line 389, in _dispatch__torch_dispatch__
    raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")
NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: aten.linear.default

The issue happens when using with torch.inference_mode():

The issue is not seen with with torch.no_grad():

Here is the code for repro

import torch
import torchao

from torchvision.models import vit_b_16, ViT_B_16_Weights 
from torchao.utils import benchmark_model
from torchao.quantization import int8_weight_only, quantize_

torch.set_float32_matmul_precision('high')

dtype  = torch.float32
device = "cuda"
N = 1

model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
quantize_(model, int8_weight_only())
model = torch.compile(model, mode='max-autotune').to(device).to(dtype)
method = "int8 quantize followed by compile"
input = (torch.randn(N, 3, 224, 224).to(device).to(dtype),)

with torch.inference_mode():
    # warmup
    benchmark_model(model, 20, input)
    # benchmark
    result.append((method, N, benchmark_model(model, 100, input)))



for (method, N, elapsed_time) in result:
    print(f"batch_size={N} : elapsed time {elapsed_time:.3f} ms :  {method} ")

@jerryzh168
Copy link
Contributor

Thanks for reporting the issue, this should be straightforward to add I think, should be similar to

@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if not input_tensor.is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except QuantizedLinearNotImplementedError as e:
# fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None:
raise e
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

jerryzh168 added a commit to jerryzh168/ao that referenced this issue Sep 13, 2024
Summary:
Fixes: pytorch#875

Test Plan:
Test locally with tutorials/quantize_vit/run_vit_b_quant.py
with:
```
with torch.inference_mode():
    benchmark_model(model, 20, inputs)
```

but can't repro the issue in unit tests

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit that referenced this issue Sep 13, 2024
Summary:
Fixes: #875

Test Plan:
Test locally with tutorials/quantize_vit/run_vit_b_quant.py
with:
```
with torch.inference_mode():
    benchmark_model(model, 20, inputs)
```

but can't repro the issue in unit tests

Reviewers:

Subscribers:

Tasks:

Tags:
jainapurva pushed a commit that referenced this issue Sep 22, 2024
Summary:
Fixes: #875

Test Plan:
Test locally with tutorials/quantize_vit/run_vit_b_quant.py
with:
```
with torch.inference_mode():
    benchmark_model(model, 20, inputs)
```

but can't repro the issue in unit tests

Reviewers:

Subscribers:

Tasks:

Tags:
jainapurva pushed a commit that referenced this issue Sep 23, 2024
Summary:
Fixes: #875

Test Plan:
Test locally with tutorials/quantize_vit/run_vit_b_quant.py
with:
```
with torch.inference_mode():
    benchmark_model(model, 20, inputs)
```

but can't repro the issue in unit tests

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants