Skip to content

Commit

Permalink
fix(//core/runtime): Support more delimiter variants
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <narens@nvidia.com>
Signed-off-by: Naren Dasan <naren@narendasan.com>
  • Loading branch information
narendasan committed Apr 27, 2022
1 parent 67e320c commit 819c911
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
40 changes: 39 additions & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,20 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");

exec_ctx = make_trt(cuda_engine->createExecutionContext());
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");

uint64_t inputs = 0;
uint64_t outputs = 0;

for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
std::string bind_name = cuda_engine->getBindingName(x);
std::string idx_s = bind_name.substr(bind_name.find("_") + 1);
auto delim = bind_name.find(".");
if (delim == std::string::npos) {
delim = bind_name.find("_");
TORCHTRT_CHECK(delim != std::string::npos, "Unable to determine binding index for input " << bind_name << "\nEnsure module was compile with Torch-TensorRT.ts");
}

std::string idx_s = bind_name.substr(delim + 1);
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));

if (cuda_engine->bindingIsInput(x)) {
Expand All @@ -71,6 +78,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
}
}
num_io = std::make_pair(inputs, outputs);

LOG_DEBUG(*this);
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
Expand All @@ -82,6 +91,34 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
return (*this);
}

std::string TRTEngine::to_str() const {
std::stringstream ss;
ss << "Torch-TensorRT TensorRT Engine:" << std::endl;
ss << " Name: " << name << std::endl;
ss << " Inputs: [" << std::endl;
for (uint64_t i = 0; i < num_io.first; i++) {
ss << " id: " << i << std::endl;
ss << " shape: " << exec_ctx->getBindingDimensions(i) << std::endl;
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(i)) << std::endl;
}
ss << " ]" << std::endl;
ss << " Outputs: [" << std::endl;
for (uint64_t o = 0; o < num_io.second; o++) {
ss << " id: " << o << std::endl;
ss << " shape: " << exec_ctx->getBindingDimensions(o) << std::endl;
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(o)) << std::endl;
}
ss << " ]" << std::endl;
ss << " Device: " << device_info << std::endl;

return ss.str();
}

std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) {
os << engine.to_str();
return os;
}

// TODO: Implement a call method
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
// auto input_vec = inputs.vec();
Expand All @@ -96,6 +133,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def(torch::init<std::vector<std::string>>())
// TODO: .def("__call__", &TRTEngine::Run)
// TODO: .def("run", &TRTEngine::Run)
.def("__str__", &TRTEngine::to_str)
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
// Serialize TensorRT engine
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct TRTEngine : torch::CustomClassHolder {
TRTEngine(std::vector<std::string> serialized_info);
TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device);
TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
};
Expand Down

0 comments on commit 819c911

Please sign in to comment.