-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorrtExecutionProvider slower than CUDAExecutionProvider: Faster-rcnn [Performance] #17434
Comments
onnx-trt parser filters out |
Tensorrt supports nmsplugin and rioAlignPlugin. Probably we can replace onnx NonMaxSuppression and RoiAlign nodes with those two TRT plugins to see the latency? |
Typically the nodes from NonMaxSuppression and on are selecting the best bounding boxes. These are relatively cheap operations where it's more efficient to stay on CPU than go back to GPU. In the NNAPI EP we have the option to set an operator after which NNAPI is not used, and we do that for NonMaxSuppression. Maybe something similar would also work for TRT/CUDA for this type of model. |
So, even since , according to @skottmckay, these 3 operators are cheaper on CPU, can we try to keep them on GPU to avoid the overhead of moving the data btw CPU and GPU (in my case images of 13MB) ? Is that the goal/capability of the nmsplugin and roiAlignPlugin ? I am ready to try . Any example how to do that ? Shall I modify the Model code, the resulting ONNX or is that a mere declaration in onnxruntime tensorRT EP configuration ? What about the third operator nonZero ? I could not find a plugin any possibility to keep it on GPU to avoid memory transfers due to other subgraph split ? |
If I want to test the performance I get by not filtering out these operators by commenting out the lines https://github.com/onnx/onnx-tensorrt/blob/main/ModelImporter.cpp#L377, then where shall I modify the ModelImporter.cpp file before recompiling onnxruntime ? I am recompiling onnxruntime with nvidia gpu and tensorrt EP in my docker image with: |
what if I compile onnxruntime with --use_tensorrt_builtin_parser : will teh nodes be filtered out ? |
no change if I recompile onnxruntime with -use_tensorrt_builtin_parser |
Here are the steps to build OSS onnx-tensorrt parser with not filtering out those operators:
I tested the not filtering out onnx-tensorrt parser with faster rcnn form onnx model zoo and it can include those nodes for TRT, but it failed to build the TRT engine. I need to investigate further, but you can try your faster-rcnn model. Update: Checked with Nvidia, those nodes should only work with TRT api |
I think we can try the TRT plugins. please see the doc here. You need to modify the graph and replace |
thx a lot @chilo-ms : I will try to integrate the 2 plugins in my model to test performance improvement. Hoping that ONNRT TRT EP to use TRT API enqueueV3 asap. |
after discussing with NVIDIA on how to integrate plugins , we found out that NMS and nonzero ARE implemented in tensorRT . cf
for ROIALign , the only way is via the TRT plugin, but is there a way to have TRT EP call the native TRT instruction to avoid data transfer between CPU and GPU ? |
in 1.16.0 there is this new session option disable_cpu_ep_fallback. How can we set it ? and will this prevent falling back nonZero and NMS on CPU EP ? |
@datinje Please see here for how to use disable_cpu_ep_fallback. But in your case, you still need CUDA EP or CPU to run those three nodes if you don't want to use TRT plugins. |
As stated above by @chilo-ms , I tried in 1.16 to disable cpu ep fallback to try to avoid moving onnx operators to CPU if onnxrt parser estimated so , but effect is not to keep the operators on GPU with TRT as expected , it is preventing the program to continue . Then what is the purpose of this option ? The mains interest for me would be for ONNRT to keep the Operators on the GPU even if faster on CPU because overhead of transferring the data would be offsetting the benefit. 2023-10-31 11:27:23.916547026 [E:onnxruntime:, inference_session.cc:1678 Initialize] This session contains graph nodes that are assigned to the default CPU EP, but fallback to CPU EP has been explicitly disabled by the user. Traceback (most recent call last): File "/cad-engine/run-onnx-pytorch.model.py", line 299, in
File "/cad-engine/run-onnx-pytorch.model.py", line 60, in main
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in init
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 471, in _create_inference_session
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : This session contains graph nodes that are assigned to the default CPU EP, but fallback to CPU EP has been explicitly disabled by the user. |
something wrong in the copy paste above , sorry. forget about the "File ..." lines. |
One of the purposes of using this disable_cpu_ep_fallback is to make sure all the nodes are placed on GPUs before ORT starts to run inference. ORT may place some nodes on CPU for performance, but in some cases, it might not be the case. So this option works as a check. However, in your case, the error you got is expected because current ORT TRT doesn't support NonZero, NMS and RoiAlign, and cpu is the only ep to run these nodes. So, only if all the nodes in your model are supported by ORT TRT, you are suggested to use disable_cpu_ep_fallback. Otherwise, you will get this error. As I mentioned previously, you can try following steps:
then you can see that ORT TRT can run all the nodes of your FasterRCNN model except RoiAlign. |
It's possible that subgraph of the "If" control flow op has no nodes. TRT EP should consider this kind of subgraph is fully supported by TRT. The faster rcnn model mentioned in this issue #17434 is the case.
closing since I realized that with ORT 1.16.3 I succeeded runing my model with TRT and it gets faster than Cuda EP in TF32 |
…osoft#18449) It's possible that subgraph of the "If" control flow op has no nodes. TRT EP should consider this kind of subgraph is fully supported by TRT. The faster rcnn model mentioned in this issue microsoft#17434 is the case.
I tested again my model with latest onnxrt 1.17.1 and got same performance results between TRT EP and CUDA EP. |
even NonZero op seems implemented in TRT : could it be implemented in ONNXRT TRT EP ? |
@jcdatin We are testing TRT EP + TRT DDS output support (meaning including the NMS/NonZero/RoiAlign operators) to see the performance and then decide whether to enable this feature in the ORT official release. If you could help test it and provide the feedback, that will be great!. Thank you! |
Sure ! I will help. |
shall --use_oss_trt_parser REPLACE --use_tensorrt_builtin_parser or simply complete it |
That's weird, there is no change in terms of EPContext/Embedded engine feature between ORT 1.17 and ORT 1.18.
Yes, please use the latest TRT 10.3 which fixes issues when running Faster-RCNN. |
Rebuilt ORT 1.18.1 with TRT 10.3.0.26 (and cudnn 9.3.0.75) - with cuda 12.2 First observation (when not using embedded context of TRT)= first thing first , do you know why I am still having these nodes on CPU EP ? Shall I remove option --use_tensorrt_oss_parser ? (I am going to try). Second Observation when using TRT embedded context with config above, I am getting the same error as with TRT 10.0= This used to work with TRT 8.6/cudnn 8.9 and ORT build bb19722 So far these are too big regressions for me to use ORT 1.18.1 and beyond. |
other question : what is the ONNXRT optimisation level to use in conjunction with TRT EP (which has its own optimizations) ? |
tried to build ORT 1.18.1 w/ TRT 10.3 without --use_tensorrt_oss_parser and the following nodes are still on CPU EP
First thing first , can you investigate why DDS nodes not on TRT EP ? |
Let me reply the first question. i agree it's a bit complicated to enable DDS like i mentioned here. One thing to note is, when running the NMS node, TRT EP + TRT 10.3 is taking much longer time to finish (compared to TRT 8.6). We are still investigating the issue. And if possible, could you share your model with us to test? Or could you help test from your side? |
@jcdatin |
@chilo-ms : thx for your answer. I was in vacations. I will try DDS with your branch and investigate TRT EP with TRT10.3 for NMS node. I will also check that TRT embedded context is working once all nodes on TRT EP. |
Nvidia informed me that the NMS performance issue is a known problem that will be fixed in TRT 10.6 |
Yeah, the NMS regression in TRT 10 is a known issue and Nvidia has been investigated this issue. |
TRT 10.6 is out as well as ONNRT 1.20. But I see some restrictions :
what is the version of Cuda supported by ORT : I am using 12.2 and TRT 10.6 seems to require 12.6 |
Re: ORT 1.20 only supports TRt 10.4 and 10.5 (and I need TRT10.6) ORT 1.20 supports TRT 10.4 and 10.5 means our CIs tested against those TRT versions and the prebuilt package built against those versions. Re: Previous ORT and TRT 10.x could not dispatch aNMS nor nonZero ops to TRT tree, so I have to take TRT10.6 : will ORT still dispatch NMS/NonZero to TRT , I prefer TRT perf limitation than ORT displating these DNS ops still to CPU. Start from TRT 10.7 (which is not released yet), TRT will completely enable DDS ops, aka ORT will dispatch NMS/NonZero/RoiAlign to TRT by default. Before TRT 10.7, user needs to build ORT with open-source parser to achieve this. But please be aware of the known DDS perf issue from TRT 10.0 to 10.7 (Nvidia likely won't fix the issue in TRT 10.7) Re: what is the version of Cuda supported by ORT : I am using 12.2 and TRT 10.6 seems to require 12.6. |
Thank you @chilo-ms , I am building and testing 1.20.0 with trt 10.6 and oss trt parser . I will report the TRT 10.6 DNS operator performance degradation. When TRT10.7 is available I will test it with ORT 1.20.1 and its empty trt_op_types_to_exclude list and default trt parser. Keep posted |
I am getting a an ort 1.20.0 compilation error when building with TRT 10.6 (TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz) over CUDA 12.2 cf In file included from /onnxruntime/build/Linux/Release/_deps/onnx_tensorrt-src/onnxErrorRecorder.cpp:5: |
Please specified the correct onnx-tensorrt commit in the cmake/deps.txt of your ort repo. |
I am using nightly build ORT_VERSION=bb1972264b which is based on somewhat 1.18.1 (this is the only onnrt version I can use with full inference speed on my faster-rcnn model with TRT 8.6.1 |
my bad , forget the above . I am using TRT ORT 1.20.0 and TRT 10.6 as I said above .
|
ORT TRT works with either built-in tensorrt parser or oss tensorrt parser. Currently, the built-in tensorrt parser (from version 10.0 to 10.7) disable DDS. The line 40 in deps.txt (points to specific commit/branch in onnx-tensorrt, you can change it to use different TRT version) is only used when you manually build ORT with |
In your case, placing DDS ops on TRT, please don't use ORT patch release 1.20.1. Then you will be able to run ORT TRT + TRT 10 with DDS ops run by TRT. |
Sorry for the confusing and inconvenience, Nvidia has root caused the perf issue of running DDS ops and they are finding a better solution now. |
Actually got the info from Nvidia will NOT implement an official fix in TRT for this “regression” in DDS Ops, not even in TRT 10.8. Here is Nvidia recommendation for DDS Operators when used with ORT =
I have then no other way than to try. |
Unfortunately I am getting regression with TRT on my faster-rcnn model = terminate called after throwing an instance of 'Ort::Exception' what(): User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options. Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input. Following input(s) has no associated shape profiles provided: /model/my_model/rpn/Squeeze_2_output_0,/model/my_model/rpn/Squeeze_1_output_0,/model/my_model/rpn/Reshape_17_output_0,/model/my_model/rpn/NonZero_output_0 the same model used to work well with ORT 1.18.0 and TRT 8.6.1 (I used onnxruntime tool symbolic_shape_infer.py to infer dimensions) as python /usr/local/lib/python3.10/dist-packages/onnxruntime/tools/symbolic_shape_infer.py --input=faster-rcnn.onnx --output=faster-rcnn-inferred.onnx --auto_merge Update :
|
Update 0 : : then error is 2025-01-30 13:12:57.087270906 [V:onnxruntime:, execution_frame.cc:563 AllocateMLValueTensorSelfOwnBufferHelper] For ort_value with index: 72, block in memory pattern size is: 14450688 but the actual size is: 2809856, fall back to default allocation behavior update 4: : |
Update 5 : here is the the build command : CUDA_VERSION=12.4 to recap : in ORT 1.20.1 Can you help ? |
@jcdatin I saw from 1.20.1 that oss parser was using version 10.4-GA-ORT-DDS, and your |
Now with update on deps.txt , ORT 1.20.1 builds with trt oss parser. 2025-02-02 12:24:01.107927206 [V:onnxruntime:ivpSelectorInference, tensorrt_execution_provider.cc:2479 GetCapability] There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable 2025-02-02 12:24:03.521702534 [V:onnxruntime:, session_state.cc:1154 VerifyEachNodeIsAssignedToAnEp] Node(s) placed on [CPUExecutionProvider]. Number of nodes: 13 and not only I got NonZero , NMS and RoiAlign allocated to CPU but also ScatterND WHat shall I do : I would like to test NVIDA workaround above to get my performance back that I have in 1.18.0 nightly build ? here is the build command |
Re: WHat shall I do : I would like to test NVIDA workaround above to get my performance back that I have in 1.18.0 nightly build ? Here is the PR that workarounds the potential DDS node perf issue. Feel free to give it a try as well. |
using your PR branch (whatever parser built-in or oss I am using) , BUT I am crashing out of memory : Also I noticed that ScatterND op is having an error when parsed by ORT (I don't have the problem with trtexec on the same model): 2025-02-04 12:16:36.012390098 [E:onnxruntime:ivpSelectorInference, tensorrt_execution_provider.h:88 log] [2025-02-04 12:16:36 ERROR] In node 302 with name: /model/my_model/rpn/ScatterND and operator: ScatterND (importScatterND): UNSUPPORTED_NODE_ATTR: Assertion failed: !attrs.count("reduction"): Attribute reduction is not supported. then ScatterND is allocated to CPU node : 2025-02-04 12:16:36.371363603 [V:onnxruntime:, session_state.cc:1249 VerifyEachNodeIsAssignedToAnEp] Node(s) placed on [TensorrtExecutionProvider]. Number of nodes: 11 Did you make a change on ScatterND node ? ANy regression on his operator autotest ? |
Thanks for reporting this. |
with Built in parser or oss parser in th build ? |
Describe the issue
on my Faster-rcnn-rpn models doing detections of patterns, after considerable efforts to infer with TensorRT EP, (see #16886 as this shows that I have simplified the model and infered the shapes of the model nodes before submitting to TRT) , I found that TRT EP is about 30% slower than with Cuda EP in FP32 (and in TF32) - only with FP16 TRT EP -almost- catches up.
I only mentions here the second inference , not the warm up once (which is considerably slower which is normal)
After looking at the VERBOSE mode logs , found out that not all the nodes are running on TRT, one is still on CPU and 6 on Cuda EP. That cause many memory transfers between Host and GPU . I suppose this is the reason. So my question is why is ther still nodes on CPU and Cuda EPs ? Can this be fixed ?
Here are the logs :
2023-09-06 16:45:59.604024060 [V:onnxruntime:, session_state.cc:1149 VerifyEachNodeIsAssignedToAnEp] Node placements
2023-09-06 16:45:59.604038849 [V:onnxruntime:, session_state.cc:1155 VerifyEachNodeIsAssignedToAnEp] Node(s) placed on [TensorrtExecutionProvider]. Number of nodes: 11
2023-09-06 16:45:59.604042765 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] TRTKernel_graph_torch_jit_15684953649142847852_0 (TensorrtExecutionProvider_TRTKernel_graph_torch_jit_15684953649142847852_0_0)
2023-09-06 16:45:59.604046398 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] TRTKernel_graph_torch_jit_15684953649142847852_1 (TensorrtExecutionProvider_TRTKernel_graph_torch_jit_15684953649142847852_1_1)
2023-09-06 16:45:59.604049385 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] TRTKernel_graph_torch_jit_15684953649142847852_2 (TensorrtExecutionProvider_TRTKernel_graph_torch_jit_15684953649142847852_2_2)
2023-09-06 16:45:59.604052381 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] TRTKernel_graph_torch_jit_15684953649142847852_3 (TensorrtExecutionProvider_TRTKernel_graph_torch_jit_15684953649142847852_3_3)
2023-09-06 16:45:59.604055213 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] TRTKernel_graph_torch_jit_15684953649142847852_4 (TensorrtExecutionProvider_TRTKernel_graph_torch_jit_15684953649142847852_4_4)
2023-09-06 16:45:59.604057978 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] TRTKernel_graph_torch_jit_15684953649142847852_5 (TensorrtExecutionProvider_TRTKernel_graph_torch_jit_15684953649142847852_5_5)
2023-09-06 16:45:59.604060720 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] TRTKernel_graph_torch_jit_15684953649142847852_6 (TensorrtExecutionProvider_TRTKernel_graph_torch_jit_15684953649142847852_6_6)
2023-09-06 16:45:59.604063521 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] MemcpyFromHost (Memcpy)
2023-09-06 16:45:59.604066111 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] MemcpyToHost (Memcpy_token_422)
2023-09-06 16:45:59.604068754 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] MemcpyToHost (Memcpy_token_423)
2023-09-06 16:45:59.604078119 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] MemcpyToHost (Memcpy_token_424)
2023-09-06 16:45:59.604081367 [V:onnxruntime:, session_state.cc:1155 VerifyEachNodeIsAssignedToAnEp] Node(s) placed on [CPUExecutionProvider]. Number of nodes: 1
2023-09-06 16:45:59.604086459 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] RoiAlign (/model/roi_heads/box_pooler/level_poolers.0/RoiAlign)
2023-09-06 16:45:59.604093948 [V:onnxruntime:, session_state.cc:1155 VerifyEachNodeIsAssignedToAnEp] Node(s) placed on [CUDAExecutionProvider]. Number of nodes: 5
2023-09-06 16:45:59.604099017 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] NonZero (/model/proposal_generator/NonZero)
2023-09-06 16:45:59.604103942 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] NonMaxSuppression (NonMaxSuppression_497)
2023-09-06 16:45:59.604108777 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] NonZero (/model/roi_heads/NonZero)
2023-09-06 16:45:59.604113159 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] NonMaxSuppression (NonMaxSuppression_796)
2023-09-06 16:45:59.604117903 [V:onnxruntime:, session_state.cc:1157 VerifyEachNodeIsAssignedToAnEp] NonZero (/model/NonZero)
I got the same issue in both C++ and python runtime APIs
To reproduce
I can't share my model for IP , but I see similar issues with public Detectron Model zoo faster-rcnn-rpn (see #16886) how to run it - but with this one even more nodes are fallback on CPU and cuda , among which the nodes in bold above. So maybe fixes investigating this one will lead to same fixes.
Urgency
I have been blocked for several months on trying to run the model on TRT EP (see #16886 thx for the ort staff that helped me) now to find out that this may not be worth. Looks like I am not fat - only actually 3 operator/nodes to go on TRT EP, but times up I will need in a couple of month to freeze the model to certify the results with no second chance certifying with TRT FP16 or better INT8. I am expecting a x2 perf improvement in TRT fp16 and another x2 improvement in INT8 (accuracy is still excellent in FP16).
Platform
Linux
OS Version
SLES15 SP4
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.15.1+ (using main latest for a fix to build TRT EP)
ONNX Runtime API
Python
Architecture
X64
Execution Provider
TensorRT
Execution Provider Library Version
TensorRT 8.6.1
Model File
I can't but could use fatser-rcnn-rpn from detectron2 model zoo (see #16886)
Is this a quantized model?
No
The text was updated successfully, but these errors were encountered: