Skip to content

Commit

Permalink
feat(lidar_centerpoint): add IoU-based NMS (#1935)
Browse files Browse the repository at this point in the history
* feat(lidar_centerpoint): add IoU-based NMS

Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>

* feat: add magic number name

Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>

* feat: remove unnecessary headers

Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>

* fix: typo

Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>

* fix: typo

Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>

Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>
  • Loading branch information
yukke42 authored Sep 27, 2022
1 parent e3292d2 commit 8f9d39b
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 16 deletions.
5 changes: 3 additions & 2 deletions perception/lidar_centerpoint/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ message(STATUS "start to download")
if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL)
# Download trained models
set(DATA_PATH ${CMAKE_CURRENT_SOURCE_DIR}/data)
set(DONWLOAD_BASE_URL https://awf.ml.dev.web.auto/perception/models/centerpoint)
set(DOWNLOAD_BASE_URL https://awf.ml.dev.web.auto/perception/models/centerpoint)
execute_process(COMMAND mkdir -p ${DATA_PATH})

function(download VERSION FILE_NAME FILE_HASH)
message(STATUS "Checking and downloading ${FILE_NAME} ")
set(DOWNLOAD_URL ${DONWLOAD_BASE_URL}/${VERSION}/${FILE_NAME})
set(DOWNLOAD_URL ${DOWNLOAD_BASE_URL}/${VERSION}/${FILE_NAME})
set(FILE_PATH ${DATA_PATH}/${FILE_NAME})
set(STATUS_CODE 0)
message(STATUS "start ${FILE_NAME}")
Expand Down Expand Up @@ -130,6 +130,7 @@ if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL)
lib/ros_utils.cpp
lib/network/network_trt.cpp
lib/network/tensorrt_wrapper.cpp
lib/postprocess/non_maximum_suppression.cpp
lib/preprocess/pointcloud_densification.cpp
lib/preprocess/voxel_generator.cpp
)
Expand Down
23 changes: 13 additions & 10 deletions perception/lidar_centerpoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ We trained the models using <https://github.com/open-mmlab/mmdetection3d>.

### Core Parameters

| Name | Type | Default Value | Description |
| ------------------------------- | ------ | ------------- | ----------------------------------------------------------- |
| `score_threshold` | float | `0.4` | detected objects with score less than threshold are ignored |
| `densification_world_frame_id` | string | `map` | the world frame id to fuse multi-frame pointcloud |
| `densification_num_past_frames` | int | `1` | the number of past frames to fuse with the current frame |
| `trt_precision` | string | `fp16` | TensorRT inference precision: `fp32` or `fp16` |
| `encoder_onnx_path` | string | `""` | path to VoxelFeatureEncoder ONNX file |
| `encoder_engine_path` | string | `""` | path to VoxelFeatureEncoder TensorRT Engine file |
| `head_onnx_path` | string | `""` | path to DetectionHead ONNX file |
| `head_engine_path` | string | `""` | path to DetectionHead TensorRT Engine file |
| Name | Type | Default Value | Description |
| ------------------------------- | ------------ | ------------- | ------------------------------------------------------------- |
| `score_threshold` | float | `0.4` | detected objects with score less than threshold are ignored |
| `densification_world_frame_id` | string | `map` | the world frame id to fuse multi-frame pointcloud |
| `densification_num_past_frames` | int | `1` | the number of past frames to fuse with the current frame |
| `trt_precision` | string | `fp16` | TensorRT inference precision: `fp32` or `fp16` |
| `encoder_onnx_path` | string | `""` | path to VoxelFeatureEncoder ONNX file |
| `encoder_engine_path` | string | `""` | path to VoxelFeatureEncoder TensorRT Engine file |
| `head_onnx_path` | string | `""` | path to DetectionHead ONNX file |
| `head_engine_path` | string | `""` | path to DetectionHead TensorRT Engine file |
| `nms_iou_target_class_names` | list[string] | - | target classes for IoU-based Non Maximum Suppression |
| `nms_iou_search_distance_2d` | double | - | If two objects are farther than the value, NMS isn't applied. |
| `nms_iou_threshold` | double | - | IoU threshold for the IoU-based Non Maximum Suppression |

## Assumptions / Known limits

Expand Down
5 changes: 5 additions & 0 deletions perception/lidar_centerpoint/config/centerpoint.param.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
voxel_size: [0.32, 0.32, 8.0]
downsample_factor: 1
encoder_in_feature_size: 9
# post-process params
circle_nms_dist_threshold: 0.5
iou_nms_target_class_names: ["CAR"]
iou_nms_search_distance_2d: 10.0
iou_nms_threshold: 0.1
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
voxel_size: [0.32, 0.32, 8.0]
downsample_factor: 2
encoder_in_feature_size: 9
# post-process params
circle_nms_dist_threshold: 0.5
iou_nms_target_class_names: ["CAR"]
iou_nms_search_distance_2d: 10.0
iou_nms_threshold: 0.1
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#ifndef LIDAR_CENTERPOINT__NODE_HPP_
#define LIDAR_CENTERPOINT__NODE_HPP_

#include "lidar_centerpoint/postprocess/non_maximum_suppression.hpp"

#include <lidar_centerpoint/centerpoint_trt.hpp>
#include <lidar_centerpoint/detection_class_remapper.hpp>
#include <rclcpp/rclcpp.hpp>
Expand Down Expand Up @@ -52,6 +54,7 @@ class LidarCenterPointNode : public rclcpp::Node
std::vector<std::string> class_names_;
bool has_twist_{false};

NonMaximumSuppression iou_bev_nms_;
DetectionClassRemapper detection_class_remapper_;

std::unique_ptr<CenterPointTRT> detector_ptr_{nullptr};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2022 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_
#define LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_

#include "lidar_centerpoint/ros_utils.hpp"

#include <Eigen/Eigen>

#include "autoware_auto_perception_msgs/msg/detected_object.hpp"

#include <string>
#include <vector>

namespace centerpoint
{
using autoware_auto_perception_msgs::msg::DetectedObject;

// TODO(yukke42): now only support IoU_BEV
enum class NMS_TYPE {
IoU_BEV
// IoU_3D
// Distance_2D
// Distance_3D
};

struct NMSParams
{
NMS_TYPE nms_type_{};
std::vector<std::string> target_class_names_{};
double search_distance_2d_{};
double iou_threshold_{};
// double distance_threshold_{};
};

std::vector<bool> classNamesToBooleanMask(const std::vector<std::string> & class_names)
{
std::vector<bool> mask;
constexpr std::size_t num_object_classification = 8;
mask.resize(num_object_classification);
for (const auto & class_name : class_names) {
const auto semantic_type = getSemanticType(class_name);
mask.at(semantic_type) = true;
}

return mask;
}

class NonMaximumSuppression
{
public:
void setParameters(const NMSParams &);

std::vector<DetectedObject> apply(const std::vector<DetectedObject> &);

private:
bool isTargetLabel(const std::uint8_t);

bool isTargetPairObject(const DetectedObject &, const DetectedObject &);

Eigen::MatrixXd generateIoUMatrix(const std::vector<DetectedObject> &);

NMSParams params_{};
std::vector<bool> target_class_mask_{};
};

} // namespace centerpoint

#endif // LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2022 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lidar_centerpoint/postprocess/non_maximum_suppression.hpp"

#include "perception_utils/geometry.hpp"
#include "perception_utils/perception_utils.hpp"
#include "tier4_autoware_utils/tier4_autoware_utils.hpp"

namespace centerpoint
{

void NonMaximumSuppression::setParameters(const NMSParams & params)
{
assert(params.target_class_names_.size() == 8);
assert(params.search_distance_2d_ >= 0.0);
assert(params.iou_threshold_ >= 0.0 && params.iou_threshold_ <= 1.0);

params_ = params;
target_class_mask_ = classNamesToBooleanMask(params.target_class_names_);
}

bool NonMaximumSuppression::isTargetLabel(const uint8_t label)
{
if (label >= target_class_mask_.size()) {
return false;
}
return target_class_mask_.at(label);
}

bool NonMaximumSuppression::isTargetPairObject(
const DetectedObject & object1, const DetectedObject & object2)
{
const auto label1 = perception_utils::getHighestProbLabel(object1.classification);
const auto label2 = perception_utils::getHighestProbLabel(object2.classification);

if (isTargetLabel(label1) && isTargetLabel(label2)) {
return true;
}

const auto search_sqr_dist_2d = params_.search_distance_2d_ * params_.search_distance_2d_;
const auto sqr_dist_2d = tier4_autoware_utils::calcSquaredDistance2d(
perception_utils::getPose(object1), perception_utils::getPose(object2));
return sqr_dist_2d <= search_sqr_dist_2d;
}

Eigen::MatrixXd NonMaximumSuppression::generateIoUMatrix(
const std::vector<DetectedObject> & input_objects)
{
// NOTE(yukke42): row = target objects to be suppressed, col = source objects to be compared
Eigen::MatrixXd triangular_matrix =
Eigen::MatrixXd::Zero(input_objects.size(), input_objects.size());
for (std::size_t target_i = 0; target_i < input_objects.size(); ++target_i) {
for (std::size_t source_i = 0; source_i < target_i; ++source_i) {
const auto & target_obj = input_objects.at(target_i);
const auto & source_obj = input_objects.at(source_i);
if (!isTargetPairObject(target_obj, source_obj)) {
continue;
}

if (params_.nms_type_ == NMS_TYPE::IoU_BEV) {
const double iou = perception_utils::get2dIoU(target_obj, source_obj);
triangular_matrix(target_i, source_i) = iou;
// NOTE(yukke42): If the target object has any objects with iou > iou_threshold, it
// will be suppressed regardless of later results.
if (iou > params_.iou_threshold_) {
break;
}
}
}
}

return triangular_matrix;
}

std::vector<DetectedObject> NonMaximumSuppression::apply(
const std::vector<DetectedObject> & input_objects)
{
Eigen::MatrixXd iou_matrix = generateIoUMatrix(input_objects);

std::vector<DetectedObject> output_objects;
output_objects.reserve(input_objects.size());
for (std::size_t i = 0; i < input_objects.size(); ++i) {
const auto value = iou_matrix.row(i).maxCoeff();
if (params_.nms_type_ == NMS_TYPE::IoU_BEV) {
if (value <= params_.iou_threshold_) {
output_objects.emplace_back(input_objects.at(i));
}
}
}

return output_objects;
}
} // namespace centerpoint
1 change: 1 addition & 0 deletions perception/lidar_centerpoint/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

<depend>autoware_auto_perception_msgs</depend>
<depend>pcl_ros</depend>
<depend>perception_utils</depend>
<depend>rclcpp</depend>
<depend>rclcpp_components</depend>
<depend>tf2_eigen</depend>
Expand Down
25 changes: 21 additions & 4 deletions perception/lidar_centerpoint/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include <tf2_geometry_msgs/tf2_geometry_msgs.hpp>
#endif

#include <Eigen/Dense>
#include <Eigen/Geometry>

#include <memory>
#include <string>
#include <vector>
Expand All @@ -38,7 +41,7 @@ LidarCenterPointNode::LidarCenterPointNode(const rclcpp::NodeOptions & node_opti
const float score_threshold =
static_cast<float>(this->declare_parameter<double>("score_threshold", 0.35));
const float circle_nms_dist_threshold =
static_cast<float>(this->declare_parameter<double>("circle_nms_dist_threshold", 1.5));
static_cast<float>(this->declare_parameter<double>("circle_nms_dist_threshold"));
const float yaw_norm_threshold =
static_cast<float>(this->declare_parameter<double>("yaw_norm_threshold", 0.0));
const std::string densification_world_frame_id =
Expand Down Expand Up @@ -71,6 +74,16 @@ LidarCenterPointNode::LidarCenterPointNode(const rclcpp::NodeOptions & node_opti
detection_class_remapper_.setParameters(
allow_remapping_by_area_matrix, min_area_matrix, max_area_matrix);

{
NMSParams p;
p.nms_type_ = NMS_TYPE::IoU_BEV;
p.target_class_names_ =
this->declare_parameter<std::vector<std::string>>("iou_nms_target_class_names");
p.search_distance_2d_ = this->declare_parameter<double>("iou_nms_search_distance_2d");
p.iou_threshold_ = this->declare_parameter<double>("iou_nms_threshold");
iou_bev_nms_.setParameters(p);
}

NetworkParam encoder_param(encoder_onnx_path, encoder_engine_path, trt_precision);
NetworkParam head_param(head_onnx_path, head_engine_path, trt_precision);
DensificationParam densification_param(
Expand Down Expand Up @@ -129,14 +142,18 @@ void LidarCenterPointNode::pointCloudCallback(
return;
}

autoware_auto_perception_msgs::msg::DetectedObjects output_msg;
output_msg.header = input_pointcloud_msg->header;
std::vector<autoware_auto_perception_msgs::msg::DetectedObject> raw_objects;
raw_objects.reserve(det_boxes3d.size());
for (const auto & box3d : det_boxes3d) {
autoware_auto_perception_msgs::msg::DetectedObject obj;
box3DToDetectedObject(box3d, class_names_, has_twist_, obj);
output_msg.objects.emplace_back(obj);
raw_objects.emplace_back(obj);
}

autoware_auto_perception_msgs::msg::DetectedObjects output_msg;
output_msg.header = input_pointcloud_msg->header;
output_msg.objects = iou_bev_nms_.apply(raw_objects);

detection_class_remapper_.mapClasses(output_msg);

if (objects_sub_count > 0) {
Expand Down

0 comments on commit 8f9d39b

Please sign in to comment.