Skip to content

Commit

Permalink
Simplify bindings to ICudaEngine
Browse files Browse the repository at this point in the history
Turns out that `bindgen` can actually parse the `NvInfer.h` file
pretty ok in most situations. This removes a lot of the need for
extra indirection and dynamic allocation that we were doing before
to wrap it in a C API.

Only outstanding issue is that it doesn't automatically generate
bindings to pure virtual functions. All member functions in the
`NvInfer.h` file are pure virtual so we just get the types but
not their methods. To fix this we still need a very thin wrapper
that accepts the type as the first parameter and then calls the
appropriate function directly on the type.

There is an issue tracking this
rust-lang/rust-bindgen#27
that hopefully will remove the need for this thin wrapper
completely.
  • Loading branch information
mstallmo committed Dec 16, 2020
1 parent 943c8fe commit 8fb3b29
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 142 deletions.
6 changes: 2 additions & 4 deletions tensorrt-sys/trt-sys/TRTBuilder/TRTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "TRTBuilder.h"
#include "../TRTNetworkDefinition/TRTNetworkDefinitionInternal.hpp"
#include "../TRTLogger/TRTLoggerInternal.hpp"
#include "../TRTCudaEngine/TRTCudaEngineInternal.hpp"
#include "../TRTLayer/TRTLayerInternal.hpp"
#include "../TRTUtils.hpp"

Expand Down Expand Up @@ -190,11 +189,10 @@ Network_t *create_network_v2(Builder_t *builder, uint32_t flags) {
}
#endif

Engine_t *build_cuda_engine(Builder_t *builder, Network_t *network) {
nvinfer1::ICudaEngine *build_cuda_engine(Builder_t *builder, Network_t *network) {
auto& b = builder->internal_builder;

auto engine = b->buildCudaEngine(network->getNetworkDefinition());
return create_engine(engine);
return b->buildCudaEngine(network->getNetworkDefinition());
}

void builder_reset(Builder_t* builder, Network_t* network) {
Expand Down
4 changes: 3 additions & 1 deletion tensorrt-sys/trt-sys/TRTBuilder/TRTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef LIBTRT_TRTBUILDER_H
#define LIBTRT_TRTBUILDER_H

#include <NvInfer.h>
#include "../TRTLogger/TRTLogger.h"
#include "../TRTNetworkDefinition/TRTNetworkDefinition.h"
#include "../TRTCudaEngine/TRTCudaEngine.h"
Expand Down Expand Up @@ -84,11 +85,12 @@ Network_t *create_network(Builder_t* builder);
Network_t *create_network_v2(Builder_t* builder, uint32_t flags);
#endif

Engine_t *build_cuda_engine(Builder_t* builder, Network_t* network);
void builder_reset(Builder_t* builder, Network_t* network);

#ifdef __cplusplus
};
#endif

nvinfer1::ICudaEngine *build_cuda_engine(Builder_t *builder, Network_t *network);

#endif //LIBTRT_TRTBUILDER_H
101 changes: 33 additions & 68 deletions tensorrt-sys/trt-sys/TRTCudaEngine/TRTCudaEngine.cpp
Original file line number Diff line number Diff line change
@@ -1,65 +1,36 @@
//
// Created by mason on 8/26/19.
//
#include <memory>
#include <cstring>
#include <cstdlib>

#include "../TRTHostMemory/TRTHostMemoryInternal.hpp"
#include "../TRTContext/TRTContextInternal.hpp"
#include "../TRTUtils.hpp"
#include "TRTCudaEngineInternal.hpp"
#include "TRTCudaEngine.h"

struct Engine {
using ICudaEnginePtr = std::unique_ptr<nvinfer1::ICudaEngine, TRTDeleter<nvinfer1::ICudaEngine>>;
ICudaEnginePtr internal_engine;

explicit Engine (nvinfer1::ICudaEngine* engine) {
internal_engine = ICudaEnginePtr(engine);
}
};

Engine_t* create_engine(nvinfer1::ICudaEngine* engine) {
return new Engine(engine);
}

void engine_destroy(Engine_t* engine) {
if (engine == nullptr)
return;

delete engine;
void engine_destroy(nvinfer1::ICudaEngine* engine) {
engine->destroy();
}

int engine_get_nb_bindings(Engine_t* engine) {
if (engine == nullptr)
return -1;

return engine->internal_engine->getNbBindings();
int engine_get_nb_bindings(nvinfer1::ICudaEngine* engine) {
return engine->getNbBindings();
}

int engine_get_binding_index(Engine_t* engine, const char* op_name) {
if (engine == nullptr)
return -1;

return engine->internal_engine->getBindingIndex(op_name);
int engine_get_binding_index(nvinfer1::ICudaEngine* engine, const char* op_name) {
return engine->getBindingIndex(op_name);
}

const char* engine_get_binding_name(Engine_t* engine, int binding_index) {
if (engine == nullptr)
return "";

return engine->internal_engine->getBindingName(binding_index);
const char* engine_get_binding_name(nvinfer1::ICudaEngine* engine, int binding_index) {
return engine->getBindingName(binding_index);
}

bool engine_binding_is_input(Engine_t *engine, int binding_index) {
return engine->internal_engine->bindingIsInput(binding_index);
bool engine_binding_is_input(nvinfer1::ICudaEngine *engine, int binding_index) {
return engine->bindingIsInput(binding_index);
}

Dims_t* engine_get_binding_dimensions(Engine_t *engine, int binding_index) {
if (engine == nullptr)
return nullptr;
Dims_t* engine_get_binding_dimensions(nvinfer1::ICudaEngine *engine, int binding_index) {

nvinfer1::Dims nvdims = engine->internal_engine->getBindingDimensions(binding_index);
nvinfer1::Dims nvdims = engine->getBindingDimensions(binding_index);
auto dims = static_cast<Dims_t *>(malloc(sizeof(Dims_t)));
dims->nbDims = nvdims.nbDims;
memcpy(dims->d, nvdims.d, nvinfer1::Dims::MAX_DIMS * sizeof(int));
Expand All @@ -68,50 +39,44 @@ Dims_t* engine_get_binding_dimensions(Engine_t *engine, int binding_index) {
return dims;
}

DataType_t engine_get_binding_data_type(Engine_t *engine, int binding_index) {
return static_cast<DataType_t>(engine->internal_engine->getBindingDataType(binding_index));
DataType_t engine_get_binding_data_type(nvinfer1::ICudaEngine *engine, int binding_index) {
return static_cast<DataType_t>(engine->getBindingDataType(binding_index));
}

int engine_get_max_batch_size(Engine_t *engine) {
return engine->internal_engine->getMaxBatchSize();
int engine_get_max_batch_size(nvinfer1::ICudaEngine *engine) {
return engine->getMaxBatchSize();
}

int engine_get_nb_layers(Engine_t *engine) {
return engine->internal_engine->getNbLayers();
int engine_get_nb_layers(nvinfer1::ICudaEngine *engine) {
return engine->getNbLayers();
}

size_t engine_get_workspace_size(Engine_t *engine) {
return engine->internal_engine->getWorkspaceSize();
size_t engine_get_workspace_size(nvinfer1::ICudaEngine *engine) {
return engine->getWorkspaceSize();
}

Context_t* engine_create_execution_context(Engine_t* engine) {
if (engine == nullptr)
return nullptr;

nvinfer1::IExecutionContext *context = engine->internal_engine->createExecutionContext();
Context_t* engine_create_execution_context(nvinfer1::ICudaEngine* engine) {
nvinfer1::IExecutionContext *context = engine->createExecutionContext();
return create_execution_context(context);
}

Context_t* engine_create_execution_context_without_device_memory(Engine_t *engine) {
nvinfer1::IExecutionContext *context = engine->internal_engine->createExecutionContextWithoutDeviceMemory();
Context_t* engine_create_execution_context_without_device_memory(nvinfer1::ICudaEngine *engine) {
nvinfer1::IExecutionContext *context = engine->createExecutionContextWithoutDeviceMemory();
return create_execution_context(context);
}

HostMemory_t* engine_serialize(Engine_t* engine) {
if (engine == nullptr)
return nullptr;

return create_host_memory(engine->internal_engine->serialize());
HostMemory_t* engine_serialize(nvinfer1::ICudaEngine* engine) {
return create_host_memory(engine->serialize());
}

TensorLocation_t engine_get_location(Engine_t *engine, int binding_index) {
return static_cast<TensorLocation_t>(engine->internal_engine->getLocation(binding_index));
TensorLocation_t engine_get_location(nvinfer1::ICudaEngine *engine, int binding_index) {
return static_cast<TensorLocation_t>(engine->getLocation(binding_index));
}

size_t engine_get_device_memory_size(Engine_t *engine) {
return engine->internal_engine->getDeviceMemorySize();
size_t engine_get_device_memory_size(nvinfer1::ICudaEngine *engine) {
return engine->getDeviceMemorySize();
}

bool engine_is_refittable(Engine_t *engine) {
return engine->internal_engine->isRefittable();
bool engine_is_refittable(nvinfer1::ICudaEngine *engine) {
return engine->isRefittable();
}
59 changes: 18 additions & 41 deletions tensorrt-sys/trt-sys/TRTCudaEngine/TRTCudaEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,28 @@
#ifndef LIBTRT_TRTCUDAENGINE_H
#define LIBTRT_TRTCUDAENGINE_H

#include <NvInfer.h>

#include "../TRTContext/TRTContext.h"
#include "../TRTHostMemory/TRTHostMemory.h"
#include "../TRTDims/TRTDims.h"
#include "../TRTEnums.h"

#ifdef __cplusplus
extern "C" {
#endif

struct Engine;
typedef struct Engine Engine_t;

void engine_destroy(Engine_t* engine);

Context_t* engine_create_execution_context(Engine_t* engine);
Context_t* engine_create_execution_context_without_device_memory(Engine_t *engine);

int engine_get_nb_bindings(Engine_t* engine);

int engine_get_binding_index(Engine_t *engine, const char* op_name);

const char* engine_get_binding_name(Engine_t* engine, int binding_index);

bool engine_binding_is_input(Engine_t *engine, int binding_index);

Dims_t* engine_get_binding_dimensions(Engine_t *engine, int binding_index);

DataType_t engine_get_binding_data_type(Engine_t *engine, int binding_index);

int engine_get_max_batch_size(Engine_t *engine);

int engine_get_nb_layers(Engine_t *engine);

size_t engine_get_workspace_size(Engine_t *engine);

HostMemory_t* engine_serialize(Engine_t* engine);

TensorLocation_t engine_get_location(Engine_t *engine, int binding_index);

size_t engine_get_device_memory_size(Engine_t *engine);

bool engine_is_refittable(Engine_t *engine);

#ifdef __cplusplus
};
#endif
void engine_destroy(nvinfer1::ICudaEngine* engine);
Context_t* engine_create_execution_context(nvinfer1::ICudaEngine* engine);
Context_t* engine_create_execution_context_without_device_memory(nvinfer1::ICudaEngine *engine);
int engine_get_nb_bindings(nvinfer1::ICudaEngine* engine);
int engine_get_binding_index(nvinfer1::ICudaEngine *engine, const char* op_name);
const char* engine_get_binding_name(nvinfer1::ICudaEngine* engine, int binding_index);
bool engine_binding_is_input(nvinfer1::ICudaEngine *engine, int binding_index);
Dims_t* engine_get_binding_dimensions(nvinfer1::ICudaEngine *engine, int binding_index);
DataType_t engine_get_binding_data_type(nvinfer1::ICudaEngine *engine, int binding_index);
int engine_get_max_batch_size(nvinfer1::ICudaEngine *engine);
int engine_get_nb_layers(nvinfer1::ICudaEngine *engine);
size_t engine_get_workspace_size(nvinfer1::ICudaEngine *engine);
HostMemory_t* engine_serialize(nvinfer1::ICudaEngine* engine);
TensorLocation_t engine_get_location(nvinfer1::ICudaEngine *engine, int binding_index);
size_t engine_get_device_memory_size(nvinfer1::ICudaEngine *engine);
bool engine_is_refittable(nvinfer1::ICudaEngine *engine);

#endif //LIBTRT_TRTCUDAENGINE_H
14 changes: 0 additions & 14 deletions tensorrt-sys/trt-sys/TRTCudaEngine/TRTCudaEngineInternal.hpp

This file was deleted.

7 changes: 2 additions & 5 deletions tensorrt-sys/trt-sys/TRTRuntime/TRTRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "TRTRuntime.h"

#include "../TRTLogger/TRTLoggerInternal.hpp"
#include "../TRTCudaEngine/TRTCudaEngineInternal.hpp"
#include "../TRTUtils.hpp"

struct Runtime {
Expand All @@ -28,10 +27,8 @@ void destroy_infer_runtime(Runtime_t* runtime) {
delete runtime;
}

Engine_t* deserialize_cuda_engine(Runtime_t* runtime, const void* blob, unsigned long long size) {
nvinfer1::ICudaEngine* engine = runtime->internal_runtime->deserializeCudaEngine(blob, size, nullptr);

return create_engine(engine);
nvinfer1::ICudaEngine *deserialize_cuda_engine(Runtime_t *runtime, const void *blob, unsigned long long size) {
return runtime->internal_runtime->deserializeCudaEngine(blob, size, nullptr);
}

int runtime_get_nb_dla_cores(Runtime_t *runtime) {
Expand Down
11 changes: 6 additions & 5 deletions tensorrt-sys/trt-sys/TRTRuntime/TRTRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef TENSRORT_SYS_TRTRUNTIME_H
#define TENSRORT_SYS_TRTRUNTIME_H

#include <NvInfer.h>
#include "../TRTLogger/TRTLogger.h"
#include "../TRTCudaEngine/TRTCudaEngine.h"

Expand All @@ -15,17 +16,17 @@ extern "C" {
struct Runtime;
typedef struct Runtime Runtime_t;

Runtime_t* create_infer_runtime(Logger_t* logger);
void destroy_infer_runtime(Runtime_t* runtime);
Runtime_t *create_infer_runtime(Logger_t *logger);
void destroy_infer_runtime(Runtime_t *runtime);

Engine_t* deserialize_cuda_engine(Runtime_t* runtime, const void* blob, unsigned long long size);
int runtime_get_nb_dla_cores(Runtime_t* runtime);
int runtime_get_nb_dla_cores(Runtime_t *runtime);
int runtime_get_dla_core(Runtime_t *runtime);
void runtime_set_dla_core(Runtime_t *runtime, int dla_core);


#ifdef __cplusplus
};
#endif

nvinfer1::ICudaEngine *deserialize_cuda_engine(Runtime_t *runtime, const void *blob, unsigned long long size);

#endif //TENSRORT_SYS_TRTRUNTIME_H
1 change: 1 addition & 0 deletions tensorrt-sys/trt-sys/tensorrt_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#ifndef TENSRORT_SYS_TENSORRT_API_H
#define TENSRORT_SYS_TENSORRT_API_H
#include <NvInfer.h>

#include "TRTEnums.h"
#include "TRTLogger/TRTLogger.h"
Expand Down
15 changes: 11 additions & 4 deletions tensorrt/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tensorrt_sys::{
engine_get_binding_name, engine_get_device_memory_size, engine_get_location,
engine_get_max_batch_size, engine_get_nb_bindings, engine_get_nb_layers,
engine_get_workspace_size, engine_is_refittable, engine_serialize, host_memory_get_data,
host_memory_get_size,
host_memory_get_size, nvinfer1_ICudaEngine,
};

#[repr(C)]
Expand All @@ -33,12 +33,17 @@ pub enum TensorLocation {

#[derive(Debug)]
pub struct Engine {
pub(crate) internal_engine: *mut tensorrt_sys::Engine_t,
pub(crate) internal_engine: *mut nvinfer1_ICudaEngine,
}

impl Engine {
pub fn get_nb_bindings(&self) -> i32 {
unsafe { engine_get_nb_bindings(self.internal_engine) }
let res = if !self.internal_engine.is_null() {
unsafe { engine_get_nb_bindings(self.internal_engine) }
} else {
0
};
res
}

pub fn get_binding_name(&self, binding_index: i32) -> Option<String> {
Expand Down Expand Up @@ -141,7 +146,9 @@ unsafe impl Send for Engine {}

impl Drop for Engine {
fn drop(&mut self) {
unsafe { engine_destroy(self.internal_engine) };
if !self.internal_engine.is_null() {
unsafe { engine_destroy(self.internal_engine) };
}
}
}

Expand Down

0 comments on commit 8fb3b29

Please sign in to comment.