diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index c523d12173..313209ba5a 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -128,14 +128,13 @@ TRTEngine::TRTEngine( for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) { auto binding_name = _out_binding_names[pyt_idx]; auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str()); - std::string engine_binded_name = cuda_engine->getIOTensorName(inputs_size + pyt_idx); - TORCHTRT_CHECK( - (binding_name == engine_binded_name), - "Could not find a TensorRT engine binding for output named " << binding_name); + TORCHTRT_CHECK((trt_idx != -1), "Could not find a TensorRT engine binding for output named " << binding_name); TORCHTRT_CHECK( !(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT), "Binding " << binding_name << " specified as output but found as input in TensorRT engine"); - LOG_DEBUG("Output binding name: " << binding_name << "pyt return idx: " << inputs_size + pyt_idx << ")"); + LOG_DEBUG( + "Output binding name: " << binding_name << " has TensorRT binding index: " << trt_idx + << ", Torch binding index: " << inputs_size + pyt_idx); out_binding_map[trt_idx] = pyt_idx; out_binding_names[pyt_idx] = binding_name; }