-
Notifications
You must be signed in to change notification settings - Fork 665
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(lidar_centerpoint): add IoU-based NMS (#1935)
* feat(lidar_centerpoint): add IoU-based NMS Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp> * feat: add magic number name Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp> * feat: remove unnecessary headers Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp> * fix: typo Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp> * fix: typo Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp> Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>
- Loading branch information
Showing
9 changed files
with
237 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
...ption/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
// Copyright 2022 TIER IV, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ | ||
#define LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ | ||
|
||
#include "lidar_centerpoint/ros_utils.hpp" | ||
|
||
#include <Eigen/Eigen> | ||
|
||
#include "autoware_auto_perception_msgs/msg/detected_object.hpp" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
namespace centerpoint | ||
{ | ||
using autoware_auto_perception_msgs::msg::DetectedObject; | ||
|
||
// TODO(yukke42): now only support IoU_BEV | ||
enum class NMS_TYPE { | ||
IoU_BEV | ||
// IoU_3D | ||
// Distance_2D | ||
// Distance_3D | ||
}; | ||
|
||
struct NMSParams | ||
{ | ||
NMS_TYPE nms_type_{}; | ||
std::vector<std::string> target_class_names_{}; | ||
double search_distance_2d_{}; | ||
double iou_threshold_{}; | ||
// double distance_threshold_{}; | ||
}; | ||
|
||
std::vector<bool> classNamesToBooleanMask(const std::vector<std::string> & class_names) | ||
{ | ||
std::vector<bool> mask; | ||
constexpr std::size_t num_object_classification = 8; | ||
mask.resize(num_object_classification); | ||
for (const auto & class_name : class_names) { | ||
const auto semantic_type = getSemanticType(class_name); | ||
mask.at(semantic_type) = true; | ||
} | ||
|
||
return mask; | ||
} | ||
|
||
class NonMaximumSuppression | ||
{ | ||
public: | ||
void setParameters(const NMSParams &); | ||
|
||
std::vector<DetectedObject> apply(const std::vector<DetectedObject> &); | ||
|
||
private: | ||
bool isTargetLabel(const std::uint8_t); | ||
|
||
bool isTargetPairObject(const DetectedObject &, const DetectedObject &); | ||
|
||
Eigen::MatrixXd generateIoUMatrix(const std::vector<DetectedObject> &); | ||
|
||
NMSParams params_{}; | ||
std::vector<bool> target_class_mask_{}; | ||
}; | ||
|
||
} // namespace centerpoint | ||
|
||
#endif // LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ |
105 changes: 105 additions & 0 deletions
105
perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Copyright 2022 TIER IV, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "lidar_centerpoint/postprocess/non_maximum_suppression.hpp" | ||
|
||
#include "perception_utils/geometry.hpp" | ||
#include "perception_utils/perception_utils.hpp" | ||
#include "tier4_autoware_utils/tier4_autoware_utils.hpp" | ||
|
||
namespace centerpoint | ||
{ | ||
|
||
void NonMaximumSuppression::setParameters(const NMSParams & params) | ||
{ | ||
assert(params.target_class_names_.size() == 8); | ||
assert(params.search_distance_2d_ >= 0.0); | ||
assert(params.iou_threshold_ >= 0.0 && params.iou_threshold_ <= 1.0); | ||
|
||
params_ = params; | ||
target_class_mask_ = classNamesToBooleanMask(params.target_class_names_); | ||
} | ||
|
||
bool NonMaximumSuppression::isTargetLabel(const uint8_t label) | ||
{ | ||
if (label >= target_class_mask_.size()) { | ||
return false; | ||
} | ||
return target_class_mask_.at(label); | ||
} | ||
|
||
bool NonMaximumSuppression::isTargetPairObject( | ||
const DetectedObject & object1, const DetectedObject & object2) | ||
{ | ||
const auto label1 = perception_utils::getHighestProbLabel(object1.classification); | ||
const auto label2 = perception_utils::getHighestProbLabel(object2.classification); | ||
|
||
if (isTargetLabel(label1) && isTargetLabel(label2)) { | ||
return true; | ||
} | ||
|
||
const auto search_sqr_dist_2d = params_.search_distance_2d_ * params_.search_distance_2d_; | ||
const auto sqr_dist_2d = tier4_autoware_utils::calcSquaredDistance2d( | ||
perception_utils::getPose(object1), perception_utils::getPose(object2)); | ||
return sqr_dist_2d <= search_sqr_dist_2d; | ||
} | ||
|
||
Eigen::MatrixXd NonMaximumSuppression::generateIoUMatrix( | ||
const std::vector<DetectedObject> & input_objects) | ||
{ | ||
// NOTE(yukke42): row = target objects to be suppressed, col = source objects to be compared | ||
Eigen::MatrixXd triangular_matrix = | ||
Eigen::MatrixXd::Zero(input_objects.size(), input_objects.size()); | ||
for (std::size_t target_i = 0; target_i < input_objects.size(); ++target_i) { | ||
for (std::size_t source_i = 0; source_i < target_i; ++source_i) { | ||
const auto & target_obj = input_objects.at(target_i); | ||
const auto & source_obj = input_objects.at(source_i); | ||
if (!isTargetPairObject(target_obj, source_obj)) { | ||
continue; | ||
} | ||
|
||
if (params_.nms_type_ == NMS_TYPE::IoU_BEV) { | ||
const double iou = perception_utils::get2dIoU(target_obj, source_obj); | ||
triangular_matrix(target_i, source_i) = iou; | ||
// NOTE(yukke42): If the target object has any objects with iou > iou_threshold, it | ||
// will be suppressed regardless of later results. | ||
if (iou > params_.iou_threshold_) { | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
|
||
return triangular_matrix; | ||
} | ||
|
||
std::vector<DetectedObject> NonMaximumSuppression::apply( | ||
const std::vector<DetectedObject> & input_objects) | ||
{ | ||
Eigen::MatrixXd iou_matrix = generateIoUMatrix(input_objects); | ||
|
||
std::vector<DetectedObject> output_objects; | ||
output_objects.reserve(input_objects.size()); | ||
for (std::size_t i = 0; i < input_objects.size(); ++i) { | ||
const auto value = iou_matrix.row(i).maxCoeff(); | ||
if (params_.nms_type_ == NMS_TYPE::IoU_BEV) { | ||
if (value <= params_.iou_threshold_) { | ||
output_objects.emplace_back(input_objects.at(i)); | ||
} | ||
} | ||
} | ||
|
||
return output_objects; | ||
} | ||
} // namespace centerpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters