Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(perception_utils): refactor object_classification #2042

Original file line number Diff line number Diff line change
Expand Up @@ -23,72 +23,74 @@ namespace perception_utils
{
using autoware_auto_perception_msgs::msg::ObjectClassification;

inline std::uint8_t getHighestProbLabel(
inline ObjectClassification getHighestProbClassification(
const std::vector<ObjectClassification> & object_classifications)
{
std::uint8_t label = ObjectClassification::UNKNOWN;
float highest_prob = 0.0;
// TODO(Satoshi Tanaka): It might be simple if you use STL or range-v3.
for (const auto & object_classification : object_classifications) {
if (highest_prob < object_classification.probability) {
highest_prob = object_classification.probability;
label = object_classification.label;
}
if (object_classifications.empty()) {
return ObjectClassification{};
}
return label;
return *std::max_element(
std::begin(object_classifications), std::end(object_classifications),
[](const auto & a, const auto & b) { return a.probability < b.probability; });
}

inline std::uint8_t getHighestProbLabel(
const std::vector<ObjectClassification> & object_classifications)
{
auto classification = getHighestProbClassification(object_classifications);
return classification.label;
}

inline bool isVehicle(const uint8_t object_classification)
inline bool isVehicle(const uint8_t label)
{
return object_classification == ObjectClassification::BICYCLE ||
object_classification == ObjectClassification::BUS ||
object_classification == ObjectClassification::CAR ||
object_classification == ObjectClassification::MOTORCYCLE ||
object_classification == ObjectClassification::TRAILER ||
object_classification == ObjectClassification::TRUCK;
return label == ObjectClassification::BICYCLE || label == ObjectClassification::BUS ||
label == ObjectClassification::CAR || label == ObjectClassification::MOTORCYCLE ||
label == ObjectClassification::TRAILER || label == ObjectClassification::TRUCK;
}

inline bool isVehicle(const ObjectClassification & object_classification)
{
return isVehicle(object_classification.label);
}

inline bool isVehicle(const std::vector<ObjectClassification> & object_classifications)
{
auto highest_prob_classification = getHighestProbLabel(object_classifications);
return highest_prob_classification == ObjectClassification::BICYCLE ||
highest_prob_classification == ObjectClassification::BUS ||
highest_prob_classification == ObjectClassification::CAR ||
highest_prob_classification == ObjectClassification::MOTORCYCLE ||
highest_prob_classification == ObjectClassification::TRAILER ||
highest_prob_classification == ObjectClassification::TRUCK;
return isVehicle(highest_prob_classification);
yukke42 marked this conversation as resolved.
Show resolved Hide resolved
}

inline bool isCarLikeVehicle(const uint8_t label)
{
return label == ObjectClassification::BUS || label == ObjectClassification::CAR ||
label == ObjectClassification::TRAILER || label == ObjectClassification::TRUCK;
}

inline bool isCarLikeVehicle(const uint8_t object_classification)
inline bool isCarLikeVehicle(const ObjectClassification & object_classification)
{
return object_classification == ObjectClassification::BUS ||
object_classification == ObjectClassification::CAR ||
object_classification == ObjectClassification::TRAILER ||
object_classification == ObjectClassification::TRUCK;
return isCarLikeVehicle(object_classification.label);
}

inline bool isCarLikeVehicle(const std::vector<ObjectClassification> & object_classifications)
{
auto highest_prob_classification = getHighestProbLabel(object_classifications);
return highest_prob_classification == ObjectClassification::BUS ||
highest_prob_classification == ObjectClassification::CAR ||
highest_prob_classification == ObjectClassification::TRAILER ||
highest_prob_classification == ObjectClassification::TRUCK;
return isCarLikeVehicle(highest_prob_classification);
yukke42 marked this conversation as resolved.
Show resolved Hide resolved
}

inline bool isLargeVehicle(const uint8_t label)
{
return label == ObjectClassification::BUS || label == ObjectClassification::TRAILER ||
label == ObjectClassification::TRUCK;
}

inline bool isLargeVehicle(const uint8_t object_classification)
inline bool isLargeVehicle(const ObjectClassification & object_classification)
{
return object_classification == ObjectClassification::BUS ||
object_classification == ObjectClassification::TRAILER ||
object_classification == ObjectClassification::TRUCK;
return isLargeVehicle(object_classification.label);
}

inline bool isLargeVehicle(const std::vector<ObjectClassification> & object_classifications)
{
auto highest_prob_classification = getHighestProbLabel(object_classifications);
return highest_prob_classification == ObjectClassification::BUS ||
highest_prob_classification == ObjectClassification::TRAILER ||
highest_prob_classification == ObjectClassification::TRUCK;
return isLargeVehicle(highest_prob_classification);
yukke42 marked this conversation as resolved.
Show resolved Hide resolved
}
} // namespace perception_utils

Expand Down
61 changes: 49 additions & 12 deletions common/perception_utils/test/src/test_object_classification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include <gtest/gtest.h>

constexpr double epsilon = 1e-06;

namespace
{
autoware_auto_perception_msgs::msg::ObjectClassification createObjectClassification(
Expand All @@ -36,28 +38,63 @@ TEST(object_classification, test_getHighestProbLabel)
using perception_utils::getHighestProbLabel;

{ // empty
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classification;
std::uint8_t label = getHighestProbLabel(classification);
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
std::uint8_t label = getHighestProbLabel(classifications);
EXPECT_EQ(label, ObjectClassification::UNKNOWN);
}

{ // normal case
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classification;
classification.push_back(createObjectClassification(ObjectClassification::CAR, 0.5));
classification.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classification.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.5));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));

std::uint8_t label = getHighestProbLabel(classification);
std::uint8_t label = getHighestProbLabel(classifications);
EXPECT_EQ(label, ObjectClassification::TRUCK);
}

{ // labels with the same probability
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classification;
classification.push_back(createObjectClassification(ObjectClassification::CAR, 0.8));
classification.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classification.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));

std::uint8_t label = getHighestProbLabel(classification);
std::uint8_t label = getHighestProbLabel(classifications);
EXPECT_EQ(label, ObjectClassification::CAR);
}
}

TEST(object_classification, test_getHighestProbClassification)
{
using autoware_auto_perception_msgs::msg::ObjectClassification;
using perception_utils::getHighestProbClassification;

{ // empty
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
auto classification = getHighestProbClassification(classifications);
EXPECT_EQ(classification.label, ObjectClassification::UNKNOWN);
EXPECT_DOUBLE_EQ(classification.probability, 0.0);
}

{ // normal case
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.5));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));
kenji-miyake marked this conversation as resolved.
Show resolved Hide resolved

auto classification = getHighestProbClassification(classifications);
EXPECT_EQ(classification.label, ObjectClassification::TRUCK);
EXPECT_NEAR(classification.probability, 0.8, epsilon);
}

{ // labels with the same probability
std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> classifications;
classifications.push_back(createObjectClassification(ObjectClassification::CAR, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::TRUCK, 0.8));
classifications.push_back(createObjectClassification(ObjectClassification::BUS, 0.7));

auto classification = getHighestProbClassification(classifications);
EXPECT_EQ(classification.label, ObjectClassification::CAR);
EXPECT_NEAR(classification.probability, 0.8, epsilon);
}
}