Skip to content

Commit

Permalink
fix: Bugfix in TRT Engine deserialization indexing (#1646)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Feb 3, 2023
1 parent 4fc2935 commit d638730
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,13 @@ TRTEngine::TRTEngine(
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
auto binding_name = _out_binding_names[pyt_idx];
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
std::string engine_binded_name = cuda_engine->getIOTensorName(inputs_size + pyt_idx);
TORCHTRT_CHECK(
(binding_name == engine_binded_name),
"Could not find a TensorRT engine binding for output named " << binding_name);
TORCHTRT_CHECK((trt_idx != -1), "Could not find a TensorRT engine binding for output named " << binding_name);
TORCHTRT_CHECK(
!(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
"Binding " << binding_name << " specified as output but found as input in TensorRT engine");
LOG_DEBUG("Output binding name: " << binding_name << "pyt return idx: " << inputs_size + pyt_idx << ")");
LOG_DEBUG(
"Output binding name: " << binding_name << " has TensorRT binding index: " << trt_idx
<< ", Torch binding index: " << inputs_size + pyt_idx);
out_binding_map[trt_idx] = pyt_idx;
out_binding_names[pyt_idx] = binding_name;
}
Expand Down

0 comments on commit d638730

Please sign in to comment.