Skip to content

Commit

Permalink
add yolox implement (#418)
Browse files Browse the repository at this point in the history
* test yolox code,TODO clean code

* add tensorrt yolox test code

* update var name to trt engine model path

* update transform func,to avoid pointers not being freed

* update transform func,to avoid pointers not being freed

* update infer code to use new transform func
  • Loading branch information
wangzijian1010 authored Jul 24, 2024
1 parent 6f60bb1 commit 42eea2b
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 27 deletions.
27 changes: 27 additions & 0 deletions examples/lite/cv/test_lite_yolox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,40 @@ static void test_tnn()
#endif
}

static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT

std::string engine_path = "../../..//examples/hub/trt/yolox_s_fp32.engine";
std::string test_img_path = "../../..//examples/lite/resources/test_lite_yolox_2.jpg";
std::string save_img_path = "../../..//examples/logs/test_lite_yolox_trt_4.jpg";

// 2. Test Specific Engine TensorRT
lite::trt::cv::detection::YoloX *yolox =
new lite::trt::cv::detection::YoloX (engine_path);

std::vector<lite::types::Boxf> detected_boxes;
cv::Mat img_bgr = cv::imread(test_img_path);
yolox->detect(img_bgr, detected_boxes);

lite::utils::draw_boxes_inplace(img_bgr, detected_boxes);

cv::imwrite(save_img_path, img_bgr);

delete yolox;
#endif
}



static void test_lite()
{
test_default();
test_onnxruntime();
test_mnn();
test_ncnn();
test_tnn();
test_tensorrt();
}

int main(__unused int argc, __unused char *argv[])
Expand Down
4 changes: 3 additions & 1 deletion lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
#include "lite/trt/core/trt_core.h"
#include "lite/trt/cv/trt_yolofacev8.h"
#include "lite/trt/cv/trt_yolov5.h"
#include "lite/trt/cv/trt_yolox.h"
#endif

// ENABLE_MNN
Expand Down Expand Up @@ -677,14 +678,15 @@ namespace lite{
{
typedef trtcv::TRTYoloFaceV8 _TRT_YOLOFaceNet;
typedef trtcv::TRTYoloV5 _TRT_YOLOv5;
typedef trtcv::TRTYoloX _TRT_YoloX;
namespace classification
{

}
namespace detection
{
typedef _TRT_YOLOv5 YOLOV5;

typedef _TRT_YoloX YoloX;
}
namespace face
{
Expand Down
1 change: 1 addition & 0 deletions lite/trt/core/trt_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
namespace trtcv{
class LITE_EXPORTS TRTYoloFaceV8; // [1] * reference: https://github.com/derronqi/yolov8-face
class LITE_EXPORTS TRTYoloV5; // [2] * reference: https://github.com/ultralytics/yolov5
class LITE_EXPORTS TRTYoloX; // [3] * reference: https://github.com/Megvii-BaseDetection/YOLOX
}

namespace trtcv{
Expand Down
11 changes: 7 additions & 4 deletions lite/trt/core/trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "trt_utils.h"


float* trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<int64_t> input_node_dims,unsigned int data_format){
void trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<float> &input_vector,std::vector<int64_t> input_node_dims,unsigned int data_format){
// make mat to float type's vector

const unsigned int rows = mat.rows;
Expand All @@ -19,12 +19,11 @@ float* trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<int
if (input_node_dims.size() != 4) throw std::runtime_error("dims mismatch.");
if (input_node_dims.at(0) != 1) throw std::runtime_error("batch != 1");


if (data_format == transform::CHW)
{
const unsigned int target_tensor_size = rows * cols * channels;
// input vector's size
float* input_vector = new float [target_tensor_size];
input_vector.resize(target_tensor_size);

for (int c = 0; c < channels; ++c)
{
Expand All @@ -36,8 +35,12 @@ float* trtcv::utils::transform::create_tensor(const cv::Mat &mat,std::vector<int
}
}
}
return input_vector;

}else
{
throw std::runtime_error("data_format must be transform::CHW!");
}

}


Expand Down
2 changes: 1 addition & 1 deletion lite/trt/core/trt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace trtcv
{
CHW = 0, HWC =1
};
LITE_EXPORTS float* create_tensor(const cv::Mat &mat,std::vector<int64_t> input_node_dims,unsigned int data_format = CHW);
LITE_EXPORTS void create_tensor(const cv::Mat &mat,std::vector<float> &input_vector,std::vector<int64_t> input_node_dims,unsigned int data_format = CHW);


}
Expand Down
16 changes: 6 additions & 10 deletions lite/trt/cv/trt_yolofacev8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,26 @@ void TRTYoloFaceV8::detect(const cv::Mat &mat, std::vector<lite::types::Boxf> &b
cv::Mat normalized_image = normalize(mat);

// 2.trans to input vector
auto input = trtcv::utils::transform::create_tensor(normalized_image,input_node_dims,trtcv::utils::transform::CHW);
std::vector<float> input;
trtcv::utils::transform::create_tensor(normalized_image,input,input_node_dims,trtcv::utils::transform::CHW);

// 3. infer
cudaMemcpyAsync(buffers[0], input, input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyAsync(buffers[0], input.data(), input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyHostToDevice, stream);
bool status = trt_context->enqueueV3(stream);

delete[] input;
input = nullptr;

if (!status){
std::cerr << "Failed to infer by TensorRT." << std::endl;
return;
}

float* output = new float[output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2]];
std::vector<float> output(output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2]);

cudaMemcpyAsync(output, buffers[1], output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2] * sizeof(float),
cudaMemcpyAsync(output.data(), buffers[1], output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
// 4. generate box
generate_box(output,boxes,0.45f,0.5f);
generate_box(output.data(),boxes,0.45f,0.5f);

// free pointer
delete[] output;
output = nullptr;

}
16 changes: 7 additions & 9 deletions lite/trt/cv/trt_yolov5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ void TRTYoloV5::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_bo
cv::Mat normalized_image = normalized(mat_rs);

//1. make the input
auto input = trtcv::utils::transform::create_tensor(normalized_image,input_node_dims,trtcv::utils::transform::CHW);
std::vector<float> input;
trtcv::utils::transform::create_tensor(normalized_image,input,input_node_dims,trtcv::utils::transform::CHW);

//2. infer
cudaMemcpyAsync(buffers[0], input, input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyAsync(buffers[0], input.data(), input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
delete[] input;
input = nullptr;


bool status = trt_context->enqueueV3(stream);
cudaStreamSynchronize(stream);
Expand All @@ -164,18 +164,16 @@ void TRTYoloV5::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_bo
// get the first output dim
auto pred_dims = output_node_dims[0];

float* output = new float[pred_dims[0] * pred_dims[1] * pred_dims[2]];
std::vector<float> output(pred_dims[0] * pred_dims[1] * pred_dims[2]);

cudaMemcpyAsync(output, buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float),
cudaMemcpyAsync(output.data(), buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);

//3. generate the boxes
std::vector<types::Boxf> bbox_collection;
generate_bboxes(scale_params, bbox_collection, output, score_threshold, img_height, img_width);
generate_bboxes(scale_params, bbox_collection, output.data(), score_threshold, img_height, img_width);
nms(bbox_collection, detected_boxes, iou_threshold, topk, nms_type);
delete[] output;
output = nullptr;
}


4 changes: 2 additions & 2 deletions lite/trt/cv/trt_yolov5.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace trtcv
class LITE_EXPORTS TRTYoloV5 : public BasicTRTHandler
{
public:
explicit TRTYoloV5(const std::string &_onnx_path, unsigned int _num_threads = 1) :
BasicTRTHandler(_onnx_path, _num_threads)
explicit TRTYoloV5(const std::string &_trt_model_path, unsigned int _num_threads = 1) :
BasicTRTHandler(_trt_model_path, _num_threads)
{};

~TRTYoloV5() override = default;
Expand Down
Loading

0 comments on commit 42eea2b

Please sign in to comment.