From 20d8ebea3ee37748101994986aeaffc553467cd9 Mon Sep 17 00:00:00 2001 From: Louis Jean Date: Wed, 5 Apr 2023 15:56:47 +0000 Subject: [PATCH] feat(torch): add map metrics with arbitrary iou threshold --- src/backends/torch/torchlib.cc | 39 +++++++++--- src/backends/torch/torchlib.h | 3 +- src/supervisedoutputconnector.h | 106 +++++++++++++++++++++++++++++--- tests/ut-torchapi.cc | 63 +++++++++++++++---- 4 files changed, 181 insertions(+), 30 deletions(-) diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 9ebf4882e..bccef07a6 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -2017,6 +2017,17 @@ namespace dd ad_out.add("measure", meas); } + std::vector iou_thresholds; + std::map ad_bbox_per_iou; + if (_bbox) + { + auto meas = ad_out.get("measure").get>(); + SupervisedOutput::find_ap_iou_thresholds(meas, iou_thresholds); + + for (int i : iou_thresholds) + ad_bbox_per_iou[i] = APIData(); + } + auto dataloader = torch::data::make_data_loader( dataset, data::DataLoaderOptions(batch_size)); torch::Device cpu("cpu"); @@ -2120,11 +2131,19 @@ namespace dd ++stop; } - auto vbad = get_bbox_stats( - targ_bboxes.index({ torch::indexing::Slice(start, stop) }), - targ_labels.index({ torch::indexing::Slice(start, stop) }), - bboxes_tensor, labels_tensor, score_tensor); - ad_bbox.add(std::to_string(entry_id), vbad); + for (int iou_thres : iou_thresholds) + { + double iou_thres_d = static_cast(iou_thres) / 100; + std::vector vbad = get_bbox_stats( + targ_bboxes.index( + { torch::indexing::Slice(start, stop) }), + targ_labels.index( + { torch::indexing::Slice(start, stop) }), + bboxes_tensor, labels_tensor, score_tensor, + iou_thres_d); + ad_bbox_per_iou[iou_thres].add(std::to_string(entry_id), + vbad); + } ++entry_id; } } @@ -2299,6 +2318,12 @@ namespace dd { ad_res.add("bbox", true); ad_res.add("pos_count", entry_id); + + for (int iou_thres : iou_thresholds) + { + ad_bbox.add("map-" + std::to_string(iou_thres), + ad_bbox_per_iou[iou_thres]); + } ad_res.add("0", ad_bbox); } else if (_segmentation) @@ -2318,7 +2343,8 @@ namespace dd const at::Tensor &targ_labels, const at::Tensor &bboxes_tensor, const at::Tensor &labels_tensor, - const at::Tensor &score_tensor) + const at::Tensor &score_tensor, + float overlap_threshold) { auto targ_bboxes_acc = targ_bboxes.accessor(); auto targ_labels_acc = targ_labels.accessor(); @@ -2348,7 +2374,6 @@ namespace dd }; std::vector eval_infos(_nclasses); - float overlap_threshold = 0.5; // TODO: parameter for (int j = 0; j < pred_bbox_count; ++j) { diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index e3604bde0..8d22e203c 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -90,7 +90,8 @@ namespace dd const at::Tensor &targ_labels, const at::Tensor &bboxes_tensor, const at::Tensor &labels_tensor, - const at::Tensor &score_tensor); + const at::Tensor &score_tensor, + float overlap_threshold); public: unsigned int _nclasses = 0; /**< number of classes*/ diff --git a/src/supervisedoutputconnector.h b/src/supervisedoutputconnector.h index 7daef24a6..86d2bdbb6 100644 --- a/src/supervisedoutputconnector.h +++ b/src/supervisedoutputconnector.h @@ -23,6 +23,9 @@ #define SUPERVISEDOUTPUTCONNECTOR_H #define TS_METRICS_EPSILON 1E-2 +#include +#include + #include "dto/output_connector.hpp" template @@ -845,19 +848,64 @@ namespace dd } if (bbox) { - bool bbmap = (std::find(measures.begin(), measures.end(), "map") - != measures.end()); - if (bbmap) + // required iou thresholds for map. If there are more than one + // threshold, the global map is the mean over the different iou + // thresholds. + std::vector thresholds; + bool has_map = find_ap_iou_thresholds(measures, thresholds); + + if (has_map) { - std::map aps; - double bmap = ap(ad_res, aps); - meas_out.add("map", bmap); - for (auto ap : aps) + double sum_map = 0; + std::map sum_aps; + int ap_count = 0; + + // map for each threshold + for (int iou_thres : thresholds) { - std::string s = "map_" + std::to_string(ap.first); - meas_out.add(s, static_cast(ap.second)); + std::map aps; + double bmap = ap(ad_res, aps, iou_thres); + std::string map_key = "map"; + if (iou_thres != 0) + { + std::stringstream ss; + ss << map_key << "-" << std::setfill('0') + << std::setw(2) << iou_thres; + map_key = ss.str(); + } + meas_out.add(map_key, bmap); + for (auto ap : aps) + { + std::string s + = map_key + "_" + std::to_string(ap.first); + meas_out.add(s, static_cast(ap.second)); + } + + sum_map += bmap; + if (sum_aps.size() == 0) + sum_aps = aps; + else + { + for (auto ap : aps) + sum_aps[ap.first] += ap.second; + } + ap_count++; + } + + // mean of all thresholds + if (thresholds.size() > 0) + { + meas_out.add("map", sum_map / ap_count); + for (auto sum_ap : sum_aps) + { + std::string s + = "map_" + std::to_string(sum_ap.first); + meas_out.add(s, static_cast(sum_ap.second + / ap_count)); + } } } + bool raw = (std::find(measures.begin(), measures.end(), "raw") != measures.end()); if (raw) @@ -1608,6 +1656,35 @@ namespace dd } } + /** \param thresholds the requested iou thresholds in percent (int) + * \return true if map is requested, false otherwise */ + static bool + find_ap_iou_thresholds(const std::vector &measures, + std::vector &thresholds) + { + bool has_map = false; + for (std::string s : measures) + { + if (s.find("map") != std::string::npos) + { + has_map = true; + std::vector sv = dd_utils::split(s, '-'); + int iou_thres = 0; + + if (sv.size() == 2) + { + iou_thres = std::atoi(sv.at(1).c_str()); + thresholds.push_back(iou_thres); + } + } + } + + // Default threshold is 0.5 (map 50) + if (thresholds.empty()) + thresholds.push_back(50); + return has_map; + } + static double straight_meas(const APIData &ad) { APIData bad = ad.getobj("0"); @@ -2407,13 +2484,22 @@ namespace dd return ap; } - static double ap(const APIData &ad, std::map &APs) + /** + * Compute AP for all classes and mean AP + * \param APs std::map containing AP for each class + * \param thres iou threshold for map in percent + */ + static double ap(const APIData &ad, std::map &APs, int thres) { double mmAP = 0.0; std::map APs_count; int APs_count_all = 0; // extract tp, fp, labels APIData bad = ad.getobj("0"); + std::string map_key = "map-" + std::to_string(thres); + if (bad.has(map_key)) + bad = bad.getobj(map_key); + // else: default threshold (legacy) int pos_count = ad.get("pos_count").get(); for (int i = 0; i < pos_count; i++) { diff --git a/tests/ut-torchapi.cc b/tests/ut-torchapi.cc index c2c6c6956..bebfaf017 100644 --- a/tests/ut-torchapi.cc +++ b/tests/ut-torchapi.cc @@ -608,7 +608,7 @@ TEST(torchapi, compute_bbox_stats) 11, 11, 101, 101, // matching 900, 10, 950, 100, // false positive 510, 510, 610, 610, // 2 preds for 1 targets - 490, 490, 590, 590, // -- + 490, 490, 590, 590, // (second pred) 940, 940, 990, 990, // overlapping but iou < 0.5 -> false positive }; at::Tensor bboxes_tensor = torch::from_blob(bboxes_data, { 5, 4 }); @@ -619,10 +619,10 @@ TEST(torchapi, compute_bbox_stats) float score_data[] = { 0.9, 0.8, 0.7, 0.6, 0.5 }; at::Tensor score_tensor = torch::from_blob(score_data, 5); - auto vbad = torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor, - labels_tensor, score_tensor); - - auto lbad = vbad.at(0); + auto vbad50 + = torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor, + labels_tensor, score_tensor, 0.5); + auto lbad = vbad50.at(0); auto tp_i = lbad.get("tp_i").get>(); auto tp_d = lbad.get("tp_d").get>(); auto fp_i = lbad.get("fp_i").get>(); @@ -643,6 +643,19 @@ TEST(torchapi, compute_bbox_stats) } ASSERT_EQ(lbad.get("num_pos").get(), 4); ASSERT_EQ(lbad.get("label").get(), 1); + APIData ad_bbox_50; + ad_bbox_50.add("0", vbad50); + + // with map 90 the third bbox is no longer matching + auto vbad90 + = torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor, + labels_tensor, score_tensor, 0.9); + lbad = vbad90.at(0); + tp_i = lbad.get("tp_i").get>(); + ASSERT_EQ(std::accumulate(tp_i.begin(), tp_i.end(), 0), 1); + ASSERT_FALSE(tp_i[2]); + APIData ad_bbox_90; + ad_bbox_90.add("0", vbad90); // Get MAP APIData ad_res; @@ -651,15 +664,25 @@ TEST(torchapi, compute_bbox_stats) ad_res.add("bbox", true); ad_res.add("pos_count", 1); APIData ad_bbox; - ad_bbox.add("0", vbad); + ad_bbox.add("map-50", ad_bbox_50); + ad_bbox.add("map-90", ad_bbox_90); ad_res.add("0", ad_bbox); ad_res.add("batch_size", 1); APIData ad_out; - ad_out.add("measure", std::vector{ "map" }); + ad_out.add("measure", std::vector{ "map", "map-50", "map-90" }); APIData out; SupervisedOutput::measure(ad_res, ad_out, out, 0, "test"); - ASSERT_NEAR(out.getobj("measure").get("map").get(), 5. / 12., + JsonAPI japi; + JDoc jdoc; + jdoc.SetObject(); + out.toJDoc(jdoc); + std::cout << japi.jrender(jdoc) << std::endl; + ASSERT_NEAR(out.getobj("measure").get("map-50").get(), 5. / 12., std::numeric_limits::epsilon()); + ASSERT_NEAR(out.getobj("measure").get("map-90").get(), 0.25, + std::numeric_limits::epsilon()); + ASSERT_NEAR(out.getobj("measure").get("map").get(), + (5. / 12 + 0.25) / 2., std::numeric_limits::epsilon()); } TEST(torchapi, map_false_negative) @@ -691,7 +714,7 @@ TEST(torchapi, map_false_negative) at::Tensor score_tensor = torch::from_blob(score_data, 1); auto vbad = torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor, - labels_tensor, score_tensor); + labels_tensor, score_tensor, 0.5); // Get MAP APIData ad_res; @@ -2360,7 +2383,8 @@ TEST(torchapi, service_train_object_detection_yolox) "true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true," "\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{" "\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true," - "\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\"" + "\"shuffle\":true},\"output\":{\"measure\":[\"map-05\",\"map-50\"," + "\"map-90\"]}},\"data\":[\"" + fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}"; joutstr = japi.jrender(japi.service_train(jtrainstr)); @@ -2372,6 +2396,9 @@ TEST(torchapi, service_train_object_detection_yolox) // ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations"; ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map"; + ASSERT_TRUE(jd["body"]["measure"]["map-05"].GetDouble() <= 1.0) << "map-05"; + ASSERT_TRUE(jd["body"]["measure"]["map-50"].GetDouble() <= 1.0) << "map-50"; + ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90"; // ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map"; // check metrics @@ -2456,7 +2483,8 @@ TEST(torchapi, service_train_object_detection_yolox_no_db) "true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true," "\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{" "\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":false," - "\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\"" + "\"shuffle\":true},\"output\":{\"measure\":[\"map-90\",\"map\"]}}," + "\"data\":[\"" + fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}"; joutstr = japi.jrender(japi.service_train(jtrainstr)); @@ -2468,6 +2496,8 @@ TEST(torchapi, service_train_object_detection_yolox_no_db) // ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations"; ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map"; + ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90"; + ASSERT_FALSE(jd["body"]["measure"].HasMember("map-50")); // ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map"; // check metrics @@ -2551,7 +2581,8 @@ TEST(torchapi, service_train_object_detection_yolox_multigpu) "true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true," "\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{" "\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true," - "\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\"" + "\"shuffle\":true},\"output\":{\"measure\":[\"map-50\",\"map-90\"]}}" + ",\"data\":[\"" + fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}"; joutstr = japi.jrender(japi.service_train(jtrainstr)); @@ -2563,6 +2594,14 @@ TEST(torchapi, service_train_object_detection_yolox_multigpu) ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations"; ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map"; + ASSERT_TRUE(jd["body"]["measure"]["map-50"].GetDouble() <= 1.0) << "map-50"; + ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90"; + ASSERT_LE(jd["body"]["measure"]["map-90"].GetDouble(), + jd["body"]["measure"]["map-50"].GetDouble()); + ASSERT_NEAR((jd["body"]["measure"]["map-90"].GetDouble() + + jd["body"]["measure"]["map-50"].GetDouble()) + / 2, + jd["body"]["measure"]["map"].GetDouble(), 0.001); // ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map"; // check metrics