From d63873081c9c1eca268a628ba315f714a97d4fc8 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Fri, 3 Feb 2023 11:58:25 -0800 Subject: [PATCH] fix: Bugfix in TRT Engine deserialization indexing (#1646) --- core/runtime/TRTEngine.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index c523d12173..313209ba5a 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -128,14 +128,13 @@ TRTEngine::TRTEngine( for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) { auto binding_name = _out_binding_names[pyt_idx]; auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str()); - std::string engine_binded_name = cuda_engine->getIOTensorName(inputs_size + pyt_idx); - TORCHTRT_CHECK( - (binding_name == engine_binded_name), - "Could not find a TensorRT engine binding for output named " << binding_name); + TORCHTRT_CHECK((trt_idx != -1), "Could not find a TensorRT engine binding for output named " << binding_name); TORCHTRT_CHECK( !(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT), "Binding " << binding_name << " specified as output but found as input in TensorRT engine"); - LOG_DEBUG("Output binding name: " << binding_name << "pyt return idx: " << inputs_size + pyt_idx << ")"); + LOG_DEBUG( + "Output binding name: " << binding_name << " has TensorRT binding index: " << trt_idx + << ", Torch binding index: " << inputs_size + pyt_idx); out_binding_map[trt_idx] = pyt_idx; out_binding_names[pyt_idx] = binding_name; }