diff --git a/perception/lidar_centerpoint/CMakeLists.txt b/perception/lidar_centerpoint/CMakeLists.txt index 646267b0d7a61..c512e06b3bf0e 100644 --- a/perception/lidar_centerpoint/CMakeLists.txt +++ b/perception/lidar_centerpoint/CMakeLists.txt @@ -70,12 +70,12 @@ message(STATUS "start to download") if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL) # Download trained models set(DATA_PATH ${CMAKE_CURRENT_SOURCE_DIR}/data) - set(DONWLOAD_BASE_URL https://awf.ml.dev.web.auto/perception/models/centerpoint) + set(DOWNLOAD_BASE_URL https://awf.ml.dev.web.auto/perception/models/centerpoint) execute_process(COMMAND mkdir -p ${DATA_PATH}) function(download VERSION FILE_NAME FILE_HASH) message(STATUS "Checking and downloading ${FILE_NAME} ") - set(DOWNLOAD_URL ${DONWLOAD_BASE_URL}/${VERSION}/${FILE_NAME}) + set(DOWNLOAD_URL ${DOWNLOAD_BASE_URL}/${VERSION}/${FILE_NAME}) set(FILE_PATH ${DATA_PATH}/${FILE_NAME}) set(STATUS_CODE 0) message(STATUS "start ${FILE_NAME}") @@ -130,6 +130,7 @@ if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL) lib/ros_utils.cpp lib/network/network_trt.cpp lib/network/tensorrt_wrapper.cpp + lib/postprocess/non_maximum_suppression.cpp lib/preprocess/pointcloud_densification.cpp lib/preprocess/voxel_generator.cpp ) diff --git a/perception/lidar_centerpoint/README.md b/perception/lidar_centerpoint/README.md index 0957749e16c83..72eb5498b71d4 100644 --- a/perception/lidar_centerpoint/README.md +++ b/perception/lidar_centerpoint/README.md @@ -30,16 +30,19 @@ We trained the models using . ### Core Parameters -| Name | Type | Default Value | Description | -| ------------------------------- | ------ | ------------- | ----------------------------------------------------------- | -| `score_threshold` | float | `0.4` | detected objects with score less than threshold are ignored | -| `densification_world_frame_id` | string | `map` | the world frame id to fuse multi-frame pointcloud | -| `densification_num_past_frames` | int | `1` | the number of past frames to fuse with the current frame | -| `trt_precision` | string | `fp16` | TensorRT inference precision: `fp32` or `fp16` | -| `encoder_onnx_path` | string | `""` | path to VoxelFeatureEncoder ONNX file | -| `encoder_engine_path` | string | `""` | path to VoxelFeatureEncoder TensorRT Engine file | -| `head_onnx_path` | string | `""` | path to DetectionHead ONNX file | -| `head_engine_path` | string | `""` | path to DetectionHead TensorRT Engine file | +| Name | Type | Default Value | Description | +| ------------------------------- | ------------ | ------------- | ------------------------------------------------------------- | +| `score_threshold` | float | `0.4` | detected objects with score less than threshold are ignored | +| `densification_world_frame_id` | string | `map` | the world frame id to fuse multi-frame pointcloud | +| `densification_num_past_frames` | int | `1` | the number of past frames to fuse with the current frame | +| `trt_precision` | string | `fp16` | TensorRT inference precision: `fp32` or `fp16` | +| `encoder_onnx_path` | string | `""` | path to VoxelFeatureEncoder ONNX file | +| `encoder_engine_path` | string | `""` | path to VoxelFeatureEncoder TensorRT Engine file | +| `head_onnx_path` | string | `""` | path to DetectionHead ONNX file | +| `head_engine_path` | string | `""` | path to DetectionHead TensorRT Engine file | +| `nms_iou_target_class_names` | list[string] | - | target classes for IoU-based Non Maximum Suppression | +| `nms_iou_search_distance_2d` | double | - | If two objects are farther than the value, NMS isn't applied. | +| `nms_iou_threshold` | double | - | IoU threshold for the IoU-based Non Maximum Suppression | ## Assumptions / Known limits diff --git a/perception/lidar_centerpoint/config/centerpoint.param.yaml b/perception/lidar_centerpoint/config/centerpoint.param.yaml index 3810c864738af..565c769a08c72 100644 --- a/perception/lidar_centerpoint/config/centerpoint.param.yaml +++ b/perception/lidar_centerpoint/config/centerpoint.param.yaml @@ -7,3 +7,8 @@ voxel_size: [0.32, 0.32, 8.0] downsample_factor: 1 encoder_in_feature_size: 9 + # post-process params + circle_nms_dist_threshold: 0.5 + iou_nms_target_class_names: ["CAR"] + iou_nms_search_distance_2d: 10.0 + iou_nms_threshold: 0.1 diff --git a/perception/lidar_centerpoint/config/centerpoint_tiny.param.yaml b/perception/lidar_centerpoint/config/centerpoint_tiny.param.yaml index c93b59ac0b8f9..e5ae31efc775d 100644 --- a/perception/lidar_centerpoint/config/centerpoint_tiny.param.yaml +++ b/perception/lidar_centerpoint/config/centerpoint_tiny.param.yaml @@ -7,3 +7,8 @@ voxel_size: [0.32, 0.32, 8.0] downsample_factor: 2 encoder_in_feature_size: 9 + # post-process params + circle_nms_dist_threshold: 0.5 + iou_nms_target_class_names: ["CAR"] + iou_nms_search_distance_2d: 10.0 + iou_nms_threshold: 0.1 diff --git a/perception/lidar_centerpoint/include/lidar_centerpoint/node.hpp b/perception/lidar_centerpoint/include/lidar_centerpoint/node.hpp index 7477442a1060d..474c9884eb36f 100644 --- a/perception/lidar_centerpoint/include/lidar_centerpoint/node.hpp +++ b/perception/lidar_centerpoint/include/lidar_centerpoint/node.hpp @@ -15,6 +15,8 @@ #ifndef LIDAR_CENTERPOINT__NODE_HPP_ #define LIDAR_CENTERPOINT__NODE_HPP_ +#include "lidar_centerpoint/postprocess/non_maximum_suppression.hpp" + #include #include #include @@ -52,6 +54,7 @@ class LidarCenterPointNode : public rclcpp::Node std::vector class_names_; bool has_twist_{false}; + NonMaximumSuppression iou_bev_nms_; DetectionClassRemapper detection_class_remapper_; std::unique_ptr detector_ptr_{nullptr}; diff --git a/perception/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp b/perception/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp new file mode 100644 index 0000000000000..78e362f77e7f1 --- /dev/null +++ b/perception/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp @@ -0,0 +1,81 @@ +// Copyright 2022 TIER IV, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ +#define LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ + +#include "lidar_centerpoint/ros_utils.hpp" + +#include + +#include "autoware_auto_perception_msgs/msg/detected_object.hpp" + +#include +#include + +namespace centerpoint +{ +using autoware_auto_perception_msgs::msg::DetectedObject; + +// TODO(yukke42): now only support IoU_BEV +enum class NMS_TYPE { + IoU_BEV + // IoU_3D + // Distance_2D + // Distance_3D +}; + +struct NMSParams +{ + NMS_TYPE nms_type_{}; + std::vector target_class_names_{}; + double search_distance_2d_{}; + double iou_threshold_{}; + // double distance_threshold_{}; +}; + +std::vector classNamesToBooleanMask(const std::vector & class_names) +{ + std::vector mask; + constexpr std::size_t num_object_classification = 8; + mask.resize(num_object_classification); + for (const auto & class_name : class_names) { + const auto semantic_type = getSemanticType(class_name); + mask.at(semantic_type) = true; + } + + return mask; +} + +class NonMaximumSuppression +{ +public: + void setParameters(const NMSParams &); + + std::vector apply(const std::vector &); + +private: + bool isTargetLabel(const std::uint8_t); + + bool isTargetPairObject(const DetectedObject &, const DetectedObject &); + + Eigen::MatrixXd generateIoUMatrix(const std::vector &); + + NMSParams params_{}; + std::vector target_class_mask_{}; +}; + +} // namespace centerpoint + +#endif // LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ diff --git a/perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp b/perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp new file mode 100644 index 0000000000000..82b9ca673061f --- /dev/null +++ b/perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp @@ -0,0 +1,105 @@ +// Copyright 2022 TIER IV, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lidar_centerpoint/postprocess/non_maximum_suppression.hpp" + +#include "perception_utils/geometry.hpp" +#include "perception_utils/perception_utils.hpp" +#include "tier4_autoware_utils/tier4_autoware_utils.hpp" + +namespace centerpoint +{ + +void NonMaximumSuppression::setParameters(const NMSParams & params) +{ + assert(params.target_class_names_.size() == 8); + assert(params.search_distance_2d_ >= 0.0); + assert(params.iou_threshold_ >= 0.0 && params.iou_threshold_ <= 1.0); + + params_ = params; + target_class_mask_ = classNamesToBooleanMask(params.target_class_names_); +} + +bool NonMaximumSuppression::isTargetLabel(const uint8_t label) +{ + if (label >= target_class_mask_.size()) { + return false; + } + return target_class_mask_.at(label); +} + +bool NonMaximumSuppression::isTargetPairObject( + const DetectedObject & object1, const DetectedObject & object2) +{ + const auto label1 = perception_utils::getHighestProbLabel(object1.classification); + const auto label2 = perception_utils::getHighestProbLabel(object2.classification); + + if (isTargetLabel(label1) && isTargetLabel(label2)) { + return true; + } + + const auto search_sqr_dist_2d = params_.search_distance_2d_ * params_.search_distance_2d_; + const auto sqr_dist_2d = tier4_autoware_utils::calcSquaredDistance2d( + perception_utils::getPose(object1), perception_utils::getPose(object2)); + return sqr_dist_2d <= search_sqr_dist_2d; +} + +Eigen::MatrixXd NonMaximumSuppression::generateIoUMatrix( + const std::vector & input_objects) +{ + // NOTE(yukke42): row = target objects to be suppressed, col = source objects to be compared + Eigen::MatrixXd triangular_matrix = + Eigen::MatrixXd::Zero(input_objects.size(), input_objects.size()); + for (std::size_t target_i = 0; target_i < input_objects.size(); ++target_i) { + for (std::size_t source_i = 0; source_i < target_i; ++source_i) { + const auto & target_obj = input_objects.at(target_i); + const auto & source_obj = input_objects.at(source_i); + if (!isTargetPairObject(target_obj, source_obj)) { + continue; + } + + if (params_.nms_type_ == NMS_TYPE::IoU_BEV) { + const double iou = perception_utils::get2dIoU(target_obj, source_obj); + triangular_matrix(target_i, source_i) = iou; + // NOTE(yukke42): If the target object has any objects with iou > iou_threshold, it + // will be suppressed regardless of later results. + if (iou > params_.iou_threshold_) { + break; + } + } + } + } + + return triangular_matrix; +} + +std::vector NonMaximumSuppression::apply( + const std::vector & input_objects) +{ + Eigen::MatrixXd iou_matrix = generateIoUMatrix(input_objects); + + std::vector output_objects; + output_objects.reserve(input_objects.size()); + for (std::size_t i = 0; i < input_objects.size(); ++i) { + const auto value = iou_matrix.row(i).maxCoeff(); + if (params_.nms_type_ == NMS_TYPE::IoU_BEV) { + if (value <= params_.iou_threshold_) { + output_objects.emplace_back(input_objects.at(i)); + } + } + } + + return output_objects; +} +} // namespace centerpoint diff --git a/perception/lidar_centerpoint/package.xml b/perception/lidar_centerpoint/package.xml index 71e8f71f6212f..4372b1848d08b 100644 --- a/perception/lidar_centerpoint/package.xml +++ b/perception/lidar_centerpoint/package.xml @@ -13,6 +13,7 @@ autoware_auto_perception_msgs pcl_ros + perception_utils rclcpp rclcpp_components tf2_eigen diff --git a/perception/lidar_centerpoint/src/node.cpp b/perception/lidar_centerpoint/src/node.cpp index 9901f9a325953..f85b9006a7dee 100644 --- a/perception/lidar_centerpoint/src/node.cpp +++ b/perception/lidar_centerpoint/src/node.cpp @@ -26,6 +26,9 @@ #include #endif +#include +#include + #include #include #include @@ -38,7 +41,7 @@ LidarCenterPointNode::LidarCenterPointNode(const rclcpp::NodeOptions & node_opti const float score_threshold = static_cast(this->declare_parameter("score_threshold", 0.35)); const float circle_nms_dist_threshold = - static_cast(this->declare_parameter("circle_nms_dist_threshold", 1.5)); + static_cast(this->declare_parameter("circle_nms_dist_threshold")); const float yaw_norm_threshold = static_cast(this->declare_parameter("yaw_norm_threshold", 0.0)); const std::string densification_world_frame_id = @@ -71,6 +74,16 @@ LidarCenterPointNode::LidarCenterPointNode(const rclcpp::NodeOptions & node_opti detection_class_remapper_.setParameters( allow_remapping_by_area_matrix, min_area_matrix, max_area_matrix); + { + NMSParams p; + p.nms_type_ = NMS_TYPE::IoU_BEV; + p.target_class_names_ = + this->declare_parameter>("iou_nms_target_class_names"); + p.search_distance_2d_ = this->declare_parameter("iou_nms_search_distance_2d"); + p.iou_threshold_ = this->declare_parameter("iou_nms_threshold"); + iou_bev_nms_.setParameters(p); + } + NetworkParam encoder_param(encoder_onnx_path, encoder_engine_path, trt_precision); NetworkParam head_param(head_onnx_path, head_engine_path, trt_precision); DensificationParam densification_param( @@ -129,14 +142,18 @@ void LidarCenterPointNode::pointCloudCallback( return; } - autoware_auto_perception_msgs::msg::DetectedObjects output_msg; - output_msg.header = input_pointcloud_msg->header; + std::vector raw_objects; + raw_objects.reserve(det_boxes3d.size()); for (const auto & box3d : det_boxes3d) { autoware_auto_perception_msgs::msg::DetectedObject obj; box3DToDetectedObject(box3d, class_names_, has_twist_, obj); - output_msg.objects.emplace_back(obj); + raw_objects.emplace_back(obj); } + autoware_auto_perception_msgs::msg::DetectedObjects output_msg; + output_msg.header = input_pointcloud_msg->header; + output_msg.objects = iou_bev_nms_.apply(raw_objects); + detection_class_remapper_.mapClasses(output_msg); if (objects_sub_count > 0) {