-
Notifications
You must be signed in to change notification settings - Fork 867
/
object_detector.py
59 lines (50 loc) · 1.92 KB
/
object_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Module for object detection default handler
"""
import torch
from torchvision import transforms
from torchvision import __version__ as torchvision_version
from packaging import version
from .vision_handler import VisionHandler
from ..utils.util import map_class_to_label
class ObjectDetector(VisionHandler):
"""
ObjectDetector handler class. This handler takes an image
and returns list of detected classes and bounding boxes respectively
"""
image_processing = transforms.Compose([transforms.ToTensor()])
threshold = 0.5
def initialize(self, context):
super().initialize(context)
properties = context.system_properties
# Torchvision breaks with object detector models before 0.6.0
if version.parse(torchvision_version) < version.parse("0.6.0"):
self.initialized = False
self.device = torch.device(
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.model.to(self.device)
self.model.eval()
self.initialized = True
def postprocess(self, data):
result = []
box_filters = [row["scores"] >= self.threshold for row in data]
filtered_boxes, filtered_classes, filtered_scores = [
[
row[key][box_filter].tolist()
for row, box_filter in zip(data, box_filters)
]
for key in ["boxes", "labels", "scores"]
]
for classes, boxes, scores in zip(
filtered_classes, filtered_boxes, filtered_scores
):
retval = []
for _class, _box, _score in zip(classes, boxes, scores):
_retval = map_class_to_label([[_box]], self.mapping, [[_class]])[0]
_retval["score"] = _score
retval.append(_retval)
result.append(retval)
return result