-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accuracy with TensorRT 10.7 and self-attention #4328
Comments
try to use polygraphy run /model/ViT-SO400M-14-SigLIP-384.onnx --trt --onnxrt \
--trt-outputs mark all \
--onnx-outputs mark all |
Yes, this fixes the accuracy issue (it prevents layer fusion), but performance is terrible (as expected). |
It is often useful to reduce it to the smallest possible subgraph that triggers the failure. That makes it easier to pinpoint the cause of the failure. dichotomy/split the onnx, more to see polygraphy debug . |
I understand. I am just reporting that the self-attention fusion in this case appears to have a bug which results in large errors even in fp32. |
Use the dichotomy just to find which layer has compute error. |
Seems like open_clip uses |
BTW, when you use trtexec, you can add flag |
@ohadravid Thank you so much! With the help of your reproducer I was able to fix the open clip problem by monkey-patching import math
import torch
def naive_sdpa(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
_, _B, _Nt, E = q.shape
q = q * math.sqrt(1.0/float(E))
attn = q @ k.transpose(-2,-1)
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
return (attn @ v)
torch.nn.functional.scaled_dot_product_attention = naive_sdpa
import open_clip
[...] |
Thanks for root causing @ohadravid! Closing this issue, let's track the proper scaled dot attention fix in #4333. |
I am trying to convert an open-clip (
pip install open_clip_torch==2.30.0
) model to TensorRT:This produces a valid onnx file, such that onnx-runtime execution matches with pytorch of the original model.
To convert the model to TensorRT, I do:
Note the magnitude of the relative error: (
p90=6.266
!!). This happens on my RTX A4500 Laptop GPU (driver 560) and on my V100 (but here I use tensorrt:24.06-py3, as TensorRT 10.7 does not support Volta anymore). The FP16/BF16 case is even worse.When I do the same conversion with
--fp8
, the error vanishes (note that the A4500 and V100 do not support FP8 kernels). I compared thetrtexec
verbose logs, and found that in the fp32 case, TensorRT recognizes the self-attention pattern, but in the FP8 case it does not:This observation got me thinking...when I replace the
/attn/Softmax
nodes with a custom TensorRT softmax plugin, the TensorRT optimizer can no longer do the self-attention optimization, and the result is that I get TensorRT engines with acceptable accuracy (even in fp16).My conclusion: Somehow, for this model, the myelin self-attenion fusion is buggy.
The text was updated successfully, but these errors were encountered: