-
Notifications
You must be signed in to change notification settings - Fork 73
/
coco.py
77 lines (65 loc) · 2.36 KB
/
coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import cv2
from torchvision.datasets import CocoDetection
from copy_paste import copy_paste_class
min_keypoints_per_image = 10
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
@copy_paste_class
class CocoDetectionCP(CocoDetection):
def __init__(
self,
root,
annFile,
transforms
):
super(CocoDetectionCP, self).__init__(
root, annFile, None, None, transforms
)
# filter images without detection annotations
ids = []
for img_id in self.ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if has_valid_annotation(anno):
ids.append(img_id)
self.ids = ids
def load_example(self, index):
img_id = self.ids[index]
ann_ids = self.coco.getAnnIds(imgIds=img_id)
target = self.coco.loadAnns(ann_ids)
path = self.coco.loadImgs(img_id)[0]['file_name']
image = cv2.imread(os.path.join(self.root, path))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#convert all of the target segmentations to masks
#bboxes are expected to be (y1, x1, y2, x2, category_id)
masks = []
bboxes = []
for ix, obj in enumerate(target):
masks.append(self.coco.annToMask(obj))
bboxes.append(obj['bbox'] + [obj['category_id']] + [ix])
#pack outputs into a dict
output = {
'image': image,
'masks': masks,
'bboxes': bboxes
}
return self.transforms(**output)