-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
62 lines (51 loc) · 2.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from config import device
# set coco label color
np.random.seed(1)
coco_color_array = np.random.randint(256, size=(81, 3)) / 255 # In plt, rgb color space's range from 0 to 1
# for coco label
coco_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse',
'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
coco_label_map = {k: v for v, k in enumerate(coco_labels)} # {0 ~ 79 : 'person' ~ 'toothbrush'}
coco_label_idx_80 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32,
33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
coco_label_idx_91 = {coco_id: i for i, coco_id in enumerate(coco_label_idx_80)}
def cxcy_to_xy(cxcy):
x1y1 = cxcy[..., :2] - cxcy[..., 2:] / 2
x2y2 = cxcy[..., :2] + cxcy[..., 2:] / 2
return torch.cat([x1y1, x2y2], dim=1)
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def detect(pred):
# pred -> pred_bboxes, pred_scores 변환 필요
out_logits, out_bbox = pred['pred_logits'].squeeze(0), pred['pred_boxes'].squeeze(0)
prob = F.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1)
# convert to [x0, y0, x1, y1] format
boxes = box_cxcywh_to_xyxy(out_bbox)
image_boxes = boxes
image_labels = labels
image_scores = scores
return image_boxes, image_labels, image_scores