-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator.py
67 lines (53 loc) · 2.35 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import json
import tempfile
from pycocotools.cocoeval import COCOeval
class Evaluator(object):
def __init__(self, data_type='coco'):
self.data_type = data_type
if self.data_type == 'coco':
self.results = list()
self.img_ids = list()
else:
print("not ready yet..")
exit()
def get_info(self, info):
if self.data_type=='coco':
(pred_boxes, pred_labels, pred_scores, img_id, img_info, coco_ids) = info
self.img_ids.append(img_id)
# convert coco_results coordination
pred_boxes[:, 2] -= pred_boxes[:, 0] # x2 to w
pred_boxes[:, 3] -= pred_boxes[:, 1] # y2 to h
w = img_info['width']
h = img_info['height']
pred_boxes[:, 0] *= w
pred_boxes[:, 2] *= w
pred_boxes[:, 1] *= h
pred_boxes[:, 3] *= h
for pred_box, pred_label, pred_score in zip(pred_boxes, pred_labels, pred_scores):
if int(pred_label) == 91: # background label is 80 #FIXME Background 라벨 설정 디버깅 후 체크 필요
# print('background label :', int(pred_label))
continue
coco_result = {
'image_id': img_id,
'category_id': int(pred_label.item()), # FIXME 라벨 설정 필요 pred_label-1? pred_label?
'score': float(pred_score),
'bbox': pred_box.tolist(),
}
self.results.append(coco_result)
def evaluate(self, dataset):
if self.data_type == 'coco':
_, tmp = tempfile.mkstemp()
json.dump(self.results, open(tmp, "w"))
cocoGt = dataset.coco
cocoDt = cocoGt.loadRes(tmp)
# https://github.com/argusswift/YOLOv4-pytorch/blob/master/eval/cocoapi_evaluator.py
# workaround: temporarily write data to json file because pycocotools can't process dict in py36.
coco_eval = COCOeval(cocoGt=cocoGt, cocoDt=cocoDt, iouType='bbox')
coco_eval.params.imgIds = self.img_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
mAP = coco_eval.stats[0]
mAP_50 = coco_eval.stats[1]
return mAP