Skip to content

Commit

Permalink
refactor(perception_utils): refactor object_classification (#2042)
Browse files Browse the repository at this point in the history
* refactor(perception_utils): refactor object_classification

Signed-off-by: scepter914 <scepter914@gmail.com>

* fix bug

Signed-off-by: scepter914 <scepter914@gmail.com>

* fix unittest

Signed-off-by: scepter914 <scepter914@gmail.com>

* refactor

Signed-off-by: scepter914 <scepter914@gmail.com>

* fix unit test

Signed-off-by: scepter914 <scepter914@gmail.com>

* remove redundant else

Signed-off-by: scepter914 <scepter914@gmail.com>

* refactor variable name

Signed-off-by: scepter914 <scepter914@gmail.com>

Signed-off-by: scepter914 <scepter914@gmail.com>
  • Loading branch information
scepter914 authored Oct 18, 2022
1 parent a8f3a98 commit f659209
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,72 +23,74 @@ namespace perception_utils
{
using autoware_auto_perception_msgs::msg::ObjectClassification;

inline std::uint8_t getHighestProbLabel(
inline ObjectClassification getHighestProbClassification(
const std::vector<ObjectClassification> & object_classifications)
{
std::uint8_t label = ObjectClassification::UNKNOWN;
float highest_prob = 0.0;
// TODO(Satoshi Tanaka): It might be simple if you use STL or range-v3.
for (const auto & object_classification : object_classifications) {
if (highest_prob < object_classification.probability) {
highest_prob = object_classification.probability;
label = object_classification.label;
}
if (object_classifications.empty()) {
return ObjectClassification{};
}
return label;
return *std::max_element(
std::begin(object_classifications), std::end(object_classifications),
[](const auto & a, const auto & b) { return a.probability < b.probability; });
}

inline std::uint8_t getHighestProbLabel(
const std::vector<ObjectClassification> & object_classifications)
{
auto classification = getHighestProbClassification(object_classifications);
return classification.label;
}

inline bool isVehicle(const uint8_t object_classification)
inline bool isVehicle(const uint8_t label)
{
return object_classification == ObjectClassification::BICYCLE ||
object_classification == ObjectClassification::BUS ||
object_classification == ObjectClassification::CAR ||
object_classification == ObjectClassification::MOTORCYCLE ||
object_classification == ObjectClassification::TRAILER ||
object_classification == ObjectClassification::TRUCK;
return label == ObjectClassification::BICYCLE || label == ObjectClassification::BUS ||
label == ObjectClassification::CAR || label == ObjectClassification::MOTORCYCLE ||
label == ObjectClassification::TRAILER || label == ObjectClassification::TRUCK;
}

inline bool isVehicle(const ObjectClassification & object_classification)
{
return isVehicle(object_classification.label);
}

inline bool isVehicle(const std::vector<ObjectClassification> & object_classifications)
{
auto highest_prob_classification = getHighestProbLabel(object_classifications);
return highest_prob_classification == ObjectClassification::BICYCLE ||
highest_prob_classification == ObjectClassification::BUS ||
highest_prob_classification == ObjectClassification::CAR ||
highest_prob_classification == ObjectClassification::MOTORCYCLE ||
highest_prob_classification == ObjectClassification::TRAILER ||
highest_prob_classification == ObjectClassification::TRUCK;
auto highest_prob_label = getHighestProbLabel(object_classifications);
return isVehicle(highest_prob_label);
}

inline bool isCarLikeVehicle(const uint8_t label)
{
return label == ObjectClassification::BUS || label == ObjectClassification::CAR ||
label == ObjectClassification::TRAILER || label == ObjectClassification::TRUCK;
}

inline bool isCarLikeVehicle(const uint8_t object_classification)
inline bool isCarLikeVehicle(const ObjectClassification & object_classification)
{
return object_classification == ObjectClassification::BUS ||
object_classification == ObjectClassification::CAR ||
object_classification == ObjectClassification::TRAILER ||
object_classification == ObjectClassification::TRUCK;
return isCarLikeVehicle(object_classification.label);
}

inline bool isCarLikeVehicle(const std::vector<ObjectClassification> & object_classifications)
{
auto highest_prob_classification = getHighestProbLabel(object_classifications);
return highest_prob_classification == ObjectClassification::BUS ||
highest_prob_classification == ObjectClassification::CAR ||
highest_prob_classification == ObjectClassification::TRAILER ||
highest_prob_classification == ObjectClassification::TRUCK;
auto highest_prob_label = getHighestProbLabel(object_classifications);
return isCarLikeVehicle(highest_prob_label);
}

inline bool isLargeVehicle(const uint8_t label)
{
return label == ObjectClassification::BUS || label == ObjectClassification::TRAILER ||
label == ObjectClassification::TRUCK;
}

inline bool isLargeVehicle(const uint8_t object_classification)
inline bool isLargeVehicle(const ObjectClassification & object_classification)
{
return object_classification == ObjectClassification::BUS ||
object_classification == ObjectClassification::TRAILER ||
object_classification == ObjectClassification::TRUCK;
return isLargeVehicle(object_classification.label);
}

inline bool isLargeVehicle(const std::vector<ObjectClassification> & object_classifications)
{
auto highest_prob_classification = getHighestProbLabel(object_classifications);
return highest_prob_classification == ObjectClassification::BUS ||
highest_prob_classification == ObjectClassification::TRAILER ||
highest_prob_classification == ObjectClassification::TRUCK;
auto highest_prob_label = getHighestProbLabel(object_classifications);
return isLargeVehicle(highest_prob_label);
}
} // namespace perception_utils

Expand Down
61 changes: 49 additions & 12 deletions common/perception_utils/test/src/test_object_classification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include <gtest/gtest.h>

constexpr double epsilon = 1e-06;

namespace
{
autoware_auto_perception_msgs::msg::ObjectClassification createObjectClassification(
Expand All @@ -36,28 +38,63 @@ TEST(object_classification, test_getHighestProbLabel)
using perception_utils::getHighestProbLabel;

{ // empty
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classification;
std::uint8_t label = getHighestProbLabel(classification);
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
std::uint8_t label = getHighestProbLabel(classifications);
EXPECT_EQ(label, ObjectClassification::UNKNOWN);
}

{ // normal case
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classification;
classification.push_back(createObjectClassification(ObjectClassification::CAR, 0.5));
classification.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classification.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.5));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));

std::uint8_t label = getHighestProbLabel(classification);
std::uint8_t label = getHighestProbLabel(classifications);
EXPECT_EQ(label, ObjectClassification::TRUCK);
}

{ // labels with the same probability
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classification;
classification.push_back(createObjectClassification(ObjectClassification::CAR, 0.8));
classification.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classification.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));

std::uint8_t label = getHighestProbLabel(classification);
std::uint8_t label = getHighestProbLabel(classifications);
EXPECT_EQ(label, ObjectClassification::CAR);
}
}

TEST(object_classification, test_getHighestProbClassification)
{
using autoware_auto_perception_msgs::msg::ObjectClassification;
using perception_utils::getHighestProbClassification;

{ // empty
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
auto classification = getHighestProbClassification(classifications);
EXPECT_EQ(classification.label, ObjectClassification::UNKNOWN);
EXPECT_DOUBLE_EQ(classification.probability, 0.0);
}

{ // normal case
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.5));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));

auto classification = getHighestProbClassification(classifications);
EXPECT_EQ(classification.label, ObjectClassification::TRUCK);
EXPECT_NEAR(classification.probability, 0.8, epsilon);
}

{ // labels with the same probability
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));

auto classification = getHighestProbClassification(classifications);
EXPECT_EQ(classification.label, ObjectClassification::CAR);
EXPECT_NEAR(classification.probability, 0.8, epsilon);
}
}

0 comments on commit f659209

Please sign in to comment.