Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TensorRT BasicHandler and add yolov5 cpp file #417

Merged
merged 15 commits into from
Jul 22, 2024
Merged
28 changes: 27 additions & 1 deletion examples/lite/cv/test_lite_yolov5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ static void test_default()
{
std::string onnx_path = "../../../examples/hub/onnx/cv/yolov5s.onnx";
std::string test_img_path = "../../../examples/lite/resources/test_lite_yolov5_1.jpg";
std::string save_img_path = "../../../examples/logs/test_lite_yolov5_1.jpg";
std::string save_img_path = "../../../examples/logs/test_lite_yolov5_1647_onnx.jpg";
wangzijian1010 marked this conversation as resolved.
Show resolved Hide resolved

// 1. Test Default Engine ONNXRuntime
lite::cv::detection::YoloV5 *yolov5 = new lite::cv::detection::YoloV5(onnx_path); // default
Expand Down Expand Up @@ -129,8 +129,34 @@ static void test_tnn()
#endif
}


static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT
std::string engine_path = "../../../examples/hub/trt/yolov5s_fp32.engine";
std::string test_img_path = "../../../examples/lite/resources/test_lite_yolov5_1.jpg";
std::string save_img_path = "../../../examples/logs/test_lite_yolov5_1647.jpg";
wangzijian1010 marked this conversation as resolved.
Show resolved Hide resolved

// 1. Test TensorRT Engine
lite::trt::cv::detection::YOLOV5 *yolov5 = new lite::trt::cv::detection::YOLOV5(engine_path);
std::vector<lite::types::Boxf> detected_boxes;
cv::Mat img_bgr = cv::imread(test_img_path);
yolov5->detect(img_bgr, detected_boxes);

lite::utils::draw_boxes_inplace(img_bgr, detected_boxes);

cv::imwrite(save_img_path, img_bgr);

std::cout << "Default Version Detected Boxes Num: " << detected_boxes.size() << std::endl;

delete yolov5;
#endif
}


static void test_lite()
{
test_tensorrt();
test_default();
test_onnxruntime();
test_mnn();
Expand Down
3 changes: 3 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
#include "lite/trt/core/trt_utils.h"
#include "lite/trt/core/trt_core.h"
#include "lite/trt/cv/trt_yolofacev8.h"
#include "lite/trt/cv/trt_yolov5.h"
#endif

// ENABLE_MNN
Expand Down Expand Up @@ -675,12 +676,14 @@ namespace lite{
namespace cv
{
typedef trtcv::TRTYoloFaceV8 _TRT_YOLOFaceNet;
typedef trtcv::TRTYoloV5 _TRT_YOLOv5;
namespace classification
{

}
namespace detection
{
typedef _TRT_YOLOv5 YOLOV5;

}
namespace face
Expand Down
1 change: 1 addition & 0 deletions lite/trt/core/trt_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace trtcv{
class LITE_EXPORTS TRTYoloFaceV8; // [1] * reference: https://github.com/derronqi/yolov8-face
class LITE_EXPORTS TRTYoloV5; // [2] * reference: https://github.com/derronqi/yolov8-face
wangzijian1010 marked this conversation as resolved.
Show resolved Hide resolved
}

namespace trtcv{
Expand Down
62 changes: 37 additions & 25 deletions lite/trt/core/trt_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ BasicTRTHandler::BasicTRTHandler(const std::string &_trt_model_path, unsigned in

BasicTRTHandler::~BasicTRTHandler() {
// don't need free by manunly
cudaFree(buffers[0]);
cudaFree(buffers[1]);
for (auto buffer : buffers) {
cudaFree(buffer);
}
cudaStreamDestroy(stream);
}

Expand Down Expand Up @@ -50,31 +51,42 @@ void BasicTRTHandler::initialize_handler() {
}
cudaStreamCreate(&stream);


auto input_name = trt_engine->getIOTensorName(0);
auto output_name = trt_engine->getIOTensorName(1);


nvinfer1::Dims input_dims = trt_engine->getTensorShape(input_name);
nvinfer1::Dims output_dims = trt_engine->getTensorShape(output_name);

input_tensor_size = 1;
for (int i = 0; i < input_dims.nbDims; ++i) {
input_node_dims.push_back(input_dims.d[i]);
input_tensor_size *= input_dims.d[i];
}

output_tensor_size = 1;
for (int i = 0; i < output_dims.nbDims; ++i) {
output_node_dims.push_back(output_dims.d[i]);
output_tensor_size *= output_dims.d[i];
// make the flexible one input and multi output
int num_io_tensors = trt_engine->getNbIOTensors(); // get the input and output's num
buffers.resize(num_io_tensors);

for (int i = 0; i < num_io_tensors; ++i) {
auto tensor_name = trt_engine->getIOTensorName(i);
nvinfer1::Dims tensor_dims = trt_engine->getTensorShape(tensor_name);

// input
if (i==0)
{
size_t tensor_size = 1;
for (int j = 0; j < tensor_dims.nbDims; ++j) {
tensor_size *= tensor_dims.d[j];
input_node_dims.push_back(tensor_dims.d[j]);
}
cudaMalloc(&buffers[i], tensor_size * sizeof(float));
trt_context->setTensorAddress(tensor_name, buffers[i]);
continue;
}

// output
size_t tensor_size = 1;

std::vector<int64_t> output_node;
for (int j = 0; j < tensor_dims.nbDims; ++j) {
output_node.push_back(tensor_dims.d[j]);
tensor_size *= tensor_dims.d[j];
}
output_node_dims.push_back(output_node);

cudaMalloc(&buffers[i], tensor_size * sizeof(float));
trt_context->setTensorAddress(tensor_name, buffers[i]);
output_tensor_size++;
}

cudaMalloc(&buffers[0], input_tensor_size * sizeof(float));
cudaMalloc(&buffers[1], output_tensor_size * sizeof(float));

trt_context->setTensorAddress(input_name, buffers[0]);
trt_context->setTensorAddress(output_name, buffers[1]);

}

Expand Down
7 changes: 3 additions & 4 deletions lite/trt/core/trt_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ namespace trtcore{
std::unique_ptr<nvinfer1::IExecutionContext> trt_context;

Logger trt_logger;
// single input and single output
void* buffers[2];
std::vector<void*> buffers;
cudaStream_t stream;

std::vector<int64_t> input_node_dims;
std::vector<int64_t> output_node_dims;
std::vector<std::vector<int64_t>> output_node_dims;
std::size_t input_tensor_size = 1;
std::size_t output_tensor_size = 1;
std::size_t output_tensor_size = 0;

const char* trt_model_path = nullptr;
const char* log_id = nullptr;
Expand Down
6 changes: 3 additions & 3 deletions lite/trt/cv/trt_yolofacev8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ cv::Mat TRTYoloFaceV8::normalize(cv::Mat srcimg) {
void TRTYoloFaceV8::generate_box(float *trt_outputs, std::vector<lite::types::Boxf> &boxes, float conf_threshold,
float iou_threshold) {

int num_box = output_node_dims[2];
int num_box = output_node_dims[0][2];
std::vector<lite::types::BoundingBoxType<float, float>> bounding_box_raw;
std::vector<float> score_raw;
for (int i = 0; i < num_box; i++)
Expand Down Expand Up @@ -152,9 +152,9 @@ void TRTYoloFaceV8::detect(const cv::Mat &mat, std::vector<lite::types::Boxf> &b
return;
}

float* output = new float[output_node_dims[0] * output_node_dims[1] * output_node_dims[2]];
float* output = new float[output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2]];
wangzijian1010 marked this conversation as resolved.
Show resolved Hide resolved

cudaMemcpyAsync(output, buffers[1], output_node_dims[0] * output_node_dims[1] * output_node_dims[2] * sizeof(float),
cudaMemcpyAsync(output, buffers[1], output_node_dims[0][0] * output_node_dims[0][1] * output_node_dims[0][2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
// 4. generate box
generate_box(output,boxes,0.45f,0.5f);
Expand Down
178 changes: 178 additions & 0 deletions lite/trt/cv/trt_yolov5.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
//
// Created by root on 7/20/24.
wangzijian1010 marked this conversation as resolved.
Show resolved Hide resolved
//

#include "trt_yolov5.h"
using trtcv::TRTYoloV5;

void TRTYoloV5::resize_unscale(const cv::Mat &mat, cv::Mat &mat_rs,
int target_height, int target_width,
YoloV5ScaleParams &scale_params)
{
if (mat.empty()) return;
int img_height = static_cast<int>(mat.rows);
int img_width = static_cast<int>(mat.cols);

mat_rs = cv::Mat(target_height, target_width, CV_8UC3,
cv::Scalar(114, 114, 114));
// scale ratio (new / old) new_shape(h,w)
float w_r = (float) target_width / (float) img_width;
float h_r = (float) target_height / (float) img_height;
float r = std::min(w_r, h_r);
// compute padding
int new_unpad_w = static_cast<int>((float) img_width * r); // floor
int new_unpad_h = static_cast<int>((float) img_height * r); // floor
int pad_w = target_width - new_unpad_w; // >=0
int pad_h = target_height - new_unpad_h; // >=0

int dw = pad_w / 2;
int dh = pad_h / 2;

// resize with unscaling
cv::Mat new_unpad_mat;
// cv::Mat new_unpad_mat = mat.clone(); // may not need clone.
cv::resize(mat, new_unpad_mat, cv::Size(new_unpad_w, new_unpad_h));
new_unpad_mat.copyTo(mat_rs(cv::Rect(dw, dh, new_unpad_w, new_unpad_h)));

// record scale params.
scale_params.r = r;
scale_params.dw = dw;
scale_params.dh = dh;
scale_params.new_unpad_w = new_unpad_w;
scale_params.new_unpad_h = new_unpad_h;
scale_params.flag = true;
}

void TRTYoloV5::nms(std::vector<types::Boxf> &input, std::vector<types::Boxf> &output,
float iou_threshold, unsigned int topk, unsigned int nms_type)
{
if (nms_type == NMS::BLEND) lite::utils::blending_nms(input, output, iou_threshold, topk);
else if (nms_type == NMS::OFFSET) lite::utils::offset_nms(input, output, iou_threshold, topk);
else lite::utils::hard_nms(input, output, iou_threshold, topk);
}


cv::Mat TRTYoloV5::normalized(const cv::Mat input_image) {
cv::Mat canvas;
cv::cvtColor(input_image,canvas,cv::COLOR_BGR2RGB);
canvas.convertTo(canvas,CV_32F,1.0 / 255.0,0);
return canvas;
}


void TRTYoloV5::generate_bboxes(const trtcv::TRTYoloV5::YoloV5ScaleParams &scale_params,
std::vector<types::Boxf> &bbox_collection, float* output, float score_threshold,
int img_height, int img_width) {
auto pred_dims = output_node_dims[0];
const unsigned int num_anchors = pred_dims.at(1); // n = ?
const unsigned int num_classes = pred_dims.at(2) - 5;

float r_ = scale_params.r;
int dw_ = scale_params.dw;
int dh_ = scale_params.dh;

bbox_collection.clear();
unsigned int count = 0;
for (unsigned int i = 0; i < num_anchors; ++i)
{
float obj_conf = output[i * pred_dims.at(2) + 4];
if (obj_conf < score_threshold) continue; // filter first.

float cls_conf = output[i * pred_dims.at(2) + 5];
unsigned int label = 0;
for (unsigned int j = 0; j < num_classes; ++j)
{
float tmp_conf = output[i * pred_dims.at(2) + 5 + j];
if (tmp_conf > cls_conf)
{
cls_conf = tmp_conf;
label = j;
}
}
float conf = obj_conf * cls_conf; // cls_conf (0.,1.)
if (conf < score_threshold) continue; // filter

float cx = output[i * pred_dims.at(2)];
float cy = output[i * pred_dims.at(2) + 1];
float w = output[i * pred_dims.at(2) + 2];
float h = output[i * pred_dims.at(2) + 3];
float x1 = ((cx - w / 2.f) - (float) dw_) / r_;
float y1 = ((cy - h / 2.f) - (float) dh_) / r_;
float x2 = ((cx + w / 2.f) - (float) dw_) / r_;
float y2 = ((cy + h / 2.f) - (float) dh_) / r_;

types::Boxf box;
box.x1 = std::max(0.f, x1);
box.y1 = std::max(0.f, y1);
box.x2 = std::min(x2, (float) img_width - 1.f);
box.y2 = std::min(y2, (float) img_height - 1.f);
box.score = conf;
box.label = label;
box.label_text = class_names[label];
box.flag = true;
bbox_collection.push_back(box);

count += 1; // limit boxes for nms.
if (count > max_nms)
break;
}

#if LITETRT_DEBUG
std::cout << "detected num_anchors: " << num_anchors << "\n";
std::cout << "generate_bboxes num: " << bbox_collection.size() << "\n";
#endif

}



void TRTYoloV5::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_boxes, float score_threshold,
float iou_threshold, unsigned int topk, unsigned int nms_type) {

if (mat.empty()) return;
const int input_height = input_node_dims.at(2);
const int input_width = input_node_dims.at(3);
int img_height = static_cast<int>(mat.rows);
int img_width = static_cast<int>(mat.cols);

// resize & unscale
cv::Mat mat_rs;
YoloV5ScaleParams scale_params;
resize_unscale(mat, mat_rs, input_height, input_width, scale_params);

cv::Mat normalized_image = normalized(mat_rs);

//1. make the input
auto input = trtcv::utils::transform::create_tensor(normalized_image,input_node_dims,trtcv::utils::transform::CHW);


//2. infer
cudaMemcpyAsync(buffers[0], input, input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);

bool status = trt_context->enqueueV3(stream);
cudaStreamSynchronize(stream);
if (!status){
std::cerr << "Failed to infer by TensorRT." << std::endl;
return;
}

// Synchronize the stream to ensure all operations are complete
cudaStreamSynchronize(stream);
// get the first output dim
auto pred_dims = output_node_dims[0];

float* output = new float[pred_dims[0] * pred_dims[1] * pred_dims[2]];
wangzijian1010 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的内存依然没有合理释放

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改到yolov8face那个文件上去了 哈哈哈@DefTruth


cudaMemcpyAsync(output, buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);

//3. generate the boxes
std::vector<types::Boxf> bbox_collection;
generate_bboxes(scale_params, bbox_collection, output, score_threshold, img_height, img_width);
nms(bbox_collection, detected_boxes, iou_threshold, topk, nms_type);
}


Loading