From f6592096d8158d72e16821213e34d4b1c34e67d9 Mon Sep 17 00:00:00 2001 From: Satoshi Tanaka <16330533+scepter914@users.noreply.github.com> Date: Tue, 18 Oct 2022 10:14:05 +0900 Subject: [PATCH] refactor(perception_utils): refactor object_classification (#2042) * refactor(perception_utils): refactor object_classification Signed-off-by: scepter914 * fix bug Signed-off-by: scepter914 * fix unittest Signed-off-by: scepter914 * refactor Signed-off-by: scepter914 * fix unit test Signed-off-by: scepter914 * remove redundant else Signed-off-by: scepter914 * refactor variable name Signed-off-by: scepter914 Signed-off-by: scepter914 --- .../object_classification.hpp | 86 ++++++++++--------- .../test/src/test_object_classification.cpp | 61 ++++++++++--- 2 files changed, 93 insertions(+), 54 deletions(-) diff --git a/common/perception_utils/include/perception_utils/object_classification.hpp b/common/perception_utils/include/perception_utils/object_classification.hpp index a73eac253b17d..dba9236f2c3c0 100644 --- a/common/perception_utils/include/perception_utils/object_classification.hpp +++ b/common/perception_utils/include/perception_utils/object_classification.hpp @@ -23,72 +23,74 @@ namespace perception_utils { using autoware_auto_perception_msgs::msg::ObjectClassification; -inline std::uint8_t getHighestProbLabel( +inline ObjectClassification getHighestProbClassification( const std::vector & object_classifications) { - std::uint8_t label = ObjectClassification::UNKNOWN; - float highest_prob = 0.0; - // TODO(Satoshi Tanaka): It might be simple if you use STL or range-v3. - for (const auto & object_classification : object_classifications) { - if (highest_prob < object_classification.probability) { - highest_prob = object_classification.probability; - label = object_classification.label; - } + if (object_classifications.empty()) { + return ObjectClassification{}; } - return label; + return *std::max_element( + std::begin(object_classifications), std::end(object_classifications), + [](const auto & a, const auto & b) { return a.probability < b.probability; }); +} + +inline std::uint8_t getHighestProbLabel( + const std::vector & object_classifications) +{ + auto classification = getHighestProbClassification(object_classifications); + return classification.label; } -inline bool isVehicle(const uint8_t object_classification) +inline bool isVehicle(const uint8_t label) { - return object_classification == ObjectClassification::BICYCLE || - object_classification == ObjectClassification::BUS || - object_classification == ObjectClassification::CAR || - object_classification == ObjectClassification::MOTORCYCLE || - object_classification == ObjectClassification::TRAILER || - object_classification == ObjectClassification::TRUCK; + return label == ObjectClassification::BICYCLE || label == ObjectClassification::BUS || + label == ObjectClassification::CAR || label == ObjectClassification::MOTORCYCLE || + label == ObjectClassification::TRAILER || label == ObjectClassification::TRUCK; +} + +inline bool isVehicle(const ObjectClassification & object_classification) +{ + return isVehicle(object_classification.label); } inline bool isVehicle(const std::vector & object_classifications) { - auto highest_prob_classification = getHighestProbLabel(object_classifications); - return highest_prob_classification == ObjectClassification::BICYCLE || - highest_prob_classification == ObjectClassification::BUS || - highest_prob_classification == ObjectClassification::CAR || - highest_prob_classification == ObjectClassification::MOTORCYCLE || - highest_prob_classification == ObjectClassification::TRAILER || - highest_prob_classification == ObjectClassification::TRUCK; + auto highest_prob_label = getHighestProbLabel(object_classifications); + return isVehicle(highest_prob_label); +} + +inline bool isCarLikeVehicle(const uint8_t label) +{ + return label == ObjectClassification::BUS || label == ObjectClassification::CAR || + label == ObjectClassification::TRAILER || label == ObjectClassification::TRUCK; } -inline bool isCarLikeVehicle(const uint8_t object_classification) +inline bool isCarLikeVehicle(const ObjectClassification & object_classification) { - return object_classification == ObjectClassification::BUS || - object_classification == ObjectClassification::CAR || - object_classification == ObjectClassification::TRAILER || - object_classification == ObjectClassification::TRUCK; + return isCarLikeVehicle(object_classification.label); } inline bool isCarLikeVehicle(const std::vector & object_classifications) { - auto highest_prob_classification = getHighestProbLabel(object_classifications); - return highest_prob_classification == ObjectClassification::BUS || - highest_prob_classification == ObjectClassification::CAR || - highest_prob_classification == ObjectClassification::TRAILER || - highest_prob_classification == ObjectClassification::TRUCK; + auto highest_prob_label = getHighestProbLabel(object_classifications); + return isCarLikeVehicle(highest_prob_label); +} + +inline bool isLargeVehicle(const uint8_t label) +{ + return label == ObjectClassification::BUS || label == ObjectClassification::TRAILER || + label == ObjectClassification::TRUCK; } -inline bool isLargeVehicle(const uint8_t object_classification) +inline bool isLargeVehicle(const ObjectClassification & object_classification) { - return object_classification == ObjectClassification::BUS || - object_classification == ObjectClassification::TRAILER || - object_classification == ObjectClassification::TRUCK; + return isLargeVehicle(object_classification.label); } inline bool isLargeVehicle(const std::vector & object_classifications) { - auto highest_prob_classification = getHighestProbLabel(object_classifications); - return highest_prob_classification == ObjectClassification::BUS || - highest_prob_classification == ObjectClassification::TRAILER || - highest_prob_classification == ObjectClassification::TRUCK; + auto highest_prob_label = getHighestProbLabel(object_classifications); + return isLargeVehicle(highest_prob_label); } } // namespace perception_utils diff --git a/common/perception_utils/test/src/test_object_classification.cpp b/common/perception_utils/test/src/test_object_classification.cpp index 1d89d73b7dccb..9758266642255 100644 --- a/common/perception_utils/test/src/test_object_classification.cpp +++ b/common/perception_utils/test/src/test_object_classification.cpp @@ -16,6 +16,8 @@ #include +constexpr double epsilon = 1e-06; + namespace { autoware_auto_perception_msgs::msg::ObjectClassification createObjectClassification( @@ -36,28 +38,63 @@ TEST(object_classification, test_getHighestProbLabel) using perception_utils::getHighestProbLabel; { // empty - std::vector classification; - std::uint8_t label = getHighestProbLabel(classification); + std::vector classifications; + std::uint8_t label = getHighestProbLabel(classifications); EXPECT_EQ(label, ObjectClassification::UNKNOWN); } { // normal case - std::vector classification; - classification.push_back(createObjectClassification(ObjectClassification::CAR, 0.5)); - classification.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8)); - classification.push_back(createObjectClassification(ObjectClassification::BUS, 0.7)); + std::vector classifications; + classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.5)); + classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8)); + classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7)); - std::uint8_t label = getHighestProbLabel(classification); + std::uint8_t label = getHighestProbLabel(classifications); EXPECT_EQ(label, ObjectClassification::TRUCK); } { // labels with the same probability - std::vector classification; - classification.push_back(createObjectClassification(ObjectClassification::CAR, 0.8)); - classification.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8)); - classification.push_back(createObjectClassification(ObjectClassification::BUS, 0.7)); + std::vector classifications; + classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.8)); + classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8)); + classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7)); - std::uint8_t label = getHighestProbLabel(classification); + std::uint8_t label = getHighestProbLabel(classifications); EXPECT_EQ(label, ObjectClassification::CAR); } } + +TEST(object_classification, test_getHighestProbClassification) +{ + using autoware_auto_perception_msgs::msg::ObjectClassification; + using perception_utils::getHighestProbClassification; + + { // empty + std::vector classifications; + auto classification = getHighestProbClassification(classifications); + EXPECT_EQ(classification.label, ObjectClassification::UNKNOWN); + EXPECT_DOUBLE_EQ(classification.probability, 0.0); + } + + { // normal case + std::vector classifications; + classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.5)); + classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8)); + classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7)); + + auto classification = getHighestProbClassification(classifications); + EXPECT_EQ(classification.label, ObjectClassification::TRUCK); + EXPECT_NEAR(classification.probability, 0.8, epsilon); + } + + { // labels with the same probability + std::vector classifications; + classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.8)); + classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8)); + classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7)); + + auto classification = getHighestProbClassification(classifications); + EXPECT_EQ(classification.label, ObjectClassification::CAR); + EXPECT_NEAR(classification.probability, 0.8, epsilon); + } +}